save
diff --git a/enzyme/Enzyme/MLIR/Analysis/SampleDependenceAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/SampleDependenceAnalysis.cpp index 7931d8e..cde8dc3 100644 --- a/enzyme/Enzyme/MLIR/Analysis/SampleDependenceAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/SampleDependenceAnalysis.cpp
@@ -9,34 +9,34 @@ using namespace mlir::enzyme; SampleDependenceAnalysis::SampleDependenceAnalysis(MCMCRegionOp regionOp) - : regionOp(regionOp) { - runAnalysis(); + : regionOp(regionOp), target(AnalysisTarget::Sampler) { + runSamplerAnalysis(); +} + +SampleDependenceAnalysis::SampleDependenceAnalysis(MCMCRegionOp regionOp, + AnalysisTarget target) + : regionOp(regionOp), target(target) { + if (target == AnalysisTarget::Logpdf) + runLogpdfAnalysis(); + else + runSamplerAnalysis(); +} + +Region &SampleDependenceAnalysis::getTargetRegion() { + if (target == AnalysisTarget::Logpdf) + return regionOp.getLogpdf(); + return regionOp.getSampler(); } void SampleDependenceAnalysis::markSampleDependent(Value value) { sampleDependentValues.insert(value); } -void SampleDependenceAnalysis::runAnalysis() { - // TODO: Handle recursive regions. - if (regionOp.getLogpdfFnAttr()) { - Block &entry = regionOp.getBody().front(); - if (!entry.getArguments().empty()) { - markSampleDependent(entry.getArgument(0)); - } - } - - regionOp.getBody().walk([&](SampleRegionOp sampleOp) { - sampleOps.push_back(sampleOp); - for (Value result : sampleOp.getResults()) { - markSampleDependent(result); - } - }); - +void SampleDependenceAnalysis::propagateDependence(Region ®ion) { bool changed = true; while (changed) { changed = false; - regionOp.getBody().walk([&](Operation *op) { + region.walk([&](Operation *op) { if (isa<SampleRegionOp>(op)) return; @@ -60,6 +60,56 @@ } } +void SampleDependenceAnalysis::runSamplerAnalysis() { + // TODO: Handle recursive regions. + if (regionOp.getLogpdfFnAttr()) { + Block &entry = regionOp.getSampler().front(); + if (!entry.getArguments().empty()) { + markSampleDependent(entry.getArgument(0)); + } + } + + DenseSet<Attribute> selectedSymbols; + bool hasSelection = false; + if (auto selection = regionOp.getSelectionAttr()) { + hasSelection = true; + for (auto addr : selection) { + auto address = cast<ArrayAttr>(addr); + if (!address.empty()) + selectedSymbols.insert(address[0]); + } + } + + regionOp.getSampler().walk([&](SampleRegionOp sampleOp) { + sampleOps.push_back(sampleOp); + auto symbol = sampleOp.getSymbolAttr(); + bool isSelected = + !hasSelection || !symbol || selectedSymbols.contains(symbol); + if (isSelected) { + for (Value result : sampleOp.getResults()) { + markSampleDependent(result); + } + } + }); + + propagateDependence(regionOp.getSampler()); +} + +void SampleDependenceAnalysis::runLogpdfAnalysis() { + Region &logpdf = regionOp.getLogpdf(); + if (logpdf.empty()) + return; + + Block &entry = logpdf.front(); + int64_t numPositionArgs = regionOp.getNumPositionArgs(); + + for (int64_t i = 0; + i < numPositionArgs && i < (int64_t)entry.getNumArguments(); ++i) + markSampleDependent(entry.getArgument(i)); + + propagateDependence(logpdf); +} + bool SampleDependenceAnalysis::isSampleDependent(Value value) const { return sampleDependentValues.contains(value); } @@ -72,6 +122,15 @@ return false; } +bool SampleDependenceAnalysis::isInTargetRegion(Operation *op) { + Region *targetRegion; + if (target == AnalysisTarget::Logpdf) + targetRegion = ®ionOp.getLogpdf(); + else + targetRegion = ®ionOp.getSampler(); + return targetRegion->isAncestor(op->getParentRegion()); +} + bool SampleDependenceAnalysis::canHoist(Operation *op) const { if (isa<SampleRegionOp>(op)) return false; @@ -134,26 +193,26 @@ return false; } -bool enzyme::hoistSampleInvariantOps(MCMCRegionOp regionOp) { - DominanceInfo dom(regionOp); - PostDominanceInfo pdom(regionOp); - SampleDependenceAnalysis sampleAnalysis(regionOp); - - Region ®ion = regionOp.getBody(); +static bool hoistFromRegion(MCMCRegionOp regionOp, + SampleDependenceAnalysis &sampleAnalysis, + Region ®ion) { if (region.empty()) return false; + DominanceInfo dom(regionOp); + PostDominanceInfo pdom(regionOp); + IRMapping regionToOuter; Block &entryBlock = region.front(); - auto inputs = regionOp.getInputs(); - bool isLogpdfMode = static_cast<bool>(regionOp.getLogpdfFnAttr()); - - for (auto [idx, blockArg] : llvm::enumerate(entryBlock.getArguments())) { - if (isLogpdfMode && idx == 0) - continue; - if (idx < inputs.size()) { - regionToOuter.map(blockArg, inputs[idx]); + if (sampleAnalysis.getTarget() == AnalysisTarget::Sampler) { + auto inputs = regionOp.getInputs(); + bool isLogpdfMode = static_cast<bool>(regionOp.getLogpdfFnAttr()); + for (auto [idx, blockArg] : llvm::enumerate(entryBlock.getArguments())) { + if (isLogpdfMode && idx == 0) + continue; + if (idx < inputs.size()) + regionToOuter.map(blockArg, inputs[idx]); } } @@ -225,3 +284,13 @@ return !sortedToHoist.empty(); } + +bool enzyme::hoistSampleInvariantOps(MCMCRegionOp regionOp) { + return hoistSampleInvariantOps(regionOp, AnalysisTarget::Sampler); +} + +bool enzyme::hoistSampleInvariantOps(MCMCRegionOp regionOp, + AnalysisTarget target) { + SampleDependenceAnalysis analysis(regionOp, target); + return hoistFromRegion(regionOp, analysis, analysis.getTargetRegion()); +}
diff --git a/enzyme/Enzyme/MLIR/Analysis/SampleDependenceAnalysis.h b/enzyme/Enzyme/MLIR/Analysis/SampleDependenceAnalysis.h index 1ab05b5..3ef22d0 100644 --- a/enzyme/Enzyme/MLIR/Analysis/SampleDependenceAnalysis.h +++ b/enzyme/Enzyme/MLIR/Analysis/SampleDependenceAnalysis.h
@@ -10,10 +10,17 @@ namespace mlir { namespace enzyme { +enum class AnalysisTarget { + Sampler, + Logpdf, +}; + class SampleDependenceAnalysis { public: explicit SampleDependenceAnalysis(MCMCRegionOp regionOp); + SampleDependenceAnalysis(MCMCRegionOp regionOp, AnalysisTarget target); + bool isSampleDependent(Value value) const; bool isSampleDependent(Operation *op) const; bool canHoist(Operation *op) const; @@ -21,18 +28,30 @@ ArrayRef<SampleRegionOp> getSampleOps() const { return sampleOps; } MCMCRegionOp getRegionOp() const { return regionOp; } + AnalysisTarget getTarget() const { return target; } + + bool isInTargetRegion(Operation *op); + + Region &getTargetRegion(); private: MCMCRegionOp regionOp; + AnalysisTarget target; DenseSet<Value> sampleDependentValues; SmallVector<SampleRegionOp> sampleOps; - void runAnalysis(); + void runSamplerAnalysis(); + void runLogpdfAnalysis(); void markSampleDependent(Value value); + void propagateDependence(Region ®ion); }; bool hoistSampleInvariantOps(MCMCRegionOp regionOp); +bool hoistSampleInvariantOps(MCMCRegionOp regionOp, AnalysisTarget target); + +bool constructUnifiedLogpdf(MCMCRegionOp regionOp); + } // namespace enzyme } // namespace mlir
diff --git a/enzyme/Enzyme/MLIR/Dialect/Dialect.cpp b/enzyme/Enzyme/MLIR/Dialect/Dialect.cpp index 5df615d..a791273 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Dialect.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Dialect.cpp
@@ -71,3 +71,31 @@ #define GET_ATTRDEF_CLASSES #include "Dialect/EnzymeAttributes.cpp.inc" + +//===----------------------------------------------------------------------===// +// SymbolAttr custom print/parse +//===----------------------------------------------------------------------===// + +void SymbolAttr::print(AsmPrinter &printer) const { + printer << "<"; + llvm::interleaveComma(getPath(), printer); + printer << ">"; +} + +Attribute SymbolAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess()) + return {}; + SmallVector<uint64_t> path; + uint64_t val; + if (failed(parser.parseInteger(val))) + return {}; + path.push_back(val); + while (succeeded(parser.parseOptionalComma())) { + if (failed(parser.parseInteger(val))) + return {}; + path.push_back(val); + } + if (parser.parseGreater()) + return {}; + return get(parser.getContext(), path); +}
diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeEnums.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeEnums.td index fbfaa78..7d9c866 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeEnums.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeEnums.td
@@ -88,6 +88,8 @@ let assemblyFormat = "`<` $kind (`,` `lower` `=` $lower_bound^)? (`,` `upper` `=` $upper_bound^)? `>`"; } +def SupportArrayAttr : TypedArrayAttrBase<SupportAttr, "Array of support specifications">; + def HMCConfigAttr : Enzyme_Attr<"HMCConfig", "hmc_config"> { let summary = "Configuration for HMC inference"; let description = [{
diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index 6ddbf05..f449bc5 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
@@ -447,10 +447,12 @@ def SymbolAttr : Enzyme_Attr<"Symbol", "symbol"> { let summary = "Symbol associated with a Sample op"; let description = [{ - Symbol associated with a Sample op. + Symbol associated with a Sample op. Holds one or more uint64_t values. + A single value is a leaf symbol; multiple values encode a composite path + created by SICM submodel inlining. }]; - let parameters = (ins "uint64_t":$ptr); - let assemblyFormat = "`<` $ptr `>`"; + let parameters = (ins ArrayRefParameter<"uint64_t", "symbol path">:$path); + let hasCustomAssemblyFormat = 1; } def AddressAttr : TypedArrayAttrBase<SymbolAttr, "Address as an array of symbols">; @@ -790,11 +792,25 @@ } def MCMCRegionOp : Enzyme_Op<"mcmc_region", [AutomaticAllocationScope, AttrSizedOperandSegments]> { - let summary = "MCMC inference with inlined model body"; + let summary = "Inference region with sampler and unified logpdf"; let description = [{ - A region-based MCMC operation where the model function body is inlined - into a region. This is an internal construct for optimization passes to see the full - computation before expansion to HMC/NUTS loops. + A region-based inference operation with two regions: + + - `sampler`: The model computation containing sample_region ops with sampler + regions (logpdf regions cleared after merging). Used by proposal-based + algorithms (MH, Gibbs) and preserved for outline compatibility. + + - `logpdf`: Unified logpdf computation for ALL position-dependent terms. + Block arguments follow a calling convention: + args[0..num_position_args-1] = position values (sample-dependent, + ordered by `selection` attribute) + Captured values from enclosing scope = parameters (sample-invariant) + + SICM operates on this region. The outline pass converts it to a + standalone function referenced by `logpdf_fn` on the resulting MCMCOp. + + If the `logpdf` region is empty, the operation falls back to the sampler-only + mode (pre-logpdf-merging state or custom logpdf_fn). }]; let arguments = (ins @@ -821,10 +837,14 @@ Optional<AnyRankedTensor>:$initial_potential_energy, OptionalAttr<StrAttr>:$fn, - DefaultValuedStrAttr<StrAttr, "">:$name + DefaultValuedStrAttr<StrAttr, "">:$name, + + DefaultValuedAttr<I64Attr, "0">:$num_position_args, + DefaultValuedAttr<I64Attr, "0">:$position_size, + OptionalAttr<SupportArrayAttr>:$supports ); - let regions = (region AnyRegion:$body); + let regions = (region AnyRegion:$sampler, AnyRegion:$logpdf); let results = (outs AnyRankedTensor:$new_trace, @@ -844,7 +864,8 @@ (`initial_position` `=` $initial_position^)? (`initial_gradient` `=` $initial_gradient^)? (`initial_potential_energy` `=` $initial_potential_energy^)? - $body attr-dict-with-keyword `:` functional-type(operands, results) + $sampler (`logpdf` $logpdf^)? + attr-dict-with-keyword `:` functional-type(operands, results) }]; } @@ -1014,6 +1035,7 @@ }]; } + def AffineAtomicRMWOp : Enzyme_Op<"affine_atomic_rmw"> { let summary = "affine atomic rmw operation"; let description = [{
diff --git a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp index 5a35244..6043fdc 100644 --- a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp +++ b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp
@@ -90,5 +90,6 @@ } MlirAttribute enzymeSymbolAttrGet(MlirContext ctx, uint64_t ptr) { - return wrap(mlir::enzyme::SymbolAttr::get(unwrap(ctx), ptr)); + llvm::SmallVector<uint64_t, 1> path = {ptr}; + return wrap(mlir::enzyme::SymbolAttr::get(unwrap(ctx), path)); }
diff --git a/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp index f7cbec2..ea70ca7 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp
@@ -679,11 +679,22 @@ auto initSplit = enzyme::RandomSplitOp::create( builder, loc, TypeRange{rng.getType(), rng.getType()}, rng); + conditionalDump(builder, loc, initSplit.getResult(0), + "InitHMC: 2-way split[0] (rng_mcmc)", debugDump); + conditionalDump(builder, loc, initSplit.getResult(1), + "InitHMC: 2-way split[1]", debugDump); auto kernelSplit = enzyme::RandomSplitOp::create( builder, loc, TypeRange{rng.getType(), rng.getType(), rng.getType()}, initSplit.getResult(0)); auto rngForSampleKernel = kernelSplit.getResult(0); auto rngForAutodiff = kernelSplit.getResult(1); + conditionalDump(builder, loc, rngForSampleKernel, + "InitHMC: 3-way split[0] (rngForSampleKernel)", debugDump); + conditionalDump(builder, loc, rngForAutodiff, + "InitHMC: 3-way split[1] (rngForAutodiff)", debugDump); + if (kernelSplit.getNumResults() > 2) + conditionalDump(builder, loc, kernelSplit.getResult(2), + "InitHMC: 3-way split[2] (unused)", debugDump); Value q0; Value U0; @@ -804,6 +815,11 @@ // (U, rng, grad) auto grad0 = autodiffInit.getResult(2); + conditionalDump(builder, loc, q0, "InitHMC: q0 (unconstrained)", debugDump); + conditionalDump(builder, loc, U0, "InitHMC: U0 (potential energy)", + debugDump); + conditionalDump(builder, loc, grad0, "InitHMC: grad0 (gradient)", debugDump); + return {q0, U0, grad0, rngForSampleKernel}; } @@ -953,18 +969,30 @@ auto i1TensorType = RankedTensorType::get({}, builder.getI1Type()); // 1. Split RNG: [rngNext, rngMomentum, rngTree] + conditionalDump(builder, loc, rng, + "SampleNUTS: input rng (before 3-way split)", debugDump); auto sampleKernelSplit = enzyme::RandomSplitOp::create( builder, loc, TypeRange{rng.getType(), rng.getType(), rng.getType()}, rng); auto rngNext = sampleKernelSplit.getResult(0); auto rngMomentum = sampleKernelSplit.getResult(1); auto rngTree = sampleKernelSplit.getResult(2); + conditionalDump(builder, loc, rngNext, "SampleNUTS: 3-way split[0] (rngNext)", + debugDump); + conditionalDump(builder, loc, rngMomentum, + "SampleNUTS: 3-way split[1] (rngMomentum)", debugDump); + conditionalDump(builder, loc, rngTree, "SampleNUTS: 3-way split[2] (rngTree)", + debugDump); // 2. Sample fresh momentum p ~ N(0, M) Value rngForMomentum = rngMomentum; if (!ctx.hasCustomLogpdf()) { auto momSplit = enzyme::RandomSplitOp::create( builder, loc, TypeRange{rng.getType(), rng.getType()}, rngMomentum); + conditionalDump(builder, loc, momSplit.getResult(0), + "SampleNUTS: momSplit[0] (rngForMomentum)", debugDump); + conditionalDump(builder, loc, momSplit.getResult(1), + "SampleNUTS: momSplit[1]", debugDump); rngForMomentum = momSplit.getResult(0); } auto [p0, rngAfterMomentum] = @@ -977,6 +1005,11 @@ // 4. Compute H0 = U + K0 auto H0 = arith::AddFOp::create(builder, loc, U, K0); + conditionalDump(builder, loc, p0, "SampleNUTS: p0 (momentum)", debugDump); + conditionalDump(builder, loc, K0, "SampleNUTS: K0 (kinetic energy)", + debugDump); + conditionalDump(builder, loc, H0, "SampleNUTS: H0 (hamiltonian)", debugDump); + // 5. Initialize NUTS tree state auto iterCtx = ctx.withH0(H0);
diff --git a/enzyme/Enzyme/MLIR/Passes/InlineMCMCRegions.cpp b/enzyme/Enzyme/MLIR/Passes/InlineMCMCRegions.cpp index 24d3890..5074b4d 100644 --- a/enzyme/Enzyme/MLIR/Passes/InlineMCMCRegions.cpp +++ b/enzyme/Enzyme/MLIR/Passes/InlineMCMCRegions.cpp
@@ -1,6 +1,8 @@ +#include "Analysis/SampleDependenceAnalysis.h" #include "Dialect/Ops.h" #include "Passes/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -16,6 +18,391 @@ } // namespace enzyme } // namespace mlir +static enzyme::SymbolAttr composeSymbols(enzyme::SymbolAttr outer, + enzyme::SymbolAttr inner, + MLIRContext *ctx) { + SmallVector<uint64_t> composed(outer.getPath()); + composed.append(inner.getPath().begin(), inner.getPath().end()); + return enzyme::SymbolAttr::get(ctx, composed); +} + +static ArrayAttr flattenAddressesForSymbol(ArrayAttr addresses, + enzyme::SymbolAttr outerSymbol, + MLIRContext *ctx) { + SmallVector<Attribute> newAddresses; + for (auto addr : addresses) { + auto address = cast<ArrayAttr>(addr); + if (address.size() >= 2 && address[0] == outerSymbol) { + auto inner = cast<enzyme::SymbolAttr>(address[1]); + auto composite = composeSymbols(outerSymbol, inner, ctx); + SmallVector<Attribute> newAddr; + newAddr.push_back(composite); + for (unsigned i = 2; i < address.size(); ++i) + newAddr.push_back(address[i]); + newAddresses.push_back(ArrayAttr::get(ctx, newAddr)); + } else { + newAddresses.push_back(addr); + } + } + return ArrayAttr::get(ctx, newAddresses); +} + +static bool inlineSubmodelSampleRegions(enzyme::MCMCRegionOp regionOp) { + bool anyChanged = false; + + SmallVector<enzyme::SampleRegionOp> sampleOps; + regionOp.getSampler().walk( + [&](enzyme::SampleRegionOp op) { sampleOps.push_back(op); }); + + for (enzyme::SampleRegionOp sampleOp : sampleOps) { + Region &logpdf = sampleOp.getLogpdf(); + if (!logpdf.empty()) + continue; + if (sampleOp.getLogpdfFnAttr()) + continue; + + Region &sampler = sampleOp.getSampler(); + if (sampler.empty() || !sampler.hasOneBlock()) + continue; + + auto outerSymbol = sampleOp.getSymbolAttr(); + if (!outerSymbol) + continue; + + Block &samplerEntry = sampler.front(); + auto *ctx = regionOp.getContext(); + + OpBuilder builder(sampleOp); + IRMapping mapper; + auto inputs = sampleOp.getInputs(); + for (unsigned i = 0, e = samplerEntry.getNumArguments(); i < e; ++i) { + if (i < inputs.size()) + mapper.map(samplerEntry.getArgument(i), inputs[i]); + } + + for (Operation &op : samplerEntry.without_terminator()) { + Operation *cloned = builder.clone(op, mapper); + if (auto innerSample = dyn_cast<enzyme::SampleRegionOp>(cloned)) { + if (auto innerSymbol = innerSample.getSymbolAttr()) + innerSample.setSymbolAttr( + composeSymbols(outerSymbol, innerSymbol, ctx)); + } else if (auto innerSampleOp = dyn_cast<enzyme::SampleOp>(cloned)) { + if (auto innerSymbol = innerSampleOp.getSymbolAttr()) + innerSampleOp.setSymbolAttr( + composeSymbols(outerSymbol, innerSymbol, ctx)); + } + } + + auto *yield = samplerEntry.getTerminator(); + for (auto [oldResult, yieldOperand] : + llvm::zip(sampleOp.getResults(), yield->getOperands())) + oldResult.replaceAllUsesWith(mapper.lookupOrDefault(yieldOperand)); + + sampleOp.erase(); + + if (auto allAddrs = regionOp.getAllAddressesAttr()) + regionOp.setAllAddressesAttr( + flattenAddressesForSymbol(allAddrs, outerSymbol, ctx)); + if (auto sel = regionOp.getSelectionAttr()) + regionOp.setSelectionAttr( + flattenAddressesForSymbol(sel, outerSymbol, ctx)); + + anyChanged = true; + } + + return anyChanged; +} + +static bool extractUnselectedSampleValues(enzyme::MCMCRegionOp regionOp) { + auto selection = regionOp.getSelectionAttr(); + if (!selection) + return false; + + DenseSet<Attribute> selectedSymbols; + for (auto addr : selection) { + auto address = cast<ArrayAttr>(addr); + if (!address.empty()) + selectedSymbols.insert(address[0]); + } + + Block &entry = regionOp.getSampler().front(); + auto inputs = regionOp.getInputs(); + + IRMapping blockToOuter; + for (auto [idx, blockArg] : llvm::enumerate(entry.getArguments())) { + if (idx < inputs.size()) + blockToOuter.map(blockArg, inputs[idx]); + } + + SmallVector<enzyme::SampleRegionOp> toExtract; + for (auto &op : entry.without_terminator()) { + auto sampleOp = dyn_cast<enzyme::SampleRegionOp>(&op); + if (!sampleOp) + continue; + auto symbol = sampleOp.getSymbolAttr(); + if (!symbol || selectedSymbols.contains(symbol)) + continue; + + bool canExtract = true; + for (Value operand : sampleOp->getOperands()) { + if (!blockToOuter.contains(operand) && + !operand.getDefiningOp<arith::ConstantOp>()) { + canExtract = false; + break; + } + } + if (!canExtract) + continue; + + SetVector<Value> nestedValues; + for (Region &nestedRegion : sampleOp->getRegions()) + getUsedValuesDefinedAbove(nestedRegion, nestedValues); + for (Value v : nestedValues) { + if (!blockToOuter.contains(v) && !v.getDefiningOp<arith::ConstantOp>()) { + canExtract = false; + break; + } + } + if (!canExtract) + continue; + + toExtract.push_back(sampleOp); + } + + if (toExtract.empty()) + return false; + + unsigned numInputs = inputs.size(); + bool anyChanged = false; + + for (enzyme::SampleRegionOp sampleOp : toExtract) { + OpBuilder builder(regionOp); + IRMapping cloneMapper(blockToOuter); + Operation *cloned = builder.clone(*sampleOp, cloneMapper); + + for (auto [original, clonedResult] : + llvm::zip(sampleOp->getResults(), cloned->getResults())) { + regionOp->insertOperands(numInputs, {clonedResult}); + auto segSizes = + regionOp->getAttrOfType<DenseI32ArrayAttr>("operandSegmentSizes"); + SmallVector<int32_t> newSizes(segSizes.asArrayRef()); + newSizes[0]++; + regionOp->setAttr("operandSegmentSizes", + builder.getDenseI32ArrayAttr(newSizes)); + numInputs++; + + Value newBlockArg = + entry.addArgument(original.getType(), sampleOp.getLoc()); + original.replaceAllUsesWith(newBlockArg); + blockToOuter.map(newBlockArg, clonedResult); + } + + sampleOp.erase(); + anyChanged = true; + } + + return anyChanged; +} + +static Value resolveValueForLogpdf(OpBuilder &builder, Location loc, + Value value, IRMapping &mapping, + enzyme::MCMCRegionOp regionOp) { + if (mapping.contains(value)) + return mapping.lookup(value); + + if (auto blockArg = dyn_cast<BlockArgument>(value)) { + if (blockArg.getOwner() == ®ionOp.getSampler().front()) { + unsigned idx = blockArg.getArgNumber(); + auto inputs = regionOp.getInputs(); + if (idx < inputs.size()) { + mapping.map(value, inputs[idx]); + return inputs[idx]; + } + } + return value; + } + + Operation *defOp = value.getDefiningOp(); + if (!defOp) + return value; + + if (defOp->getParentRegion() != ®ionOp.getSampler()) + return value; + + if (isa<enzyme::SampleRegionOp>(defOp)) { + Block *logpdfBlock = ®ionOp.getLogpdf().front(); + Value newArg = logpdfBlock->addArgument(value.getType(), defOp->getLoc()); + mapping.map(value, newArg); + return newArg; + } + + for (Value operand : defOp->getOperands()) + resolveValueForLogpdf(builder, loc, operand, mapping, regionOp); + + Operation *cloned = builder.clone(*defOp, mapping); + for (auto [orig, clonedRes] : + llvm::zip(defOp->getResults(), cloned->getResults())) + mapping.map(orig, clonedRes); + + return mapping.lookup(value); +} + +bool enzyme::constructUnifiedLogpdf(enzyme::MCMCRegionOp regionOp) { + Region &samplerRegion = regionOp.getSampler(); + Region &logpdfRegion = regionOp.getLogpdf(); + auto selection = regionOp.getSelectionAttr(); + if (!selection || selection.empty()) + return false; + + SmallVector<enzyme::SampleRegionOp> allSampleOps; + samplerRegion.walk( + [&](enzyme::SampleRegionOp op) { allSampleOps.push_back(op); }); + + DenseMap<Attribute, enzyme::SampleRegionOp> symbolToSampleOp; + for (auto sampleOp : allSampleOps) { + if (auto sym = sampleOp.getSymbolAttr()) + symbolToSampleOp[sym] = sampleOp; + } + + DenseSet<Attribute> selectedSymbols; + SmallVector<Attribute> selectionOrder; + for (auto addr : selection) { + auto address = cast<ArrayAttr>(addr); + if (!address.empty()) { + selectedSymbols.insert(address[0]); + selectionOrder.push_back(address[0]); + } + } + + Block *logpdfBlock = new Block(); + logpdfRegion.push_back(logpdfBlock); + + Location loc = regionOp.getLoc(); + IRMapping positionMapping; + SmallVector<enzyme::SupportAttr> supportsVec; + int64_t totalPositionSize = 0; + + for (auto symbol : selectionOrder) { + auto it = symbolToSampleOp.find(symbol); + if (it == symbolToSampleOp.end()) + continue; + enzyme::SampleRegionOp sampleOp = it->second; + + for (unsigned i = 1; i < sampleOp.getNumResults(); ++i) { + Value sampleResult = sampleOp.getResult(i); + auto resultType = sampleResult.getType(); + Value blockArg = logpdfBlock->addArgument(resultType, loc); + positionMapping.map(sampleResult, blockArg); + + if (auto tensorType = dyn_cast<RankedTensorType>(resultType)) + totalPositionSize += tensorType.getNumElements(); + else + totalPositionSize += 1; + } + + if (auto support = sampleOp.getSupportAttr()) + supportsVec.push_back(support); + else + supportsVec.push_back(enzyme::SupportAttr::get( + regionOp.getContext(), enzyme::SupportKind::REAL, nullptr, nullptr)); + } + + int64_t numPositionArgs = logpdfBlock->getNumArguments(); + + OpBuilder logpdfBuilder(logpdfBlock, logpdfBlock->end()); + Value totalLogpdf; + auto scalarF64 = RankedTensorType::get({}, logpdfBuilder.getF64Type()); + + for (auto sampleOp : allSampleOps) { + Region &siteLogpdf = sampleOp.getLogpdf(); + if (siteLogpdf.empty() || !siteLogpdf.hasOneBlock()) + continue; + + Block &siteEntry = siteLogpdf.front(); + if (siteEntry.getNumArguments() == 0) + continue; + + unsigned numSampleOutputs = sampleOp.getNumResults() - 1; + auto inputs = sampleOp.getInputs(); + + IRMapping siteMapping(positionMapping); + + for (unsigned i = 0; i < numSampleOutputs; ++i) { + if (i >= siteEntry.getNumArguments()) + break; + Value sampleResult = sampleOp.getResult(i + 1); + Value resolved = resolveValueForLogpdf(logpdfBuilder, loc, sampleResult, + siteMapping, regionOp); + siteMapping.map(siteEntry.getArgument(i), resolved); + } + + for (unsigned i = numSampleOutputs, e = siteEntry.getNumArguments(); i < e; + ++i) { + unsigned inputIdx = i - numSampleOutputs + 1; + if (inputIdx < inputs.size()) { + Value contextVal = inputs[inputIdx]; + Value resolved = resolveValueForLogpdf(logpdfBuilder, loc, contextVal, + siteMapping, regionOp); + siteMapping.map(siteEntry.getArgument(i), resolved); + } + } + + for (Operation &op : siteEntry.without_terminator()) + logpdfBuilder.clone(op, siteMapping); + + auto *yield = siteEntry.getTerminator(); + assert(yield->getNumOperands() > 0); + Value siteResult = siteMapping.lookupOrDefault(yield->getOperand(0)); + + if (!totalLogpdf) + totalLogpdf = siteResult; + else + totalLogpdf = + arith::AddFOp::create(logpdfBuilder, loc, totalLogpdf, siteResult); + } + + if (totalLogpdf) { + enzyme::YieldOp::create(logpdfBuilder, loc, {totalLogpdf}); + } else { + auto zeroConst = arith::ConstantOp::create( + logpdfBuilder, loc, scalarF64, + DenseElementsAttr::get(scalarF64, logpdfBuilder.getF64FloatAttr(0.0))); + enzyme::YieldOp::create(logpdfBuilder, loc, {zeroConst}); + } + + logpdfBuilder.getContext(); + regionOp.setNumPositionArgsAttr( + logpdfBuilder.getI64IntegerAttr(numPositionArgs)); + regionOp.setPositionSizeAttr( + logpdfBuilder.getI64IntegerAttr(totalPositionSize)); + if (!supportsVec.empty()) { + regionOp.setSupportsAttr(ArrayAttr::get( + regionOp.getContext(), + SmallVector<Attribute>(supportsVec.begin(), supportsVec.end()))); + } + + for (auto sampleOp : allSampleOps) { + Region &siteLogpdf = sampleOp.getLogpdf(); + if (siteLogpdf.empty()) + continue; + Block &entry = siteLogpdf.front(); + SmallVector<Operation *> opsToErase; + for (auto &op : entry) + opsToErase.push_back(&op); + for (auto *op : llvm::reverse(opsToErase)) { + op->dropAllUses(); + op->erase(); + } + while (entry.getNumArguments() > 0) { + entry.getArgument(0).dropAllUses(); + entry.eraseArgument(0); + } + entry.erase(); + } + + return true; +} + namespace { static void inlineFunctionIntoRegion(OpBuilder &builder, FunctionOpInterface fn, @@ -137,9 +524,12 @@ mcmcOp.getHmcConfigAttr(), mcmcOp.getNutsConfigAttr(), mcmcOp.getLogpdfFnAttr(), mcmcOp.getInitialPosition(), mcmcOp.getInitialGradient(), mcmcOp.getInitialPotentialEnergy(), - fnStrAttr, mcmcOp.getNameAttr()); + fnStrAttr, mcmcOp.getNameAttr(), + /*num_position_args=*/rewriter.getI64IntegerAttr(0), + /*position_size=*/rewriter.getI64IntegerAttr(0), + /*supports=*/ArrayAttr()); - Block *bodyBlock = rewriter.createBlock(&mcmcRegionOp.getBody()); + Block *bodyBlock = rewriter.createBlock(&mcmcRegionOp.getSampler()); Block &fnEntry = targetFn.getFunctionBody().front(); for (auto arg : fnEntry.getArguments()) { @@ -182,22 +572,83 @@ } }; +static func::FuncOp outlineSampleSubRegion(OpBuilder &moduleBuilder, + Region ®ion, StringRef funcName) { + assert(!region.empty() && region.hasOneBlock()); + Block &entry = region.front(); + + SmallVector<Type> argTypes(entry.getArgumentTypes()); + SmallVector<Location> argLocs; + for (auto arg : entry.getArguments()) + argLocs.push_back(arg.getLoc()); + + auto *yield = entry.getTerminator(); + SmallVector<Type> resultTypes(yield->getOperandTypes()); + + auto fnType = moduleBuilder.getFunctionType(argTypes, resultTypes); + auto func = + func::FuncOp::create(moduleBuilder, region.getLoc(), funcName, fnType); + func.setPrivate(); + + OpBuilder bodyBuilder(func.getContext()); + Block *newEntry = bodyBuilder.createBlock( + &func.getBody(), func.getBody().begin(), argTypes, argLocs); + bodyBuilder.setInsertionPointToEnd(newEntry); + + IRMapping map; + for (auto [oldArg, newArg] : + llvm::zip(entry.getArguments(), newEntry->getArguments())) + map.map(oldArg, newArg); + + for (Operation &op : entry.getOperations()) { + if (isa<enzyme::YieldOp>(&op)) { + SmallVector<Value> returnOperands; + for (Value operand : op.getOperands()) + returnOperands.push_back(map.lookupOrDefault(operand)); + func::ReturnOp::create(bodyBuilder, op.getLoc(), returnOperands); + continue; + } + bodyBuilder.clone(op, map); + } + + return func; +} + static void convertSampleRegionToSample(func::FuncOp outlinedFunc) { SmallVector<enzyme::SampleRegionOp> toConvert; outlinedFunc.walk( [&](enzyme::SampleRegionOp op) { toConvert.push_back(op); }); + auto *parentOp = outlinedFunc->getParentOp(); + OpBuilder moduleBuilder(&parentOp->getRegion(0)); + moduleBuilder.setInsertionPointAfter(outlinedFunc); + + unsigned counter = 0; for (auto sampleRegionOp : toConvert) { OpBuilder builder(sampleRegionOp); auto *ctx = builder.getContext(); FlatSymbolRefAttr fnAttr; - if (auto fnStrAttr = sampleRegionOp.getFnAttr()) { + Region &samplerRegion = sampleRegionOp.getSampler(); + if (!samplerRegion.empty() && samplerRegion.hasOneBlock()) { + std::string samplerName = + (Twine(outlinedFunc.getName()) + "_sampler_" + Twine(counter)).str(); + auto samplerFunc = + outlineSampleSubRegion(moduleBuilder, samplerRegion, samplerName); + fnAttr = FlatSymbolRefAttr::get(ctx, samplerFunc.getName()); + } else if (auto fnStrAttr = sampleRegionOp.getFnAttr()) { fnAttr = FlatSymbolRefAttr::get(ctx, fnStrAttr); } FlatSymbolRefAttr logpdfAttr; - if (auto logpdfStrAttr = sampleRegionOp.getLogpdfFnAttr()) { + Region &logpdfRegion = sampleRegionOp.getLogpdf(); + if (!logpdfRegion.empty() && logpdfRegion.hasOneBlock()) { + std::string logpdfName = + (Twine(outlinedFunc.getName()) + "_logpdf_" + Twine(counter)).str(); + auto logpdfFunc = + outlineSampleSubRegion(moduleBuilder, logpdfRegion, logpdfName); + logpdfAttr = FlatSymbolRefAttr::get(ctx, logpdfFunc.getName()); + } else if (auto logpdfStrAttr = sampleRegionOp.getLogpdfFnAttr()) { logpdfAttr = FlatSymbolRefAttr::get(ctx, logpdfStrAttr); } @@ -209,6 +660,7 @@ sampleRegionOp.replaceAllUsesWith(sampleOp.getResults()); sampleRegionOp.erase(); + ++counter; } } @@ -295,17 +747,232 @@ return outlinedFunc; } +static bool canOutlineLogpdf(enzyme::MCMCRegionOp regionOp) { + Region &logpdf = regionOp.getLogpdf(); + if (logpdf.empty()) + return false; + + int64_t numPosArgs = regionOp.getNumPositionArgs(); + if (numPosArgs == 0) + return false; + + Block &entry = logpdf.front(); + if (static_cast<int64_t>(entry.getNumArguments()) != numPosArgs) + return false; + + if (!regionOp.getOriginalTrace()) + return false; + + return true; +} + +static LogicalResult outlineLogpdfToFunction(enzyme::MCMCRegionOp regionOp, + StringRef logpdfFuncName, + OpBuilder &builder) { + Location loc = regionOp.getLoc(); + auto elemType = builder.getF64Type(); + int64_t positionSize = regionOp.getPositionSize(); + int64_t numPosArgs = regionOp.getNumPositionArgs(); + auto scalarType = RankedTensorType::get({}, elemType); + auto positionType = RankedTensorType::get({1, positionSize}, elemType); + auto i64TensorType = RankedTensorType::get({}, builder.getI64Type()); + + Region &logpdfRegion = regionOp.getLogpdf(); + Block &logpdfEntry = logpdfRegion.front(); + + llvm::SetVector<Value> logpdfFreeValues; + getUsedValuesDefinedAbove(logpdfRegion, logpdfFreeValues); + + SmallVector<Type> wrapperArgTypes; + wrapperArgTypes.push_back(positionType); + for (Value freeVal : logpdfFreeValues) + wrapperArgTypes.push_back(freeVal.getType()); + + auto wrapperFnType = builder.getFunctionType(wrapperArgTypes, {scalarType}); + auto wrapperFunc = + func::FuncOp::create(builder, loc, logpdfFuncName, wrapperFnType); + wrapperFunc.setPrivate(); + + Block *wrapperBlock = builder.createBlock( + &wrapperFunc.getBody(), wrapperFunc.getBody().begin(), wrapperArgTypes, + SmallVector<Location>(wrapperArgTypes.size(), loc)); + builder.setInsertionPointToStart(wrapperBlock); + + Value flatPosition = wrapperBlock->getArgument(0); + + IRMapping logpdfMapping; + unsigned wrapperArgIdx = 1; + for (Value freeVal : logpdfFreeValues) + logpdfMapping.map(freeVal, wrapperBlock->getArgument(wrapperArgIdx++)); + + auto c0 = arith::ConstantOp::create( + builder, loc, i64TensorType, + DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0))); + int64_t curOffset = 0; + for (unsigned i = 0; i < static_cast<unsigned>(numPosArgs); ++i) { + Type posArgType = logpdfEntry.getArgument(i).getType(); + auto tensorType = cast<RankedTensorType>(posArgType); + int64_t numElements = tensorType.getNumElements(); + auto sliceType = RankedTensorType::get({1, numElements}, elemType); + + auto offsetConst = arith::ConstantOp::create( + builder, loc, i64TensorType, + DenseElementsAttr::get(i64TensorType, + builder.getI64IntegerAttr(curOffset))); + auto slice = enzyme::DynamicSliceOp::create( + builder, loc, sliceType, flatPosition, ValueRange{c0, offsetConst}, + builder.getDenseI64ArrayAttr({1, numElements})); + auto component = enzyme::ReshapeOp::create(builder, loc, posArgType, slice); + + logpdfMapping.map(logpdfEntry.getArgument(i), component); + curOffset += numElements; + } + + for (Operation &op : logpdfEntry.without_terminator()) + builder.clone(op, logpdfMapping); + + auto *yield = logpdfEntry.getTerminator(); + Value logpdfResult = logpdfMapping.lookupOrDefault(yield->getOperand(0)); + func::ReturnOp::create(builder, loc, {logpdfResult}); + + return success(); +} + +static Value computeInitialPositionFromTrace(enzyme::MCMCRegionOp regionOp, + OpBuilder &builder) { + Location loc = regionOp.getLoc(); + auto elemType = builder.getF64Type(); + int64_t positionSize = regionOp.getPositionSize(); + auto positionType = RankedTensorType::get({1, positionSize}, elemType); + auto i64TensorType = RankedTensorType::get({}, builder.getI64Type()); + Value trace = regionOp.getOriginalTrace(); + + DenseMap<Attribute, std::pair<int64_t, int64_t>> traceOffsets; + { + DenseMap<Attribute, enzyme::SampleRegionOp> symbolToOp; + regionOp.getSampler().walk([&](enzyme::SampleRegionOp sampleOp) { + if (auto sym = sampleOp.getSymbolAttr()) + symbolToOp[sym] = sampleOp; + }); + + int64_t offset = 0; + for (auto addr : regionOp.getAllAddressesAttr()) { + auto addressArray = cast<ArrayAttr>(addr); + if (addressArray.empty()) + continue; + auto symbol = addressArray[0]; + auto it = symbolToOp.find(symbol); + if (it == symbolToOp.end()) + continue; + + auto sampleOp = it->second; + int64_t size = 0; + for (unsigned i = 1; i < sampleOp.getNumResults(); ++i) { + auto resultType = sampleOp.getResult(i).getType(); + if (auto tensorType = dyn_cast<RankedTensorType>(resultType)) + size += tensorType.getNumElements(); + else + size += 1; + } + traceOffsets[symbol] = {offset, size}; + offset += size; + } + } + + auto c0 = arith::ConstantOp::create( + builder, loc, i64TensorType, + DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0))); + auto zeroPos = arith::ConstantOp::create( + builder, loc, positionType, + DenseElementsAttr::get(positionType, + builder.getFloatAttr(elemType, 0.0))); + Value result = zeroPos; + + int64_t posOffset = 0; + for (auto addr : regionOp.getSelectionAttr()) { + auto addressArray = cast<ArrayAttr>(addr); + if (addressArray.empty()) + continue; + auto symbol = addressArray[0]; + auto it = traceOffsets.find(symbol); + if (it == traceOffsets.end()) + continue; + + int64_t traceOff = it->second.first; + int64_t size = it->second.second; + auto sliceType = RankedTensorType::get({1, size}, elemType); + + auto traceOffConst = arith::ConstantOp::create( + builder, loc, i64TensorType, + DenseElementsAttr::get(i64TensorType, + builder.getI64IntegerAttr(traceOff))); + auto posOffConst = arith::ConstantOp::create( + builder, loc, i64TensorType, + DenseElementsAttr::get(i64TensorType, + builder.getI64IntegerAttr(posOffset))); + + auto traceSlice = enzyme::DynamicSliceOp::create( + builder, loc, sliceType, trace, ValueRange{c0, traceOffConst}, + builder.getDenseI64ArrayAttr({1, size})); + result = enzyme::DynamicUpdateSliceOp::create(builder, loc, positionType, + result, traceSlice, + ValueRange{c0, posOffConst}); + + posOffset += size; + } + + return result; +} + LogicalResult outlineMCMCRegion(enzyme::MCMCRegionOp regionOp, StringRef funcName, OpBuilder &builder) { OpBuilder::InsertionGuard insertionGuard(builder); + + if (canOutlineLogpdf(regionOp)) { + builder.setInsertionPointAfter( + regionOp->getParentOfType<SymbolOpInterface>()); + + std::string logpdfFuncName = (Twine(funcName) + "_logpdf").str(); + if (failed(outlineLogpdfToFunction(regionOp, logpdfFuncName, builder))) + return failure(); + + auto logpdfSymRef = + FlatSymbolRefAttr::get(builder.getContext(), logpdfFuncName); + + llvm::SetVector<Value> logpdfFreeValues; + getUsedValuesDefinedAbove(regionOp.getLogpdf(), logpdfFreeValues); + + builder.setInsertionPoint(regionOp); + Value initialPosition = computeInitialPositionFromTrace(regionOp, builder); + + SmallVector<Value> mcmcInputs; + mcmcInputs.push_back(regionOp.getInputs()[0]); // rng + mcmcInputs.append(logpdfFreeValues.begin(), logpdfFreeValues.end()); + + auto newOp = enzyme::MCMCOp::create( + builder, regionOp.getLoc(), regionOp.getResultTypes(), + /*fn=*/FlatSymbolRefAttr{}, mcmcInputs, + /*original_trace=*/Value(), regionOp.getSelectionAttr(), + regionOp.getAllAddressesAttr(), regionOp.getNumWarmupAttr(), + regionOp.getNumSamplesAttr(), regionOp.getThinningAttr(), + regionOp.getInverseMassMatrix(), regionOp.getStepSize(), + regionOp.getHmcConfigAttr(), regionOp.getNutsConfigAttr(), logpdfSymRef, + initialPosition, regionOp.getInitialGradient(), + regionOp.getInitialPotentialEnergy(), regionOp.getNameAttr()); + + regionOp.replaceAllUsesWith(newOp.getResults()); + regionOp.erase(); + return success(); + } + builder.setInsertionPointAfter( regionOp->getParentOfType<SymbolOpInterface>()); llvm::SetVector<Value> freeValues; - getUsedValuesDefinedAbove(regionOp.getBody(), freeValues); + getUsedValuesDefinedAbove(regionOp.getSampler(), freeValues); FailureOr<func::FuncOp> outlinedFunc = - outlineRegionToFunction(regionOp.getBody(), funcName, builder); + outlineRegionToFunction(regionOp.getSampler(), funcName, builder); if (failed(outlinedFunc)) return failure(); @@ -345,6 +1012,18 @@ GreedyRewriteConfig config; (void)applyPatternsGreedily(getOperation(), std::move(patterns), config); + + SmallVector<enzyme::MCMCRegionOp> regionOps; + getOperation()->walk( + [&](enzyme::MCMCRegionOp op) { regionOps.push_back(op); }); + + for (auto regionOp : regionOps) { + bool submodelChanged = true; + while (submodelChanged) + submodelChanged = inlineSubmodelSampleRegions(regionOp); + + extractUnselectedSampleValues(regionOp); + } } };
diff --git a/enzyme/Enzyme/MLIR/Passes/PrintSampleDependenceAnalysis.cpp b/enzyme/Enzyme/MLIR/Passes/PrintSampleDependenceAnalysis.cpp index 4ead38c..2407271 100644 --- a/enzyme/Enzyme/MLIR/Passes/PrintSampleDependenceAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Passes/PrintSampleDependenceAnalysis.cpp
@@ -28,7 +28,7 @@ os << "Sample regions: " << analysis.getSampleOps().size() << "\n"; - regionOp.getBody().walk([&](Operation *op) { + regionOp.getSampler().walk([&](Operation *op) { bool dependent = analysis.isSampleDependent(op); bool hoistable = analysis.canHoist(op);
diff --git a/enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp index 42b5d99..b38f1c4 100644 --- a/enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp
@@ -776,6 +776,8 @@ } else { auto baseCtx = makeHMCContext(adaptedInvMass, adaptedMassMatrixSqrt, stepSize); + rngInput = conditionalDump(rewriter, loc, rngInput, + "MCMC: initial rng state", debugDump); auto initState = InitHMC( rewriter, loc, rngInput, baseCtx, hasLogpdfFn ? mcmcOp.getInitialPosition() : Value(), debugDump); @@ -1093,10 +1095,29 @@ SmallVector<Value> trueYieldValues; if (adaptMassMatrix) { + conditionalDump(rewriter, loc, iterT, "WINDOW_END: iteration", + debugDump); + conditionalDump(rewriter, loc, windowIdxLoop, + "WINDOW_END: window_idx", debugDump); + conditionalDump(rewriter, loc, conditionalWelford.mean, + "WINDOW_END: welford_mean", debugDump); + conditionalDump(rewriter, loc, conditionalWelford.m2, + "WINDOW_END: welford_m2", debugDump); + conditionalDump(rewriter, loc, conditionalWelford.n, + "WINDOW_END: welford_n", debugDump); + auto newInvMass = finalizeWelford(rewriter, loc, conditionalWelford, welfordConfig); auto newMassMatrixSqrt = computeMassMatrixSqrt(rewriter, loc, newInvMass, positionType); + + conditionalDump(rewriter, loc, newInvMass, + "WINDOW_END: new_inv_mass", debugDump); + conditionalDump(rewriter, loc, newMassMatrixSqrt, + "WINDOW_END: new_mass_sqrt", debugDump); + conditionalDump(rewriter, loc, adaptedStepSizeInLoop, + "WINDOW_END: step_size", debugDump); + auto reinitWelford = initWelford(rewriter, loc, positionSize, diagonal); @@ -1314,6 +1335,13 @@ conditionalDump(rewriter, loc, finalSamplesBuffer, "MCMC: collected samples", debugDump); + auto expectedAcceptedType = + cast<RankedTensorType>(mcmcOp.getResult(1).getType()); + if (finalAcceptedBuffer.getType() != expectedAcceptedType) { + finalAcceptedBuffer = enzyme::ReshapeOp::create( + rewriter, loc, expectedAcceptedType, finalAcceptedBuffer); + } + rewriter.replaceOp(mcmcOp, {finalSamplesBuffer, finalAcceptedBuffer, finalRng, finalQ, finalGrad, finalU, adaptedStepSize, adaptedInvMass});
diff --git a/enzyme/test/MLIR/ProbProg/outline_logpdf.mlir b/enzyme/test/MLIR/ProbProg/outline_logpdf.mlir new file mode 100644 index 0000000..e8b5249 --- /dev/null +++ b/enzyme/test/MLIR/ProbProg/outline_logpdf.mlir
@@ -0,0 +1,142 @@ +// RUN: %eopt --outline-mcmc-regions %s | FileCheck %s + +// ============================================================================ +// Test 1: Basic outline of unified logpdf region (no free values) +// ============================================================================ + +// CHECK-LABEL: func.func @test_basic_outline +// Initial position extracted from trace via slice/update ops +// CHECK: enzyme.dynamic_slice +// CHECK: enzyme.dynamic_update_slice +// CHECK: enzyme.dynamic_slice +// CHECK: %[[INIT_POS:.+]] = enzyme.dynamic_update_slice +// MCMCOp in pure logpdf_fn mode (no @fn reference before the parens) +// CHECK: enzyme.mcmc( +// CHECK-SAME: logpdf_fn = @[[LOGPDF:[a-zA-Z0-9_]+]] +// CHECK-SAME: initial_position = %[[INIT_POS]] +// +// Outlined logpdf wrapper function +// CHECK: func.func private @[[LOGPDF]](%[[POS:[^:]+]]: tensor<1x2xf64>) -> tensor<f64> +// Slice position into two scalar components +// CHECK: %[[S0:.+]] = enzyme.dynamic_slice %[[POS]] +// CHECK-NEXT: %[[X0:.+]] = enzyme.reshape %[[S0]] +// CHECK: %[[S1:.+]] = enzyme.dynamic_slice %[[POS]] +// CHECK-NEXT: %[[X1:.+]] = enzyme.reshape %[[S1]] +// Logpdf body: -x0 + -x1 +// CHECK: %[[N0:.+]] = arith.negf %[[X0]] +// CHECK-NEXT: %[[N1:.+]] = arith.negf %[[X1]] +// CHECK-NEXT: %[[SUM:.+]] = arith.addf %[[N0]], %[[N1]] +// CHECK-NEXT: return %[[SUM]] + +// ============================================================================ +// Test 2: Logpdf with free values (hoisted op from enclosing scope) +// ============================================================================ + +// CHECK-LABEL: func.func @test_outline_with_free_values +// Hoisted log computed before mcmc_region +// CHECK: %[[LOG:.+]] = math.log +// CHECK: enzyme.mcmc( +// CHECK-SAME: logpdf_fn = @[[LP2:[a-zA-Z0-9_]+]] +// The logpdf function receives the hoisted log as an extra parameter +// CHECK: func.func private @[[LP2]](%[[P2:[^:]+]]: tensor<1x2xf64>, %[[FLOG:[^:]+]]: tensor<f64>) -> tensor<f64> +// Body uses the hoisted log value +// CHECK: arith.subf %{{.*}}, %[[FLOG]] +// CHECK: arith.subf %{{.*}}, %[[FLOG]] +// CHECK: arith.addf +// CHECK-NEXT: return + +module { + // Test 1: Basic (no free values) + func.func @test_basic_outline( + %rng : tensor<2xui64>, + %mu : tensor<f64>, + %trace : tensor<1x2xf64>) { + + %step_size = arith.constant dense<0.1> : tensor<f64> + + %result:8 = enzyme.mcmc_region(%rng, %mu) given %trace + step_size = %step_size { + ^bb0(%r: tensor<2xui64>, %m: tensor<f64>): + %x0:2 = enzyme.sample_region(%r, %m) sampler { + ^bb0(%sr: tensor<2xui64>, %sm: tensor<f64>): + enzyme.yield %sr, %sm : tensor<2xui64>, tensor<f64> + } logpdf { + } {symbol = #enzyme.symbol<0>} + : (tensor<2xui64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>) + + %x1:2 = enzyme.sample_region(%x0#0, %m) sampler { + ^bb0(%sr2: tensor<2xui64>, %sm2: tensor<f64>): + enzyme.yield %sr2, %sm2 : tensor<2xui64>, tensor<f64> + } logpdf { + } {symbol = #enzyme.symbol<1>} + : (tensor<2xui64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>) + + enzyme.yield %x1#0, %x1#1 : tensor<2xui64>, tensor<f64> + } logpdf { + ^bb0(%lp_x0: tensor<f64>, %lp_x1: tensor<f64>): + %neg0 = arith.negf %lp_x0 : tensor<f64> + %neg1 = arith.negf %lp_x1 : tensor<f64> + %total = arith.addf %neg0, %neg1 : tensor<f64> + enzyme.yield %total : tensor<f64> + } attributes { + selection = [[#enzyme.symbol<0>], [#enzyme.symbol<1>]], + all_addresses = [[#enzyme.symbol<0>], [#enzyme.symbol<1>]], + nuts_config = #enzyme.nuts_config<max_tree_depth = 10, max_delta_energy = 1000.0>, + num_samples = 1 : i64, thinning = 1 : i64, num_warmup = 0 : i64, + num_position_args = 2 : i64, position_size = 2 : i64 + } : (tensor<2xui64>, tensor<f64>, tensor<1x2xf64>, tensor<f64>) + -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, + tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, + tensor<f64>, tensor<1x2xf64>) + return + } + + // Test 2: Free values (hoisted math.log from enclosing scope) + func.func @test_outline_with_free_values( + %rng : tensor<2xui64>, + %sigma : tensor<f64>, + %trace : tensor<1x2xf64>) { + + %step_size = arith.constant dense<0.1> : tensor<f64> + // This math.log is hoisted by SICM before the mcmc_region. + // The logpdf region captures it as a free value. + %log_sigma = math.log %sigma : tensor<f64> + + %result:8 = enzyme.mcmc_region(%rng, %sigma) given %trace + step_size = %step_size { + ^bb0(%r: tensor<2xui64>, %s: tensor<f64>): + %x0:2 = enzyme.sample_region(%r, %s) sampler { + ^bb0(%sr: tensor<2xui64>, %ss: tensor<f64>): + enzyme.yield %sr, %ss : tensor<2xui64>, tensor<f64> + } logpdf { + } {symbol = #enzyme.symbol<0>} + : (tensor<2xui64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>) + + %x1:2 = enzyme.sample_region(%x0#0, %s) sampler { + ^bb0(%sr2: tensor<2xui64>, %ss2: tensor<f64>): + enzyme.yield %sr2, %ss2 : tensor<2xui64>, tensor<f64> + } logpdf { + } {symbol = #enzyme.symbol<1>} + : (tensor<2xui64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>) + + enzyme.yield %x1#0, %x1#1 : tensor<2xui64>, tensor<f64> + } logpdf { + ^bb0(%lp_x0: tensor<f64>, %lp_x1: tensor<f64>): + // Both sites subtract the hoisted log_sigma (CSE'd free value) + %sub0 = arith.subf %lp_x0, %log_sigma : tensor<f64> + %sub1 = arith.subf %lp_x1, %log_sigma : tensor<f64> + %total = arith.addf %sub0, %sub1 : tensor<f64> + enzyme.yield %total : tensor<f64> + } attributes { + selection = [[#enzyme.symbol<0>], [#enzyme.symbol<1>]], + all_addresses = [[#enzyme.symbol<0>], [#enzyme.symbol<1>]], + nuts_config = #enzyme.nuts_config<max_tree_depth = 10, max_delta_energy = 1000.0>, + num_samples = 1 : i64, thinning = 1 : i64, num_warmup = 0 : i64, + num_position_args = 2 : i64, position_size = 2 : i64 + } : (tensor<2xui64>, tensor<f64>, tensor<1x2xf64>, tensor<f64>) + -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, + tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, + tensor<f64>, tensor<1x2xf64>) + return + } +}
diff --git a/enzyme/test/MLIR/ProbProg/outline_logpdf_lower.mlir b/enzyme/test/MLIR/ProbProg/outline_logpdf_lower.mlir new file mode 100644 index 0000000..4a67043 --- /dev/null +++ b/enzyme/test/MLIR/ProbProg/outline_logpdf_lower.mlir
@@ -0,0 +1,168 @@ +// RUN: %eopt --outline-mcmc-regions --probprog %s | FileCheck %s + +// ============================================================================ +// Test: Full pipeline for outlined unified logpdf +// mcmc_region (with unified logpdf) → outline → probprog lowering +// ============================================================================ +// +// Verifies that the outlined logpdf wrapper function is correctly called +// during NUTS lowering: initial U0 computation, gradient via autodiff_region, +// and leapfrog integration inside the tree-building loop. + +// CHECK-LABEL: func.func @test_outline_lower_nuts +// Position extracted from trace → outline builds initial_position +// CHECK: enzyme.dynamic_slice +// CHECK: enzyme.dynamic_update_slice +// CHECK: enzyme.dynamic_slice +// CHECK: %[[INIT_POS:.+]] = enzyme.dynamic_update_slice +// Initial U0: call logpdf, negf +// CHECK: call @[[LOGPDF:[a-zA-Z0-9_]+]](%[[INIT_POS]]) +// CHECK-NEXT: %[[U0:.+]] = arith.negf +// Gradient via autodiff_region +// CHECK: enzyme.autodiff_region +// CHECK: func.call @[[LOGPDF]] +// CHECK-NEXT: %[[NEG:.+]] = arith.negf +// CHECK-NEXT: enzyme.yield +// NUTS sample loop +// CHECK: enzyme.for_loop +// Gradient inside tree building +// CHECK: enzyme.autodiff_region +// CHECK: func.call @[[LOGPDF]] +// CHECK-NEXT: %{{.+}} = arith.negf +// CHECK-NEXT: enzyme.yield + +// The outlined logpdf wrapper function +// CHECK: func.func private @[[LOGPDF]](%[[POS:[^:]+]]: tensor<1x2xf64>) -> tensor<f64> +// Slice position into per-site scalars +// CHECK: %[[S0:.+]] = enzyme.dynamic_slice %[[POS]] +// CHECK-NEXT: %[[X0:.+]] = enzyme.reshape %[[S0]] +// CHECK: %[[S1:.+]] = enzyme.dynamic_slice %[[POS]] +// CHECK-NEXT: %[[X1:.+]] = enzyme.reshape %[[S1]] +// Logpdf body: -x0 + -x1 +// CHECK: %[[N0:.+]] = arith.negf %[[X0]] +// CHECK-NEXT: %[[N1:.+]] = arith.negf %[[X1]] +// CHECK-NEXT: %[[SUM:.+]] = arith.addf %[[N0]], %[[N1]] +// CHECK-NEXT: return %[[SUM]] + +module { + func.func @test_outline_lower_nuts( + %rng : tensor<2xui64>, + %mu : tensor<f64>, + %trace : tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) { + + %step_size = arith.constant dense<0.1> : tensor<f64> + + %result:8 = enzyme.mcmc_region(%rng, %mu) given %trace + step_size = %step_size { + ^bb0(%r: tensor<2xui64>, %m: tensor<f64>): + %x0:2 = enzyme.sample_region(%r, %m) sampler { + ^bb0(%sr: tensor<2xui64>, %sm: tensor<f64>): + enzyme.yield %sr, %sm : tensor<2xui64>, tensor<f64> + } logpdf { + } {symbol = #enzyme.symbol<0>} + : (tensor<2xui64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>) + + %x1:2 = enzyme.sample_region(%x0#0, %m) sampler { + ^bb0(%sr2: tensor<2xui64>, %sm2: tensor<f64>): + enzyme.yield %sr2, %sm2 : tensor<2xui64>, tensor<f64> + } logpdf { + } {symbol = #enzyme.symbol<1>} + : (tensor<2xui64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>) + + enzyme.yield %x1#0, %x1#1 : tensor<2xui64>, tensor<f64> + } logpdf { + ^bb0(%lp_x0: tensor<f64>, %lp_x1: tensor<f64>): + %neg0 = arith.negf %lp_x0 : tensor<f64> + %neg1 = arith.negf %lp_x1 : tensor<f64> + %total = arith.addf %neg0, %neg1 : tensor<f64> + enzyme.yield %total : tensor<f64> + } attributes { + selection = [[#enzyme.symbol<0>], [#enzyme.symbol<1>]], + all_addresses = [[#enzyme.symbol<0>], [#enzyme.symbol<1>]], + nuts_config = #enzyme.nuts_config<max_tree_depth = 10, max_delta_energy = 1000.0>, + num_samples = 1 : i64, thinning = 1 : i64, num_warmup = 0 : i64, + num_position_args = 2 : i64, position_size = 2 : i64 + } : (tensor<2xui64>, tensor<f64>, tensor<1x2xf64>, tensor<f64>) + -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, + tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, + tensor<f64>, tensor<1x2xf64>) + return %result#0, %result#1, %result#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64> + } + + // ========================================================================== + // Test 2: Logpdf with free values (hoisted math.log from enclosing scope) + // The logpdf wrapper receives the hoisted value as an extra parameter. + // Lowering must pass it through as an MCMCOp input. + // ========================================================================== + + // CHECK-LABEL: func.func @test_outline_lower_free_values + // Hoisted log before mcmc + // CHECK: %[[LOG:.+]] = math.log + // Initial U0: call logpdf with position AND free value + // CHECK: call @[[LP2:[a-zA-Z0-9_]+]](%{{.+}}, %[[LOG]]) + // CHECK-NEXT: arith.negf + // Gradient via autodiff_region + // CHECK: enzyme.autodiff_region + // CHECK: func.call @[[LP2]](%{{.+}}, %[[LOG]]) + // CHECK-NEXT: arith.negf + // CHECK-NEXT: enzyme.yield + // NUTS sample loop + // CHECK: enzyme.for_loop + // Gradient inside tree building + // CHECK: enzyme.autodiff_region + // CHECK: func.call @[[LP2]](%{{.+}}, %[[LOG]]) + // CHECK-NEXT: arith.negf + // CHECK-NEXT: enzyme.yield + // + // The logpdf wrapper with free value parameter + // CHECK: func.func private @[[LP2]](%[[P2:[^:]+]]: tensor<1x2xf64>, %[[FLOG:[^:]+]]: tensor<f64>) -> tensor<f64> + // CHECK: arith.subf %{{.*}}, %[[FLOG]] + // CHECK: arith.subf %{{.*}}, %[[FLOG]] + // CHECK: arith.addf + // CHECK-NEXT: return + + func.func @test_outline_lower_free_values( + %rng : tensor<2xui64>, + %sigma : tensor<f64>, + %trace : tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) { + + %step_size = arith.constant dense<0.1> : tensor<f64> + %log_sigma = math.log %sigma : tensor<f64> + + %result:8 = enzyme.mcmc_region(%rng, %sigma) given %trace + step_size = %step_size { + ^bb0(%r: tensor<2xui64>, %s: tensor<f64>): + %x0:2 = enzyme.sample_region(%r, %s) sampler { + ^bb0(%sr: tensor<2xui64>, %ss: tensor<f64>): + enzyme.yield %sr, %ss : tensor<2xui64>, tensor<f64> + } logpdf { + } {symbol = #enzyme.symbol<0>} + : (tensor<2xui64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>) + + %x1:2 = enzyme.sample_region(%x0#0, %s) sampler { + ^bb0(%sr2: tensor<2xui64>, %ss2: tensor<f64>): + enzyme.yield %sr2, %ss2 : tensor<2xui64>, tensor<f64> + } logpdf { + } {symbol = #enzyme.symbol<1>} + : (tensor<2xui64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>) + + enzyme.yield %x1#0, %x1#1 : tensor<2xui64>, tensor<f64> + } logpdf { + ^bb0(%lp_x0: tensor<f64>, %lp_x1: tensor<f64>): + %sub0 = arith.subf %lp_x0, %log_sigma : tensor<f64> + %sub1 = arith.subf %lp_x1, %log_sigma : tensor<f64> + %total = arith.addf %sub0, %sub1 : tensor<f64> + enzyme.yield %total : tensor<f64> + } attributes { + selection = [[#enzyme.symbol<0>], [#enzyme.symbol<1>]], + all_addresses = [[#enzyme.symbol<0>], [#enzyme.symbol<1>]], + nuts_config = #enzyme.nuts_config<max_tree_depth = 10, max_delta_energy = 1000.0>, + num_samples = 1 : i64, thinning = 1 : i64, num_warmup = 0 : i64, + num_position_args = 2 : i64, position_size = 2 : i64 + } : (tensor<2xui64>, tensor<f64>, tensor<1x2xf64>, tensor<f64>) + -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, + tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, + tensor<f64>, tensor<1x2xf64>) + return %result#0, %result#1, %result#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64> + } +}
diff --git a/enzyme/test/MLIR/ProbProg/sample_dependence_analysis.mlir b/enzyme/test/MLIR/ProbProg/sample_dependence_analysis.mlir index 4ed8775..d9b56bc 100644 --- a/enzyme/test/MLIR/ProbProg/sample_dependence_analysis.mlir +++ b/enzyme/test/MLIR/ProbProg/sample_dependence_analysis.mlir
@@ -14,9 +14,8 @@ // CHECK: [INV] [HOIST] arith.addf -> tensor<f64> // CHECK: [INV] [KEEP] enzyme.yield -// Operations inside logpdf region (nested in sample_region) -// CHECK: [INV] [HOIST] arith.negf -> tensor<f64> -// CHECK: [INV] [KEEP] enzyme.yield +// Per-site logpdf bodies are merged into MCMCRegionOp's unified logpdf +// region by inline-mcmc-regions Phase 2, so arith.negf no longer appears here. // The sample_region itself - sample-dependent // CHECK: [DEP] [KEEP] enzyme.sample_region -> tensor<2xui64>, tensor<f64>