mlir: enable mincut support for scf.for (#2359)
* mlir: fix mincut for unranked types
* tests
* mlir: enable mincut support for scf.for
* better test
diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp
index 0f3c3d9..60d3192 100644
--- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp
+++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp
@@ -97,16 +97,20 @@
}
}
- SmallVector<CacheInfo> caches;
- caches.reserve(cachesMap.size());
- for (auto &&[_, info] : cachesMap) {
- caches.push_back(info);
- }
+ SmallVector<CacheInfo> caches =
+ llvm::map_to_vector(cachesMap, [](auto p) { return std::get<1>(p); });
// nothing to do
if (updatedGradients.empty() && caches.empty())
return success();
+ if (forOp->hasAttr("enzyme.enable_mincut")) {
+ mlir::enzyme::minCutCache(forOp.getBody(), otherForOp.getBody(), caches,
+ rewriter);
+ if (caches.empty())
+ return success();
+ }
+
for (auto &it : *body) {
Operation *op = ⁢
diff --git a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp
index 23becff..6525890 100644
--- a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp
+++ b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp
@@ -385,9 +385,14 @@
return T ? T.getApproxSize() : INT64_MAX;
};
+ auto computeRankOfType = [](Value val) -> int64_t {
+ auto TT = dyn_cast<RankedTensorType>(val.getType());
+ return TT ? TT.getRank() : 0;
+ };
+
Value picked = newCache;
int64_t curSize = computeSizeOfType(picked),
- curRank = cast<RankedTensorType>(picked.getType()).getRank();
+ curRank = computeRankOfType(picked);
while (!worklist.empty()) {
Value candidate = worklist.pop_back_val();
@@ -403,7 +408,7 @@
continue; // TODO: support this
int64_t newSize = computeSizeOfType(candidate),
- newRank = cast<RankedTensorType>(candidate.getType()).getRank();
+ newRank = computeRankOfType(candidate);
if (newSize < curSize || (newSize == curSize && newRank < curRank) ||
candidate.getDefiningOp<enzyme::PopOp>() != nullptr) {
curSize = newSize;
diff --git a/enzyme/test/MLIR/ReverseMode/pow.mlir b/enzyme/test/MLIR/ReverseMode/pow.mlir
index b8a07fb..a07497b 100644
--- a/enzyme/test/MLIR/ReverseMode/pow.mlir
+++ b/enzyme/test/MLIR/ReverseMode/pow.mlir
@@ -32,12 +32,12 @@
// CHECK-NEXT: %[[fwd:.+]] = arith.mulf %[[r_it]], %[[x]] : f64
// CHECK-NEXT: scf.yield %[[fwd]], %[[cache_new]] : f64, tensor<10xf64>
// CHECK-NEXT: }
-// CHECK-NEXT: %{{.+}} = scf.for %[[div:.+]] = %c0 to %c10 step %c1 iter_args(%[[dr_it:.+]] = %[[dr]], %[[rev_idx:.+]] = %c9, %[[dx0:.+]] = %[[zero]]) -> (f64, index, f64) {
+// CHECK-NEXT: %{{.+}} = scf.for %[[div:.+]] = %c0 to %c10 step %c1 iter_args(%[[dr_it:.+]] = %[[dr]], %[[dx0:.+]] = %[[zero]], %[[rev_idx:.+]] = %c9) -> (f64, f64, index) {
// CHECK-NEXT: %[[r_cached:.+]] = tensor.extract %1#1[%[[rev_idx]]] : tensor<10xf64>
// CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x]] : f64
// CHECK-NEXT: %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]] : f64
// CHECK-NEXT: %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]]
// CHECK-NEXT: %[[new_rev_idx:.+]] = arith.subi %[[rev_idx]], %c1 : index
-// CHECK-NEXT: scf.yield %[[dr_next]], %[[new_rev_idx]], %[[dx1]] : f64, index, f64
+// CHECK-NEXT: scf.yield %[[dr_next]], %[[dx1]], %[[new_rev_idx]] : f64, f64, index
// CHECK-NEXT: }
-// CHECK-NEXT: return %2#2 : f64
+// CHECK-NEXT: return %2#1 : f64
diff --git a/enzyme/test/MLIR/ReverseMode/scf_for.mlir b/enzyme/test/MLIR/ReverseMode/scf_for.mlir
index 22f7a93..5576c10 100644
--- a/enzyme/test/MLIR/ReverseMode/scf_for.mlir
+++ b/enzyme/test/MLIR/ReverseMode/scf_for.mlir
@@ -29,14 +29,14 @@
// CHECK-NEXT: scf.yield %4, %inserted : f32, tensor<?xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %2 = arith.addf %arg2, %cst_0 : f32
-// CHECK-NEXT: %3:3 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %2, %arg5 = %arg1, %arg6 = %cst_0) -> (f32, index, f32) {
-// CHECK-NEXT: %extracted = tensor.extract %1#1[%arg5] : tensor<?xf32>
+// CHECK-NEXT: %3:4 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %2, %arg5 = %cst_0, %arg6 = %arg1, %arg7 = %cst_0) -> (f32, f32, index, f32) {
+// CHECK-NEXT: %extracted = tensor.extract %1#1[%arg6] : tensor<?xf32>
// CHECK-NEXT: %4 = arith.mulf %arg4, %arg0 : f32
// CHECK-NEXT: %5 = arith.addf %4, %cst_0 : f32
// CHECK-NEXT: %6 = arith.mulf %arg4, %extracted : f32
-// CHECK-NEXT: %7 = arith.addf %arg6, %6 : f32
-// CHECK-NEXT: %8 = arith.subi %arg5, %c1 : index
-// CHECK-NEXT: scf.yield %5, %8, %7 : f32, index, f32
+// CHECK-NEXT: %7 = arith.addf %arg5, %6 : f32
+// CHECK-NEXT: %8 = arith.subi %arg6, %c1 : index
+// CHECK-NEXT: scf.yield %5, %7, %8, %7 : f32, f32, index, f32
// CHECK-NEXT: }
-// CHECK-NEXT: return %3#2 : f32
+// CHECK-NEXT: return %3#1 : f32
// CHECK-NEXT: }
diff --git a/enzyme/test/MLIR/ReverseMode/scf_for_mincut.mlir b/enzyme/test/MLIR/ReverseMode/scf_for_mincut.mlir
new file mode 100644
index 0000000..37f2369
--- /dev/null
+++ b/enzyme/test/MLIR/ReverseMode/scf_for_mincut.mlir
@@ -0,0 +1,45 @@
+// RUN: %eopt %s --enzyme-wrap="infn=main outfn= argTys=enzyme_active retTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops --enzyme-simplify-math --canonicalize | FileCheck %s
+
+module {
+ func.func @main(%arg0: f32) -> (f32) {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 10 : index
+ %step = arith.constant 1 : index
+
+ %sum = scf.for %iv = %lb to %ub step %step
+ iter_args(%sum_iter = %arg0) -> (f32) {
+ %sum_next = arith.mulf %sum_iter, %sum_iter : f32
+ %cos_next = math.cos %sum_next : f32
+ scf.yield %cos_next : f32
+ } {enzyme.enable_mincut}
+
+ return %sum : f32
+ }
+}
+
+// CHECK: func.func @main(%arg0: f32, %arg1: f32) -> f32 {
+// CHECK-NEXT: %c9 = arith.constant 9 : index
+// CHECK-NEXT: %c1 = arith.constant 1 : index
+// CHECK-NEXT: %c10 = arith.constant 10 : index
+// CHECK-NEXT: %c0 = arith.constant 0 : index
+// CHECK-NEXT: %[[v0:.+]] = tensor.empty() : tensor<10xf32>
+// CHECK-NEXT: %[[for:.+]]:2 = scf.for %arg2 = %c0 to %c10 step %c1 iter_args(%arg3 = %arg0, %arg4 = %[[v0]]) -> (f32, tensor<10xf32>) {
+// CHECK-NEXT: %[[cache:.+]] = tensor.insert %arg3 into %arg4[%arg2] : tensor<10xf32>
+// CHECK-NEXT: %[[v3:.+]] = arith.mulf %arg3, %arg3 : f32
+// CHECK-NEXT: %[[v4:.+]] = math.cos %[[v3]] : f32
+// CHECK-NEXT: scf.yield %[[v4]], %[[cache]] : f32, tensor<10xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[revFor:.+]]:2 = scf.for %arg2 = %c0 to %c10 step %c1 iter_args(%arg3 = %arg1, %arg4 = %c9) -> (f32, index) {
+// CHECK-NEXT: %[[cache:.+]] = tensor.extract %[[for]]#1[%arg4] : tensor<10xf32>
+// CHECK-NEXT: %[[v3:.+]] = arith.mulf %[[cache]], %[[cache]] : f32
+// CHECK-NEXT: %[[v4:.+]] = math.sin %[[v3]] : f32
+// CHECK-NEXT: %[[v5:.+]] = arith.negf %[[v4]] : f32
+// CHECK-NEXT: %[[v6:.+]] = arith.mulf %arg3, %[[v5]] : f32
+// CHECK-NEXT: %[[v7:.+]] = arith.mulf %[[v6]], %extracted : f32
+// CHECK-NEXT: %[[v8:.+]] = arith.mulf %[[v6]], %extracted : f32
+// CHECK-NEXT: %[[v9:.+]] = arith.addf %[[v7]], %[[v8]] : f32
+// CHECK-NEXT: %[[v10:.+]] = arith.subi %arg4, %c1 : index
+// CHECK-NEXT: scf.yield %[[v9]], %[[v10]] : f32, index
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[revFor:.+]]#0 : f32
+// CHECK-NEXT: }