Skip to content

Commit 3d2911d

Browse files
committed
test
1 parent 123a985 commit 3d2911d

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

enzyme/test/MLIR/ForwardMode/test_vector.mlir renamed to enzyme/test/MLIR/ForwardMode/batched_scalar.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@ module {
1111
}
1212
}
1313

14-
// CHECK: func.func @dsq(%[[arg0:.+]]: tensor<2xf64>, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> {
15-
// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xf64>
14+
// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> {
15+
// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (f64, tensor<2xf64>) -> tensor<2xf64>
1616
// CHECK-NEXT: return %[[i0]] : tensor<2xf64>
1717
// CHECK-NEXT: }
1818
// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> {
19-
// CHECK-NEXT: %[[s0:.+]] = tensor.splat %[[arg0]] : f64 -> tensor<2xf64>
19+
// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{width = 2 : i64}> : f64 -> tensor<2xf64>
2020
// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2xf64>
21-
// CHECK-NEXT: %[[s1:.+]] = tensor.splat %[[arg0]] : f64 -> tensor<2xf64>
21+
// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{width = 2 : i64}> : f64 -> tensor<2xf64>
2222
// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64>
2323
// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64>
2424
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64>
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: %eopt --enzyme %s | FileCheck %s
2+
3+
module {
4+
func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{
5+
%y = arith.mulf %x, %x : tensor<10xf64>
6+
return %y : tensor<10xf64>
7+
}
8+
func.func @dsq(%x : tensor<10xf64>, %dx : tensor<2x10xf64>) -> tensor<2x10xf64> {
9+
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<2x10xf64>)
10+
return %r : tensor<2x10xf64>
11+
}
12+
}
13+
14+
// CHECK: func.func @dsq(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> {
15+
// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64>
16+
// CHECK-NEXT: return %[[i0]] : tensor<2x10xf64>
17+
// CHECK-NEXT: }
18+
// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> {
19+
// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{width = 2 : i64}> : (tensor<10xf64>) -> tensor<2x10xf64>
20+
// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2x10xf64>
21+
// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{width = 2 : i64}> : (tensor<10xf64>) -> tensor<2x10xf64>
22+
// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2x10xf64>
23+
// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2x10xf64>
24+
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<10xf64>
25+
// CHECK-NEXT: return %[[i2]] : tensor<2x10xf64>
26+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)