save
diff --git a/enzyme/Enzyme/MLIR/Analysis/SampleDependenceAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/SampleDependenceAnalysis.cpp
index 7931d8e..cde8dc3 100644
--- a/enzyme/Enzyme/MLIR/Analysis/SampleDependenceAnalysis.cpp
+++ b/enzyme/Enzyme/MLIR/Analysis/SampleDependenceAnalysis.cpp
@@ -9,34 +9,34 @@
 using namespace mlir::enzyme;

 

 SampleDependenceAnalysis::SampleDependenceAnalysis(MCMCRegionOp regionOp)

-    : regionOp(regionOp) {

-  runAnalysis();

+    : regionOp(regionOp), target(AnalysisTarget::Sampler) {

+  runSamplerAnalysis();

+}

+

+SampleDependenceAnalysis::SampleDependenceAnalysis(MCMCRegionOp regionOp,

+                                                   AnalysisTarget target)

+    : regionOp(regionOp), target(target) {

+  if (target == AnalysisTarget::Logpdf)

+    runLogpdfAnalysis();

+  else

+    runSamplerAnalysis();

+}

+

+Region &SampleDependenceAnalysis::getTargetRegion() {

+  if (target == AnalysisTarget::Logpdf)

+    return regionOp.getLogpdf();

+  return regionOp.getSampler();

 }

 

 void SampleDependenceAnalysis::markSampleDependent(Value value) {

   sampleDependentValues.insert(value);

 }

 

-void SampleDependenceAnalysis::runAnalysis() {

-  // TODO: Handle recursive regions.

-  if (regionOp.getLogpdfFnAttr()) {

-    Block &entry = regionOp.getBody().front();

-    if (!entry.getArguments().empty()) {

-      markSampleDependent(entry.getArgument(0));

-    }

-  }

-

-  regionOp.getBody().walk([&](SampleRegionOp sampleOp) {

-    sampleOps.push_back(sampleOp);

-    for (Value result : sampleOp.getResults()) {

-      markSampleDependent(result);

-    }

-  });

-

+void SampleDependenceAnalysis::propagateDependence(Region &region) {

   bool changed = true;

   while (changed) {

     changed = false;

-    regionOp.getBody().walk([&](Operation *op) {

+    region.walk([&](Operation *op) {

       if (isa<SampleRegionOp>(op))

         return;

 

@@ -60,6 +60,56 @@
   }

 }

 

+void SampleDependenceAnalysis::runSamplerAnalysis() {

+  // TODO: Handle recursive regions.

+  if (regionOp.getLogpdfFnAttr()) {

+    Block &entry = regionOp.getSampler().front();

+    if (!entry.getArguments().empty()) {

+      markSampleDependent(entry.getArgument(0));

+    }

+  }

+

+  DenseSet<Attribute> selectedSymbols;

+  bool hasSelection = false;

+  if (auto selection = regionOp.getSelectionAttr()) {

+    hasSelection = true;

+    for (auto addr : selection) {

+      auto address = cast<ArrayAttr>(addr);

+      if (!address.empty())

+        selectedSymbols.insert(address[0]);

+    }

+  }

+

+  regionOp.getSampler().walk([&](SampleRegionOp sampleOp) {

+    sampleOps.push_back(sampleOp);

+    auto symbol = sampleOp.getSymbolAttr();

+    bool isSelected =

+        !hasSelection || !symbol || selectedSymbols.contains(symbol);

+    if (isSelected) {

+      for (Value result : sampleOp.getResults()) {

+        markSampleDependent(result);

+      }

+    }

+  });

+

+  propagateDependence(regionOp.getSampler());

+}

+

+void SampleDependenceAnalysis::runLogpdfAnalysis() {

+  Region &logpdf = regionOp.getLogpdf();

+  if (logpdf.empty())

+    return;

+

+  Block &entry = logpdf.front();

+  int64_t numPositionArgs = regionOp.getNumPositionArgs();

+

+  for (int64_t i = 0;

+       i < numPositionArgs && i < (int64_t)entry.getNumArguments(); ++i)

+    markSampleDependent(entry.getArgument(i));

+

+  propagateDependence(logpdf);

+}

+

 bool SampleDependenceAnalysis::isSampleDependent(Value value) const {

   return sampleDependentValues.contains(value);

 }

@@ -72,6 +122,15 @@
   return false;

 }

 

+bool SampleDependenceAnalysis::isInTargetRegion(Operation *op) {

+  Region *targetRegion;

+  if (target == AnalysisTarget::Logpdf)

+    targetRegion = &regionOp.getLogpdf();

+  else

+    targetRegion = &regionOp.getSampler();

+  return targetRegion->isAncestor(op->getParentRegion());

+}

+

 bool SampleDependenceAnalysis::canHoist(Operation *op) const {

   if (isa<SampleRegionOp>(op))

     return false;

@@ -134,26 +193,26 @@
   return false;

 }

 

-bool enzyme::hoistSampleInvariantOps(MCMCRegionOp regionOp) {

-  DominanceInfo dom(regionOp);

-  PostDominanceInfo pdom(regionOp);

-  SampleDependenceAnalysis sampleAnalysis(regionOp);

-

-  Region &region = regionOp.getBody();

+static bool hoistFromRegion(MCMCRegionOp regionOp,

+                            SampleDependenceAnalysis &sampleAnalysis,

+                            Region &region) {

   if (region.empty())

     return false;

 

+  DominanceInfo dom(regionOp);

+  PostDominanceInfo pdom(regionOp);

+

   IRMapping regionToOuter;

   Block &entryBlock = region.front();

-  auto inputs = regionOp.getInputs();

 

-  bool isLogpdfMode = static_cast<bool>(regionOp.getLogpdfFnAttr());

-

-  for (auto [idx, blockArg] : llvm::enumerate(entryBlock.getArguments())) {

-    if (isLogpdfMode && idx == 0)

-      continue;

-    if (idx < inputs.size()) {

-      regionToOuter.map(blockArg, inputs[idx]);

+  if (sampleAnalysis.getTarget() == AnalysisTarget::Sampler) {

+    auto inputs = regionOp.getInputs();

+    bool isLogpdfMode = static_cast<bool>(regionOp.getLogpdfFnAttr());

+    for (auto [idx, blockArg] : llvm::enumerate(entryBlock.getArguments())) {

+      if (isLogpdfMode && idx == 0)

+        continue;

+      if (idx < inputs.size())

+        regionToOuter.map(blockArg, inputs[idx]);

     }

   }

 

@@ -225,3 +284,13 @@
 

   return !sortedToHoist.empty();

 }

+

+bool enzyme::hoistSampleInvariantOps(MCMCRegionOp regionOp) {

+  return hoistSampleInvariantOps(regionOp, AnalysisTarget::Sampler);

+}

+

+bool enzyme::hoistSampleInvariantOps(MCMCRegionOp regionOp,

+                                     AnalysisTarget target) {

+  SampleDependenceAnalysis analysis(regionOp, target);

+  return hoistFromRegion(regionOp, analysis, analysis.getTargetRegion());

+}

diff --git a/enzyme/Enzyme/MLIR/Analysis/SampleDependenceAnalysis.h b/enzyme/Enzyme/MLIR/Analysis/SampleDependenceAnalysis.h
index 1ab05b5..3ef22d0 100644
--- a/enzyme/Enzyme/MLIR/Analysis/SampleDependenceAnalysis.h
+++ b/enzyme/Enzyme/MLIR/Analysis/SampleDependenceAnalysis.h
@@ -10,10 +10,17 @@
 namespace mlir {

 namespace enzyme {

 

+enum class AnalysisTarget {

+  Sampler,

+  Logpdf,

+};

+

 class SampleDependenceAnalysis {

 public:

   explicit SampleDependenceAnalysis(MCMCRegionOp regionOp);

 

+  SampleDependenceAnalysis(MCMCRegionOp regionOp, AnalysisTarget target);

+

   bool isSampleDependent(Value value) const;

   bool isSampleDependent(Operation *op) const;

   bool canHoist(Operation *op) const;

@@ -21,18 +28,30 @@
   ArrayRef<SampleRegionOp> getSampleOps() const { return sampleOps; }

 

   MCMCRegionOp getRegionOp() const { return regionOp; }

+  AnalysisTarget getTarget() const { return target; }

+

+  bool isInTargetRegion(Operation *op);

+

+  Region &getTargetRegion();

 

 private:

   MCMCRegionOp regionOp;

+  AnalysisTarget target;

   DenseSet<Value> sampleDependentValues;

   SmallVector<SampleRegionOp> sampleOps;

 

-  void runAnalysis();

+  void runSamplerAnalysis();

+  void runLogpdfAnalysis();

   void markSampleDependent(Value value);

+  void propagateDependence(Region &region);

 };

 

 bool hoistSampleInvariantOps(MCMCRegionOp regionOp);

 

+bool hoistSampleInvariantOps(MCMCRegionOp regionOp, AnalysisTarget target);

+

+bool constructUnifiedLogpdf(MCMCRegionOp regionOp);

+

 } // namespace enzyme

 } // namespace mlir

 

diff --git a/enzyme/Enzyme/MLIR/Dialect/Dialect.cpp b/enzyme/Enzyme/MLIR/Dialect/Dialect.cpp
index 5df615d..a791273 100644
--- a/enzyme/Enzyme/MLIR/Dialect/Dialect.cpp
+++ b/enzyme/Enzyme/MLIR/Dialect/Dialect.cpp
@@ -71,3 +71,31 @@
 
 #define GET_ATTRDEF_CLASSES
 #include "Dialect/EnzymeAttributes.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// SymbolAttr custom print/parse
+//===----------------------------------------------------------------------===//
+
+void SymbolAttr::print(AsmPrinter &printer) const {
+  printer << "<";
+  llvm::interleaveComma(getPath(), printer);
+  printer << ">";
+}
+
+Attribute SymbolAttr::parse(AsmParser &parser, Type type) {
+  if (parser.parseLess())
+    return {};
+  SmallVector<uint64_t> path;
+  uint64_t val;
+  if (failed(parser.parseInteger(val)))
+    return {};
+  path.push_back(val);
+  while (succeeded(parser.parseOptionalComma())) {
+    if (failed(parser.parseInteger(val)))
+      return {};
+    path.push_back(val);
+  }
+  if (parser.parseGreater())
+    return {};
+  return get(parser.getContext(), path);
+}
diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeEnums.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeEnums.td
index fbfaa78..7d9c866 100644
--- a/enzyme/Enzyme/MLIR/Dialect/EnzymeEnums.td
+++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeEnums.td
@@ -88,6 +88,8 @@
   let assemblyFormat = "`<` $kind (`,` `lower` `=` $lower_bound^)? (`,` `upper` `=` $upper_bound^)? `>`";
 }
 
+def SupportArrayAttr : TypedArrayAttrBase<SupportAttr, "Array of support specifications">;
+
 def HMCConfigAttr : Enzyme_Attr<"HMCConfig", "hmc_config"> {
   let summary = "Configuration for HMC inference";
   let description = [{
diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
index 6ddbf05..f449bc5 100644
--- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
+++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
@@ -447,10 +447,12 @@
 def SymbolAttr : Enzyme_Attr<"Symbol", "symbol"> {
   let summary = "Symbol associated with a Sample op";
   let description = [{
-  Symbol associated with a Sample op.
+  Symbol associated with a Sample op. Holds one or more uint64_t values.
+  A single value is a leaf symbol; multiple values encode a composite path
+  created by SICM submodel inlining.
   }];
-  let parameters = (ins "uint64_t":$ptr);
-  let assemblyFormat = "`<` $ptr `>`";
+  let parameters = (ins ArrayRefParameter<"uint64_t", "symbol path">:$path);
+  let hasCustomAssemblyFormat = 1;
 }
 
 def AddressAttr : TypedArrayAttrBase<SymbolAttr, "Address as an array of symbols">;
@@ -790,11 +792,25 @@
 }
 
 def MCMCRegionOp : Enzyme_Op<"mcmc_region", [AutomaticAllocationScope, AttrSizedOperandSegments]> {
-  let summary = "MCMC inference with inlined model body";
+  let summary = "Inference region with sampler and unified logpdf";
   let description = [{
-    A region-based MCMC operation where the model function body is inlined
-    into a region. This is an internal construct for optimization passes to see the full
-    computation before expansion to HMC/NUTS loops.
+    A region-based inference operation with two regions:
+
+    - `sampler`: The model computation containing sample_region ops with sampler
+      regions (logpdf regions cleared after merging). Used by proposal-based
+      algorithms (MH, Gibbs) and preserved for outline compatibility.
+
+    - `logpdf`: Unified logpdf computation for ALL position-dependent terms.
+      Block arguments follow a calling convention:
+        args[0..num_position_args-1] = position values (sample-dependent,
+          ordered by `selection` attribute)
+        Captured values from enclosing scope = parameters (sample-invariant)
+
+      SICM operates on this region. The outline pass converts it to a
+      standalone function referenced by `logpdf_fn` on the resulting MCMCOp.
+
+    If the `logpdf` region is empty, the operation falls back to the sampler-only
+    mode (pre-logpdf-merging state or custom logpdf_fn).
   }];
 
   let arguments = (ins
@@ -821,10 +837,14 @@
     Optional<AnyRankedTensor>:$initial_potential_energy,
 
     OptionalAttr<StrAttr>:$fn,
-    DefaultValuedStrAttr<StrAttr, "">:$name
+    DefaultValuedStrAttr<StrAttr, "">:$name,
+
+    DefaultValuedAttr<I64Attr, "0">:$num_position_args,
+    DefaultValuedAttr<I64Attr, "0">:$position_size,
+    OptionalAttr<SupportArrayAttr>:$supports
   );
 
-  let regions = (region AnyRegion:$body);
+  let regions = (region AnyRegion:$sampler, AnyRegion:$logpdf);
 
   let results = (outs
     AnyRankedTensor:$new_trace,
@@ -844,7 +864,8 @@
     (`initial_position` `=` $initial_position^)?
     (`initial_gradient` `=` $initial_gradient^)?
     (`initial_potential_energy` `=` $initial_potential_energy^)?
-    $body attr-dict-with-keyword `:` functional-type(operands, results)
+    $sampler (`logpdf` $logpdf^)?
+    attr-dict-with-keyword `:` functional-type(operands, results)
   }];
 }
 
@@ -1014,6 +1035,7 @@
   }];
 }
 
+
 def AffineAtomicRMWOp : Enzyme_Op<"affine_atomic_rmw"> {
   let summary = "affine atomic rmw operation";
   let description = [{
diff --git a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp
index 5a35244..6043fdc 100644
--- a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp
+++ b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp
@@ -90,5 +90,6 @@
 }
 
 MlirAttribute enzymeSymbolAttrGet(MlirContext ctx, uint64_t ptr) {
-  return wrap(mlir::enzyme::SymbolAttr::get(unwrap(ctx), ptr));
+  llvm::SmallVector<uint64_t, 1> path = {ptr};
+  return wrap(mlir::enzyme::SymbolAttr::get(unwrap(ctx), path));
 }
diff --git a/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp
index f7cbec2..ea70ca7 100644
--- a/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp
+++ b/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp
@@ -679,11 +679,22 @@
 
   auto initSplit = enzyme::RandomSplitOp::create(
       builder, loc, TypeRange{rng.getType(), rng.getType()}, rng);
+  conditionalDump(builder, loc, initSplit.getResult(0),
+                  "InitHMC: 2-way split[0] (rng_mcmc)", debugDump);
+  conditionalDump(builder, loc, initSplit.getResult(1),
+                  "InitHMC: 2-way split[1]", debugDump);
   auto kernelSplit = enzyme::RandomSplitOp::create(
       builder, loc, TypeRange{rng.getType(), rng.getType(), rng.getType()},
       initSplit.getResult(0));
   auto rngForSampleKernel = kernelSplit.getResult(0);
   auto rngForAutodiff = kernelSplit.getResult(1);
+  conditionalDump(builder, loc, rngForSampleKernel,
+                  "InitHMC: 3-way split[0] (rngForSampleKernel)", debugDump);
+  conditionalDump(builder, loc, rngForAutodiff,
+                  "InitHMC: 3-way split[1] (rngForAutodiff)", debugDump);
+  if (kernelSplit.getNumResults() > 2)
+    conditionalDump(builder, loc, kernelSplit.getResult(2),
+                    "InitHMC: 3-way split[2] (unused)", debugDump);
 
   Value q0;
   Value U0;
@@ -804,6 +815,11 @@
   // (U, rng, grad)
   auto grad0 = autodiffInit.getResult(2);
 
+  conditionalDump(builder, loc, q0, "InitHMC: q0 (unconstrained)", debugDump);
+  conditionalDump(builder, loc, U0, "InitHMC: U0 (potential energy)",
+                  debugDump);
+  conditionalDump(builder, loc, grad0, "InitHMC: grad0 (gradient)", debugDump);
+
   return {q0, U0, grad0, rngForSampleKernel};
 }
 
@@ -953,18 +969,30 @@
   auto i1TensorType = RankedTensorType::get({}, builder.getI1Type());
 
   // 1. Split RNG: [rngNext, rngMomentum, rngTree]
+  conditionalDump(builder, loc, rng,
+                  "SampleNUTS: input rng (before 3-way split)", debugDump);
   auto sampleKernelSplit = enzyme::RandomSplitOp::create(
       builder, loc, TypeRange{rng.getType(), rng.getType(), rng.getType()},
       rng);
   auto rngNext = sampleKernelSplit.getResult(0);
   auto rngMomentum = sampleKernelSplit.getResult(1);
   auto rngTree = sampleKernelSplit.getResult(2);
+  conditionalDump(builder, loc, rngNext, "SampleNUTS: 3-way split[0] (rngNext)",
+                  debugDump);
+  conditionalDump(builder, loc, rngMomentum,
+                  "SampleNUTS: 3-way split[1] (rngMomentum)", debugDump);
+  conditionalDump(builder, loc, rngTree, "SampleNUTS: 3-way split[2] (rngTree)",
+                  debugDump);
 
   // 2. Sample fresh momentum p ~ N(0, M)
   Value rngForMomentum = rngMomentum;
   if (!ctx.hasCustomLogpdf()) {
     auto momSplit = enzyme::RandomSplitOp::create(
         builder, loc, TypeRange{rng.getType(), rng.getType()}, rngMomentum);
+    conditionalDump(builder, loc, momSplit.getResult(0),
+                    "SampleNUTS: momSplit[0] (rngForMomentum)", debugDump);
+    conditionalDump(builder, loc, momSplit.getResult(1),
+                    "SampleNUTS: momSplit[1]", debugDump);
     rngForMomentum = momSplit.getResult(0);
   }
   auto [p0, rngAfterMomentum] =
@@ -977,6 +1005,11 @@
   // 4. Compute H0 = U + K0
   auto H0 = arith::AddFOp::create(builder, loc, U, K0);
 
+  conditionalDump(builder, loc, p0, "SampleNUTS: p0 (momentum)", debugDump);
+  conditionalDump(builder, loc, K0, "SampleNUTS: K0 (kinetic energy)",
+                  debugDump);
+  conditionalDump(builder, loc, H0, "SampleNUTS: H0 (hamiltonian)", debugDump);
+
   // 5. Initialize NUTS tree state
   auto iterCtx = ctx.withH0(H0);
 
diff --git a/enzyme/Enzyme/MLIR/Passes/InlineMCMCRegions.cpp b/enzyme/Enzyme/MLIR/Passes/InlineMCMCRegions.cpp
index 24d3890..5074b4d 100644
--- a/enzyme/Enzyme/MLIR/Passes/InlineMCMCRegions.cpp
+++ b/enzyme/Enzyme/MLIR/Passes/InlineMCMCRegions.cpp
@@ -1,6 +1,8 @@
+#include "Analysis/SampleDependenceAnalysis.h"

 #include "Dialect/Ops.h"

 #include "Passes/Passes.h"

 

+#include "mlir/Dialect/Arith/IR/Arith.h"

 #include "mlir/Dialect/Func/IR/FuncOps.h"

 #include "mlir/Interfaces/FunctionInterfaces.h"

 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"

@@ -16,6 +18,391 @@
 } // namespace enzyme

 } // namespace mlir

 

+static enzyme::SymbolAttr composeSymbols(enzyme::SymbolAttr outer,

+                                         enzyme::SymbolAttr inner,

+                                         MLIRContext *ctx) {

+  SmallVector<uint64_t> composed(outer.getPath());

+  composed.append(inner.getPath().begin(), inner.getPath().end());

+  return enzyme::SymbolAttr::get(ctx, composed);

+}

+

+static ArrayAttr flattenAddressesForSymbol(ArrayAttr addresses,

+                                           enzyme::SymbolAttr outerSymbol,

+                                           MLIRContext *ctx) {

+  SmallVector<Attribute> newAddresses;

+  for (auto addr : addresses) {

+    auto address = cast<ArrayAttr>(addr);

+    if (address.size() >= 2 && address[0] == outerSymbol) {

+      auto inner = cast<enzyme::SymbolAttr>(address[1]);

+      auto composite = composeSymbols(outerSymbol, inner, ctx);

+      SmallVector<Attribute> newAddr;

+      newAddr.push_back(composite);

+      for (unsigned i = 2; i < address.size(); ++i)

+        newAddr.push_back(address[i]);

+      newAddresses.push_back(ArrayAttr::get(ctx, newAddr));

+    } else {

+      newAddresses.push_back(addr);

+    }

+  }

+  return ArrayAttr::get(ctx, newAddresses);

+}

+

+static bool inlineSubmodelSampleRegions(enzyme::MCMCRegionOp regionOp) {

+  bool anyChanged = false;

+

+  SmallVector<enzyme::SampleRegionOp> sampleOps;

+  regionOp.getSampler().walk(

+      [&](enzyme::SampleRegionOp op) { sampleOps.push_back(op); });

+

+  for (enzyme::SampleRegionOp sampleOp : sampleOps) {

+    Region &logpdf = sampleOp.getLogpdf();

+    if (!logpdf.empty())

+      continue;

+    if (sampleOp.getLogpdfFnAttr())

+      continue;

+

+    Region &sampler = sampleOp.getSampler();

+    if (sampler.empty() || !sampler.hasOneBlock())

+      continue;

+

+    auto outerSymbol = sampleOp.getSymbolAttr();

+    if (!outerSymbol)

+      continue;

+

+    Block &samplerEntry = sampler.front();

+    auto *ctx = regionOp.getContext();

+

+    OpBuilder builder(sampleOp);

+    IRMapping mapper;

+    auto inputs = sampleOp.getInputs();

+    for (unsigned i = 0, e = samplerEntry.getNumArguments(); i < e; ++i) {

+      if (i < inputs.size())

+        mapper.map(samplerEntry.getArgument(i), inputs[i]);

+    }

+

+    for (Operation &op : samplerEntry.without_terminator()) {

+      Operation *cloned = builder.clone(op, mapper);

+      if (auto innerSample = dyn_cast<enzyme::SampleRegionOp>(cloned)) {

+        if (auto innerSymbol = innerSample.getSymbolAttr())

+          innerSample.setSymbolAttr(

+              composeSymbols(outerSymbol, innerSymbol, ctx));

+      } else if (auto innerSampleOp = dyn_cast<enzyme::SampleOp>(cloned)) {

+        if (auto innerSymbol = innerSampleOp.getSymbolAttr())

+          innerSampleOp.setSymbolAttr(

+              composeSymbols(outerSymbol, innerSymbol, ctx));

+      }

+    }

+

+    auto *yield = samplerEntry.getTerminator();

+    for (auto [oldResult, yieldOperand] :

+         llvm::zip(sampleOp.getResults(), yield->getOperands()))

+      oldResult.replaceAllUsesWith(mapper.lookupOrDefault(yieldOperand));

+

+    sampleOp.erase();

+

+    if (auto allAddrs = regionOp.getAllAddressesAttr())

+      regionOp.setAllAddressesAttr(

+          flattenAddressesForSymbol(allAddrs, outerSymbol, ctx));

+    if (auto sel = regionOp.getSelectionAttr())

+      regionOp.setSelectionAttr(

+          flattenAddressesForSymbol(sel, outerSymbol, ctx));

+

+    anyChanged = true;

+  }

+

+  return anyChanged;

+}

+

+static bool extractUnselectedSampleValues(enzyme::MCMCRegionOp regionOp) {

+  auto selection = regionOp.getSelectionAttr();

+  if (!selection)

+    return false;

+

+  DenseSet<Attribute> selectedSymbols;

+  for (auto addr : selection) {

+    auto address = cast<ArrayAttr>(addr);

+    if (!address.empty())

+      selectedSymbols.insert(address[0]);

+  }

+

+  Block &entry = regionOp.getSampler().front();

+  auto inputs = regionOp.getInputs();

+

+  IRMapping blockToOuter;

+  for (auto [idx, blockArg] : llvm::enumerate(entry.getArguments())) {

+    if (idx < inputs.size())

+      blockToOuter.map(blockArg, inputs[idx]);

+  }

+

+  SmallVector<enzyme::SampleRegionOp> toExtract;

+  for (auto &op : entry.without_terminator()) {

+    auto sampleOp = dyn_cast<enzyme::SampleRegionOp>(&op);

+    if (!sampleOp)

+      continue;

+    auto symbol = sampleOp.getSymbolAttr();

+    if (!symbol || selectedSymbols.contains(symbol))

+      continue;

+

+    bool canExtract = true;

+    for (Value operand : sampleOp->getOperands()) {

+      if (!blockToOuter.contains(operand) &&

+          !operand.getDefiningOp<arith::ConstantOp>()) {

+        canExtract = false;

+        break;

+      }

+    }

+    if (!canExtract)

+      continue;

+

+    SetVector<Value> nestedValues;

+    for (Region &nestedRegion : sampleOp->getRegions())

+      getUsedValuesDefinedAbove(nestedRegion, nestedValues);

+    for (Value v : nestedValues) {

+      if (!blockToOuter.contains(v) && !v.getDefiningOp<arith::ConstantOp>()) {

+        canExtract = false;

+        break;

+      }

+    }

+    if (!canExtract)

+      continue;

+

+    toExtract.push_back(sampleOp);

+  }

+

+  if (toExtract.empty())

+    return false;

+

+  unsigned numInputs = inputs.size();

+  bool anyChanged = false;

+

+  for (enzyme::SampleRegionOp sampleOp : toExtract) {

+    OpBuilder builder(regionOp);

+    IRMapping cloneMapper(blockToOuter);

+    Operation *cloned = builder.clone(*sampleOp, cloneMapper);

+

+    for (auto [original, clonedResult] :

+         llvm::zip(sampleOp->getResults(), cloned->getResults())) {

+      regionOp->insertOperands(numInputs, {clonedResult});

+      auto segSizes =

+          regionOp->getAttrOfType<DenseI32ArrayAttr>("operandSegmentSizes");

+      SmallVector<int32_t> newSizes(segSizes.asArrayRef());

+      newSizes[0]++;

+      regionOp->setAttr("operandSegmentSizes",

+                        builder.getDenseI32ArrayAttr(newSizes));

+      numInputs++;

+

+      Value newBlockArg =

+          entry.addArgument(original.getType(), sampleOp.getLoc());

+      original.replaceAllUsesWith(newBlockArg);

+      blockToOuter.map(newBlockArg, clonedResult);

+    }

+

+    sampleOp.erase();

+    anyChanged = true;

+  }

+

+  return anyChanged;

+}

+

+static Value resolveValueForLogpdf(OpBuilder &builder, Location loc,

+                                   Value value, IRMapping &mapping,

+                                   enzyme::MCMCRegionOp regionOp) {

+  if (mapping.contains(value))

+    return mapping.lookup(value);

+

+  if (auto blockArg = dyn_cast<BlockArgument>(value)) {

+    if (blockArg.getOwner() == &regionOp.getSampler().front()) {

+      unsigned idx = blockArg.getArgNumber();

+      auto inputs = regionOp.getInputs();

+      if (idx < inputs.size()) {

+        mapping.map(value, inputs[idx]);

+        return inputs[idx];

+      }

+    }

+    return value;

+  }

+

+  Operation *defOp = value.getDefiningOp();

+  if (!defOp)

+    return value;

+

+  if (defOp->getParentRegion() != &regionOp.getSampler())

+    return value;

+

+  if (isa<enzyme::SampleRegionOp>(defOp)) {

+    Block *logpdfBlock = &regionOp.getLogpdf().front();

+    Value newArg = logpdfBlock->addArgument(value.getType(), defOp->getLoc());

+    mapping.map(value, newArg);

+    return newArg;

+  }

+

+  for (Value operand : defOp->getOperands())

+    resolveValueForLogpdf(builder, loc, operand, mapping, regionOp);

+

+  Operation *cloned = builder.clone(*defOp, mapping);

+  for (auto [orig, clonedRes] :

+       llvm::zip(defOp->getResults(), cloned->getResults()))

+    mapping.map(orig, clonedRes);

+

+  return mapping.lookup(value);

+}

+

+bool enzyme::constructUnifiedLogpdf(enzyme::MCMCRegionOp regionOp) {

+  Region &samplerRegion = regionOp.getSampler();

+  Region &logpdfRegion = regionOp.getLogpdf();

+  auto selection = regionOp.getSelectionAttr();

+  if (!selection || selection.empty())

+    return false;

+

+  SmallVector<enzyme::SampleRegionOp> allSampleOps;

+  samplerRegion.walk(

+      [&](enzyme::SampleRegionOp op) { allSampleOps.push_back(op); });

+

+  DenseMap<Attribute, enzyme::SampleRegionOp> symbolToSampleOp;

+  for (auto sampleOp : allSampleOps) {

+    if (auto sym = sampleOp.getSymbolAttr())

+      symbolToSampleOp[sym] = sampleOp;

+  }

+

+  DenseSet<Attribute> selectedSymbols;

+  SmallVector<Attribute> selectionOrder;

+  for (auto addr : selection) {

+    auto address = cast<ArrayAttr>(addr);

+    if (!address.empty()) {

+      selectedSymbols.insert(address[0]);

+      selectionOrder.push_back(address[0]);

+    }

+  }

+

+  Block *logpdfBlock = new Block();

+  logpdfRegion.push_back(logpdfBlock);

+

+  Location loc = regionOp.getLoc();

+  IRMapping positionMapping;

+  SmallVector<enzyme::SupportAttr> supportsVec;

+  int64_t totalPositionSize = 0;

+

+  for (auto symbol : selectionOrder) {

+    auto it = symbolToSampleOp.find(symbol);

+    if (it == symbolToSampleOp.end())

+      continue;

+    enzyme::SampleRegionOp sampleOp = it->second;

+

+    for (unsigned i = 1; i < sampleOp.getNumResults(); ++i) {

+      Value sampleResult = sampleOp.getResult(i);

+      auto resultType = sampleResult.getType();

+      Value blockArg = logpdfBlock->addArgument(resultType, loc);

+      positionMapping.map(sampleResult, blockArg);

+

+      if (auto tensorType = dyn_cast<RankedTensorType>(resultType))

+        totalPositionSize += tensorType.getNumElements();

+      else

+        totalPositionSize += 1;

+    }

+

+    if (auto support = sampleOp.getSupportAttr())

+      supportsVec.push_back(support);

+    else

+      supportsVec.push_back(enzyme::SupportAttr::get(

+          regionOp.getContext(), enzyme::SupportKind::REAL, nullptr, nullptr));

+  }

+

+  int64_t numPositionArgs = logpdfBlock->getNumArguments();

+

+  OpBuilder logpdfBuilder(logpdfBlock, logpdfBlock->end());

+  Value totalLogpdf;

+  auto scalarF64 = RankedTensorType::get({}, logpdfBuilder.getF64Type());

+

+  for (auto sampleOp : allSampleOps) {

+    Region &siteLogpdf = sampleOp.getLogpdf();

+    if (siteLogpdf.empty() || !siteLogpdf.hasOneBlock())

+      continue;

+

+    Block &siteEntry = siteLogpdf.front();

+    if (siteEntry.getNumArguments() == 0)

+      continue;

+

+    unsigned numSampleOutputs = sampleOp.getNumResults() - 1;

+    auto inputs = sampleOp.getInputs();

+

+    IRMapping siteMapping(positionMapping);

+

+    for (unsigned i = 0; i < numSampleOutputs; ++i) {

+      if (i >= siteEntry.getNumArguments())

+        break;

+      Value sampleResult = sampleOp.getResult(i + 1);

+      Value resolved = resolveValueForLogpdf(logpdfBuilder, loc, sampleResult,

+                                             siteMapping, regionOp);

+      siteMapping.map(siteEntry.getArgument(i), resolved);

+    }

+

+    for (unsigned i = numSampleOutputs, e = siteEntry.getNumArguments(); i < e;

+         ++i) {

+      unsigned inputIdx = i - numSampleOutputs + 1;

+      if (inputIdx < inputs.size()) {

+        Value contextVal = inputs[inputIdx];

+        Value resolved = resolveValueForLogpdf(logpdfBuilder, loc, contextVal,

+                                               siteMapping, regionOp);

+        siteMapping.map(siteEntry.getArgument(i), resolved);

+      }

+    }

+

+    for (Operation &op : siteEntry.without_terminator())

+      logpdfBuilder.clone(op, siteMapping);

+

+    auto *yield = siteEntry.getTerminator();

+    assert(yield->getNumOperands() > 0);

+    Value siteResult = siteMapping.lookupOrDefault(yield->getOperand(0));

+

+    if (!totalLogpdf)

+      totalLogpdf = siteResult;

+    else

+      totalLogpdf =

+          arith::AddFOp::create(logpdfBuilder, loc, totalLogpdf, siteResult);

+  }

+

+  if (totalLogpdf) {

+    enzyme::YieldOp::create(logpdfBuilder, loc, {totalLogpdf});

+  } else {

+    auto zeroConst = arith::ConstantOp::create(

+        logpdfBuilder, loc, scalarF64,

+        DenseElementsAttr::get(scalarF64, logpdfBuilder.getF64FloatAttr(0.0)));

+    enzyme::YieldOp::create(logpdfBuilder, loc, {zeroConst});

+  }

+

+  logpdfBuilder.getContext();

+  regionOp.setNumPositionArgsAttr(

+      logpdfBuilder.getI64IntegerAttr(numPositionArgs));

+  regionOp.setPositionSizeAttr(

+      logpdfBuilder.getI64IntegerAttr(totalPositionSize));

+  if (!supportsVec.empty()) {

+    regionOp.setSupportsAttr(ArrayAttr::get(

+        regionOp.getContext(),

+        SmallVector<Attribute>(supportsVec.begin(), supportsVec.end())));

+  }

+

+  for (auto sampleOp : allSampleOps) {

+    Region &siteLogpdf = sampleOp.getLogpdf();

+    if (siteLogpdf.empty())

+      continue;

+    Block &entry = siteLogpdf.front();

+    SmallVector<Operation *> opsToErase;

+    for (auto &op : entry)

+      opsToErase.push_back(&op);

+    for (auto *op : llvm::reverse(opsToErase)) {

+      op->dropAllUses();

+      op->erase();

+    }

+    while (entry.getNumArguments() > 0) {

+      entry.getArgument(0).dropAllUses();

+      entry.eraseArgument(0);

+    }

+    entry.erase();

+  }

+

+  return true;

+}

+

 namespace {

 

 static void inlineFunctionIntoRegion(OpBuilder &builder, FunctionOpInterface fn,

@@ -137,9 +524,12 @@
         mcmcOp.getHmcConfigAttr(), mcmcOp.getNutsConfigAttr(),

         mcmcOp.getLogpdfFnAttr(), mcmcOp.getInitialPosition(),

         mcmcOp.getInitialGradient(), mcmcOp.getInitialPotentialEnergy(),

-        fnStrAttr, mcmcOp.getNameAttr());

+        fnStrAttr, mcmcOp.getNameAttr(),

+        /*num_position_args=*/rewriter.getI64IntegerAttr(0),

+        /*position_size=*/rewriter.getI64IntegerAttr(0),

+        /*supports=*/ArrayAttr());

 

-    Block *bodyBlock = rewriter.createBlock(&mcmcRegionOp.getBody());

+    Block *bodyBlock = rewriter.createBlock(&mcmcRegionOp.getSampler());

 

     Block &fnEntry = targetFn.getFunctionBody().front();

     for (auto arg : fnEntry.getArguments()) {

@@ -182,22 +572,83 @@
   }

 };

 

+static func::FuncOp outlineSampleSubRegion(OpBuilder &moduleBuilder,

+                                           Region &region, StringRef funcName) {

+  assert(!region.empty() && region.hasOneBlock());

+  Block &entry = region.front();

+

+  SmallVector<Type> argTypes(entry.getArgumentTypes());

+  SmallVector<Location> argLocs;

+  for (auto arg : entry.getArguments())

+    argLocs.push_back(arg.getLoc());

+

+  auto *yield = entry.getTerminator();

+  SmallVector<Type> resultTypes(yield->getOperandTypes());

+

+  auto fnType = moduleBuilder.getFunctionType(argTypes, resultTypes);

+  auto func =

+      func::FuncOp::create(moduleBuilder, region.getLoc(), funcName, fnType);

+  func.setPrivate();

+

+  OpBuilder bodyBuilder(func.getContext());

+  Block *newEntry = bodyBuilder.createBlock(

+      &func.getBody(), func.getBody().begin(), argTypes, argLocs);

+  bodyBuilder.setInsertionPointToEnd(newEntry);

+

+  IRMapping map;

+  for (auto [oldArg, newArg] :

+       llvm::zip(entry.getArguments(), newEntry->getArguments()))

+    map.map(oldArg, newArg);

+

+  for (Operation &op : entry.getOperations()) {

+    if (isa<enzyme::YieldOp>(&op)) {

+      SmallVector<Value> returnOperands;

+      for (Value operand : op.getOperands())

+        returnOperands.push_back(map.lookupOrDefault(operand));

+      func::ReturnOp::create(bodyBuilder, op.getLoc(), returnOperands);

+      continue;

+    }

+    bodyBuilder.clone(op, map);

+  }

+

+  return func;

+}

+

 static void convertSampleRegionToSample(func::FuncOp outlinedFunc) {

   SmallVector<enzyme::SampleRegionOp> toConvert;

   outlinedFunc.walk(

       [&](enzyme::SampleRegionOp op) { toConvert.push_back(op); });

 

+  auto *parentOp = outlinedFunc->getParentOp();

+  OpBuilder moduleBuilder(&parentOp->getRegion(0));

+  moduleBuilder.setInsertionPointAfter(outlinedFunc);

+

+  unsigned counter = 0;

   for (auto sampleRegionOp : toConvert) {

     OpBuilder builder(sampleRegionOp);

     auto *ctx = builder.getContext();

 

     FlatSymbolRefAttr fnAttr;

-    if (auto fnStrAttr = sampleRegionOp.getFnAttr()) {

+    Region &samplerRegion = sampleRegionOp.getSampler();

+    if (!samplerRegion.empty() && samplerRegion.hasOneBlock()) {

+      std::string samplerName =

+          (Twine(outlinedFunc.getName()) + "_sampler_" + Twine(counter)).str();

+      auto samplerFunc =

+          outlineSampleSubRegion(moduleBuilder, samplerRegion, samplerName);

+      fnAttr = FlatSymbolRefAttr::get(ctx, samplerFunc.getName());

+    } else if (auto fnStrAttr = sampleRegionOp.getFnAttr()) {

       fnAttr = FlatSymbolRefAttr::get(ctx, fnStrAttr);

     }

 

     FlatSymbolRefAttr logpdfAttr;

-    if (auto logpdfStrAttr = sampleRegionOp.getLogpdfFnAttr()) {

+    Region &logpdfRegion = sampleRegionOp.getLogpdf();

+    if (!logpdfRegion.empty() && logpdfRegion.hasOneBlock()) {

+      std::string logpdfName =

+          (Twine(outlinedFunc.getName()) + "_logpdf_" + Twine(counter)).str();

+      auto logpdfFunc =

+          outlineSampleSubRegion(moduleBuilder, logpdfRegion, logpdfName);

+      logpdfAttr = FlatSymbolRefAttr::get(ctx, logpdfFunc.getName());

+    } else if (auto logpdfStrAttr = sampleRegionOp.getLogpdfFnAttr()) {

       logpdfAttr = FlatSymbolRefAttr::get(ctx, logpdfStrAttr);

     }

 

@@ -209,6 +660,7 @@
 

     sampleRegionOp.replaceAllUsesWith(sampleOp.getResults());

     sampleRegionOp.erase();

+    ++counter;

   }

 }

 

@@ -295,17 +747,232 @@
   return outlinedFunc;

 }

 

+static bool canOutlineLogpdf(enzyme::MCMCRegionOp regionOp) {

+  Region &logpdf = regionOp.getLogpdf();

+  if (logpdf.empty())

+    return false;

+

+  int64_t numPosArgs = regionOp.getNumPositionArgs();

+  if (numPosArgs == 0)

+    return false;

+

+  Block &entry = logpdf.front();

+  if (static_cast<int64_t>(entry.getNumArguments()) != numPosArgs)

+    return false;

+

+  if (!regionOp.getOriginalTrace())

+    return false;

+

+  return true;

+}

+

+static LogicalResult outlineLogpdfToFunction(enzyme::MCMCRegionOp regionOp,

+                                             StringRef logpdfFuncName,

+                                             OpBuilder &builder) {

+  Location loc = regionOp.getLoc();

+  auto elemType = builder.getF64Type();

+  int64_t positionSize = regionOp.getPositionSize();

+  int64_t numPosArgs = regionOp.getNumPositionArgs();

+  auto scalarType = RankedTensorType::get({}, elemType);

+  auto positionType = RankedTensorType::get({1, positionSize}, elemType);

+  auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());

+

+  Region &logpdfRegion = regionOp.getLogpdf();

+  Block &logpdfEntry = logpdfRegion.front();

+

+  llvm::SetVector<Value> logpdfFreeValues;

+  getUsedValuesDefinedAbove(logpdfRegion, logpdfFreeValues);

+

+  SmallVector<Type> wrapperArgTypes;

+  wrapperArgTypes.push_back(positionType);

+  for (Value freeVal : logpdfFreeValues)

+    wrapperArgTypes.push_back(freeVal.getType());

+

+  auto wrapperFnType = builder.getFunctionType(wrapperArgTypes, {scalarType});

+  auto wrapperFunc =

+      func::FuncOp::create(builder, loc, logpdfFuncName, wrapperFnType);

+  wrapperFunc.setPrivate();

+

+  Block *wrapperBlock = builder.createBlock(

+      &wrapperFunc.getBody(), wrapperFunc.getBody().begin(), wrapperArgTypes,

+      SmallVector<Location>(wrapperArgTypes.size(), loc));

+  builder.setInsertionPointToStart(wrapperBlock);

+

+  Value flatPosition = wrapperBlock->getArgument(0);

+

+  IRMapping logpdfMapping;

+  unsigned wrapperArgIdx = 1;

+  for (Value freeVal : logpdfFreeValues)

+    logpdfMapping.map(freeVal, wrapperBlock->getArgument(wrapperArgIdx++));

+

+  auto c0 = arith::ConstantOp::create(

+      builder, loc, i64TensorType,

+      DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));

+  int64_t curOffset = 0;

+  for (unsigned i = 0; i < static_cast<unsigned>(numPosArgs); ++i) {

+    Type posArgType = logpdfEntry.getArgument(i).getType();

+    auto tensorType = cast<RankedTensorType>(posArgType);

+    int64_t numElements = tensorType.getNumElements();

+    auto sliceType = RankedTensorType::get({1, numElements}, elemType);

+

+    auto offsetConst = arith::ConstantOp::create(

+        builder, loc, i64TensorType,

+        DenseElementsAttr::get(i64TensorType,

+                               builder.getI64IntegerAttr(curOffset)));

+    auto slice = enzyme::DynamicSliceOp::create(

+        builder, loc, sliceType, flatPosition, ValueRange{c0, offsetConst},

+        builder.getDenseI64ArrayAttr({1, numElements}));

+    auto component = enzyme::ReshapeOp::create(builder, loc, posArgType, slice);

+

+    logpdfMapping.map(logpdfEntry.getArgument(i), component);

+    curOffset += numElements;

+  }

+

+  for (Operation &op : logpdfEntry.without_terminator())

+    builder.clone(op, logpdfMapping);

+

+  auto *yield = logpdfEntry.getTerminator();

+  Value logpdfResult = logpdfMapping.lookupOrDefault(yield->getOperand(0));

+  func::ReturnOp::create(builder, loc, {logpdfResult});

+

+  return success();

+}

+

+static Value computeInitialPositionFromTrace(enzyme::MCMCRegionOp regionOp,

+                                             OpBuilder &builder) {

+  Location loc = regionOp.getLoc();

+  auto elemType = builder.getF64Type();

+  int64_t positionSize = regionOp.getPositionSize();

+  auto positionType = RankedTensorType::get({1, positionSize}, elemType);

+  auto i64TensorType = RankedTensorType::get({}, builder.getI64Type());

+  Value trace = regionOp.getOriginalTrace();

+

+  DenseMap<Attribute, std::pair<int64_t, int64_t>> traceOffsets;

+  {

+    DenseMap<Attribute, enzyme::SampleRegionOp> symbolToOp;

+    regionOp.getSampler().walk([&](enzyme::SampleRegionOp sampleOp) {

+      if (auto sym = sampleOp.getSymbolAttr())

+        symbolToOp[sym] = sampleOp;

+    });

+

+    int64_t offset = 0;

+    for (auto addr : regionOp.getAllAddressesAttr()) {

+      auto addressArray = cast<ArrayAttr>(addr);

+      if (addressArray.empty())

+        continue;

+      auto symbol = addressArray[0];

+      auto it = symbolToOp.find(symbol);

+      if (it == symbolToOp.end())

+        continue;

+

+      auto sampleOp = it->second;

+      int64_t size = 0;

+      for (unsigned i = 1; i < sampleOp.getNumResults(); ++i) {

+        auto resultType = sampleOp.getResult(i).getType();

+        if (auto tensorType = dyn_cast<RankedTensorType>(resultType))

+          size += tensorType.getNumElements();

+        else

+          size += 1;

+      }

+      traceOffsets[symbol] = {offset, size};

+      offset += size;

+    }

+  }

+

+  auto c0 = arith::ConstantOp::create(

+      builder, loc, i64TensorType,

+      DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0)));

+  auto zeroPos = arith::ConstantOp::create(

+      builder, loc, positionType,

+      DenseElementsAttr::get(positionType,

+                             builder.getFloatAttr(elemType, 0.0)));

+  Value result = zeroPos;

+

+  int64_t posOffset = 0;

+  for (auto addr : regionOp.getSelectionAttr()) {

+    auto addressArray = cast<ArrayAttr>(addr);

+    if (addressArray.empty())

+      continue;

+    auto symbol = addressArray[0];

+    auto it = traceOffsets.find(symbol);

+    if (it == traceOffsets.end())

+      continue;

+

+    int64_t traceOff = it->second.first;

+    int64_t size = it->second.second;

+    auto sliceType = RankedTensorType::get({1, size}, elemType);

+

+    auto traceOffConst = arith::ConstantOp::create(

+        builder, loc, i64TensorType,

+        DenseElementsAttr::get(i64TensorType,

+                               builder.getI64IntegerAttr(traceOff)));

+    auto posOffConst = arith::ConstantOp::create(

+        builder, loc, i64TensorType,

+        DenseElementsAttr::get(i64TensorType,

+                               builder.getI64IntegerAttr(posOffset)));

+

+    auto traceSlice = enzyme::DynamicSliceOp::create(

+        builder, loc, sliceType, trace, ValueRange{c0, traceOffConst},

+        builder.getDenseI64ArrayAttr({1, size}));

+    result = enzyme::DynamicUpdateSliceOp::create(builder, loc, positionType,

+                                                  result, traceSlice,

+                                                  ValueRange{c0, posOffConst});

+

+    posOffset += size;

+  }

+

+  return result;

+}

+

 LogicalResult outlineMCMCRegion(enzyme::MCMCRegionOp regionOp,

                                 StringRef funcName, OpBuilder &builder) {

   OpBuilder::InsertionGuard insertionGuard(builder);

+

+  if (canOutlineLogpdf(regionOp)) {

+    builder.setInsertionPointAfter(

+        regionOp->getParentOfType<SymbolOpInterface>());

+

+    std::string logpdfFuncName = (Twine(funcName) + "_logpdf").str();

+    if (failed(outlineLogpdfToFunction(regionOp, logpdfFuncName, builder)))

+      return failure();

+

+    auto logpdfSymRef =

+        FlatSymbolRefAttr::get(builder.getContext(), logpdfFuncName);

+

+    llvm::SetVector<Value> logpdfFreeValues;

+    getUsedValuesDefinedAbove(regionOp.getLogpdf(), logpdfFreeValues);

+

+    builder.setInsertionPoint(regionOp);

+    Value initialPosition = computeInitialPositionFromTrace(regionOp, builder);

+

+    SmallVector<Value> mcmcInputs;

+    mcmcInputs.push_back(regionOp.getInputs()[0]); // rng

+    mcmcInputs.append(logpdfFreeValues.begin(), logpdfFreeValues.end());

+

+    auto newOp = enzyme::MCMCOp::create(

+        builder, regionOp.getLoc(), regionOp.getResultTypes(),

+        /*fn=*/FlatSymbolRefAttr{}, mcmcInputs,

+        /*original_trace=*/Value(), regionOp.getSelectionAttr(),

+        regionOp.getAllAddressesAttr(), regionOp.getNumWarmupAttr(),

+        regionOp.getNumSamplesAttr(), regionOp.getThinningAttr(),

+        regionOp.getInverseMassMatrix(), regionOp.getStepSize(),

+        regionOp.getHmcConfigAttr(), regionOp.getNutsConfigAttr(), logpdfSymRef,

+        initialPosition, regionOp.getInitialGradient(),

+        regionOp.getInitialPotentialEnergy(), regionOp.getNameAttr());

+

+    regionOp.replaceAllUsesWith(newOp.getResults());

+    regionOp.erase();

+    return success();

+  }

+

   builder.setInsertionPointAfter(

       regionOp->getParentOfType<SymbolOpInterface>());

 

   llvm::SetVector<Value> freeValues;

-  getUsedValuesDefinedAbove(regionOp.getBody(), freeValues);

+  getUsedValuesDefinedAbove(regionOp.getSampler(), freeValues);

 

   FailureOr<func::FuncOp> outlinedFunc =

-      outlineRegionToFunction(regionOp.getBody(), funcName, builder);

+      outlineRegionToFunction(regionOp.getSampler(), funcName, builder);

   if (failed(outlinedFunc))

     return failure();

 

@@ -345,6 +1012,18 @@
 

     GreedyRewriteConfig config;

     (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);

+

+    SmallVector<enzyme::MCMCRegionOp> regionOps;

+    getOperation()->walk(

+        [&](enzyme::MCMCRegionOp op) { regionOps.push_back(op); });

+

+    for (auto regionOp : regionOps) {

+      bool submodelChanged = true;

+      while (submodelChanged)

+        submodelChanged = inlineSubmodelSampleRegions(regionOp);

+

+      extractUnselectedSampleValues(regionOp);

+    }

   }

 };

 

diff --git a/enzyme/Enzyme/MLIR/Passes/PrintSampleDependenceAnalysis.cpp b/enzyme/Enzyme/MLIR/Passes/PrintSampleDependenceAnalysis.cpp
index 4ead38c..2407271 100644
--- a/enzyme/Enzyme/MLIR/Passes/PrintSampleDependenceAnalysis.cpp
+++ b/enzyme/Enzyme/MLIR/Passes/PrintSampleDependenceAnalysis.cpp
@@ -28,7 +28,7 @@
 

       os << "Sample regions: " << analysis.getSampleOps().size() << "\n";

 

-      regionOp.getBody().walk([&](Operation *op) {

+      regionOp.getSampler().walk([&](Operation *op) {

         bool dependent = analysis.isSampleDependent(op);

         bool hoistable = analysis.canHoist(op);

 

diff --git a/enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp
index 42b5d99..b38f1c4 100644
--- a/enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp
+++ b/enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp
@@ -776,6 +776,8 @@
       } else {
         auto baseCtx =
             makeHMCContext(adaptedInvMass, adaptedMassMatrixSqrt, stepSize);
+        rngInput = conditionalDump(rewriter, loc, rngInput,
+                                   "MCMC: initial rng state", debugDump);
         auto initState = InitHMC(
             rewriter, loc, rngInput, baseCtx,
             hasLogpdfFn ? mcmcOp.getInitialPosition() : Value(), debugDump);
@@ -1093,10 +1095,29 @@
           SmallVector<Value> trueYieldValues;
 
           if (adaptMassMatrix) {
+            conditionalDump(rewriter, loc, iterT, "WINDOW_END: iteration",
+                            debugDump);
+            conditionalDump(rewriter, loc, windowIdxLoop,
+                            "WINDOW_END: window_idx", debugDump);
+            conditionalDump(rewriter, loc, conditionalWelford.mean,
+                            "WINDOW_END: welford_mean", debugDump);
+            conditionalDump(rewriter, loc, conditionalWelford.m2,
+                            "WINDOW_END: welford_m2", debugDump);
+            conditionalDump(rewriter, loc, conditionalWelford.n,
+                            "WINDOW_END: welford_n", debugDump);
+
             auto newInvMass = finalizeWelford(rewriter, loc, conditionalWelford,
                                               welfordConfig);
             auto newMassMatrixSqrt =
                 computeMassMatrixSqrt(rewriter, loc, newInvMass, positionType);
+
+            conditionalDump(rewriter, loc, newInvMass,
+                            "WINDOW_END: new_inv_mass", debugDump);
+            conditionalDump(rewriter, loc, newMassMatrixSqrt,
+                            "WINDOW_END: new_mass_sqrt", debugDump);
+            conditionalDump(rewriter, loc, adaptedStepSizeInLoop,
+                            "WINDOW_END: step_size", debugDump);
+
             auto reinitWelford =
                 initWelford(rewriter, loc, positionSize, diagonal);
 
@@ -1314,6 +1335,13 @@
           conditionalDump(rewriter, loc, finalSamplesBuffer,
                           "MCMC: collected samples", debugDump);
 
+      auto expectedAcceptedType =
+          cast<RankedTensorType>(mcmcOp.getResult(1).getType());
+      if (finalAcceptedBuffer.getType() != expectedAcceptedType) {
+        finalAcceptedBuffer = enzyme::ReshapeOp::create(
+            rewriter, loc, expectedAcceptedType, finalAcceptedBuffer);
+      }
+
       rewriter.replaceOp(mcmcOp, {finalSamplesBuffer, finalAcceptedBuffer,
                                   finalRng, finalQ, finalGrad, finalU,
                                   adaptedStepSize, adaptedInvMass});
diff --git a/enzyme/test/MLIR/ProbProg/outline_logpdf.mlir b/enzyme/test/MLIR/ProbProg/outline_logpdf.mlir
new file mode 100644
index 0000000..e8b5249
--- /dev/null
+++ b/enzyme/test/MLIR/ProbProg/outline_logpdf.mlir
@@ -0,0 +1,142 @@
+// RUN: %eopt --outline-mcmc-regions %s | FileCheck %s
+
+// ============================================================================
+// Test 1: Basic outline of unified logpdf region (no free values)
+// ============================================================================
+
+// CHECK-LABEL: func.func @test_basic_outline
+// Initial position extracted from trace via slice/update ops
+// CHECK: enzyme.dynamic_slice
+// CHECK: enzyme.dynamic_update_slice
+// CHECK: enzyme.dynamic_slice
+// CHECK: %[[INIT_POS:.+]] = enzyme.dynamic_update_slice
+// MCMCOp in pure logpdf_fn mode (no @fn reference before the parens)
+// CHECK: enzyme.mcmc(
+// CHECK-SAME: logpdf_fn = @[[LOGPDF:[a-zA-Z0-9_]+]]
+// CHECK-SAME: initial_position = %[[INIT_POS]]
+//
+// Outlined logpdf wrapper function
+// CHECK: func.func private @[[LOGPDF]](%[[POS:[^:]+]]: tensor<1x2xf64>) -> tensor<f64>
+// Slice position into two scalar components
+// CHECK: %[[S0:.+]] = enzyme.dynamic_slice %[[POS]]
+// CHECK-NEXT: %[[X0:.+]] = enzyme.reshape %[[S0]]
+// CHECK: %[[S1:.+]] = enzyme.dynamic_slice %[[POS]]
+// CHECK-NEXT: %[[X1:.+]] = enzyme.reshape %[[S1]]
+// Logpdf body: -x0 + -x1
+// CHECK: %[[N0:.+]] = arith.negf %[[X0]]
+// CHECK-NEXT: %[[N1:.+]] = arith.negf %[[X1]]
+// CHECK-NEXT: %[[SUM:.+]] = arith.addf %[[N0]], %[[N1]]
+// CHECK-NEXT: return %[[SUM]]
+
+// ============================================================================
+// Test 2: Logpdf with free values (hoisted op from enclosing scope)
+// ============================================================================
+
+// CHECK-LABEL: func.func @test_outline_with_free_values
+// Hoisted log computed before mcmc_region
+// CHECK: %[[LOG:.+]] = math.log
+// CHECK: enzyme.mcmc(
+// CHECK-SAME: logpdf_fn = @[[LP2:[a-zA-Z0-9_]+]]
+// The logpdf function receives the hoisted log as an extra parameter
+// CHECK: func.func private @[[LP2]](%[[P2:[^:]+]]: tensor<1x2xf64>, %[[FLOG:[^:]+]]: tensor<f64>) -> tensor<f64>
+// Body uses the hoisted log value
+// CHECK: arith.subf %{{.*}}, %[[FLOG]]
+// CHECK: arith.subf %{{.*}}, %[[FLOG]]
+// CHECK: arith.addf
+// CHECK-NEXT: return
+
+module {
+  // Test 1: Basic (no free values)
+  func.func @test_basic_outline(
+      %rng : tensor<2xui64>,
+      %mu : tensor<f64>,
+      %trace : tensor<1x2xf64>) {
+
+    %step_size = arith.constant dense<0.1> : tensor<f64>
+
+    %result:8 = enzyme.mcmc_region(%rng, %mu) given %trace
+        step_size = %step_size {
+    ^bb0(%r: tensor<2xui64>, %m: tensor<f64>):
+      %x0:2 = enzyme.sample_region(%r, %m) sampler {
+      ^bb0(%sr: tensor<2xui64>, %sm: tensor<f64>):
+        enzyme.yield %sr, %sm : tensor<2xui64>, tensor<f64>
+      } logpdf {
+      } {symbol = #enzyme.symbol<0>}
+        : (tensor<2xui64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>)
+
+      %x1:2 = enzyme.sample_region(%x0#0, %m) sampler {
+      ^bb0(%sr2: tensor<2xui64>, %sm2: tensor<f64>):
+        enzyme.yield %sr2, %sm2 : tensor<2xui64>, tensor<f64>
+      } logpdf {
+      } {symbol = #enzyme.symbol<1>}
+        : (tensor<2xui64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>)
+
+      enzyme.yield %x1#0, %x1#1 : tensor<2xui64>, tensor<f64>
+    } logpdf {
+    ^bb0(%lp_x0: tensor<f64>, %lp_x1: tensor<f64>):
+      %neg0 = arith.negf %lp_x0 : tensor<f64>
+      %neg1 = arith.negf %lp_x1 : tensor<f64>
+      %total = arith.addf %neg0, %neg1 : tensor<f64>
+      enzyme.yield %total : tensor<f64>
+    } attributes {
+      selection = [[#enzyme.symbol<0>], [#enzyme.symbol<1>]],
+      all_addresses = [[#enzyme.symbol<0>], [#enzyme.symbol<1>]],
+      nuts_config = #enzyme.nuts_config<max_tree_depth = 10, max_delta_energy = 1000.0>,
+      num_samples = 1 : i64, thinning = 1 : i64, num_warmup = 0 : i64,
+      num_position_args = 2 : i64, position_size = 2 : i64
+    } : (tensor<2xui64>, tensor<f64>, tensor<1x2xf64>, tensor<f64>)
+        -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>,
+            tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>,
+            tensor<f64>, tensor<1x2xf64>)
+    return
+  }
+
+  // Test 2: Free values (hoisted math.log from enclosing scope)
+  func.func @test_outline_with_free_values(
+      %rng : tensor<2xui64>,
+      %sigma : tensor<f64>,
+      %trace : tensor<1x2xf64>) {
+
+    %step_size = arith.constant dense<0.1> : tensor<f64>
+    // This math.log is hoisted by SICM before the mcmc_region.
+    // The logpdf region captures it as a free value.
+    %log_sigma = math.log %sigma : tensor<f64>
+
+    %result:8 = enzyme.mcmc_region(%rng, %sigma) given %trace
+        step_size = %step_size {
+    ^bb0(%r: tensor<2xui64>, %s: tensor<f64>):
+      %x0:2 = enzyme.sample_region(%r, %s) sampler {
+      ^bb0(%sr: tensor<2xui64>, %ss: tensor<f64>):
+        enzyme.yield %sr, %ss : tensor<2xui64>, tensor<f64>
+      } logpdf {
+      } {symbol = #enzyme.symbol<0>}
+        : (tensor<2xui64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>)
+
+      %x1:2 = enzyme.sample_region(%x0#0, %s) sampler {
+      ^bb0(%sr2: tensor<2xui64>, %ss2: tensor<f64>):
+        enzyme.yield %sr2, %ss2 : tensor<2xui64>, tensor<f64>
+      } logpdf {
+      } {symbol = #enzyme.symbol<1>}
+        : (tensor<2xui64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>)
+
+      enzyme.yield %x1#0, %x1#1 : tensor<2xui64>, tensor<f64>
+    } logpdf {
+    ^bb0(%lp_x0: tensor<f64>, %lp_x1: tensor<f64>):
+      // Both sites subtract the hoisted log_sigma (CSE'd free value)
+      %sub0 = arith.subf %lp_x0, %log_sigma : tensor<f64>
+      %sub1 = arith.subf %lp_x1, %log_sigma : tensor<f64>
+      %total = arith.addf %sub0, %sub1 : tensor<f64>
+      enzyme.yield %total : tensor<f64>
+    } attributes {
+      selection = [[#enzyme.symbol<0>], [#enzyme.symbol<1>]],
+      all_addresses = [[#enzyme.symbol<0>], [#enzyme.symbol<1>]],
+      nuts_config = #enzyme.nuts_config<max_tree_depth = 10, max_delta_energy = 1000.0>,
+      num_samples = 1 : i64, thinning = 1 : i64, num_warmup = 0 : i64,
+      num_position_args = 2 : i64, position_size = 2 : i64
+    } : (tensor<2xui64>, tensor<f64>, tensor<1x2xf64>, tensor<f64>)
+        -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>,
+            tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>,
+            tensor<f64>, tensor<1x2xf64>)
+    return
+  }
+}
diff --git a/enzyme/test/MLIR/ProbProg/outline_logpdf_lower.mlir b/enzyme/test/MLIR/ProbProg/outline_logpdf_lower.mlir
new file mode 100644
index 0000000..4a67043
--- /dev/null
+++ b/enzyme/test/MLIR/ProbProg/outline_logpdf_lower.mlir
@@ -0,0 +1,168 @@
+// RUN: %eopt --outline-mcmc-regions --probprog %s | FileCheck %s
+
+// ============================================================================
+// Test: Full pipeline for outlined unified logpdf
+//   mcmc_region (with unified logpdf) → outline → probprog lowering
+// ============================================================================
+//
+// Verifies that the outlined logpdf wrapper function is correctly called
+// during NUTS lowering: initial U0 computation, gradient via autodiff_region,
+// and leapfrog integration inside the tree-building loop.
+
+// CHECK-LABEL: func.func @test_outline_lower_nuts
+// Position extracted from trace → outline builds initial_position
+// CHECK: enzyme.dynamic_slice
+// CHECK: enzyme.dynamic_update_slice
+// CHECK: enzyme.dynamic_slice
+// CHECK: %[[INIT_POS:.+]] = enzyme.dynamic_update_slice
+// Initial U0: call logpdf, negf
+// CHECK: call @[[LOGPDF:[a-zA-Z0-9_]+]](%[[INIT_POS]])
+// CHECK-NEXT: %[[U0:.+]] = arith.negf
+// Gradient via autodiff_region
+// CHECK: enzyme.autodiff_region
+// CHECK: func.call @[[LOGPDF]]
+// CHECK-NEXT: %[[NEG:.+]] = arith.negf
+// CHECK-NEXT: enzyme.yield
+// NUTS sample loop
+// CHECK: enzyme.for_loop
+// Gradient inside tree building
+// CHECK: enzyme.autodiff_region
+// CHECK: func.call @[[LOGPDF]]
+// CHECK-NEXT: %{{.+}} = arith.negf
+// CHECK-NEXT: enzyme.yield
+
+// The outlined logpdf wrapper function
+// CHECK: func.func private @[[LOGPDF]](%[[POS:[^:]+]]: tensor<1x2xf64>) -> tensor<f64>
+// Slice position into per-site scalars
+// CHECK: %[[S0:.+]] = enzyme.dynamic_slice %[[POS]]
+// CHECK-NEXT: %[[X0:.+]] = enzyme.reshape %[[S0]]
+// CHECK: %[[S1:.+]] = enzyme.dynamic_slice %[[POS]]
+// CHECK-NEXT: %[[X1:.+]] = enzyme.reshape %[[S1]]
+// Logpdf body: -x0 + -x1
+// CHECK: %[[N0:.+]] = arith.negf %[[X0]]
+// CHECK-NEXT: %[[N1:.+]] = arith.negf %[[X1]]
+// CHECK-NEXT: %[[SUM:.+]] = arith.addf %[[N0]], %[[N1]]
+// CHECK-NEXT: return %[[SUM]]
+
+module {
+  func.func @test_outline_lower_nuts(
+      %rng : tensor<2xui64>,
+      %mu : tensor<f64>,
+      %trace : tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) {
+
+    %step_size = arith.constant dense<0.1> : tensor<f64>
+
+    %result:8 = enzyme.mcmc_region(%rng, %mu) given %trace
+        step_size = %step_size {
+    ^bb0(%r: tensor<2xui64>, %m: tensor<f64>):
+      %x0:2 = enzyme.sample_region(%r, %m) sampler {
+      ^bb0(%sr: tensor<2xui64>, %sm: tensor<f64>):
+        enzyme.yield %sr, %sm : tensor<2xui64>, tensor<f64>
+      } logpdf {
+      } {symbol = #enzyme.symbol<0>}
+        : (tensor<2xui64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>)
+
+      %x1:2 = enzyme.sample_region(%x0#0, %m) sampler {
+      ^bb0(%sr2: tensor<2xui64>, %sm2: tensor<f64>):
+        enzyme.yield %sr2, %sm2 : tensor<2xui64>, tensor<f64>
+      } logpdf {
+      } {symbol = #enzyme.symbol<1>}
+        : (tensor<2xui64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>)
+
+      enzyme.yield %x1#0, %x1#1 : tensor<2xui64>, tensor<f64>
+    } logpdf {
+    ^bb0(%lp_x0: tensor<f64>, %lp_x1: tensor<f64>):
+      %neg0 = arith.negf %lp_x0 : tensor<f64>
+      %neg1 = arith.negf %lp_x1 : tensor<f64>
+      %total = arith.addf %neg0, %neg1 : tensor<f64>
+      enzyme.yield %total : tensor<f64>
+    } attributes {
+      selection = [[#enzyme.symbol<0>], [#enzyme.symbol<1>]],
+      all_addresses = [[#enzyme.symbol<0>], [#enzyme.symbol<1>]],
+      nuts_config = #enzyme.nuts_config<max_tree_depth = 10, max_delta_energy = 1000.0>,
+      num_samples = 1 : i64, thinning = 1 : i64, num_warmup = 0 : i64,
+      num_position_args = 2 : i64, position_size = 2 : i64
+    } : (tensor<2xui64>, tensor<f64>, tensor<1x2xf64>, tensor<f64>)
+        -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>,
+            tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>,
+            tensor<f64>, tensor<1x2xf64>)
+    return %result#0, %result#1, %result#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>
+  }
+
+  // ==========================================================================
+  // Test 2: Logpdf with free values (hoisted math.log from enclosing scope)
+  //   The logpdf wrapper receives the hoisted value as an extra parameter.
+  //   Lowering must pass it through as an MCMCOp input.
+  // ==========================================================================
+
+  // CHECK-LABEL: func.func @test_outline_lower_free_values
+  // Hoisted log before mcmc
+  // CHECK: %[[LOG:.+]] = math.log
+  // Initial U0: call logpdf with position AND free value
+  // CHECK: call @[[LP2:[a-zA-Z0-9_]+]](%{{.+}}, %[[LOG]])
+  // CHECK-NEXT: arith.negf
+  // Gradient via autodiff_region
+  // CHECK: enzyme.autodiff_region
+  // CHECK: func.call @[[LP2]](%{{.+}}, %[[LOG]])
+  // CHECK-NEXT: arith.negf
+  // CHECK-NEXT: enzyme.yield
+  // NUTS sample loop
+  // CHECK: enzyme.for_loop
+  // Gradient inside tree building
+  // CHECK: enzyme.autodiff_region
+  // CHECK: func.call @[[LP2]](%{{.+}}, %[[LOG]])
+  // CHECK-NEXT: arith.negf
+  // CHECK-NEXT: enzyme.yield
+  //
+  // The logpdf wrapper with free value parameter
+  // CHECK: func.func private @[[LP2]](%[[P2:[^:]+]]: tensor<1x2xf64>, %[[FLOG:[^:]+]]: tensor<f64>) -> tensor<f64>
+  // CHECK: arith.subf %{{.*}}, %[[FLOG]]
+  // CHECK: arith.subf %{{.*}}, %[[FLOG]]
+  // CHECK: arith.addf
+  // CHECK-NEXT: return
+
+  func.func @test_outline_lower_free_values(
+      %rng : tensor<2xui64>,
+      %sigma : tensor<f64>,
+      %trace : tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) {
+
+    %step_size = arith.constant dense<0.1> : tensor<f64>
+    %log_sigma = math.log %sigma : tensor<f64>
+
+    %result:8 = enzyme.mcmc_region(%rng, %sigma) given %trace
+        step_size = %step_size {
+    ^bb0(%r: tensor<2xui64>, %s: tensor<f64>):
+      %x0:2 = enzyme.sample_region(%r, %s) sampler {
+      ^bb0(%sr: tensor<2xui64>, %ss: tensor<f64>):
+        enzyme.yield %sr, %ss : tensor<2xui64>, tensor<f64>
+      } logpdf {
+      } {symbol = #enzyme.symbol<0>}
+        : (tensor<2xui64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>)
+
+      %x1:2 = enzyme.sample_region(%x0#0, %s) sampler {
+      ^bb0(%sr2: tensor<2xui64>, %ss2: tensor<f64>):
+        enzyme.yield %sr2, %ss2 : tensor<2xui64>, tensor<f64>
+      } logpdf {
+      } {symbol = #enzyme.symbol<1>}
+        : (tensor<2xui64>, tensor<f64>) -> (tensor<2xui64>, tensor<f64>)
+
+      enzyme.yield %x1#0, %x1#1 : tensor<2xui64>, tensor<f64>
+    } logpdf {
+    ^bb0(%lp_x0: tensor<f64>, %lp_x1: tensor<f64>):
+      %sub0 = arith.subf %lp_x0, %log_sigma : tensor<f64>
+      %sub1 = arith.subf %lp_x1, %log_sigma : tensor<f64>
+      %total = arith.addf %sub0, %sub1 : tensor<f64>
+      enzyme.yield %total : tensor<f64>
+    } attributes {
+      selection = [[#enzyme.symbol<0>], [#enzyme.symbol<1>]],
+      all_addresses = [[#enzyme.symbol<0>], [#enzyme.symbol<1>]],
+      nuts_config = #enzyme.nuts_config<max_tree_depth = 10, max_delta_energy = 1000.0>,
+      num_samples = 1 : i64, thinning = 1 : i64, num_warmup = 0 : i64,
+      num_position_args = 2 : i64, position_size = 2 : i64
+    } : (tensor<2xui64>, tensor<f64>, tensor<1x2xf64>, tensor<f64>)
+        -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>,
+            tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>,
+            tensor<f64>, tensor<1x2xf64>)
+    return %result#0, %result#1, %result#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>
+  }
+}
diff --git a/enzyme/test/MLIR/ProbProg/sample_dependence_analysis.mlir b/enzyme/test/MLIR/ProbProg/sample_dependence_analysis.mlir
index 4ed8775..d9b56bc 100644
--- a/enzyme/test/MLIR/ProbProg/sample_dependence_analysis.mlir
+++ b/enzyme/test/MLIR/ProbProg/sample_dependence_analysis.mlir
@@ -14,9 +14,8 @@
 // CHECK: [INV] [HOIST] arith.addf -> tensor<f64>
 // CHECK: [INV] [KEEP]  enzyme.yield
 
-// Operations inside logpdf region (nested in sample_region)
-// CHECK: [INV] [HOIST] arith.negf -> tensor<f64>
-// CHECK: [INV] [KEEP]  enzyme.yield
+// Per-site logpdf bodies are merged into MCMCRegionOp's unified logpdf
+// region by inline-mcmc-regions Phase 2, so arith.negf no longer appears here.
 
 // The sample_region itself - sample-dependent
 // CHECK: [DEP] [KEEP]  enzyme.sample_region -> tensor<2xui64>, tensor<f64>