[MLIR] autodiff_region op (#2426)

* make it AutomaticAllocationScope, remove AffineScope
* support outlining multiple autodiff_regions from the same function
diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
index b169c0a..80772f9 100644
--- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
+++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
@@ -145,6 +145,26 @@
   let hasCanonicalizer = 1;
 }
 
+def AutoDiffRegionOp : Enzyme_Op<"autodiff_region", [AutomaticAllocationScope]> {
+  let summary = "Perform reverse mode AD on a child region";
+  let arguments = (ins Variadic<AnyType>:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr<I64Attr, "1">:$width, DefaultValuedAttr<BoolAttr, "false">:$strong_zero, OptionalAttr<StrAttr>:$fn);
+  let regions = (region AnyRegion:$body);
+  let results = (outs Variadic<AnyType>:$outputs);
+
+  let assemblyFormat = [{
+    `(` $inputs `)` $body attr-dict-with-keyword `:` functional-type($inputs, results)
+  }];
+}
+
+def YieldOp : Enzyme_Op<"yield", [Pure, ReturnLike, Terminator,
+    HasParent<"AutoDiffRegionOp">]> {
+  let summary = "Yield values at the end of an autodiff_region op";
+  let arguments = (ins Variadic<AnyType>:$operands);
+  let assemblyFormat = [{
+    attr-dict ($operands^ `:` type($operands))?
+  }];
+}
+
 def BatchOp : Enzyme_Op<"batch",
     [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let summary = "Perform reverse mode AD on a funcop";
diff --git a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt
index d9b8564..720769b 100644
--- a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt
+++ b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt
@@ -8,6 +8,7 @@
     ProbProgMLIRPass.cpp
     EnzymeBatchPass.cpp
     EnzymeWrapPass.cpp
+    InlineEnzymeRegions.cpp
     PrintActivityAnalysis.cpp
     PrintAliasAnalysis.cpp
     EnzymeToMemRef.cpp
diff --git a/enzyme/Enzyme/MLIR/Passes/InlineEnzymeRegions.cpp b/enzyme/Enzyme/MLIR/Passes/InlineEnzymeRegions.cpp
new file mode 100644
index 0000000..84bdcec
--- /dev/null
+++ b/enzyme/Enzyme/MLIR/Passes/InlineEnzymeRegions.cpp
@@ -0,0 +1,268 @@
+//===- InlineEnzymeRegions.cpp - Inline/outline enzyme.autodiff  ------------ //
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements passes to inlining and outlining to convert
+// between enzyme.autodiff and enzyme.autodiff_region ops.
+//
+//===----------------------------------------------------------------------===//
+#include "Dialect/Ops.h"
+#include "Passes/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace enzyme {
+#define GEN_PASS_DEF_INLINEENZYMEINTOREGIONPASS
+#define GEN_PASS_DEF_OUTLINEENZYMEFROMREGIONPASS
+#include "Passes/Passes.h.inc"
+} // namespace enzyme
+} // namespace mlir
+
+namespace {
+constexpr static llvm::StringLiteral kFnAttrsName = "fn_attrs";
+
+static StringRef getFunctionTypeAttrName(Operation *operation) {
+  return llvm::TypeSwitch<Operation *, StringRef>(operation)
+      .Case<func::FuncOp, LLVM::LLVMFuncOp>(
+          [](auto op) { return op.getFunctionTypeAttrName(); })
+      .Default([](Operation *) {
+        llvm_unreachable("expected op with a function type");
+        return "";
+      });
+}
+
+static StringRef getArgAttrsAttrName(Operation *operation) {
+  return llvm::TypeSwitch<Operation *, StringRef>(operation)
+      .Case<func::FuncOp, LLVM::LLVMFuncOp>(
+          [](auto op) { return op.getArgAttrsAttrName(); })
+      .Default([](Operation *) {
+        llvm_unreachable("expected op with arg attrs");
+        return "";
+      });
+}
+
+static void serializeFunctionAttributes(Operation *fn,
+                                        enzyme::AutoDiffRegionOp regionOp) {
+  SmallVector<NamedAttribute> fnAttrs;
+  fnAttrs.reserve(fn->getAttrDictionary().size());
+  for (auto attr : fn->getAttrs()) {
+    // Don't store the function type because it may change when outlining
+    if (attr.getName() == getFunctionTypeAttrName(fn))
+      continue;
+    fnAttrs.push_back(attr);
+  }
+
+  regionOp->setAttr(kFnAttrsName,
+                    DictionaryAttr::getWithSorted(fn->getContext(), fnAttrs));
+}
+
+static void deserializeFunctionAttributes(enzyme::AutoDiffRegionOp op,
+                                          Operation *outlinedFunc,
+                                          unsigned addedArgCount) {
+  if (!op->hasAttrOfType<DictionaryAttr>(kFnAttrsName))
+    return;
+
+  MLIRContext *ctx = op->getContext();
+  SmallVector<NamedAttribute> fnAttrs;
+  for (auto attr : op->getAttrOfType<DictionaryAttr>(kFnAttrsName)) {
+    // New arguments are potentially added when outlining due to references to
+    // values outside the region. Insert an empty arg attr for each newly
+    // added argument.
+    if (attr.getName() == getArgAttrsAttrName(outlinedFunc)) {
+      SmallVector<Attribute> argAttrs(
+          cast<ArrayAttr>(attr.getValue()).getAsRange<DictionaryAttr>());
+      for (unsigned i = 0; i < addedArgCount; ++i)
+        argAttrs.push_back(DictionaryAttr::getWithSorted(ctx, {}));
+      fnAttrs.push_back(
+          NamedAttribute(attr.getName(), ArrayAttr::get(ctx, argAttrs)));
+    } else
+      fnAttrs.push_back(attr);
+  }
+  outlinedFunc->setAttrs(fnAttrs);
+}
+
+struct InlineEnzymeAutoDiff : public OpRewritePattern<enzyme::AutoDiffOp> {
+  using OpRewritePattern<enzyme::AutoDiffOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(enzyme::AutoDiffOp op,
+                                PatternRewriter &rewriter) const override {
+    SymbolTableCollection symbolTable;
+    auto *symbol = symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr());
+    auto fn = cast<FunctionOpInterface>(symbol);
+    // Use a StringAttr rather than a SymbolRefAttr so the function can get
+    // symbol-DCE'd
+    auto fnAttr = StringAttr::get(op.getContext(), op.getFn());
+    auto regionOp = rewriter.replaceOpWithNewOp<enzyme::AutoDiffRegionOp>(
+        op, op.getResultTypes(), op.getInputs(), op.getActivity(),
+        op.getRetActivity(), op.getWidth(), op.getStrongZero(), fnAttr);
+    serializeFunctionAttributes(fn, regionOp);
+    rewriter.cloneRegionBefore(fn.getFunctionBody(), regionOp.getBody(),
+                               regionOp.getBody().begin());
+    SmallVector<Operation *> toErase;
+    for (Operation &bodyOp : regionOp.getBody().getOps()) {
+      if (bodyOp.hasTrait<OpTrait::ReturnLike>()) {
+        PatternRewriter::InsertionGuard insertionGuard(rewriter);
+        rewriter.setInsertionPoint(&bodyOp);
+        enzyme::YieldOp::create(rewriter, bodyOp.getLoc(),
+                                bodyOp.getOperands());
+        toErase.push_back(&bodyOp);
+      }
+    }
+
+    for (Operation *opToErase : toErase)
+      rewriter.eraseOp(opToErase);
+
+    return success();
+  }
+};
+
+// Based on
+// https://github.com/llvm/llvm-project/blob/665da0a1649814471739c41a702e0e9447316b20/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+static FailureOr<func::FuncOp>
+outlineAutoDiffFunc(enzyme::AutoDiffRegionOp op, StringRef funcName,
+                    SmallVectorImpl<Value> &inputs,
+                    SmallVectorImpl<enzyme::Activity> &argActivities,
+                    OpBuilder &builder) {
+  Region &autodiffRegion = op.getBody();
+  SmallVector<Type> argTypes(autodiffRegion.getArgumentTypes()), resultTypes;
+  SmallVector<Location> argLocs(autodiffRegion.getNumArguments(), op.getLoc());
+  // Infer the result types from an enzyme.yield op
+  bool found = false;
+  autodiffRegion.walk([&](enzyme::YieldOp yieldOp) {
+    found = true;
+    llvm::append_range(resultTypes, yieldOp.getOperandTypes());
+    return WalkResult::interrupt();
+  });
+  if (!found)
+    return op.emitError()
+           << "enzyme.yield was not found in enzyme.autodiff_region";
+
+  llvm::SetVector<Value> freeValues;
+  getUsedValuesDefinedAbove(autodiffRegion, freeValues);
+
+  for (Value value : freeValues) {
+    inputs.push_back(value);
+    argTypes.push_back(value.getType());
+    argLocs.push_back(value.getLoc());
+    argActivities.push_back(enzyme::Activity::enzyme_const);
+  }
+  auto fnType = builder.getFunctionType(argTypes, resultTypes);
+
+  // FIXME: making this location the location of the
+  // enzyme.autodiff_region op causes translation to LLVM IR to fail due
+  // to some issue with the dbg info.
+  Location loc = UnknownLoc::get(op.getContext());
+  auto outlinedFunc = func::FuncOp::create(builder, loc, funcName, fnType);
+  Region &outlinedBody = outlinedFunc.getBody();
+  deserializeFunctionAttributes(op, outlinedFunc, freeValues.size());
+
+  // Copy over the function body.
+  IRMapping map;
+  Block *entryBlock = builder.createBlock(&outlinedBody, outlinedBody.begin(),
+                                          argTypes, argLocs);
+  unsigned originalArgCount = autodiffRegion.getNumArguments();
+  for (const auto &arg : autodiffRegion.getArguments())
+    map.map(arg, entryBlock->getArgument(arg.getArgNumber()));
+  for (const auto &operand : enumerate(freeValues))
+    map.map(operand.value(),
+            entryBlock->getArgument(originalArgCount + operand.index()));
+  autodiffRegion.cloneInto(&outlinedBody, map);
+
+  // Replace the terminators with returns
+  for (Block &block : autodiffRegion) {
+    Block *clonedBlock = map.lookup(&block);
+    auto terminator = dyn_cast<enzyme::YieldOp>(clonedBlock->getTerminator());
+    if (!terminator)
+      continue;
+    OpBuilder replacer(terminator);
+    func::ReturnOp::create(replacer, terminator->getLoc(),
+                           terminator->getOperands());
+    terminator->erase();
+  }
+
+  // cloneInto results in two blocks, the actual outlined entry block and the
+  // cloned autodiff_region entry block. Splice the cloned entry block into
+  // the actual entry block, then erase the cloned autodiff_region entry.
+  Block *clonedEntry = map.lookup(&autodiffRegion.front());
+  entryBlock->getOperations().splice(entryBlock->getOperations().end(),
+                                     clonedEntry->getOperations());
+  clonedEntry->erase();
+  return outlinedFunc;
+}
+
+LogicalResult outlineEnzymeAutoDiffRegion(enzyme::AutoDiffRegionOp op,
+                                          StringRef defaultFuncName,
+                                          OpBuilder &builder) {
+  StringRef funcName = op.getFn().value_or(defaultFuncName);
+  OpBuilder::InsertionGuard insertionGuard(builder);
+  builder.setInsertionPointAfter(op->getParentOfType<SymbolOpInterface>());
+
+  SmallVector<enzyme::Activity> argActivities =
+      llvm::map_to_vector(op.getActivity().getAsRange<enzyme::ActivityAttr>(),
+                          [](auto attr) { return attr.getValue(); });
+  SmallVector<Value> inputs(op.getInputs());
+
+  FailureOr<func::FuncOp> outlinedFunc =
+      outlineAutoDiffFunc(op, funcName, inputs, argActivities, builder);
+  if (failed(outlinedFunc))
+    return failure();
+  builder.setInsertionPoint(op);
+  ArrayAttr argActivityAttr = builder.getArrayAttr(llvm::map_to_vector(
+      argActivities, [&op](enzyme::Activity actv) -> Attribute {
+        return enzyme::ActivityAttr::get(op.getContext(), actv);
+      }));
+  auto newOp = enzyme::AutoDiffOp::create(
+      builder, op.getLoc(), op.getResultTypes(), outlinedFunc->getName(),
+      inputs, argActivityAttr, op.getRetActivity(), op.getWidth(),
+      op.getStrongZero());
+  op.replaceAllUsesWith(newOp.getResults());
+  op.erase();
+  return success();
+}
+
+struct InlineEnzymeIntoRegion
+    : public enzyme::impl::InlineEnzymeIntoRegionPassBase<
+          InlineEnzymeIntoRegion> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    patterns.insert<InlineEnzymeAutoDiff>(&getContext());
+
+    GreedyRewriteConfig config;
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
+  }
+};
+
+struct OutlineEnzymeFromRegion
+    : public enzyme::impl::OutlineEnzymeFromRegionPassBase<
+          OutlineEnzymeFromRegion> {
+  void runOnOperation() override {
+    SmallVector<enzyme::AutoDiffRegionOp> toOutline;
+    getOperation()->walk(
+        [&](enzyme::AutoDiffRegionOp op) { toOutline.push_back(op); });
+
+    OpBuilder builder(getOperation());
+    unsigned increment = 0;
+    for (auto regionOp : toOutline) {
+      auto symbol = regionOp->getParentOfType<SymbolOpInterface>();
+      std::string defaultName =
+          (Twine(symbol.getName(), "_to_diff") + Twine(increment)).str();
+      if (failed(outlineEnzymeAutoDiffRegion(regionOp, defaultName, builder)))
+        return signalPassFailure();
+
+      ++increment;
+    }
+  }
+};
+
+} // namespace
diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td
index 3976c9f..33d05fb 100644
--- a/enzyme/Enzyme/MLIR/Passes/Passes.td
+++ b/enzyme/Enzyme/MLIR/Passes/Passes.td
@@ -185,6 +185,17 @@
   let summary = "Print the results of alias analysis";
 }
 
+def InlineEnzymeIntoRegionPass : Pass<"inline-enzyme-regions"> {
+  let summary = "Inline enzyme.autodiff ops into enzyme.autodiff_region ops.";
+}
+
+def OutlineEnzymeFromRegionPass : Pass<"outline-enzyme-regions"> {
+  let summary = "Outlines enzyme.autodiff_region ops into enzyme.autodiff ops.";
+  let dependentDialects = [
+    "func::FuncDialect"
+  ];
+}
+
 def EnzymeOpsToMemRefPass : Pass<"convert-enzyme-to-memref"> {
   let summary = "Lower custom Enzyme ops to the MemRef dialect";
   let dependentDialects = [
diff --git a/enzyme/test/MLIR/Passes/region_inline.mlir b/enzyme/test/MLIR/Passes/region_inline.mlir
new file mode 100644
index 0000000..ce2421c
--- /dev/null
+++ b/enzyme/test/MLIR/Passes/region_inline.mlir
@@ -0,0 +1,54 @@
+// RUN: %eopt --inline-enzyme-regions --split-input-file %s | FileCheck %s
+
+func.func @square(%x: f64) -> f64 {
+  %next = arith.mulf %x, %x : f64
+  return %next : f64
+}
+
+func.func @dsquare(%x: f64, %dr: f64) -> f64 {
+  %r = enzyme.autodiff @square(%x, %dr)
+    {
+      activity=[#enzyme<activity enzyme_active>],
+      ret_activity=[#enzyme<activity enzyme_activenoneed>]
+    } : (f64, f64) -> f64
+  return %r : f64
+}
+
+// CHECK:  func.func @dsquare(%arg0: f64, %arg1: f64) -> f64 {
+// CHECK-NEXT:    %0 = enzyme.autodiff_region(%arg0, %arg1) {
+// CHECK-NEXT:    ^bb0(%arg2: f64):
+// CHECK-NEXT:      %1 = arith.mulf %arg2, %arg2 : f64
+// CHECK-NEXT:      enzyme.yield %1 : f64
+// CHECK-NEXT:    } attributes {activity = [#enzyme<activity enzyme_active>], fn = "square", fn_attrs = {sym_name = "square"}, ret_activity = [#enzyme<activity enzyme_activenoneed>]} : (f64, f64) -> f64
+// CHECK-NEXT:    return %0 : f64
+// CHECK-NEXT:  }
+
+// -----
+
+llvm.func internal @_Z6squarePfS_(%arg0: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg1: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.writeonly}) attributes {dso_local, frame_pointer = #llvm.framePointerKind<all>, memory_effects = #llvm.memory_effects<other = none, argMem = readwrite, inaccessibleMem = none>, no_unwind, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "sm_86"]], target_cpu = "sm_86", target_features = #llvm.target_features<["+ptx88", "+sm_86"]>, will_return, sym_visibility = "private"} {
+  %0 = llvm.mlir.constant(5.600000e+00 : f64) : f64
+  %1 = nvvm.read.ptx.sreg.tid.x : i32
+  %2 = llvm.zext nneg %1 : i32 to i64
+  %3 = llvm.getelementptr inbounds|nuw %arg0[%2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+  %4 = llvm.load %3 {alignment = 4 : i64} : !llvm.ptr -> f32
+  %5 = llvm.fpext %4 : f32 to f64
+  %6 = llvm.fmul %5, %0 {fastmathFlags = #llvm.fastmath<contract>} : f64
+  %7 = llvm.fptrunc %6 : f64 to f32
+  %8 = llvm.getelementptr inbounds|nuw %arg1[%2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+  llvm.store %7, %8 {alignment = 4 : i64} : f32, !llvm.ptr
+  llvm.return
+}
+
+llvm.func internal @d_Z6squarePfS_(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr, %arg3: !llvm.ptr) {
+  enzyme.autodiff @_Z6squarePfS_(%arg0, %arg1, %arg2, %arg3)
+    {
+      activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_dup>],
+      ret_activity=[]
+    } : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
+  llvm.return
+}
+
+// Make sure that function attributes are preserved
+// CHECK: llvm.func internal @d_Z6squarePfS_
+// CHECK:    enzyme.autodiff_region(%arg0, %arg1, %arg2, %arg3) {
+// CHECK:    } attributes {activity = [#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_dup>], fn = "_Z6squarePfS_", fn_attrs = {CConv = #llvm.cconv<ccc>, arg_attrs = [{llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.writeonly}], dso_local, frame_pointer = #llvm.framePointerKind<all>, linkage = #llvm.linkage<internal>, memory_effects = #llvm.memory_effects<other = none, argMem = readwrite, inaccessibleMem = none>, no_unwind, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "sm_86"]], sym_name = "_Z6squarePfS_", sym_visibility = "private", target_cpu = "sm_86", target_features = #llvm.target_features<["+ptx88", "+sm_86"]>, unnamed_addr = 0 : i64, visibility_ = 0 : i64, will_return}, ret_activity = []} : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> (
diff --git a/enzyme/test/MLIR/Passes/region_outline.mlir b/enzyme/test/MLIR/Passes/region_outline.mlir
new file mode 100644
index 0000000..5ce1878
--- /dev/null
+++ b/enzyme/test/MLIR/Passes/region_outline.mlir
@@ -0,0 +1,100 @@
+// RUN: %eopt --outline-enzyme-regions --split-input-file %s | FileCheck %s
+
+func.func @to_outline(%26: !llvm.ptr, %27: !llvm.ptr, %28: !llvm.ptr, %29: !llvm.ptr) {
+  %cst = arith.constant 5.6 : f64
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  scf.parallel (%arg2, %arg3, %arg4) = (%c0, %c0, %c0) to (%c1, %c1, %c1) step (%c1, %c1, %c1) {
+    scf.parallel (%arg5, %arg6, %arg7) = (%c0, %c0, %c0) to (%c4, %c1, %c1) step (%c1, %c1, %c1) {
+      memref.alloca_scope  {
+        scf.execute_region {
+          enzyme.autodiff_region(%26, %27, %28, %29) {
+          ^bb0(%arg8: !llvm.ptr, %arg9: !llvm.ptr):
+            %63 = arith.index_castui %arg5 : index to i64
+            %64 = llvm.getelementptr inbounds|nuw %arg8[%63] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+            %65 = llvm.load %64 invariant {alignment = 4 : i64} : !llvm.ptr -> f32
+            %66 = arith.extf %65 : f32 to f64
+            %67 = arith.mulf %66, %cst {fastmathFlags = #llvm.fastmath<contract>} : f64
+            %68 = arith.truncf %67 : f64 to f32
+            %69 = llvm.getelementptr inbounds|nuw %arg9[%63] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+            llvm.store %68, %69 {alignment = 4 : i64} : f32, !llvm.ptr
+            enzyme.yield
+          } attributes {activity = [#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_dup>], fn = "outlined_func", ret_activity = []} : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
+          scf.yield
+        }
+      }
+      scf.reduce 
+    }
+    scf.reduce 
+  }
+  return
+}
+
+// CHECK: func.func @outlined_func(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: index, %arg3: f64) {
+// CHECK-NEXT:    %0 = arith.index_castui %arg2 : index to i64
+// CHECK-NEXT:    %1 = llvm.getelementptr inbounds|nuw %arg0[%0] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK-NEXT:    %2 = llvm.load %1 invariant {alignment = 4 : i64} : !llvm.ptr -> f32
+// CHECK-NEXT:    %3 = arith.extf %2 : f32 to f64
+// CHECK-NEXT:    %4 = arith.mulf %3, %arg3 {fastmathFlags = #llvm.fastmath<contract>} : f64
+// CHECK-NEXT:    %5 = arith.truncf %4 : f64 to f32
+// CHECK-NEXT:    %6 = llvm.getelementptr inbounds|nuw %arg1[%0] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK-NEXT:    llvm.store %5, %6 {alignment = 4 : i64} : f32, !llvm.ptr
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// -----
+
+llvm.func internal @d_Z6squarePfS_(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr, %arg3: !llvm.ptr) {
+  %0 = llvm.mlir.constant(5.600000e+00 : f64) : f64
+  enzyme.autodiff_region(%arg0, %arg1, %arg2, %arg3) {
+  ^bb0(%arg4: !llvm.ptr, %arg5: !llvm.ptr):
+    %1 = nvvm.read.ptx.sreg.tid.x : i32
+    %2 = llvm.zext nneg %1 : i32 to i64
+    %3 = llvm.getelementptr inbounds|nuw %arg4[%2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    %4 = llvm.load %3 {alignment = 4 : i64} : !llvm.ptr -> f32
+    %5 = llvm.fpext %4 : f32 to f64
+    %6 = llvm.fmul %5, %0 {fastmathFlags = #llvm.fastmath<contract>} : f64
+    %7 = llvm.fptrunc %6 : f64 to f32
+    %8 = llvm.getelementptr inbounds|nuw %arg5[%2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    llvm.store %7, %8 {alignment = 4 : i64} : f32, !llvm.ptr
+    enzyme.yield
+  } attributes {activity = [#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_dup>], fn = "_Z6squarePfS_", fn_attrs = {CConv = #llvm.cconv<ccc>, arg_attrs = [{llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.writeonly}], dso_local, frame_pointer = #llvm.framePointerKind<all>, linkage = #llvm.linkage<internal>, memory_effects = #llvm.memory_effects<other = none, argMem = readwrite, inaccessibleMem = none>, no_unwind, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "sm_86"]], sym_name = "_Z6squarePfS_", sym_visibility = "private", target_cpu = "sm_86", target_features = #llvm.target_features<["+ptx88", "+sm_86"]>, unnamed_addr = 0 : i64, visibility_ = 0 : i64, will_return}, ret_activity = []} : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
+  llvm.return
+}
+
+// Attributes should be passed back to the outlined function
+// CHECK: func.func private @_Z6squarePfS_(%arg0: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg1: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.writeonly}, %arg2: f64) attributes {CConv = #llvm.cconv<ccc>, dso_local, frame_pointer = #llvm.framePointerKind<all>, linkage = #llvm.linkage<internal>, memory_effects = #llvm.memory_effects<other = none, argMem = readwrite, inaccessibleMem = none>, no_unwind, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "sm_86"]], target_cpu = "sm_86", target_features = #llvm.target_features<["+ptx88", "+sm_86"]>, unnamed_addr = 0 : i64, visibility_ = 0 : i64, will_return} {
+
+// -----
+
+func.func @outline_multi(%x: f64, %dr: f64) -> (f64, f64) {
+  %r0 = enzyme.autodiff_region(%x, %dr) {
+  ^bb0(%arg0: f64):
+    %sq = arith.mulf %arg0, %arg0 : f64
+    enzyme.yield %sq : f64
+  } attributes {activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_activenoneed>]} : (f64, f64) -> f64
+
+  %r1 = enzyme.autodiff_region(%x, %dr) {
+  ^bb0(%arg0: f64):
+    %add = arith.addf %arg0, %arg0 : f64
+    enzyme.yield %add : f64
+  } attributes {activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_activenoneed>]} : (f64, f64) -> f64
+  return %r0, %r1 : f64, f64
+}
+
+// CHECK: func.func @outline_multi(%arg0: f64, %arg1: f64) -> (f64, f64) {
+// CHECK-NEXT:    %0 = enzyme.autodiff @outline_multi_to_diff0(%arg0, %arg1) {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>]} : (f64, f64) -> f64
+// CHECK-NEXT:    %1 = enzyme.autodiff @outline_multi_to_diff1(%arg0, %arg1) {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>]} : (f64, f64) -> f64
+// CHECK-NEXT:    return %0, %1 : f64, f64
+// CHECK-NEXT:  }
+
+// CHECK: func.func @outline_multi_to_diff1(%arg0: f64) -> f64 {
+// CHECK-NEXT:    %0 = arith.addf %arg0, %arg0 : f64
+// CHECK-NEXT:    return %0 : f64
+// CHECK-NEXT:  }
+
+// CHECK: func.func @outline_multi_to_diff0(%arg0: f64) -> f64 {
+// CHECK-NEXT:    %0 = arith.mulf %arg0, %arg0 : f64
+// CHECK-NEXT:    return %0 : f64
+// CHECK-NEXT:  }