blob: 22f7a93762a17155b7cc38bd9eb6edb6ed0b2278 [file] [log] [blame]
// RUN: %eopt %s --enzyme-wrap="infn=reduce outfn= argTys=enzyme_active,enzyme_const retTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops | FileCheck %s
func.func @reduce(%x: f32, %ub: index) -> (f32) {
%lb = arith.constant 0 : index
%step = arith.constant 1 : index
// Initial sum set to 0.
%sum_0 = arith.constant 1.0 : f32
// iter_args binds initial values to the loop's region arguments.
%sum = scf.for %iv = %lb to %ub step %step
iter_args(%sum_iter = %sum_0) -> (f32) {
%sum_next = arith.mulf %sum_iter, %x : f32
// Yield current iteration sum to next iteration %sum_iter or to %sum
// if final iteration.
scf.yield %sum_next : f32
}
return %sum : f32
}
// CHECK: func.func @reduce(%arg0: f32, %arg1: index, %arg2: f32) -> f32 {
// CHECK-NEXT: %cst = arith.constant 1.000000e+00 : f32
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %0 = tensor.empty(%arg1) : tensor<?xf32>
// CHECK-NEXT: %1:2 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %cst, %arg5 = %0) -> (f32, tensor<?xf32>) {
// CHECK-NEXT: %inserted = tensor.insert %arg4 into %arg5[%arg3] : tensor<?xf32>
// CHECK-NEXT: %4 = arith.mulf %arg4, %arg0 : f32
// CHECK-NEXT: scf.yield %4, %inserted : f32, tensor<?xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %2 = arith.addf %arg2, %cst_0 : f32
// CHECK-NEXT: %3:3 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %2, %arg5 = %arg1, %arg6 = %cst_0) -> (f32, index, f32) {
// CHECK-NEXT: %extracted = tensor.extract %1#1[%arg5] : tensor<?xf32>
// CHECK-NEXT: %4 = arith.mulf %arg4, %arg0 : f32
// CHECK-NEXT: %5 = arith.addf %4, %cst_0 : f32
// CHECK-NEXT: %6 = arith.mulf %arg4, %extracted : f32
// CHECK-NEXT: %7 = arith.addf %arg6, %6 : f32
// CHECK-NEXT: %8 = arith.subi %arg5, %c1 : index
// CHECK-NEXT: scf.yield %5, %8, %7 : f32, index, f32
// CHECK-NEXT: }
// CHECK-NEXT: return %3#2 : f32
// CHECK-NEXT: }