[WIP] work on sroa input fns
diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp
index 793d6da..669e6d0 100644
--- a/enzyme/Enzyme/Enzyme.cpp
+++ b/enzyme/Enzyme/Enzyme.cpp
@@ -530,7 +530,7 @@
     // initializeLowerAutodiffIntrinsicPass(*PassRegistry::getPassRegistry());
   }
 
-  Function *parseFunctionParameter(CallInst *CI) {
+  Function *parseFunctionParameter(CallInst *CI, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI) {
     Value *fn = CI->getArgOperand(0);
 
     // determine function to differentiate
@@ -539,7 +539,7 @@
     }
 
     Value *ofn = fn;
-    fn = GetFunctionFromValue(fn);
+    fn = GetFunctionFromValue(fn, AA, TLI);
 
     if (!fn || !isa<Function>(fn)) {
       assert(ofn);
@@ -1224,14 +1224,14 @@
     return type_args;
   }
 
-  bool HandleBatch(CallInst *CI) {
+  bool HandleBatch(CallInst *CI, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI) {
     unsigned width = 1;
     unsigned truei = 0;
     std::map<unsigned, Value *> batchOffset;
     SmallVector<Value *, 4> args;
     SmallVector<BATCH_TYPE, 4> arg_types;
     IRBuilder<> Builder(CI);
-    Function *F = parseFunctionParameter(CI);
+    Function *F = parseFunctionParameter(CI, AA, TLI);
     if (!F)
       return false;
 
@@ -1757,11 +1757,11 @@
   }
 
   /// Return whether successful
-  bool HandleAutoDiffArguments(CallInst *CI, DerivativeMode mode,
+  bool HandleAutoDiffArguments(CallInst *CI, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,DerivativeMode mode,
                                bool sizeOnly) {
 
     // determine function to differentiate
-    Function *fn = parseFunctionParameter(CI);
+    Function *fn = parseFunctionParameter(CI, AA, TLI);
     if (!fn)
       return false;
 
@@ -1805,9 +1805,10 @@
 #endif
   }
 
-  bool HandleProbProg(CallInst *CI, ProbProgMode mode) {
+  bool HandleProbProg(CallInst *CI, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
+ProbProgMode mode) {
     IRBuilder<> Builder(CI);
-    Function *F = parseFunctionParameter(CI);
+    Function *F = parseFunctionParameter(CI, AA, TLI);
     if (!F)
       return false;
 
@@ -2012,6 +2013,8 @@
     MapVector<CallInst *, ProbProgMode> toProbProg;
     SetVector<CallInst *> InactiveCalls;
     SetVector<CallInst *> IterCalls;
+    auto &TLI = Logic.PPC.FAM.getResult<TargetLibraryAnalysis>(F);
+    auto &AA = Logic.PPC.getAAResultsFromFunction(&F);
   retry:;
     for (BasicBlock &BB : F) {
       for (Instruction &I : BB) {
@@ -2364,15 +2367,7 @@
         if (enableEnzyme) {
 
           Value *fn = CI->getArgOperand(0);
-          while (auto ci = dyn_cast<CastInst>(fn)) {
-            fn = ci->getOperand(0);
-          }
-          while (auto ci = dyn_cast<BlockAddress>(fn)) {
-            fn = ci->getFunction();
-          }
-          while (auto ci = dyn_cast<ConstantExpr>(fn)) {
-            fn = ci->getOperand(0);
-          }
+		  GetFunctionFromValue(fn, AA, TLI, &fn);
           if (auto si = dyn_cast<SelectInst>(fn)) {
             BasicBlock *post = BB.splitBasicBlock(CI);
             BasicBlock *sel1 = BasicBlock::Create(BB.getContext(), "sel1", &F);
@@ -2449,14 +2444,14 @@
 
     // Perform all the size replacements first to create constants
     for (auto pair : toSize) {
-      bool successful = HandleAutoDiffArguments(pair.first, pair.second,
+      bool successful = HandleAutoDiffArguments(pair.first, AA, TLI, pair.second,
                                                 /*sizeOnly*/ true);
       Changed = true;
       if (!successful)
         break;
     }
     for (auto pair : toLower) {
-      bool successful = HandleAutoDiffArguments(pair.first, pair.second,
+      bool successful = HandleAutoDiffArguments(pair.first, AA, TLI, pair.second,
                                                 /*sizeOnly*/ false);
       Changed = true;
       if (!successful)
@@ -2491,11 +2486,11 @@
     }
 
     for (auto call : toBatch) {
-      HandleBatch(call);
+      HandleBatch(call, AA, TLI);
     }
 
     for (auto &&[call, mode] : toProbProg) {
-      HandleProbProg(call, mode);
+      HandleProbProg(call, AA, TLI, mode);
     }
 
     if (Changed && EnzymeAttributor) {
@@ -2685,7 +2680,7 @@
                     "IllegalNumberOfArguments", CI->getDebugLoc(), CI,
                     "Not enough arguments passed to call to __enzyme_sample");
               }
-              Function *samplefn = GetFunctionFromValue(CI->getOperand(0));
+              Function *samplefn = GetFunctionFromValue(CI->getOperand(0), AA, TLI);
               unsigned expected =
                   samplefn->getFunctionType()->getNumParams() + 3;
 #if LLVM_VERSION_MAJOR >= 14
@@ -2699,7 +2694,7 @@
                             "__enzyme_sample.",
                             " Expected: ", expected, " got: ", actual);
               }
-              Function *pdf = GetFunctionFromValue(CI->getArgOperand(1));
+              Function *pdf = GetFunctionFromValue(CI->getArgOperand(1), AA, TLI);
 
               for (unsigned i = 0;
                    i < samplefn->getFunctionType()->getNumParams(); ++i) {
@@ -2744,7 +2739,7 @@
                     "Not enough arguments passed to call to __enzyme_sample");
               }
               Value *observed = CI->getOperand(0);
-              Function *pdf = GetFunctionFromValue(CI->getArgOperand(1));
+              Function *pdf = GetFunctionFromValue(CI->getArgOperand(1), AA, TLI);
               unsigned expected = pdf->getFunctionType()->getNumParams() - 1;
 
 #if LLVM_VERSION_MAJOR >= 14
diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp
index 94de076..9d4a4b5 100644
--- a/enzyme/Enzyme/Utils.cpp
+++ b/enzyme/Enzyme/Utils.cpp
@@ -2209,7 +2209,7 @@
   llvm_unreachable("unknown inst2");
 }
 
-Function *GetFunctionFromValue(Value *fn) {
+Function *GetFunctionFromValue(Value *fn, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, llvm::Value **lastVal) {
   while (!isa<Function>(fn)) {
     if (auto ci = dyn_cast<CastInst>(fn)) {
       fn = ci->getOperand(0);
@@ -2321,8 +2321,12 @@
           if (isa<LoadInst>(cur))
             continue;
 
-          if (!cur->mayWriteToMemory() && cur->getType()->isVoidTy())
-            continue;
+		  if (cur->getType()->isVoidTy()) {
+            if (!cur->mayWriteToMemory())
+              continue;
+		    if (!writesToMemoryReadBy(AA, TLI, LI, cur))
+		      continue;
+		  }
 
           legal = false;
           break;
@@ -2337,6 +2341,7 @@
     break;
   }
 
+  if (lastVal) *lastVal = fn;
   return dyn_cast<Function>(fn);
 }
 
diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h
index 3ace540..12ea7ac 100644
--- a/enzyme/Enzyme/Utils.h
+++ b/enzyme/Enzyme/Utils.h
@@ -1192,7 +1192,7 @@
                             llvm::Value *shadow, const char *Message,
                             llvm::DebugLoc &&loc, llvm::Instruction *orig);
 
-llvm::Function *GetFunctionFromValue(llvm::Value *fn);
+llvm::Function *GetFunctionFromValue(llvm::Value *fn, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, llvm::Value **lastVal=nullptr);
 
 static inline bool shouldDisableNoWrite(const llvm::CallInst *CI) {
   auto F = getFunctionFromCall(CI);