[WIP] work on sroa input fns
diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 793d6da..669e6d0 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp
@@ -530,7 +530,7 @@ // initializeLowerAutodiffIntrinsicPass(*PassRegistry::getPassRegistry()); } - Function *parseFunctionParameter(CallInst *CI) { + Function *parseFunctionParameter(CallInst *CI, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI) { Value *fn = CI->getArgOperand(0); // determine function to differentiate @@ -539,7 +539,7 @@ } Value *ofn = fn; - fn = GetFunctionFromValue(fn); + fn = GetFunctionFromValue(fn, AA, TLI); if (!fn || !isa<Function>(fn)) { assert(ofn); @@ -1224,14 +1224,14 @@ return type_args; } - bool HandleBatch(CallInst *CI) { + bool HandleBatch(CallInst *CI, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI) { unsigned width = 1; unsigned truei = 0; std::map<unsigned, Value *> batchOffset; SmallVector<Value *, 4> args; SmallVector<BATCH_TYPE, 4> arg_types; IRBuilder<> Builder(CI); - Function *F = parseFunctionParameter(CI); + Function *F = parseFunctionParameter(CI, AA, TLI); if (!F) return false; @@ -1757,11 +1757,11 @@ } /// Return whether successful - bool HandleAutoDiffArguments(CallInst *CI, DerivativeMode mode, + bool HandleAutoDiffArguments(CallInst *CI, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,DerivativeMode mode, bool sizeOnly) { // determine function to differentiate - Function *fn = parseFunctionParameter(CI); + Function *fn = parseFunctionParameter(CI, AA, TLI); if (!fn) return false; @@ -1805,9 +1805,10 @@ #endif } - bool HandleProbProg(CallInst *CI, ProbProgMode mode) { + bool HandleProbProg(CallInst *CI, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, +ProbProgMode mode) { IRBuilder<> Builder(CI); - Function *F = parseFunctionParameter(CI); + Function *F = parseFunctionParameter(CI, AA, TLI); if (!F) return false; @@ -2012,6 +2013,8 @@ MapVector<CallInst *, ProbProgMode> toProbProg; SetVector<CallInst *> InactiveCalls; SetVector<CallInst *> IterCalls; + auto &TLI = Logic.PPC.FAM.getResult<TargetLibraryAnalysis>(F); + auto &AA = Logic.PPC.getAAResultsFromFunction(&F); retry:; for (BasicBlock &BB : F) { for (Instruction &I : BB) { @@ -2364,15 +2367,7 @@ if (enableEnzyme) { Value *fn = CI->getArgOperand(0); - while (auto ci = dyn_cast<CastInst>(fn)) { - fn = ci->getOperand(0); - } - while (auto ci = dyn_cast<BlockAddress>(fn)) { - fn = ci->getFunction(); - } - while (auto ci = dyn_cast<ConstantExpr>(fn)) { - fn = ci->getOperand(0); - } + GetFunctionFromValue(fn, AA, TLI, &fn); if (auto si = dyn_cast<SelectInst>(fn)) { BasicBlock *post = BB.splitBasicBlock(CI); BasicBlock *sel1 = BasicBlock::Create(BB.getContext(), "sel1", &F); @@ -2449,14 +2444,14 @@ // Perform all the size replacements first to create constants for (auto pair : toSize) { - bool successful = HandleAutoDiffArguments(pair.first, pair.second, + bool successful = HandleAutoDiffArguments(pair.first, AA, TLI, pair.second, /*sizeOnly*/ true); Changed = true; if (!successful) break; } for (auto pair : toLower) { - bool successful = HandleAutoDiffArguments(pair.first, pair.second, + bool successful = HandleAutoDiffArguments(pair.first, AA, TLI, pair.second, /*sizeOnly*/ false); Changed = true; if (!successful) @@ -2491,11 +2486,11 @@ } for (auto call : toBatch) { - HandleBatch(call); + HandleBatch(call, AA, TLI); } for (auto &&[call, mode] : toProbProg) { - HandleProbProg(call, mode); + HandleProbProg(call, AA, TLI, mode); } if (Changed && EnzymeAttributor) { @@ -2685,7 +2680,7 @@ "IllegalNumberOfArguments", CI->getDebugLoc(), CI, "Not enough arguments passed to call to __enzyme_sample"); } - Function *samplefn = GetFunctionFromValue(CI->getOperand(0)); + Function *samplefn = GetFunctionFromValue(CI->getOperand(0), AA, TLI); unsigned expected = samplefn->getFunctionType()->getNumParams() + 3; #if LLVM_VERSION_MAJOR >= 14 @@ -2699,7 +2694,7 @@ "__enzyme_sample.", " Expected: ", expected, " got: ", actual); } - Function *pdf = GetFunctionFromValue(CI->getArgOperand(1)); + Function *pdf = GetFunctionFromValue(CI->getArgOperand(1), AA, TLI); for (unsigned i = 0; i < samplefn->getFunctionType()->getNumParams(); ++i) { @@ -2744,7 +2739,7 @@ "Not enough arguments passed to call to __enzyme_sample"); } Value *observed = CI->getOperand(0); - Function *pdf = GetFunctionFromValue(CI->getArgOperand(1)); + Function *pdf = GetFunctionFromValue(CI->getArgOperand(1), AA, TLI); unsigned expected = pdf->getFunctionType()->getNumParams() - 1; #if LLVM_VERSION_MAJOR >= 14
diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 94de076..9d4a4b5 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp
@@ -2209,7 +2209,7 @@ llvm_unreachable("unknown inst2"); } -Function *GetFunctionFromValue(Value *fn) { +Function *GetFunctionFromValue(Value *fn, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, llvm::Value **lastVal) { while (!isa<Function>(fn)) { if (auto ci = dyn_cast<CastInst>(fn)) { fn = ci->getOperand(0); @@ -2321,8 +2321,12 @@ if (isa<LoadInst>(cur)) continue; - if (!cur->mayWriteToMemory() && cur->getType()->isVoidTy()) - continue; + if (cur->getType()->isVoidTy()) { + if (!cur->mayWriteToMemory()) + continue; + if (!writesToMemoryReadBy(AA, TLI, LI, cur)) + continue; + } legal = false; break; @@ -2337,6 +2341,7 @@ break; } + if (lastVal) *lastVal = fn; return dyn_cast<Function>(fn); }
diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 3ace540..12ea7ac 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h
@@ -1192,7 +1192,7 @@ llvm::Value *shadow, const char *Message, llvm::DebugLoc &&loc, llvm::Instruction *orig); -llvm::Function *GetFunctionFromValue(llvm::Value *fn); +llvm::Function *GetFunctionFromValue(llvm::Value *fn, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, llvm::Value **lastVal=nullptr); static inline bool shouldDisableNoWrite(const llvm::CallInst *CI) { auto F = getFunctionFromCall(CI);