mlir: enable mincut support for scf.for (#2359)

* mlir: fix mincut for unranked types

* tests

* mlir: enable mincut support for scf.for

* better test
diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp
index 0f3c3d9..60d3192 100644
--- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp
+++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp
@@ -97,16 +97,20 @@
       }
     }
 
-    SmallVector<CacheInfo> caches;
-    caches.reserve(cachesMap.size());
-    for (auto &&[_, info] : cachesMap) {
-      caches.push_back(info);
-    }
+    SmallVector<CacheInfo> caches =
+        llvm::map_to_vector(cachesMap, [](auto p) { return std::get<1>(p); });
 
     // nothing to do
     if (updatedGradients.empty() && caches.empty())
       return success();
 
+    if (forOp->hasAttr("enzyme.enable_mincut")) {
+      mlir::enzyme::minCutCache(forOp.getBody(), otherForOp.getBody(), caches,
+                                rewriter);
+      if (caches.empty())
+        return success();
+    }
+
     for (auto &it : *body) {
       Operation *op = &it;
 
diff --git a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp
index 23becff..6525890 100644
--- a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp
+++ b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp
@@ -385,9 +385,14 @@
       return T ? T.getApproxSize() : INT64_MAX;
     };
 
+    auto computeRankOfType = [](Value val) -> int64_t {
+      auto TT = dyn_cast<RankedTensorType>(val.getType());
+      return TT ? TT.getRank() : 0;
+    };
+
     Value picked = newCache;
     int64_t curSize = computeSizeOfType(picked),
-            curRank = cast<RankedTensorType>(picked.getType()).getRank();
+            curRank = computeRankOfType(picked);
 
     while (!worklist.empty()) {
       Value candidate = worklist.pop_back_val();
@@ -403,7 +408,7 @@
         continue; // TODO: support this
 
       int64_t newSize = computeSizeOfType(candidate),
-              newRank = cast<RankedTensorType>(candidate.getType()).getRank();
+              newRank = computeRankOfType(candidate);
       if (newSize < curSize || (newSize == curSize && newRank < curRank) ||
           candidate.getDefiningOp<enzyme::PopOp>() != nullptr) {
         curSize = newSize;
diff --git a/enzyme/test/MLIR/ReverseMode/pow.mlir b/enzyme/test/MLIR/ReverseMode/pow.mlir
index b8a07fb..a07497b 100644
--- a/enzyme/test/MLIR/ReverseMode/pow.mlir
+++ b/enzyme/test/MLIR/ReverseMode/pow.mlir
@@ -32,12 +32,12 @@
 // CHECK-NEXT:      %[[fwd:.+]] = arith.mulf %[[r_it]], %[[x]] : f64
 // CHECK-NEXT:      scf.yield %[[fwd]], %[[cache_new]] : f64, tensor<10xf64>
 // CHECK-NEXT:    }
-// CHECK-NEXT:    %{{.+}} = scf.for %[[div:.+]] = %c0 to %c10 step %c1 iter_args(%[[dr_it:.+]] = %[[dr]], %[[rev_idx:.+]] = %c9, %[[dx0:.+]] = %[[zero]]) -> (f64, index, f64) {
+// CHECK-NEXT:    %{{.+}} = scf.for %[[div:.+]] = %c0 to %c10 step %c1 iter_args(%[[dr_it:.+]] = %[[dr]], %[[dx0:.+]] = %[[zero]], %[[rev_idx:.+]] = %c9) -> (f64, f64, index) {
 // CHECK-NEXT:      %[[r_cached:.+]] = tensor.extract %1#1[%[[rev_idx]]] : tensor<10xf64>
 // CHECK-NEXT:      %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x]] : f64
 // CHECK-NEXT:      %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]] : f64
 // CHECK-NEXT:      %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]]
 // CHECK-NEXT:      %[[new_rev_idx:.+]] = arith.subi %[[rev_idx]], %c1 : index
-// CHECK-NEXT:      scf.yield %[[dr_next]], %[[new_rev_idx]], %[[dx1]] : f64, index, f64
+// CHECK-NEXT:      scf.yield %[[dr_next]], %[[dx1]], %[[new_rev_idx]] : f64, f64, index
 // CHECK-NEXT:    }
-// CHECK-NEXT:    return %2#2 : f64
+// CHECK-NEXT:    return %2#1 : f64
diff --git a/enzyme/test/MLIR/ReverseMode/scf_for.mlir b/enzyme/test/MLIR/ReverseMode/scf_for.mlir
index 22f7a93..5576c10 100644
--- a/enzyme/test/MLIR/ReverseMode/scf_for.mlir
+++ b/enzyme/test/MLIR/ReverseMode/scf_for.mlir
@@ -29,14 +29,14 @@
 // CHECK-NEXT:      scf.yield %4, %inserted : f32, tensor<?xf32>
 // CHECK-NEXT:    }
 // CHECK-NEXT:    %2 = arith.addf %arg2, %cst_0 : f32
-// CHECK-NEXT:    %3:3 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %2, %arg5 = %arg1, %arg6 = %cst_0) -> (f32, index, f32) {
-// CHECK-NEXT:      %extracted = tensor.extract %1#1[%arg5] : tensor<?xf32>
+// CHECK-NEXT:    %3:4 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %2, %arg5 = %cst_0, %arg6 = %arg1, %arg7 = %cst_0) -> (f32, f32, index, f32) {
+// CHECK-NEXT:      %extracted = tensor.extract %1#1[%arg6] : tensor<?xf32>
 // CHECK-NEXT:      %4 = arith.mulf %arg4, %arg0 : f32
 // CHECK-NEXT:      %5 = arith.addf %4, %cst_0 : f32
 // CHECK-NEXT:      %6 = arith.mulf %arg4, %extracted : f32
-// CHECK-NEXT:      %7 = arith.addf %arg6, %6 : f32
-// CHECK-NEXT:      %8 = arith.subi %arg5, %c1 : index
-// CHECK-NEXT:      scf.yield %5, %8, %7 : f32, index, f32
+// CHECK-NEXT:      %7 = arith.addf %arg5, %6 : f32
+// CHECK-NEXT:      %8 = arith.subi %arg6, %c1 : index
+// CHECK-NEXT:      scf.yield %5, %7, %8, %7 : f32, f32, index, f32
 // CHECK-NEXT:    }
-// CHECK-NEXT:    return %3#2 : f32
+// CHECK-NEXT:    return %3#1 : f32
 // CHECK-NEXT:  }
diff --git a/enzyme/test/MLIR/ReverseMode/scf_for_mincut.mlir b/enzyme/test/MLIR/ReverseMode/scf_for_mincut.mlir
new file mode 100644
index 0000000..37f2369
--- /dev/null
+++ b/enzyme/test/MLIR/ReverseMode/scf_for_mincut.mlir
@@ -0,0 +1,45 @@
+// RUN: %eopt %s --enzyme-wrap="infn=main outfn= argTys=enzyme_active retTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops --enzyme-simplify-math --canonicalize | FileCheck %s
+
+module {
+  func.func @main(%arg0: f32) -> (f32) {
+    %lb = arith.constant 0 : index
+    %ub = arith.constant 10 : index
+    %step = arith.constant 1 : index
+
+    %sum = scf.for %iv = %lb to %ub step %step
+        iter_args(%sum_iter = %arg0) -> (f32) {
+      %sum_next = arith.mulf %sum_iter, %sum_iter : f32
+      %cos_next = math.cos %sum_next : f32
+      scf.yield %cos_next : f32
+    } {enzyme.enable_mincut}
+
+    return %sum : f32
+  }
+}
+
+// CHECK:  func.func @main(%arg0: f32, %arg1: f32) -> f32 {
+// CHECK-NEXT:    %c9 = arith.constant 9 : index
+// CHECK-NEXT:    %c1 = arith.constant 1 : index
+// CHECK-NEXT:    %c10 = arith.constant 10 : index
+// CHECK-NEXT:    %c0 = arith.constant 0 : index
+// CHECK-NEXT:    %[[v0:.+]] = tensor.empty() : tensor<10xf32>
+// CHECK-NEXT:    %[[for:.+]]:2 = scf.for %arg2 = %c0 to %c10 step %c1 iter_args(%arg3 = %arg0, %arg4 = %[[v0]]) -> (f32, tensor<10xf32>) {
+// CHECK-NEXT:      %[[cache:.+]] = tensor.insert %arg3 into %arg4[%arg2] : tensor<10xf32>
+// CHECK-NEXT:      %[[v3:.+]] = arith.mulf %arg3, %arg3 : f32
+// CHECK-NEXT:      %[[v4:.+]] = math.cos %[[v3]] : f32
+// CHECK-NEXT:      scf.yield %[[v4]], %[[cache]] : f32, tensor<10xf32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[revFor:.+]]:2 = scf.for %arg2 = %c0 to %c10 step %c1 iter_args(%arg3 = %arg1, %arg4 = %c9) -> (f32, index) {
+// CHECK-NEXT:      %[[cache:.+]] = tensor.extract %[[for]]#1[%arg4] : tensor<10xf32>
+// CHECK-NEXT:      %[[v3:.+]] = arith.mulf %[[cache]], %[[cache]] : f32
+// CHECK-NEXT:      %[[v4:.+]] = math.sin %[[v3]] : f32
+// CHECK-NEXT:      %[[v5:.+]] = arith.negf %[[v4]] : f32
+// CHECK-NEXT:      %[[v6:.+]] = arith.mulf %arg3, %[[v5]] : f32
+// CHECK-NEXT:      %[[v7:.+]] = arith.mulf %[[v6]], %extracted : f32
+// CHECK-NEXT:      %[[v8:.+]] = arith.mulf %[[v6]], %extracted : f32
+// CHECK-NEXT:      %[[v9:.+]] = arith.addf %[[v7]], %[[v8]] : f32
+// CHECK-NEXT:      %[[v10:.+]] = arith.subi %arg4, %c1 : index
+// CHECK-NEXT:      scf.yield %[[v9]], %[[v10]] : f32, index
+// CHECK-NEXT:    }
+// CHECK-NEXT:    return %[[revFor:.+]]#0 : f32
+// CHECK-NEXT:  }