Expose DataFlowActivityAnalysis to headers (#2854)
diff --git a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp
index 77304fa..a539272 100644
--- a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp
+++ b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp
@@ -31,10 +31,6 @@
 
 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
-#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
-#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
-#include "mlir/Analysis/DataFlowFramework.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
 
 // TODO: Don't depend on specific dialects
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -43,127 +39,15 @@
 
 #include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h"
 
-#include "Interfaces/AutoDiffOpInterface.h"
-
 using namespace mlir;
 using namespace mlir::dataflow;
 using enzyme::AliasClassLattice;
 
-/// From LLVM Enzyme's activity analysis, there are four activity states.
-// constant instruction vs constant value, a value/instruction (one and the same
-// in LLVM) can be a constant instruction but active value, active instruction
-// but constant value, or active/constant both.
-
-// The result of activity states are potentially different for multiple
-// enzyme.autodiff calls.
-
-enum class ActivityKind { Constant, ActiveVal, Unknown };
-
-using llvm::errs;
-class ValueActivity {
-public:
-  static ValueActivity getConstant() {
-    return ValueActivity(ActivityKind::Constant);
-  }
-
-  static ValueActivity getActiveVal() {
-    return ValueActivity(ActivityKind::ActiveVal);
-  }
-
-  static ValueActivity getUnknown() {
-    return ValueActivity(ActivityKind::Unknown);
-  }
-
-  bool isActiveVal() const { return value == ActivityKind::ActiveVal; }
-
-  bool isConstant() const { return value == ActivityKind::Constant; }
-
-  bool isUnknown() const { return value == ActivityKind::Unknown; }
-
-  ValueActivity() {}
-  ValueActivity(ActivityKind value) : value(value) {}
-
-  /// Get the known activity state.
-  const ActivityKind &getValue() const { return value; }
-
-  bool operator==(const ValueActivity &rhs) const { return value == rhs.value; }
-
-  static ValueActivity merge(const ValueActivity &lhs,
-                             const ValueActivity &rhs) {
-    if (lhs.isUnknown() || rhs.isUnknown())
-      return ValueActivity::getUnknown();
-
-    if (lhs.isConstant() && rhs.isConstant())
-      return ValueActivity::getConstant();
-    return ValueActivity::getActiveVal();
-  }
-
-  static ValueActivity join(const ValueActivity &lhs,
-                            const ValueActivity &rhs) {
-    return ValueActivity::merge(lhs, rhs);
-  }
-
-  void print(raw_ostream &os) const {
-    switch (value) {
-    case ActivityKind::ActiveVal:
-      os << "ActiveVal";
-      break;
-    case ActivityKind::Constant:
-      os << "Constant";
-      break;
-    case ActivityKind::Unknown:
-      os << "Unknown";
-      break;
-    }
-  }
-
-  raw_ostream &operator<<(raw_ostream &os) const {
-    print(os);
-    return os;
-  }
-
-private:
-  /// The activity kind. Optimistically initialized to constant.
-  ActivityKind value = ActivityKind::Constant;
-};
-
-raw_ostream &operator<<(raw_ostream &os, const ValueActivity &val) {
+raw_ostream &operator<<(raw_ostream &os, const enzyme::ValueActivity &val) {
   val.print(os);
   return os;
 }
 
-class ForwardValueActivity : public Lattice<ValueActivity> {
-public:
-  using Lattice::Lattice;
-};
-
-class BackwardValueActivity : public AbstractSparseLattice {
-public:
-  using AbstractSparseLattice::AbstractSparseLattice;
-
-  ChangeResult meet(const AbstractSparseLattice &other) override {
-    const auto *rhs = reinterpret_cast<const BackwardValueActivity *>(&other);
-    return meet(rhs->getValue());
-  }
-
-  void print(raw_ostream &os) const override { value.print(os); }
-
-  ValueActivity getValue() const { return value; }
-
-  ChangeResult meet(ValueActivity other) {
-    auto met = ValueActivity::merge(getValue(), other);
-    if (getValue() == met) {
-      return ChangeResult::NoChange;
-    }
-
-    value = met;
-    return ChangeResult::Change;
-  }
-
-private:
-  ValueActivity value;
-};
-
 raw_ostream &operator<<(raw_ostream &os, const CallControlFlowAction &action) {
   switch (action) {
   case CallControlFlowAction::EnterCallee:
@@ -179,296 +63,139 @@
   return os;
 }
 
-/// This needs to keep track of three things:
-///   1. Could active info store in?
-///   2. Could active info load out?
-///   TODO: Necessary for run-time activity
-///   3. Could constant info propagate (store?) in?
-///
-/// Active: (forward) active in && (backward) active out && (??) !const in
-/// ActiveOrConstant: active in && active out && const in
-/// Constant: everything else
-struct MemoryActivityState {
-  /// Whether active data has stored into this memory location.
-  bool activeIn = false;
-  /// Whether active data was loaded out of this memory location.
-  bool activeOut = false;
+//===----------------------------------------------------------------------===//
+// ValueActivity
+//===----------------------------------------------------------------------===//
+raw_ostream &enzyme::ValueActivity::operator<<(raw_ostream &os) const {
+  print(os);
+  return os;
+}
 
-  bool operator==(const MemoryActivityState &other) {
-    return activeIn == other.activeIn && activeOut == other.activeOut;
+void enzyme::ValueActivity::print(llvm::raw_ostream &os) const {
+  switch (value) {
+  case ActivityKind::ActiveVal:
+    os << "ActiveVal";
+    break;
+  case ActivityKind::Constant:
+    os << "Constant";
+    break;
+  case ActivityKind::Unknown:
+    os << "Unknown";
+    break;
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// BackwardValueActivity
+//===----------------------------------------------------------------------===//
+ChangeResult
+enzyme::BackwardValueActivity::meet(const AbstractSparseLattice &other) {
+  const auto *rhs =
+      reinterpret_cast<const enzyme::BackwardValueActivity *>(&other);
+  return meet(rhs->getValue());
+}
+
+ChangeResult enzyme::BackwardValueActivity::meet(enzyme::ValueActivity other) {
+  auto met = enzyme::ValueActivity::merge(getValue(), other);
+  if (getValue() == met) {
+    return ChangeResult::NoChange;
   }
 
-  bool operator!=(const MemoryActivityState &other) {
-    return !(*this == other);
+  value = met;
+  return ChangeResult::Change;
+}
+
+void enzyme::BackwardValueActivity::print(raw_ostream &os) const {
+  value.print(os);
+}
+
+//===----------------------------------------------------------------------===//
+// SparseForwardActivityAnalysis
+//===----------------------------------------------------------------------===//
+
+/// In general, we don't know anything about entry operands.
+void enzyme::SparseForwardActivityAnalysis::setToEntryState(
+    enzyme::ForwardValueActivity *lattice) {
+  // llvm::errs() << "sparse forward setting to entry state\n";
+  propagateIfChanged(lattice, lattice->join(enzyme::ValueActivity()));
+}
+
+LogicalResult enzyme::SparseForwardActivityAnalysis::visitOperation(
+    Operation *op, ArrayRef<const enzyme::ForwardValueActivity *> operands,
+    ArrayRef<enzyme::ForwardValueActivity *> results) {
+  if (op->hasTrait<OpTrait::ConstantLike>())
+    return success();
+
+  // Bail out if this op affects memory.
+  if (!isPure(op))
+    return success();
+
+  transfer(op, operands, results);
+
+  return success();
+}
+
+void enzyme::SparseForwardActivityAnalysis::visitExternalCall(
+    CallOpInterface call,
+    ArrayRef<const enzyme::ForwardValueActivity *> operands,
+    ArrayRef<enzyme::ForwardValueActivity *> results) {
+  transfer(call, operands, results);
+}
+
+void enzyme::SparseForwardActivityAnalysis::transfer(
+    Operation *op, ArrayRef<const enzyme::ForwardValueActivity *> operands,
+    ArrayRef<enzyme::ForwardValueActivity *> results) {
+  // For value-based AA, assume any active argument leads to an active
+  // result.
+  enzyme::ValueActivity joinedResult;
+  for (const enzyme::ForwardValueActivity *operand : operands)
+    joinedResult =
+        enzyme::ValueActivity::merge(joinedResult, operand->getValue());
+
+  // Only mark results as active data if the type can carry perturbations and
+  // has value semantics
+  for (enzyme::ForwardValueActivity *result : results) {
+    if (joinedResult.isActiveVal())
+      propagateIfChanged(
+          result, result->join(
+                      isa<FloatType, ComplexType>(result->getAnchor().getType())
+                          ? joinedResult
+                          : enzyme::ValueActivity::getConstant()));
+    else
+      propagateIfChanged(result, result->join(joinedResult));
   }
+}
 
-  ChangeResult reset() {
-    if (!activeIn && !activeOut)
-      return ChangeResult::NoChange;
-    activeIn = false;
-    activeOut = false;
-    return ChangeResult::Change;
-  }
+//===----------------------------------------------------------------------===//
+// SparseBackwardActivityAnalysis
+//===----------------------------------------------------------------------===//
 
-  ChangeResult merge(const MemoryActivityState &other) {
-    if (*this == other) {
-      return ChangeResult::NoChange;
-    }
+void enzyme::SparseBackwardActivityAnalysis::transfer(
+    Operation *op, ArrayRef<enzyme::BackwardValueActivity *> operands,
+    ArrayRef<const enzyme::BackwardValueActivity *> results) {
+  // Propagate all operands to all results
+  for (auto operand : operands)
+    for (auto result : results)
+      meet(operand, *result);
+}
 
-    activeIn |= other.activeIn;
-    activeOut |= other.activeOut;
-    return ChangeResult::Change;
-  }
-};
-
-class MemoryActivity : public AbstractDenseLattice {
-public:
-  using AbstractDenseLattice::AbstractDenseLattice;
-
-  /// Clear all modifications.
-  ChangeResult reset() {
-    if (activityStates.empty())
-      return otherMemoryActivity.reset();
-    activityStates.clear();
-    return otherMemoryActivity.reset();
-  }
-
-  bool hasActiveData(DistinctAttr aliasClass) const {
-    if (!aliasClass)
-      return otherMemoryActivity.activeIn;
-    auto it = activityStates.find(aliasClass);
-    if (it != activityStates.end())
-      return it->getSecond().activeIn;
-    return otherMemoryActivity.activeIn;
-  }
-
-  bool activeDataFlowsOut(DistinctAttr aliasClass) const {
-    if (!aliasClass)
-      return otherMemoryActivity.activeOut;
-
-    auto it = activityStates.find(aliasClass);
-    if (it != activityStates.end())
-      return it->getSecond().activeOut;
-    return otherMemoryActivity.activeOut;
-  }
-
-  /// Set the internal activity state. Accepts null attribute to indicate "other
-  /// classes".
-  ChangeResult setActiveIn(DistinctAttr aliasClass) {
-    if (!aliasClass)
-      return setActiveIn();
-
-    auto &state = activityStates[aliasClass];
-    ChangeResult result =
-        state.activeIn ? ChangeResult::NoChange : ChangeResult::Change;
-    state.activeIn = true;
-    return result;
-  }
-  ChangeResult setActiveIn() {
-    if (otherMemoryActivity.activeIn && activityStates.empty())
-      return ChangeResult::NoChange;
-    otherMemoryActivity.activeIn = true;
-    activityStates.clear();
-    return ChangeResult::Change;
-  }
-  ChangeResult setActiveOut(DistinctAttr aliasClass) {
-    if (!aliasClass)
-      return setActiveOut();
-
-    auto &state = activityStates[aliasClass];
-    ChangeResult result =
-        state.activeOut ? ChangeResult::NoChange : ChangeResult::Change;
-    state.activeOut = true;
-    return result;
-  }
-  ChangeResult setActiveOut() {
-    if (otherMemoryActivity.activeOut && activityStates.empty())
-      return ChangeResult::NoChange;
-    otherMemoryActivity.activeOut = true;
-    activityStates.clear();
-    return ChangeResult::Change;
-  }
-
-  void print(raw_ostream &os) const override {
-    if (activityStates.empty()) {
-      os << "<memory activity state was empty>"
-         << "\n";
-    }
-    for (const auto &[value, state] : activityStates) {
-      os << value << ": in " << state.activeIn << " out " << state.activeOut
-         << "\n";
-    }
-    os << "other classes: in " << otherMemoryActivity.activeIn << " out "
-       << otherMemoryActivity.activeOut << "\n";
-  }
-
-  raw_ostream &operator<<(raw_ostream &os) const {
-    print(os);
-    return os;
-  }
-
-protected:
-  ChangeResult merge(const AbstractDenseLattice &lattice) {
-    const auto &rhs = static_cast<const MemoryActivity &>(lattice);
-    ChangeResult result = ChangeResult::NoChange;
-    DenseSet<DistinctAttr> known;
-    auto lhsRange = llvm::make_first_range(activityStates);
-    auto rhsRange = llvm::make_first_range(rhs.activityStates);
-    known.insert(lhsRange.begin(), lhsRange.end());
-    known.insert(rhsRange.begin(), rhsRange.end());
-
-    MemoryActivityState updatedOther(otherMemoryActivity);
-    result |= updatedOther.merge(rhs.otherMemoryActivity);
-    DenseMap<DistinctAttr, MemoryActivityState> updated;
-    for (DistinctAttr d : known) {
-      auto lhsIt = activityStates.find(d);
-      auto rhsIt = rhs.activityStates.find(d);
-      bool isKnownInLHS = lhsIt != activityStates.end();
-      bool isKnownInRHS = rhsIt != rhs.activityStates.end();
-      const MemoryActivityState *lhsActivity =
-          isKnownInLHS ? &lhsIt->getSecond() : &otherMemoryActivity;
-      const MemoryActivityState *rhsActivity =
-          isKnownInRHS ? &rhsIt->getSecond() : &rhs.otherMemoryActivity;
-      MemoryActivityState updatedActivity(*lhsActivity);
-      (void)updatedActivity.merge(*rhsActivity);
-      if ((lhsIt != activityStates.end() &&
-           updatedActivity != lhsIt->getSecond()) ||
-          (lhsIt == activityStates.end() &&
-           updatedActivity != otherMemoryActivity)) {
-        result |= ChangeResult::Change;
-      }
-      if (updatedActivity != updatedOther)
-        updated.try_emplace(d, updatedActivity);
-    }
-    std::swap(updated, activityStates);
-    return otherMemoryActivity.merge(rhs.otherMemoryActivity) | result;
-  }
-
-private:
-  DenseMap<DistinctAttr, MemoryActivityState> activityStates;
-  MemoryActivityState otherMemoryActivity;
-};
-
-class ForwardMemoryActivity : public MemoryActivity {
-public:
-  using MemoryActivity::MemoryActivity;
-
-  /// Join the activity states.
-  ChangeResult join(const AbstractDenseLattice &lattice) {
-    return merge(lattice);
-  }
-};
-
-class BackwardMemoryActivity : public MemoryActivity {
-public:
-  using MemoryActivity::MemoryActivity;
-
-  ChangeResult meet(const AbstractDenseLattice &lattice) override {
-    return merge(lattice);
-  }
-};
-
-/// Sparse activity analysis reasons about activity by traversing forward down
-/// the def-use chains starting from active function arguments.
-class SparseForwardActivityAnalysis
-    : public SparseForwardDataFlowAnalysis<ForwardValueActivity> {
-public:
-  using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
-
-  /// In general, we don't know anything about entry operands.
-  void setToEntryState(ForwardValueActivity *lattice) override {
-    // errs() << "sparse forward setting to entry state\n";
-    propagateIfChanged(lattice, lattice->join(ValueActivity()));
-  }
-
-  LogicalResult
-  visitOperation(Operation *op, ArrayRef<const ForwardValueActivity *> operands,
-                 ArrayRef<ForwardValueActivity *> results) override {
-    if (op->hasTrait<OpTrait::ConstantLike>())
-      return success();
-
-    // Bail out if this op affects memory.
-    if (!isPure(op))
-      return success();
-
-    transfer(op, operands, results);
-
+LogicalResult enzyme::SparseBackwardActivityAnalysis::visitOperation(
+    Operation *op, ArrayRef<enzyme::BackwardValueActivity *> operands,
+    ArrayRef<const enzyme::BackwardValueActivity *> results) {
+  // Bail out if the op propagates memory
+  if (!isPure(op)) {
     return success();
   }
 
-  void visitExternalCall(CallOpInterface call,
-                         ArrayRef<const ForwardValueActivity *> operands,
-                         ArrayRef<ForwardValueActivity *> results) override {
-    transfer(call, operands, results);
-  }
+  transfer(op, operands, results);
+  return success();
+}
 
-  void transfer(Operation *op, ArrayRef<const ForwardValueActivity *> operands,
-                ArrayRef<ForwardValueActivity *> results) {
-    // For value-based AA, assume any active argument leads to an active
-    // result.
-    ValueActivity joinedResult;
-    for (const ForwardValueActivity *operand : operands)
-      joinedResult = ValueActivity::merge(joinedResult, operand->getValue());
-
-    // Only mark results as active data if the type can carry perturbations and
-    // has value semantics
-    for (ForwardValueActivity *result : results) {
-      if (joinedResult.isActiveVal())
-        propagateIfChanged(result,
-                           result->join(isa<FloatType, ComplexType>(
-                                            result->getAnchor().getType())
-                                            ? joinedResult
-                                            : ValueActivity::getConstant()));
-      else
-        propagateIfChanged(result, result->join(joinedResult));
-    }
-  }
-};
-
-class SparseBackwardActivityAnalysis
-    : public SparseBackwardDataFlowAnalysis<BackwardValueActivity> {
-public:
-  using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
-
-  void setToExitState(BackwardValueActivity *lattice) override {
-    // errs() << "backward sparse setting to exit state\n";
-  }
-
-  void visitBranchOperand(OpOperand &operand) override {}
-
-  void visitCallOperand(OpOperand &operand) override {}
-
-  void
-  visitNonControlFlowArguments(RegionSuccessor &successor,
-                               ArrayRef<BlockArgument> arguments) override {}
-
-  void transfer(Operation *op, ArrayRef<BackwardValueActivity *> operands,
-                ArrayRef<const BackwardValueActivity *> results) {
-    // Propagate all operands to all results
-    for (auto operand : operands)
-      for (auto result : results)
-        meet(operand, *result);
-  }
-
-  LogicalResult
-  visitOperation(Operation *op, ArrayRef<BackwardValueActivity *> operands,
-                 ArrayRef<const BackwardValueActivity *> results) override {
-    // Bail out if the op propagates memory
-    if (!isPure(op)) {
-      return success();
-    }
-
-    transfer(op, operands, results);
-    return success();
-  }
-
-  void
-  visitExternalCall(CallOpInterface call,
-                    ArrayRef<BackwardValueActivity *> operands,
-                    ArrayRef<const BackwardValueActivity *> results) override {
-    transfer(call, operands, results);
-  }
-};
+void enzyme::SparseBackwardActivityAnalysis::visitExternalCall(
+    CallOpInterface call, ArrayRef<enzyme::BackwardValueActivity *> operands,
+    ArrayRef<const enzyme::BackwardValueActivity *> results) {
+  transfer(call, operands, results);
+}
 
 // When applying a transfer function to a store from memory, we need to know
 // what value is being stored.
@@ -504,324 +231,6 @@
       });
 }
 
-class DenseForwardActivityAnalysis
-    : public DenseForwardDataFlowAnalysis<ForwardMemoryActivity> {
-public:
-  DenseForwardActivityAnalysis(DataFlowSolver &solver, Block *entryBlock,
-                               ArrayRef<enzyme::Activity> argumentActivity)
-      : DenseForwardDataFlowAnalysis(solver), entryBlock(entryBlock),
-        argumentActivity(argumentActivity) {}
-
-  LogicalResult visitOperation(Operation *op,
-                               const ForwardMemoryActivity &before,
-                               ForwardMemoryActivity *after) override {
-    join(after, before);
-    ChangeResult result = ChangeResult::NoChange;
-
-    // TODO If we know this is inactive by definition
-    // if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
-    //   if (ifaceOp.isInactive()) {
-    //     propagateIfChanged(after, result);
-    //     return;
-    //   }
-    // }
-
-    auto memory = dyn_cast<MemoryEffectOpInterface>(op);
-    // If we can't reason about the memory effects, then conservatively assume
-    // we can't deduce anything about activity via side-effects.
-    if (!memory)
-      return success();
-
-    SmallVector<MemoryEffects::EffectInstance> effects;
-    memory.getEffects(effects);
-
-    for (const auto &effect : effects) {
-      Value value = effect.getValue();
-
-      // If we see an effect on anything other than a value, assume we can't
-      // deduce anything about the activity.
-      if (!value)
-        return success();
-
-      // In forward-flow, a value is active if loaded from a memory resource
-      // that has previously been actively stored to.
-      if (isa<MemoryEffects::Read>(effect.getEffect())) {
-        auto *ptrAliasClass =
-            getOrCreateFor<AliasClassLattice>(getProgramPointAfter(op), value);
-        forEachAliasedAlloc(ptrAliasClass, [&](DistinctAttr alloc) {
-          if (before.hasActiveData(alloc)) {
-            for (OpResult opResult : op->getResults()) {
-              // Mark the result as (forward) active
-              // TODO: We might need type analysis here
-              // Structs and tensors also have value semantics
-              if (isa<FloatType, ComplexType>(opResult.getType())) {
-                auto *valueState = getOrCreate<ForwardValueActivity>(opResult);
-                propagateIfChanged(
-                    valueState,
-                    valueState->join(ValueActivity::getActiveVal()));
-              }
-            }
-
-            if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
-              // propagate from input to block argument
-              for (OpOperand *inputOperand : linalgOp.getDpsInputOperands()) {
-                if (inputOperand->get() == value) {
-                  auto *valueState = getOrCreate<ForwardValueActivity>(
-                      linalgOp.getMatchingBlockArgument(inputOperand));
-                  propagateIfChanged(
-                      valueState,
-                      valueState->join(ValueActivity::getActiveVal()));
-                }
-              }
-            }
-          }
-        });
-      }
-
-      if (isa<MemoryEffects::Write>(effect.getEffect())) {
-        std::optional<Value> stored = getStored(op);
-        if (stored.has_value()) {
-          auto *valueState = getOrCreateFor<ForwardValueActivity>(
-              getProgramPointAfter(op), *stored);
-          if (valueState->getValue().isActiveVal()) {
-            auto *ptrAliasClass = getOrCreateFor<AliasClassLattice>(
-                getProgramPointAfter(op), value);
-            forEachAliasedAlloc(ptrAliasClass, [&](DistinctAttr alloc) {
-              // Mark the pointer as having been actively stored into
-              result |= after->setActiveIn(alloc);
-            });
-          }
-        } else if (auto copySource = getCopySource(op)) {
-          auto *srcAliasClass = getOrCreateFor<AliasClassLattice>(
-              getProgramPointAfter(op), *copySource);
-          forEachAliasedAlloc(srcAliasClass, [&](DistinctAttr srcAlloc) {
-            if (before.hasActiveData(srcAlloc)) {
-              auto *destAliasClass = getOrCreateFor<AliasClassLattice>(
-                  getProgramPointAfter(op), value);
-              forEachAliasedAlloc(destAliasClass, [&](DistinctAttr destAlloc) {
-                result |= after->setActiveIn(destAlloc);
-              });
-            }
-          });
-        } else if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
-          // linalg.yield stores to the corresponding value.
-          for (OpOperand &dpsInit : linalgOp.getDpsInitsMutable()) {
-            if (dpsInit.get() == value) {
-              int64_t resultIndex =
-                  dpsInit.getOperandNumber() - linalgOp.getNumDpsInputs();
-              Value yieldOperand =
-                  linalgOp.getBlock()->getTerminator()->getOperand(resultIndex);
-              auto *valueState = getOrCreateFor<ForwardValueActivity>(
-                  getProgramPointAfter(op), yieldOperand);
-              if (valueState->getValue().isActiveVal()) {
-                auto *ptrAliasClass = getOrCreateFor<AliasClassLattice>(
-                    getProgramPointAfter(op), value);
-                forEachAliasedAlloc(ptrAliasClass, [&](DistinctAttr alloc) {
-                  result |= after->setActiveIn(alloc);
-                });
-              }
-            }
-          }
-        }
-      }
-    }
-    propagateIfChanged(after, result);
-    return success();
-  }
-
-  void visitCallControlFlowTransfer(CallOpInterface call,
-                                    CallControlFlowAction action,
-                                    const ForwardMemoryActivity &before,
-                                    ForwardMemoryActivity *after) override {
-    join(after, before);
-  }
-
-  /// Initialize the entry block with the supplied argument activities.
-  void setToEntryState(ForwardMemoryActivity *lattice) override {
-    if (auto pp = dyn_cast_if_present<ProgramPoint *>(lattice->getAnchor()))
-      if (Block *block = pp->getBlock();
-          block && block == entryBlock && pp->isBlockStart()) {
-        for (const auto &[arg, activity] :
-             llvm::zip(block->getArguments(), argumentActivity)) {
-          if (activity != enzyme::Activity::enzyme_dup &&
-              activity != enzyme::Activity::enzyme_dupnoneed)
-            continue;
-          auto *argAliasClasses = getOrCreateFor<AliasClassLattice>(
-              getProgramPointBefore(block), arg);
-          ChangeResult changed =
-              argAliasClasses->getAliasClassesObject().foreachElement(
-                  [lattice](DistinctAttr argAliasClass,
-                            enzyme::AliasClassSet::State state) {
-                    if (state == enzyme::AliasClassSet::State::Undefined)
-                      return ChangeResult::NoChange;
-                    return lattice->setActiveIn(argAliasClass);
-                  });
-          propagateIfChanged(lattice, changed);
-        }
-      }
-  }
-
-private:
-  // A pointer to the entry block and argument activities of the top-level
-  // function being differentiated. This is used to set the entry state
-  // because we need access to the results of points-to analysis.
-  Block *entryBlock;
-  SmallVector<enzyme::Activity> argumentActivity;
-};
-
-class DenseBackwardActivityAnalysis
-    : public DenseBackwardDataFlowAnalysis<BackwardMemoryActivity> {
-public:
-  DenseBackwardActivityAnalysis(DataFlowSolver &solver,
-                                SymbolTableCollection &symbolTable,
-                                FunctionOpInterface parentOp,
-                                ArrayRef<enzyme::Activity> argumentActivity)
-      : DenseBackwardDataFlowAnalysis(solver, symbolTable), parentOp(parentOp),
-        argumentActivity(argumentActivity) {}
-
-  LogicalResult visitOperation(Operation *op,
-                               const BackwardMemoryActivity &after,
-                               BackwardMemoryActivity *before) override {
-
-    // TODO: If we know this is inactive by definition
-    // if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
-    //   if (ifaceOp.isInactive()) {
-    //     return;
-    //   }
-    // }
-
-    // Initialize the return activity of arguments.
-    if (op->hasTrait<OpTrait::ReturnLike>() && op->getParentOp() == parentOp) {
-      for (const auto &[arg, argActivity] :
-           llvm::zip(parentOp->getRegions().front().getArguments(),
-                     argumentActivity)) {
-        if (argActivity != enzyme::Activity::enzyme_dup &&
-            argActivity != enzyme::Activity::enzyme_dupnoneed) {
-          continue;
-        }
-        auto *argAliasClasses =
-            getOrCreateFor<AliasClassLattice>(getProgramPointBefore(op), arg);
-        ChangeResult changed =
-            argAliasClasses->getAliasClassesObject().foreachElement(
-                [before](DistinctAttr argAliasClass,
-                         enzyme::AliasClassSet::State state) {
-                  if (state == enzyme::AliasClassSet::State::Undefined)
-                    return ChangeResult::NoChange;
-                  return before->setActiveOut(argAliasClass);
-                });
-        propagateIfChanged(before, changed);
-      }
-
-      // Initialize the return activity of the operands
-      for (Value operand : op->getOperands()) {
-        if (isa<MemRefType, LLVM::LLVMPointerType>(operand.getType())) {
-          auto *retAliasClasses = getOrCreateFor<AliasClassLattice>(
-              getProgramPointBefore(op), operand);
-          ChangeResult changed =
-              retAliasClasses->getAliasClassesObject().foreachElement(
-                  [before](DistinctAttr retAliasClass,
-                           enzyme::AliasClassSet::State state) {
-                    if (state == enzyme::AliasClassSet::State::Undefined)
-                      return ChangeResult::NoChange;
-                    return before->setActiveOut(retAliasClass);
-                  });
-          propagateIfChanged(before, changed);
-        }
-      }
-    }
-
-    meet(before, after);
-    ChangeResult result = ChangeResult::NoChange;
-    auto memory = dyn_cast<MemoryEffectOpInterface>(op);
-    // If we can't reason about the memory effects, then conservatively assume
-    // we can't deduce anything about activity via side-effects.
-    if (!memory)
-      return success();
-
-    SmallVector<MemoryEffects::EffectInstance> effects;
-    memory.getEffects(effects);
-
-    for (const auto &effect : effects) {
-      Value value = effect.getValue();
-
-      // If we see an effect on anything other than a value, assume we can't
-      // deduce anything about the activity.
-      if (!value)
-        return success();
-
-      // In backward-flow, a value is active if stored into a memory resource
-      // that has subsequently been actively loaded from.
-      if (isa<MemoryEffects::Read>(effect.getEffect())) {
-        for (Value opResult : op->getResults()) {
-          auto *valueState = getOrCreateFor<BackwardValueActivity>(
-              getProgramPointBefore(op), opResult);
-          if (valueState->getValue().isActiveVal()) {
-            auto *ptrAliasClass = getOrCreateFor<AliasClassLattice>(
-                getProgramPointBefore(op), value);
-            forEachAliasedAlloc(ptrAliasClass, [&](DistinctAttr alloc) {
-              result |= before->setActiveOut(alloc);
-            });
-          }
-        }
-      }
-      if (isa<MemoryEffects::Write>(effect.getEffect())) {
-        auto *ptrAliasClass =
-            getOrCreateFor<AliasClassLattice>(getProgramPointBefore(op), value);
-        std::optional<Value> stored = getStored(op);
-        std::optional<Value> copySource = getCopySource(op);
-        forEachAliasedAlloc(ptrAliasClass, [&](DistinctAttr alloc) {
-          if (stored.has_value() && after.activeDataFlowsOut(alloc)) {
-            if (isa<FloatType, ComplexType>(stored->getType())) {
-              auto *valueState = getOrCreate<BackwardValueActivity>(*stored);
-              propagateIfChanged(
-                  valueState, valueState->meet(ValueActivity::getActiveVal()));
-            }
-          } else if (copySource.has_value() &&
-                     after.activeDataFlowsOut(alloc)) {
-            auto *srcAliasClass = getOrCreateFor<AliasClassLattice>(
-                getProgramPointBefore(op), *copySource);
-            forEachAliasedAlloc(srcAliasClass, [&](DistinctAttr srcAlloc) {
-              result |= before->setActiveOut(srcAlloc);
-            });
-          } else if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
-            if (after.activeDataFlowsOut(alloc)) {
-              for (OpOperand &dpsInit : linalgOp.getDpsInitsMutable()) {
-                if (dpsInit.get() == value) {
-                  int64_t resultIndex =
-                      dpsInit.getOperandNumber() - linalgOp.getNumDpsInputs();
-                  Value yieldOperand =
-                      linalgOp.getBlock()->getTerminator()->getOperand(
-                          resultIndex);
-                  auto *valueState =
-                      getOrCreate<BackwardValueActivity>(yieldOperand);
-                  propagateIfChanged(
-                      valueState,
-                      valueState->meet(ValueActivity::getActiveVal()));
-                }
-              }
-            }
-          }
-        });
-      }
-    }
-    propagateIfChanged(before, result);
-    return success();
-  }
-
-  void visitCallControlFlowTransfer(CallOpInterface call,
-                                    CallControlFlowAction action,
-                                    const BackwardMemoryActivity &after,
-                                    BackwardMemoryActivity *before) override {
-    meet(before, after);
-  }
-
-  void setToExitState(BackwardMemoryActivity *lattice) override {}
-
-private:
-  FunctionOpInterface parentOp;
-  SmallVector<enzyme::Activity> argumentActivity;
-};
-
 void traverseCallGraph(FunctionOpInterface root,
                        SymbolTableCollection *symbolTable,
                        function_ref<void(FunctionOpInterface)> processFunc) {
@@ -844,14 +253,440 @@
   }
 }
 
+//===----------------------------------------------------------------------===//
+// MemoryActivityState
+//===----------------------------------------------------------------------===//
+ChangeResult enzyme::MemoryActivityState::reset() {
+  if (!activeIn && !activeOut)
+    return ChangeResult::NoChange;
+  activeIn = false;
+  activeOut = false;
+  return ChangeResult::Change;
+}
+
+ChangeResult
+enzyme::MemoryActivityState::merge(const enzyme::MemoryActivityState &other) {
+  if (*this == other) {
+    return ChangeResult::NoChange;
+  }
+
+  activeIn |= other.activeIn;
+  activeOut |= other.activeOut;
+  return ChangeResult::Change;
+}
+
+//===----------------------------------------------------------------------===//
+// MemoryActivity
+//===----------------------------------------------------------------------===//
+
+ChangeResult enzyme::MemoryActivity::reset() {
+  if (activityStates.empty())
+    return otherMemoryActivity.reset();
+  activityStates.clear();
+  return otherMemoryActivity.reset();
+}
+
+bool enzyme::MemoryActivity::hasActiveData(DistinctAttr aliasClass) const {
+  if (!aliasClass)
+    return otherMemoryActivity.activeIn;
+  auto it = activityStates.find(aliasClass);
+  if (it != activityStates.end())
+    return it->getSecond().activeIn;
+  return otherMemoryActivity.activeIn;
+}
+
+bool enzyme::MemoryActivity::activeDataFlowsOut(DistinctAttr aliasClass) const {
+  if (!aliasClass)
+    return otherMemoryActivity.activeOut;
+
+  auto it = activityStates.find(aliasClass);
+  if (it != activityStates.end())
+    return it->getSecond().activeOut;
+  return otherMemoryActivity.activeOut;
+}
+
+/// Set the internal activity state. Accepts null attribute to indicate "other
+/// classes".
+ChangeResult enzyme::MemoryActivity::setActiveIn(DistinctAttr aliasClass) {
+  if (!aliasClass)
+    return setActiveIn();
+
+  auto &state = activityStates[aliasClass];
+  ChangeResult result =
+      state.activeIn ? ChangeResult::NoChange : ChangeResult::Change;
+  state.activeIn = true;
+  return result;
+}
+ChangeResult enzyme::MemoryActivity::setActiveIn() {
+  if (otherMemoryActivity.activeIn && activityStates.empty())
+    return ChangeResult::NoChange;
+  otherMemoryActivity.activeIn = true;
+  activityStates.clear();
+  return ChangeResult::Change;
+}
+ChangeResult enzyme::MemoryActivity::setActiveOut(DistinctAttr aliasClass) {
+  if (!aliasClass)
+    return setActiveOut();
+
+  auto &state = activityStates[aliasClass];
+  ChangeResult result =
+      state.activeOut ? ChangeResult::NoChange : ChangeResult::Change;
+  state.activeOut = true;
+  return result;
+}
+ChangeResult enzyme::MemoryActivity::setActiveOut() {
+  if (otherMemoryActivity.activeOut && activityStates.empty())
+    return ChangeResult::NoChange;
+  otherMemoryActivity.activeOut = true;
+  activityStates.clear();
+  return ChangeResult::Change;
+}
+
+void enzyme::MemoryActivity::print(raw_ostream &os) const {
+  if (activityStates.empty()) {
+    os << "<memory activity state was empty>"
+       << "\n";
+  }
+  for (const auto &[value, state] : activityStates) {
+    os << value << ": in " << state.activeIn << " out " << state.activeOut
+       << "\n";
+  }
+  os << "other classes: in " << otherMemoryActivity.activeIn << " out "
+     << otherMemoryActivity.activeOut << "\n";
+}
+
+raw_ostream &enzyme::MemoryActivity::operator<<(raw_ostream &os) const {
+  print(os);
+  return os;
+}
+
+ChangeResult
+enzyme::MemoryActivity::merge(const AbstractDenseLattice &lattice) {
+
+  const auto &rhs = static_cast<const MemoryActivity &>(lattice);
+  ChangeResult result = ChangeResult::NoChange;
+  DenseSet<DistinctAttr> known;
+  auto lhsRange = llvm::make_first_range(activityStates);
+  auto rhsRange = llvm::make_first_range(rhs.activityStates);
+  known.insert(lhsRange.begin(), lhsRange.end());
+  known.insert(rhsRange.begin(), rhsRange.end());
+
+  MemoryActivityState updatedOther(otherMemoryActivity);
+  result |= updatedOther.merge(rhs.otherMemoryActivity);
+  DenseMap<DistinctAttr, MemoryActivityState> updated;
+  for (DistinctAttr d : known) {
+    auto lhsIt = activityStates.find(d);
+    auto rhsIt = rhs.activityStates.find(d);
+    bool isKnownInLHS = lhsIt != activityStates.end();
+    bool isKnownInRHS = rhsIt != rhs.activityStates.end();
+    const MemoryActivityState *lhsActivity =
+        isKnownInLHS ? &lhsIt->getSecond() : &otherMemoryActivity;
+    const MemoryActivityState *rhsActivity =
+        isKnownInRHS ? &rhsIt->getSecond() : &rhs.otherMemoryActivity;
+    MemoryActivityState updatedActivity(*lhsActivity);
+    (void)updatedActivity.merge(*rhsActivity);
+    if ((lhsIt != activityStates.end() &&
+         updatedActivity != lhsIt->getSecond()) ||
+        (lhsIt == activityStates.end() &&
+         updatedActivity != otherMemoryActivity)) {
+      result |= ChangeResult::Change;
+    }
+    if (updatedActivity != updatedOther)
+      updated.try_emplace(d, updatedActivity);
+  }
+
+  std::swap(updated, activityStates);
+  return otherMemoryActivity.merge(rhs.otherMemoryActivity) | result;
+}
+
+//===----------------------------------------------------------------------===//
+// DenseForwardActivityAnalysis
+//===----------------------------------------------------------------------===//
+
+LogicalResult enzyme::DenseForwardActivityAnalysis::visitOperation(
+    Operation *op, const ForwardMemoryActivity &before,
+    ForwardMemoryActivity *after) {
+  join(after, before);
+  ChangeResult result = ChangeResult::NoChange;
+
+  // TODO If we know this is inactive by definition
+  // if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
+  //   if (ifaceOp.isInactive()) {
+  //     propagateIfChanged(after, result);
+  //     return;
+  //   }
+  // }
+
+  auto memory = dyn_cast<MemoryEffectOpInterface>(op);
+  // If we can't reason about the memory effects, then conservatively assume
+  // we can't deduce anything about activity via side-effects.
+  if (!memory)
+    return success();
+
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  memory.getEffects(effects);
+
+  for (const auto &effect : effects) {
+    Value value = effect.getValue();
+
+    // If we see an effect on anything other than a value, assume we can't
+    // deduce anything about the activity.
+    if (!value)
+      return success();
+
+    // In forward-flow, a value is active if loaded from a memory resource
+    // that has previously been actively stored to.
+    if (isa<MemoryEffects::Read>(effect.getEffect())) {
+      auto *ptrAliasClass =
+          getOrCreateFor<AliasClassLattice>(getProgramPointAfter(op), value);
+      forEachAliasedAlloc(ptrAliasClass, [&](DistinctAttr alloc) {
+        if (before.hasActiveData(alloc)) {
+          for (OpResult opResult : op->getResults()) {
+            // Mark the result as (forward) active
+            // TODO: We might need type analysis here
+            // Structs and tensors also have value semantics
+            if (isa<FloatType, ComplexType>(opResult.getType())) {
+              auto *valueState =
+                  getOrCreate<enzyme::ForwardValueActivity>(opResult);
+              propagateIfChanged(
+                  valueState,
+                  valueState->join(enzyme::ValueActivity::getActiveVal()));
+            }
+          }
+
+          if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
+            // propagate from input to block argument
+            for (OpOperand *inputOperand : linalgOp.getDpsInputOperands()) {
+              if (inputOperand->get() == value) {
+                auto *valueState = getOrCreate<enzyme::ForwardValueActivity>(
+                    linalgOp.getMatchingBlockArgument(inputOperand));
+                propagateIfChanged(
+                    valueState,
+                    valueState->join(enzyme::ValueActivity::getActiveVal()));
+              }
+            }
+          }
+        }
+      });
+    }
+
+    if (isa<MemoryEffects::Write>(effect.getEffect())) {
+      std::optional<Value> stored = getStored(op);
+      if (stored.has_value()) {
+        auto *valueState = getOrCreateFor<enzyme::ForwardValueActivity>(
+            getProgramPointAfter(op), *stored);
+        if (valueState->getValue().isActiveVal()) {
+          auto *ptrAliasClass = getOrCreateFor<AliasClassLattice>(
+              getProgramPointAfter(op), value);
+          forEachAliasedAlloc(ptrAliasClass, [&](DistinctAttr alloc) {
+            // Mark the pointer as having been actively stored into
+            result |= after->setActiveIn(alloc);
+          });
+        }
+      } else if (auto copySource = getCopySource(op)) {
+        auto *srcAliasClass = getOrCreateFor<AliasClassLattice>(
+            getProgramPointAfter(op), *copySource);
+        forEachAliasedAlloc(srcAliasClass, [&](DistinctAttr srcAlloc) {
+          if (before.hasActiveData(srcAlloc)) {
+            auto *destAliasClass = getOrCreateFor<AliasClassLattice>(
+                getProgramPointAfter(op), value);
+            forEachAliasedAlloc(destAliasClass, [&](DistinctAttr destAlloc) {
+              result |= after->setActiveIn(destAlloc);
+            });
+          }
+        });
+      } else if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
+        // linalg.yield stores to the corresponding value.
+        for (OpOperand &dpsInit : linalgOp.getDpsInitsMutable()) {
+          if (dpsInit.get() == value) {
+            int64_t resultIndex =
+                dpsInit.getOperandNumber() - linalgOp.getNumDpsInputs();
+            Value yieldOperand =
+                linalgOp.getBlock()->getTerminator()->getOperand(resultIndex);
+            auto *valueState = getOrCreateFor<enzyme::ForwardValueActivity>(
+                getProgramPointAfter(op), yieldOperand);
+            if (valueState->getValue().isActiveVal()) {
+              auto *ptrAliasClass = getOrCreateFor<AliasClassLattice>(
+                  getProgramPointAfter(op), value);
+              forEachAliasedAlloc(ptrAliasClass, [&](DistinctAttr alloc) {
+                result |= after->setActiveIn(alloc);
+              });
+            }
+          }
+        }
+      }
+    }
+  }
+  propagateIfChanged(after, result);
+  return success();
+}
+
+void enzyme::DenseForwardActivityAnalysis::setToEntryState(
+    ForwardMemoryActivity *lattice) {
+  if (auto pp = dyn_cast_if_present<ProgramPoint *>(lattice->getAnchor()))
+    if (Block *block = pp->getBlock();
+        block && block == entryBlock && pp->isBlockStart()) {
+      for (const auto &[arg, activity] :
+           llvm::zip(block->getArguments(), argumentActivity)) {
+        if (activity != enzyme::Activity::enzyme_dup &&
+            activity != enzyme::Activity::enzyme_dupnoneed)
+          continue;
+        auto *argAliasClasses = getOrCreateFor<AliasClassLattice>(
+            getProgramPointBefore(block), arg);
+        ChangeResult changed =
+            argAliasClasses->getAliasClassesObject().foreachElement(
+                [lattice](DistinctAttr argAliasClass,
+                          enzyme::AliasClassSet::State state) {
+                  if (state == enzyme::AliasClassSet::State::Undefined)
+                    return ChangeResult::NoChange;
+                  return lattice->setActiveIn(argAliasClass);
+                });
+        propagateIfChanged(lattice, changed);
+      }
+    }
+}
+
+//===----------------------------------------------------------------------===//
+// DenseBackwardActivityAnalysis
+//===----------------------------------------------------------------------===//
+
+LogicalResult enzyme::DenseBackwardActivityAnalysis::visitOperation(
+    Operation *op, const BackwardMemoryActivity &after,
+    BackwardMemoryActivity *before) {
+
+  // TODO: If we know this is inactive by definition
+  // if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
+  //   if (ifaceOp.isInactive()) {
+  //     return;
+  //   }
+  // }
+
+  // Initialize the return activity of arguments.
+  if (op->hasTrait<OpTrait::ReturnLike>() && op->getParentOp() == parentOp) {
+    for (const auto &[arg, argActivity] : llvm::zip(
+             parentOp->getRegions().front().getArguments(), argumentActivity)) {
+      if (argActivity != enzyme::Activity::enzyme_dup &&
+          argActivity != enzyme::Activity::enzyme_dupnoneed) {
+        continue;
+      }
+      auto *argAliasClasses =
+          getOrCreateFor<AliasClassLattice>(getProgramPointBefore(op), arg);
+      ChangeResult changed =
+          argAliasClasses->getAliasClassesObject().foreachElement(
+              [before](DistinctAttr argAliasClass,
+                       enzyme::AliasClassSet::State state) {
+                if (state == enzyme::AliasClassSet::State::Undefined)
+                  return ChangeResult::NoChange;
+                return before->setActiveOut(argAliasClass);
+              });
+      propagateIfChanged(before, changed);
+    }
+
+    // Initialize the return activity of the operands
+    for (Value operand : op->getOperands()) {
+      if (isa<MemRefType, LLVM::LLVMPointerType>(operand.getType())) {
+        auto *retAliasClasses = getOrCreateFor<AliasClassLattice>(
+            getProgramPointBefore(op), operand);
+        ChangeResult changed =
+            retAliasClasses->getAliasClassesObject().foreachElement(
+                [before](DistinctAttr retAliasClass,
+                         enzyme::AliasClassSet::State state) {
+                  if (state == enzyme::AliasClassSet::State::Undefined)
+                    return ChangeResult::NoChange;
+                  return before->setActiveOut(retAliasClass);
+                });
+        propagateIfChanged(before, changed);
+      }
+    }
+  }
+
+  meet(before, after);
+  ChangeResult result = ChangeResult::NoChange;
+  auto memory = dyn_cast<MemoryEffectOpInterface>(op);
+  // If we can't reason about the memory effects, then conservatively assume
+  // we can't deduce anything about activity via side-effects.
+  if (!memory)
+    return success();
+
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  memory.getEffects(effects);
+
+  for (const auto &effect : effects) {
+    Value value = effect.getValue();
+
+    // If we see an effect on anything other than a value, assume we can't
+    // deduce anything about the activity.
+    if (!value)
+      return success();
+
+    // In backward-flow, a value is active if stored into a memory resource
+    // that has subsequently been actively loaded from.
+    if (isa<MemoryEffects::Read>(effect.getEffect())) {
+      for (Value opResult : op->getResults()) {
+        auto *valueState = getOrCreateFor<enzyme::BackwardValueActivity>(
+            getProgramPointBefore(op), opResult);
+        if (valueState->getValue().isActiveVal()) {
+          auto *ptrAliasClass = getOrCreateFor<AliasClassLattice>(
+              getProgramPointBefore(op), value);
+          forEachAliasedAlloc(ptrAliasClass, [&](DistinctAttr alloc) {
+            result |= before->setActiveOut(alloc);
+          });
+        }
+      }
+    }
+    if (isa<MemoryEffects::Write>(effect.getEffect())) {
+      auto *ptrAliasClass =
+          getOrCreateFor<AliasClassLattice>(getProgramPointBefore(op), value);
+      std::optional<Value> stored = getStored(op);
+      std::optional<Value> copySource = getCopySource(op);
+      forEachAliasedAlloc(ptrAliasClass, [&](DistinctAttr alloc) {
+        if (stored.has_value() && after.activeDataFlowsOut(alloc)) {
+          if (isa<FloatType, ComplexType>(stored->getType())) {
+            auto *valueState =
+                getOrCreate<enzyme::BackwardValueActivity>(*stored);
+            propagateIfChanged(
+                valueState,
+                valueState->meet(enzyme::ValueActivity::getActiveVal()));
+          }
+        } else if (copySource.has_value() && after.activeDataFlowsOut(alloc)) {
+          auto *srcAliasClass = getOrCreateFor<AliasClassLattice>(
+              getProgramPointBefore(op), *copySource);
+          forEachAliasedAlloc(srcAliasClass, [&](DistinctAttr srcAlloc) {
+            result |= before->setActiveOut(srcAlloc);
+          });
+        } else if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
+          if (after.activeDataFlowsOut(alloc)) {
+            for (OpOperand &dpsInit : linalgOp.getDpsInitsMutable()) {
+              if (dpsInit.get() == value) {
+                int64_t resultIndex =
+                    dpsInit.getOperandNumber() - linalgOp.getNumDpsInputs();
+                Value yieldOperand =
+                    linalgOp.getBlock()->getTerminator()->getOperand(
+                        resultIndex);
+                auto *valueState =
+                    getOrCreate<enzyme::BackwardValueActivity>(yieldOperand);
+                propagateIfChanged(
+                    valueState,
+                    valueState->meet(enzyme::ValueActivity::getActiveVal()));
+              }
+            }
+          }
+        }
+      });
+    }
+  }
+  propagateIfChanged(before, result);
+  return success();
+}
+
 void printActivityAnalysisResults(DataFlowSolver &solver,
                                   FunctionOpInterface callee,
                                   const SmallPtrSet<Operation *, 2> &returnOps,
                                   SymbolTableCollection *symbolTable,
                                   bool verbose, bool annotate) {
   auto isActiveData = [&](Value value) {
-    auto fva = solver.lookupState<ForwardValueActivity>(value);
-    auto bva = solver.lookupState<BackwardValueActivity>(value);
+    auto fva = solver.lookupState<enzyme::ForwardValueActivity>(value);
+    auto bva = solver.lookupState<enzyme::BackwardValueActivity>(value);
     bool forwardActive = fva && fva->getValue().isActiveVal();
     bool backwardActive = bva && bva->getValue().isActiveVal();
     return forwardActive && backwardActive;
@@ -861,9 +696,9 @@
     // TODO: integers/vectors that might be pointers
     if (isa<LLVM::LLVMPointerType, MemRefType>(value.getType())) {
       assert(returnOps.size() == 1);
-      auto *fma = solver.lookupState<ForwardMemoryActivity>(
+      auto *fma = solver.lookupState<enzyme::ForwardMemoryActivity>(
           solver.getProgramPointAfter(*returnOps.begin()));
-      auto *bma = solver.lookupState<BackwardMemoryActivity>(
+      auto *bma = solver.lookupState<enzyme::BackwardMemoryActivity>(
           solver.getProgramPointBefore(
               &callee.getFunctionBody().front().front()));
 
@@ -963,12 +798,12 @@
            llvm::none_of(op->getResults(), isActiveData);
   };
 
-  errs() << FlatSymbolRefAttr::get(callee) << ":\n";
+  llvm::errs() << FlatSymbolRefAttr::get(callee) << ":\n";
   for (BlockArgument arg : callee.getArguments()) {
     if (Attribute tagAttr =
             callee.getArgAttr(arg.getArgNumber(), "enzyme.tag")) {
-      errs() << "  " << tagAttr << ": "
-             << (isConstantValue(arg) ? "Constant" : "Active") << "\n";
+      llvm::errs() << "  " << tagAttr << ": "
+                   << (isConstantValue(arg) ? "Constant" : "Active") << "\n";
     }
   }
 
@@ -1003,16 +838,17 @@
   }
   callee.walk([&](Operation *op) {
     if (op->hasAttr("tag")) {
-      errs() << "  " << op->getAttr("tag") << ": ";
+      llvm::errs() << "  " << op->getAttr("tag") << ": ";
       for (OpResult opResult : op->getResults()) {
-        errs() << (isConstantValue(opResult) ? "Constant" : "Active") << "\n";
+        llvm::errs() << (isConstantValue(opResult) ? "Constant" : "Active")
+                     << "\n";
       }
     }
     if (verbose) {
       // Annotate each op's results with its value activity states
       for (OpResult result : op->getResults()) {
         auto forwardValueActivity =
-            solver.lookupState<ForwardValueActivity>(result);
+            solver.lookupState<enzyme::ForwardValueActivity>(result);
         if (forwardValueActivity) {
           std::string dest, key{"fva"};
           llvm::raw_string_ostream os(dest);
@@ -1023,7 +859,7 @@
         }
 
         auto backwardValueActivity =
-            solver.lookupState<BackwardValueActivity>(result);
+            solver.lookupState<enzyme::BackwardValueActivity>(result);
         if (backwardValueActivity) {
           std::string dest, key{"bva"};
           llvm::raw_string_ostream os(dest);
@@ -1040,7 +876,7 @@
     // Annotate function attributes
     for (BlockArgument arg : callee.getArguments()) {
       auto backwardValueActivity =
-          solver.lookupState<BackwardValueActivity>(arg);
+          solver.lookupState<enzyme::BackwardValueActivity>(arg);
       if (backwardValueActivity) {
         std::string dest;
         llvm::raw_string_ostream os(dest);
@@ -1051,20 +887,20 @@
     }
 
     for (Operation *returnOp : returnOps) {
-      auto *state = solver.lookupState<ForwardMemoryActivity>(
+      auto *state = solver.lookupState<enzyme::ForwardMemoryActivity>(
           solver.getProgramPointAfter(returnOp));
       if (state)
-        errs() << "forward end state:\n" << *state << "\n";
+        llvm::errs() << "forward end state:\n" << *state << "\n";
       else
-        errs() << "state was null\n";
+        llvm::errs() << "state was null\n";
     }
 
-    auto startState = solver.lookupState<BackwardMemoryActivity>(
+    auto startState = solver.lookupState<enzyme::BackwardMemoryActivity>(
         solver.getProgramPointAfter(&callee.getFunctionBody().front().front()));
     if (startState)
-      errs() << "backwards end state:\n" << *startState << "\n";
+      llvm::errs() << "backwards end state:\n" << *startState << "\n";
     else
-      errs() << "backwards end state was null\n";
+      llvm::errs() << "backwards end state was null\n";
   }
 }
 
@@ -1093,7 +929,8 @@
     // enzyme_dup, dupnoneed are initialized within the dense forward/backward
     // analyses, enzyme_const is the default.
     if (activity == enzyme::Activity::enzyme_active) {
-      auto *argLattice = solver.getOrCreateState<ForwardValueActivity>(arg);
+      auto *argLattice =
+          solver.getOrCreateState<enzyme::ForwardValueActivity>(arg);
       (void)argLattice->join(ValueActivity::getActiveVal());
     }
   }
@@ -1106,7 +943,7 @@
       returnOps.insert(&op);
       for (Value operand : op.getOperands()) {
         auto *returnLattice =
-            solver.getOrCreateState<BackwardValueActivity>(operand);
+            solver.getOrCreateState<enzyme::BackwardValueActivity>(operand);
         // Very basic type inference of the type
         if (isa<FloatType, ComplexType>(operand.getType())) {
           (void)returnLattice->meet(ValueActivity::getActiveVal());
diff --git a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.h b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.h
index 0aca1eb..4e41744 100644
--- a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.h
+++ b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.h
@@ -27,8 +27,20 @@
 #ifndef ENZYME_MLIR_ANALYSIS_DATAFLOW_ACTIVITYANALYSIS_H
 #define ENZYME_MLIR_ANALYSIS_DATAFLOW_ACTIVITYANALYSIS_H
 
+#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/IR/Block.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LLVM.h"
 
+#include "DataFlowAliasAnalysis.h"
+#include "Dialect/Ops.h"
+
+#include "Interfaces/AutoDiffOpInterface.h"
+
+using namespace mlir::dataflow;
 namespace mlir {
 class FunctionOpInterface;
 
@@ -36,6 +48,279 @@
 
 enum class Activity : uint32_t;
 
+/// From LLVM Enzyme's activity analysis, there are four activity states.
+// constant instruction vs constant value, a value/instruction (one and the same
+// in LLVM) can be a constant instruction but active value, active instruction
+// but constant value, or active/constant both.
+
+// The result of activity states are potentially different for multiple
+// enzyme.autodiff calls.
+enum class ActivityKind { Constant, ActiveVal, Unknown };
+
+class ValueActivity {
+public:
+  static ValueActivity getConstant() {
+    return ValueActivity(ActivityKind::Constant);
+  }
+
+  static ValueActivity getActiveVal() {
+    return ValueActivity(ActivityKind::ActiveVal);
+  }
+
+  static ValueActivity getUnknown() {
+    return ValueActivity(ActivityKind::Unknown);
+  }
+
+  bool isActiveVal() const { return value == ActivityKind::ActiveVal; }
+
+  bool isConstant() const { return value == ActivityKind::Constant; }
+
+  bool isUnknown() const { return value == ActivityKind::Unknown; }
+
+  ValueActivity() {}
+  ValueActivity(ActivityKind value) : value(value) {}
+
+  /// Get the known activity state.
+  const ActivityKind &getValue() const { return value; }
+
+  bool operator==(const ValueActivity &rhs) const { return value == rhs.value; }
+
+  static ValueActivity merge(const ValueActivity &lhs,
+                             const ValueActivity &rhs) {
+    if (lhs.isUnknown() || rhs.isUnknown())
+      return ValueActivity::getUnknown();
+
+    if (lhs.isConstant() && rhs.isConstant())
+      return ValueActivity::getConstant();
+    return ValueActivity::getActiveVal();
+  }
+
+  static ValueActivity join(const ValueActivity &lhs,
+                            const ValueActivity &rhs) {
+    return ValueActivity::merge(lhs, rhs);
+  }
+
+  void print(raw_ostream &os) const;
+  raw_ostream &operator<<(raw_ostream &os) const;
+
+private:
+  /// The activity kind. Optimistically initialized to constant.
+  ActivityKind value = ActivityKind::Constant;
+};
+
+//===----------------------------------------------------------------------===//
+// ForwardValueActivity
+//===----------------------------------------------------------------------===//
+class ForwardValueActivity : public Lattice<enzyme::ValueActivity> {
+public:
+  using Lattice::Lattice;
+};
+
+//===----------------------------------------------------------------------===//
+// BackwardValueActivity
+//===----------------------------------------------------------------------===//
+class BackwardValueActivity : public AbstractSparseLattice {
+public:
+  using AbstractSparseLattice::AbstractSparseLattice;
+
+  enzyme::ValueActivity getValue() const { return value; }
+
+  void print(raw_ostream &os) const override;
+
+  ChangeResult meet(const AbstractSparseLattice &other) override;
+
+  ChangeResult meet(enzyme::ValueActivity other);
+
+private:
+  enzyme::ValueActivity value;
+};
+
+/// Sparse activity analysis reasons about activity by traversing forward down
+/// the def-use chains starting from active function arguments.
+class SparseForwardActivityAnalysis
+    : public SparseForwardDataFlowAnalysis<enzyme::ForwardValueActivity> {
+public:
+  using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
+
+  void setToEntryState(enzyme::ForwardValueActivity *lattice) override;
+
+  LogicalResult
+  visitOperation(Operation *op,
+                 ArrayRef<const enzyme::ForwardValueActivity *> operands,
+                 ArrayRef<enzyme::ForwardValueActivity *> results) override;
+
+  void
+  visitExternalCall(CallOpInterface call,
+                    ArrayRef<const enzyme::ForwardValueActivity *> operands,
+                    ArrayRef<enzyme::ForwardValueActivity *> results) override;
+
+  void transfer(Operation *op,
+                ArrayRef<const enzyme::ForwardValueActivity *> operands,
+                ArrayRef<enzyme::ForwardValueActivity *> results);
+};
+
+class SparseBackwardActivityAnalysis
+    : public SparseBackwardDataFlowAnalysis<enzyme::BackwardValueActivity> {
+public:
+  using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
+
+  void setToExitState(enzyme::BackwardValueActivity *lattice) override {
+    // llvm::errs() << "backward sparse setting to exit state\n";
+  }
+
+  void visitBranchOperand(OpOperand &operand) override {}
+
+  void visitCallOperand(OpOperand &operand) override {}
+
+  void
+  visitNonControlFlowArguments(RegionSuccessor &successor,
+                               ArrayRef<BlockArgument> arguments) override {}
+
+  void transfer(Operation *op,
+                ArrayRef<enzyme::BackwardValueActivity *> operands,
+                ArrayRef<const enzyme::BackwardValueActivity *> results);
+
+  LogicalResult visitOperation(
+      Operation *op, ArrayRef<enzyme::BackwardValueActivity *> operands,
+      ArrayRef<const enzyme::BackwardValueActivity *> results) override;
+
+  void visitExternalCall(
+      CallOpInterface call, ArrayRef<enzyme::BackwardValueActivity *> operands,
+      ArrayRef<const enzyme::BackwardValueActivity *> results) override;
+};
+
+/// This needs to keep track of three things:
+///   1. Could active info store in?
+///   2. Could active info load out?
+///   TODO: Necessary for run-time activity
+///   3. Could constant info propagate (store?) in?
+///
+/// Active: (forward) active in && (backward) active out && (??) !const in
+/// ActiveOrConstant: active in && active out && const in
+/// Constant: everything else
+struct MemoryActivityState {
+  /// Whether active data has stored into this memory location.
+  bool activeIn = false;
+  /// Whether active data was loaded out of this memory location.
+  bool activeOut = false;
+
+  bool operator==(const MemoryActivityState &other) {
+    return activeIn == other.activeIn && activeOut == other.activeOut;
+  }
+
+  bool operator!=(const MemoryActivityState &other) {
+    return !(*this == other);
+  }
+
+  ChangeResult reset();
+  ChangeResult merge(const MemoryActivityState &other);
+};
+
+class MemoryActivity : public AbstractDenseLattice {
+public:
+  using AbstractDenseLattice::AbstractDenseLattice;
+
+  /// Clear all modifications.
+  ChangeResult reset();
+
+  bool hasActiveData(DistinctAttr aliasClass) const;
+
+  bool activeDataFlowsOut(DistinctAttr aliasClass) const;
+
+  /// Set the internal activity state. Accepts null attribute to indicate "other
+  /// classes".
+  ChangeResult setActiveIn(DistinctAttr aliasClass);
+  ChangeResult setActiveIn();
+  ChangeResult setActiveOut(DistinctAttr aliasClass);
+  ChangeResult setActiveOut();
+  void print(raw_ostream &os) const override;
+  raw_ostream &operator<<(raw_ostream &os) const;
+
+protected:
+  ChangeResult merge(const AbstractDenseLattice &lattice);
+
+private:
+  DenseMap<DistinctAttr, MemoryActivityState> activityStates;
+  MemoryActivityState otherMemoryActivity;
+};
+
+class ForwardMemoryActivity : public MemoryActivity {
+public:
+  using MemoryActivity::MemoryActivity;
+
+  /// Join the activity states.
+  ChangeResult join(const AbstractDenseLattice &lattice) {
+    return merge(lattice);
+  }
+};
+
+class BackwardMemoryActivity : public MemoryActivity {
+public:
+  using MemoryActivity::MemoryActivity;
+
+  ChangeResult meet(const AbstractDenseLattice &lattice) override {
+    return merge(lattice);
+  }
+};
+
+class DenseForwardActivityAnalysis
+    : public DenseForwardDataFlowAnalysis<ForwardMemoryActivity> {
+public:
+  DenseForwardActivityAnalysis(DataFlowSolver &solver, Block *entryBlock,
+                               ArrayRef<enzyme::Activity> argumentActivity)
+      : DenseForwardDataFlowAnalysis(solver), entryBlock(entryBlock),
+        argumentActivity(argumentActivity) {}
+
+  LogicalResult visitOperation(Operation *op,
+                               const ForwardMemoryActivity &before,
+                               ForwardMemoryActivity *after) override;
+
+  void visitCallControlFlowTransfer(CallOpInterface call,
+                                    CallControlFlowAction action,
+                                    const ForwardMemoryActivity &before,
+                                    ForwardMemoryActivity *after) override {
+    join(after, before);
+  }
+
+  /// Initialize the entry block with the supplied argument activities.
+  void setToEntryState(ForwardMemoryActivity *lattice) override;
+
+private:
+  // A pointer to the entry block and argument activities of the top-level
+  // function being differentiated. This is used to set the entry state
+  // because we need access to the results of points-to analysis.
+  Block *entryBlock;
+  SmallVector<enzyme::Activity> argumentActivity;
+};
+
+class DenseBackwardActivityAnalysis
+    : public DenseBackwardDataFlowAnalysis<BackwardMemoryActivity> {
+public:
+  DenseBackwardActivityAnalysis(DataFlowSolver &solver,
+                                SymbolTableCollection &symbolTable,
+                                FunctionOpInterface parentOp,
+                                ArrayRef<enzyme::Activity> argumentActivity)
+      : DenseBackwardDataFlowAnalysis(solver, symbolTable), parentOp(parentOp),
+        argumentActivity(argumentActivity) {}
+
+  LogicalResult visitOperation(Operation *op,
+                               const BackwardMemoryActivity &after,
+                               BackwardMemoryActivity *before) override;
+
+  void visitCallControlFlowTransfer(CallOpInterface call,
+                                    CallControlFlowAction action,
+                                    const BackwardMemoryActivity &after,
+                                    BackwardMemoryActivity *before) override {
+    meet(before, after);
+  }
+
+  void setToExitState(BackwardMemoryActivity *lattice) override {}
+
+private:
+  FunctionOpInterface parentOp;
+  SmallVector<enzyme::Activity> argumentActivity;
+};
+
 void runDataFlowActivityAnalysis(FunctionOpInterface callee,
                                  ArrayRef<enzyme::Activity> argumentActivity,
                                  bool print = false, bool verbose = false,