[MLIR] autodiff_region op (#2426)
* make it AutomaticAllocationScope, remove AffineScope
* support outlining multiple autodiff_regions from the same function
diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
index b169c0a..80772f9 100644
--- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
+++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
@@ -145,6 +145,26 @@
let hasCanonicalizer = 1;
}
+def AutoDiffRegionOp : Enzyme_Op<"autodiff_region", [AutomaticAllocationScope]> {
+ let summary = "Perform reverse mode AD on a child region";
+ let arguments = (ins Variadic<AnyType>:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr<I64Attr, "1">:$width, DefaultValuedAttr<BoolAttr, "false">:$strong_zero, OptionalAttr<StrAttr>:$fn);
+ let regions = (region AnyRegion:$body);
+ let results = (outs Variadic<AnyType>:$outputs);
+
+ let assemblyFormat = [{
+ `(` $inputs `)` $body attr-dict-with-keyword `:` functional-type($inputs, results)
+ }];
+}
+
+def YieldOp : Enzyme_Op<"yield", [Pure, ReturnLike, Terminator,
+ HasParent<"AutoDiffRegionOp">]> {
+ let summary = "Yield values at the end of an autodiff_region op";
+ let arguments = (ins Variadic<AnyType>:$operands);
+ let assemblyFormat = [{
+ attr-dict ($operands^ `:` type($operands))?
+ }];
+}
+
def BatchOp : Enzyme_Op<"batch",
[DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Perform reverse mode AD on a funcop";
diff --git a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt
index d9b8564..720769b 100644
--- a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt
+++ b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt
@@ -8,6 +8,7 @@
ProbProgMLIRPass.cpp
EnzymeBatchPass.cpp
EnzymeWrapPass.cpp
+ InlineEnzymeRegions.cpp
PrintActivityAnalysis.cpp
PrintAliasAnalysis.cpp
EnzymeToMemRef.cpp
diff --git a/enzyme/Enzyme/MLIR/Passes/InlineEnzymeRegions.cpp b/enzyme/Enzyme/MLIR/Passes/InlineEnzymeRegions.cpp
new file mode 100644
index 0000000..84bdcec
--- /dev/null
+++ b/enzyme/Enzyme/MLIR/Passes/InlineEnzymeRegions.cpp
@@ -0,0 +1,268 @@
+//===- InlineEnzymeRegions.cpp - Inline/outline enzyme.autodiff ------------ //
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements passes to inlining and outlining to convert
+// between enzyme.autodiff and enzyme.autodiff_region ops.
+//
+//===----------------------------------------------------------------------===//
+#include "Dialect/Ops.h"
+#include "Passes/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace enzyme {
+#define GEN_PASS_DEF_INLINEENZYMEINTOREGIONPASS
+#define GEN_PASS_DEF_OUTLINEENZYMEFROMREGIONPASS
+#include "Passes/Passes.h.inc"
+} // namespace enzyme
+} // namespace mlir
+
+namespace {
+constexpr static llvm::StringLiteral kFnAttrsName = "fn_attrs";
+
+static StringRef getFunctionTypeAttrName(Operation *operation) {
+ return llvm::TypeSwitch<Operation *, StringRef>(operation)
+ .Case<func::FuncOp, LLVM::LLVMFuncOp>(
+ [](auto op) { return op.getFunctionTypeAttrName(); })
+ .Default([](Operation *) {
+ llvm_unreachable("expected op with a function type");
+ return "";
+ });
+}
+
+static StringRef getArgAttrsAttrName(Operation *operation) {
+ return llvm::TypeSwitch<Operation *, StringRef>(operation)
+ .Case<func::FuncOp, LLVM::LLVMFuncOp>(
+ [](auto op) { return op.getArgAttrsAttrName(); })
+ .Default([](Operation *) {
+ llvm_unreachable("expected op with arg attrs");
+ return "";
+ });
+}
+
+static void serializeFunctionAttributes(Operation *fn,
+ enzyme::AutoDiffRegionOp regionOp) {
+ SmallVector<NamedAttribute> fnAttrs;
+ fnAttrs.reserve(fn->getAttrDictionary().size());
+ for (auto attr : fn->getAttrs()) {
+ // Don't store the function type because it may change when outlining
+ if (attr.getName() == getFunctionTypeAttrName(fn))
+ continue;
+ fnAttrs.push_back(attr);
+ }
+
+ regionOp->setAttr(kFnAttrsName,
+ DictionaryAttr::getWithSorted(fn->getContext(), fnAttrs));
+}
+
+static void deserializeFunctionAttributes(enzyme::AutoDiffRegionOp op,
+ Operation *outlinedFunc,
+ unsigned addedArgCount) {
+ if (!op->hasAttrOfType<DictionaryAttr>(kFnAttrsName))
+ return;
+
+ MLIRContext *ctx = op->getContext();
+ SmallVector<NamedAttribute> fnAttrs;
+ for (auto attr : op->getAttrOfType<DictionaryAttr>(kFnAttrsName)) {
+ // New arguments are potentially added when outlining due to references to
+ // values outside the region. Insert an empty arg attr for each newly
+ // added argument.
+ if (attr.getName() == getArgAttrsAttrName(outlinedFunc)) {
+ SmallVector<Attribute> argAttrs(
+ cast<ArrayAttr>(attr.getValue()).getAsRange<DictionaryAttr>());
+ for (unsigned i = 0; i < addedArgCount; ++i)
+ argAttrs.push_back(DictionaryAttr::getWithSorted(ctx, {}));
+ fnAttrs.push_back(
+ NamedAttribute(attr.getName(), ArrayAttr::get(ctx, argAttrs)));
+ } else
+ fnAttrs.push_back(attr);
+ }
+ outlinedFunc->setAttrs(fnAttrs);
+}
+
+struct InlineEnzymeAutoDiff : public OpRewritePattern<enzyme::AutoDiffOp> {
+ using OpRewritePattern<enzyme::AutoDiffOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(enzyme::AutoDiffOp op,
+ PatternRewriter &rewriter) const override {
+ SymbolTableCollection symbolTable;
+ auto *symbol = symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr());
+ auto fn = cast<FunctionOpInterface>(symbol);
+ // Use a StringAttr rather than a SymbolRefAttr so the function can get
+ // symbol-DCE'd
+ auto fnAttr = StringAttr::get(op.getContext(), op.getFn());
+ auto regionOp = rewriter.replaceOpWithNewOp<enzyme::AutoDiffRegionOp>(
+ op, op.getResultTypes(), op.getInputs(), op.getActivity(),
+ op.getRetActivity(), op.getWidth(), op.getStrongZero(), fnAttr);
+ serializeFunctionAttributes(fn, regionOp);
+ rewriter.cloneRegionBefore(fn.getFunctionBody(), regionOp.getBody(),
+ regionOp.getBody().begin());
+ SmallVector<Operation *> toErase;
+ for (Operation &bodyOp : regionOp.getBody().getOps()) {
+ if (bodyOp.hasTrait<OpTrait::ReturnLike>()) {
+ PatternRewriter::InsertionGuard insertionGuard(rewriter);
+ rewriter.setInsertionPoint(&bodyOp);
+ enzyme::YieldOp::create(rewriter, bodyOp.getLoc(),
+ bodyOp.getOperands());
+ toErase.push_back(&bodyOp);
+ }
+ }
+
+ for (Operation *opToErase : toErase)
+ rewriter.eraseOp(opToErase);
+
+ return success();
+ }
+};
+
+// Based on
+// https://github.com/llvm/llvm-project/blob/665da0a1649814471739c41a702e0e9447316b20/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+static FailureOr<func::FuncOp>
+outlineAutoDiffFunc(enzyme::AutoDiffRegionOp op, StringRef funcName,
+ SmallVectorImpl<Value> &inputs,
+ SmallVectorImpl<enzyme::Activity> &argActivities,
+ OpBuilder &builder) {
+ Region &autodiffRegion = op.getBody();
+ SmallVector<Type> argTypes(autodiffRegion.getArgumentTypes()), resultTypes;
+ SmallVector<Location> argLocs(autodiffRegion.getNumArguments(), op.getLoc());
+ // Infer the result types from an enzyme.yield op
+ bool found = false;
+ autodiffRegion.walk([&](enzyme::YieldOp yieldOp) {
+ found = true;
+ llvm::append_range(resultTypes, yieldOp.getOperandTypes());
+ return WalkResult::interrupt();
+ });
+ if (!found)
+ return op.emitError()
+ << "enzyme.yield was not found in enzyme.autodiff_region";
+
+ llvm::SetVector<Value> freeValues;
+ getUsedValuesDefinedAbove(autodiffRegion, freeValues);
+
+ for (Value value : freeValues) {
+ inputs.push_back(value);
+ argTypes.push_back(value.getType());
+ argLocs.push_back(value.getLoc());
+ argActivities.push_back(enzyme::Activity::enzyme_const);
+ }
+ auto fnType = builder.getFunctionType(argTypes, resultTypes);
+
+ // FIXME: making this location the location of the
+ // enzyme.autodiff_region op causes translation to LLVM IR to fail due
+ // to some issue with the dbg info.
+ Location loc = UnknownLoc::get(op.getContext());
+ auto outlinedFunc = func::FuncOp::create(builder, loc, funcName, fnType);
+ Region &outlinedBody = outlinedFunc.getBody();
+ deserializeFunctionAttributes(op, outlinedFunc, freeValues.size());
+
+ // Copy over the function body.
+ IRMapping map;
+ Block *entryBlock = builder.createBlock(&outlinedBody, outlinedBody.begin(),
+ argTypes, argLocs);
+ unsigned originalArgCount = autodiffRegion.getNumArguments();
+ for (const auto &arg : autodiffRegion.getArguments())
+ map.map(arg, entryBlock->getArgument(arg.getArgNumber()));
+ for (const auto &operand : enumerate(freeValues))
+ map.map(operand.value(),
+ entryBlock->getArgument(originalArgCount + operand.index()));
+ autodiffRegion.cloneInto(&outlinedBody, map);
+
+ // Replace the terminators with returns
+ for (Block &block : autodiffRegion) {
+ Block *clonedBlock = map.lookup(&block);
+ auto terminator = dyn_cast<enzyme::YieldOp>(clonedBlock->getTerminator());
+ if (!terminator)
+ continue;
+ OpBuilder replacer(terminator);
+ func::ReturnOp::create(replacer, terminator->getLoc(),
+ terminator->getOperands());
+ terminator->erase();
+ }
+
+ // cloneInto results in two blocks, the actual outlined entry block and the
+ // cloned autodiff_region entry block. Splice the cloned entry block into
+ // the actual entry block, then erase the cloned autodiff_region entry.
+ Block *clonedEntry = map.lookup(&autodiffRegion.front());
+ entryBlock->getOperations().splice(entryBlock->getOperations().end(),
+ clonedEntry->getOperations());
+ clonedEntry->erase();
+ return outlinedFunc;
+}
+
+LogicalResult outlineEnzymeAutoDiffRegion(enzyme::AutoDiffRegionOp op,
+ StringRef defaultFuncName,
+ OpBuilder &builder) {
+ StringRef funcName = op.getFn().value_or(defaultFuncName);
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPointAfter(op->getParentOfType<SymbolOpInterface>());
+
+ SmallVector<enzyme::Activity> argActivities =
+ llvm::map_to_vector(op.getActivity().getAsRange<enzyme::ActivityAttr>(),
+ [](auto attr) { return attr.getValue(); });
+ SmallVector<Value> inputs(op.getInputs());
+
+ FailureOr<func::FuncOp> outlinedFunc =
+ outlineAutoDiffFunc(op, funcName, inputs, argActivities, builder);
+ if (failed(outlinedFunc))
+ return failure();
+ builder.setInsertionPoint(op);
+ ArrayAttr argActivityAttr = builder.getArrayAttr(llvm::map_to_vector(
+ argActivities, [&op](enzyme::Activity actv) -> Attribute {
+ return enzyme::ActivityAttr::get(op.getContext(), actv);
+ }));
+ auto newOp = enzyme::AutoDiffOp::create(
+ builder, op.getLoc(), op.getResultTypes(), outlinedFunc->getName(),
+ inputs, argActivityAttr, op.getRetActivity(), op.getWidth(),
+ op.getStrongZero());
+ op.replaceAllUsesWith(newOp.getResults());
+ op.erase();
+ return success();
+}
+
+struct InlineEnzymeIntoRegion
+ : public enzyme::impl::InlineEnzymeIntoRegionPassBase<
+ InlineEnzymeIntoRegion> {
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ patterns.insert<InlineEnzymeAutoDiff>(&getContext());
+
+ GreedyRewriteConfig config;
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
+ }
+};
+
+struct OutlineEnzymeFromRegion
+ : public enzyme::impl::OutlineEnzymeFromRegionPassBase<
+ OutlineEnzymeFromRegion> {
+ void runOnOperation() override {
+ SmallVector<enzyme::AutoDiffRegionOp> toOutline;
+ getOperation()->walk(
+ [&](enzyme::AutoDiffRegionOp op) { toOutline.push_back(op); });
+
+ OpBuilder builder(getOperation());
+ unsigned increment = 0;
+ for (auto regionOp : toOutline) {
+ auto symbol = regionOp->getParentOfType<SymbolOpInterface>();
+ std::string defaultName =
+ (Twine(symbol.getName(), "_to_diff") + Twine(increment)).str();
+ if (failed(outlineEnzymeAutoDiffRegion(regionOp, defaultName, builder)))
+ return signalPassFailure();
+
+ ++increment;
+ }
+ }
+};
+
+} // namespace
diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td
index 3976c9f..33d05fb 100644
--- a/enzyme/Enzyme/MLIR/Passes/Passes.td
+++ b/enzyme/Enzyme/MLIR/Passes/Passes.td
@@ -185,6 +185,17 @@
let summary = "Print the results of alias analysis";
}
+def InlineEnzymeIntoRegionPass : Pass<"inline-enzyme-regions"> {
+ let summary = "Inline enzyme.autodiff ops into enzyme.autodiff_region ops.";
+}
+
+def OutlineEnzymeFromRegionPass : Pass<"outline-enzyme-regions"> {
+ let summary = "Outlines enzyme.autodiff_region ops into enzyme.autodiff ops.";
+ let dependentDialects = [
+ "func::FuncDialect"
+ ];
+}
+
def EnzymeOpsToMemRefPass : Pass<"convert-enzyme-to-memref"> {
let summary = "Lower custom Enzyme ops to the MemRef dialect";
let dependentDialects = [
diff --git a/enzyme/test/MLIR/Passes/region_inline.mlir b/enzyme/test/MLIR/Passes/region_inline.mlir
new file mode 100644
index 0000000..ce2421c
--- /dev/null
+++ b/enzyme/test/MLIR/Passes/region_inline.mlir
@@ -0,0 +1,54 @@
+// RUN: %eopt --inline-enzyme-regions --split-input-file %s | FileCheck %s
+
+func.func @square(%x: f64) -> f64 {
+ %next = arith.mulf %x, %x : f64
+ return %next : f64
+}
+
+func.func @dsquare(%x: f64, %dr: f64) -> f64 {
+ %r = enzyme.autodiff @square(%x, %dr)
+ {
+ activity=[#enzyme<activity enzyme_active>],
+ ret_activity=[#enzyme<activity enzyme_activenoneed>]
+ } : (f64, f64) -> f64
+ return %r : f64
+}
+
+// CHECK: func.func @dsquare(%arg0: f64, %arg1: f64) -> f64 {
+// CHECK-NEXT: %0 = enzyme.autodiff_region(%arg0, %arg1) {
+// CHECK-NEXT: ^bb0(%arg2: f64):
+// CHECK-NEXT: %1 = arith.mulf %arg2, %arg2 : f64
+// CHECK-NEXT: enzyme.yield %1 : f64
+// CHECK-NEXT: } attributes {activity = [#enzyme<activity enzyme_active>], fn = "square", fn_attrs = {sym_name = "square"}, ret_activity = [#enzyme<activity enzyme_activenoneed>]} : (f64, f64) -> f64
+// CHECK-NEXT: return %0 : f64
+// CHECK-NEXT: }
+
+// -----
+
+llvm.func internal @_Z6squarePfS_(%arg0: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg1: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.writeonly}) attributes {dso_local, frame_pointer = #llvm.framePointerKind<all>, memory_effects = #llvm.memory_effects<other = none, argMem = readwrite, inaccessibleMem = none>, no_unwind, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "sm_86"]], target_cpu = "sm_86", target_features = #llvm.target_features<["+ptx88", "+sm_86"]>, will_return, sym_visibility = "private"} {
+ %0 = llvm.mlir.constant(5.600000e+00 : f64) : f64
+ %1 = nvvm.read.ptx.sreg.tid.x : i32
+ %2 = llvm.zext nneg %1 : i32 to i64
+ %3 = llvm.getelementptr inbounds|nuw %arg0[%2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ %4 = llvm.load %3 {alignment = 4 : i64} : !llvm.ptr -> f32
+ %5 = llvm.fpext %4 : f32 to f64
+ %6 = llvm.fmul %5, %0 {fastmathFlags = #llvm.fastmath<contract>} : f64
+ %7 = llvm.fptrunc %6 : f64 to f32
+ %8 = llvm.getelementptr inbounds|nuw %arg1[%2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ llvm.store %7, %8 {alignment = 4 : i64} : f32, !llvm.ptr
+ llvm.return
+}
+
+llvm.func internal @d_Z6squarePfS_(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr, %arg3: !llvm.ptr) {
+ enzyme.autodiff @_Z6squarePfS_(%arg0, %arg1, %arg2, %arg3)
+ {
+ activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_dup>],
+ ret_activity=[]
+ } : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
+ llvm.return
+}
+
+// Make sure that function attributes are preserved
+// CHECK: llvm.func internal @d_Z6squarePfS_
+// CHECK: enzyme.autodiff_region(%arg0, %arg1, %arg2, %arg3) {
+// CHECK: } attributes {activity = [#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_dup>], fn = "_Z6squarePfS_", fn_attrs = {CConv = #llvm.cconv<ccc>, arg_attrs = [{llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.writeonly}], dso_local, frame_pointer = #llvm.framePointerKind<all>, linkage = #llvm.linkage<internal>, memory_effects = #llvm.memory_effects<other = none, argMem = readwrite, inaccessibleMem = none>, no_unwind, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "sm_86"]], sym_name = "_Z6squarePfS_", sym_visibility = "private", target_cpu = "sm_86", target_features = #llvm.target_features<["+ptx88", "+sm_86"]>, unnamed_addr = 0 : i64, visibility_ = 0 : i64, will_return}, ret_activity = []} : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> (
diff --git a/enzyme/test/MLIR/Passes/region_outline.mlir b/enzyme/test/MLIR/Passes/region_outline.mlir
new file mode 100644
index 0000000..5ce1878
--- /dev/null
+++ b/enzyme/test/MLIR/Passes/region_outline.mlir
@@ -0,0 +1,100 @@
+// RUN: %eopt --outline-enzyme-regions --split-input-file %s | FileCheck %s
+
+func.func @to_outline(%26: !llvm.ptr, %27: !llvm.ptr, %28: !llvm.ptr, %29: !llvm.ptr) {
+ %cst = arith.constant 5.6 : f64
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ scf.parallel (%arg2, %arg3, %arg4) = (%c0, %c0, %c0) to (%c1, %c1, %c1) step (%c1, %c1, %c1) {
+ scf.parallel (%arg5, %arg6, %arg7) = (%c0, %c0, %c0) to (%c4, %c1, %c1) step (%c1, %c1, %c1) {
+ memref.alloca_scope {
+ scf.execute_region {
+ enzyme.autodiff_region(%26, %27, %28, %29) {
+ ^bb0(%arg8: !llvm.ptr, %arg9: !llvm.ptr):
+ %63 = arith.index_castui %arg5 : index to i64
+ %64 = llvm.getelementptr inbounds|nuw %arg8[%63] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ %65 = llvm.load %64 invariant {alignment = 4 : i64} : !llvm.ptr -> f32
+ %66 = arith.extf %65 : f32 to f64
+ %67 = arith.mulf %66, %cst {fastmathFlags = #llvm.fastmath<contract>} : f64
+ %68 = arith.truncf %67 : f64 to f32
+ %69 = llvm.getelementptr inbounds|nuw %arg9[%63] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ llvm.store %68, %69 {alignment = 4 : i64} : f32, !llvm.ptr
+ enzyme.yield
+ } attributes {activity = [#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_dup>], fn = "outlined_func", ret_activity = []} : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
+ scf.yield
+ }
+ }
+ scf.reduce
+ }
+ scf.reduce
+ }
+ return
+}
+
+// CHECK: func.func @outlined_func(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: index, %arg3: f64) {
+// CHECK-NEXT: %0 = arith.index_castui %arg2 : index to i64
+// CHECK-NEXT: %1 = llvm.getelementptr inbounds|nuw %arg0[%0] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK-NEXT: %2 = llvm.load %1 invariant {alignment = 4 : i64} : !llvm.ptr -> f32
+// CHECK-NEXT: %3 = arith.extf %2 : f32 to f64
+// CHECK-NEXT: %4 = arith.mulf %3, %arg3 {fastmathFlags = #llvm.fastmath<contract>} : f64
+// CHECK-NEXT: %5 = arith.truncf %4 : f64 to f32
+// CHECK-NEXT: %6 = llvm.getelementptr inbounds|nuw %arg1[%0] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK-NEXT: llvm.store %5, %6 {alignment = 4 : i64} : f32, !llvm.ptr
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// -----
+
+llvm.func internal @d_Z6squarePfS_(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr, %arg3: !llvm.ptr) {
+ %0 = llvm.mlir.constant(5.600000e+00 : f64) : f64
+ enzyme.autodiff_region(%arg0, %arg1, %arg2, %arg3) {
+ ^bb0(%arg4: !llvm.ptr, %arg5: !llvm.ptr):
+ %1 = nvvm.read.ptx.sreg.tid.x : i32
+ %2 = llvm.zext nneg %1 : i32 to i64
+ %3 = llvm.getelementptr inbounds|nuw %arg4[%2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ %4 = llvm.load %3 {alignment = 4 : i64} : !llvm.ptr -> f32
+ %5 = llvm.fpext %4 : f32 to f64
+ %6 = llvm.fmul %5, %0 {fastmathFlags = #llvm.fastmath<contract>} : f64
+ %7 = llvm.fptrunc %6 : f64 to f32
+ %8 = llvm.getelementptr inbounds|nuw %arg5[%2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ llvm.store %7, %8 {alignment = 4 : i64} : f32, !llvm.ptr
+ enzyme.yield
+ } attributes {activity = [#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_dup>], fn = "_Z6squarePfS_", fn_attrs = {CConv = #llvm.cconv<ccc>, arg_attrs = [{llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.writeonly}], dso_local, frame_pointer = #llvm.framePointerKind<all>, linkage = #llvm.linkage<internal>, memory_effects = #llvm.memory_effects<other = none, argMem = readwrite, inaccessibleMem = none>, no_unwind, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "sm_86"]], sym_name = "_Z6squarePfS_", sym_visibility = "private", target_cpu = "sm_86", target_features = #llvm.target_features<["+ptx88", "+sm_86"]>, unnamed_addr = 0 : i64, visibility_ = 0 : i64, will_return}, ret_activity = []} : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
+ llvm.return
+}
+
+// Attributes should be passed back to the outlined function
+// CHECK: func.func private @_Z6squarePfS_(%arg0: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg1: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.writeonly}, %arg2: f64) attributes {CConv = #llvm.cconv<ccc>, dso_local, frame_pointer = #llvm.framePointerKind<all>, linkage = #llvm.linkage<internal>, memory_effects = #llvm.memory_effects<other = none, argMem = readwrite, inaccessibleMem = none>, no_unwind, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "sm_86"]], target_cpu = "sm_86", target_features = #llvm.target_features<["+ptx88", "+sm_86"]>, unnamed_addr = 0 : i64, visibility_ = 0 : i64, will_return} {
+
+// -----
+
+func.func @outline_multi(%x: f64, %dr: f64) -> (f64, f64) {
+ %r0 = enzyme.autodiff_region(%x, %dr) {
+ ^bb0(%arg0: f64):
+ %sq = arith.mulf %arg0, %arg0 : f64
+ enzyme.yield %sq : f64
+ } attributes {activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_activenoneed>]} : (f64, f64) -> f64
+
+ %r1 = enzyme.autodiff_region(%x, %dr) {
+ ^bb0(%arg0: f64):
+ %add = arith.addf %arg0, %arg0 : f64
+ enzyme.yield %add : f64
+ } attributes {activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_activenoneed>]} : (f64, f64) -> f64
+ return %r0, %r1 : f64, f64
+}
+
+// CHECK: func.func @outline_multi(%arg0: f64, %arg1: f64) -> (f64, f64) {
+// CHECK-NEXT: %0 = enzyme.autodiff @outline_multi_to_diff0(%arg0, %arg1) {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>]} : (f64, f64) -> f64
+// CHECK-NEXT: %1 = enzyme.autodiff @outline_multi_to_diff1(%arg0, %arg1) {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>]} : (f64, f64) -> f64
+// CHECK-NEXT: return %0, %1 : f64, f64
+// CHECK-NEXT: }
+
+// CHECK: func.func @outline_multi_to_diff1(%arg0: f64) -> f64 {
+// CHECK-NEXT: %0 = arith.addf %arg0, %arg0 : f64
+// CHECK-NEXT: return %0 : f64
+// CHECK-NEXT: }
+
+// CHECK: func.func @outline_multi_to_diff0(%arg0: f64) -> f64 {
+// CHECK-NEXT: %0 = arith.mulf %arg0, %arg0 : f64
+// CHECK-NEXT: return %0 : f64
+// CHECK-NEXT: }