Correct errors
diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 980487c..0986fd4 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h
@@ -5748,8 +5748,8 @@ cubcall->setDebugLoc(gutils->getNewFromOriginal(orig->getDebugLoc())); cubcall->setCallingConv(orig->getCallingConv()); Value *dif0 = Builder2.CreateFDiv( - Builder2.CreateFMul(diffe(orig, Builder2), x), - Builder2.CreateFMul(ConstantFP::get(x->getType(), 3), cubcall)); + Builder2.CreateFMul(diffe(orig, Builder2), cubcall), + Builder2.CreateFMul(ConstantFP::get(x->getType(), 3), x)); addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType()); return; } @@ -6876,7 +6876,12 @@ goto badaugmentedfn; for (unsigned i = 0; i < pre_args.size(); ++i) { - if (pre_args[i]->getType() != FT->getParamType(i)) + if (pre_args[i]->getType() == FT->getParamType(i)) + continue; + else if (!orig->getCalledFunction()) + pre_args[i] = + BuilderZ.CreateBitCast(pre_args[i], FT->getParamType(i)); + else goto badaugmentedfn; } @@ -7194,7 +7199,11 @@ goto badfn; for (unsigned i = 0; i < args.size(); ++i) { - if (args[i]->getType() != FT->getParamType(i)) + if (args[i]->getType() == FT->getParamType(i)) + continue; + else if (!orig->getCalledFunction()) + args[i] = Builder2.CreateBitCast(args[i], FT->getParamType(i)); + else goto badfn; }
diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index c70fddb..1baa5c3 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp
@@ -505,12 +505,7 @@ std::map<Argument *, bool> compute_uncacheable_args_for_one_callsite(CallInst *callsite_op) { - Function *Fn = callsite_op->getCalledFunction(); - -#if LLVM_VERSION_MAJOR >= 11 - if (auto alias = dyn_cast<GlobalAlias>(callsite_op->getCalledOperand())) - Fn = dyn_cast<Function>(alias->getAliasee()); -#endif + Function *Fn = getFunctionFromCall(callsite_op); if (!Fn) return {}; @@ -575,19 +570,7 @@ allFollowersOf(callsite_op, [&](Instruction *inst2) { // Don't consider modref from malloc/free as a need to cache if (auto obj_op = dyn_cast<CallInst>(inst2)) { - Function *called = obj_op->getCalledFunction(); -#if LLVM_VERSION_MAJOR >= 11 - if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledOperand())) -#else - if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue())) -#endif - { - if (castinst->isCast()) { - if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) { - called = fn; - } - } - } + Function *called = getFunctionFromCall(obj_op); if (called && isCertainPrintMallocOrFree(called)) { return false; } @@ -802,20 +785,7 @@ bool isLibMFn = false; if (auto obj_op = dyn_cast<CallInst>(inst)) { - Function *called = obj_op->getCalledFunction(); -#if LLVM_VERSION_MAJOR >= 11 - if (auto castinst = - dyn_cast<ConstantExpr>(obj_op->getCalledOperand())) { -#else - if (auto castinst = - dyn_cast<ConstantExpr>(obj_op->getCalledValue())) { -#endif - if (castinst->isCast()) { - if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) { - called = fn; - } - } - } + Function *called = getFunctionFromCall((CallInst *)obj_op); if (called && isDeallocationFunction(*called, TLI)) { if ((mode == DerivativeMode::ReverseModePrimal || mode == DerivativeMode::ReverseModeCombined) && @@ -1213,19 +1183,9 @@ } if (auto op = dyn_cast<CallInst>(I)) { - Function *called = op->getCalledFunction(); - - if (auto castinst = dyn_cast<ConstantExpr>(calledValue)) { - if (castinst->isCast()) { - if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) { - if (isAllocationFunction(*fn, gutils->TLI) || - isDeallocationFunction(*fn, gutils->TLI)) { - return; - } - } - } - } - if (called && isDeallocationFunction(*called, gutils->TLI)) + Function *called = getFunctionFromCall(op); + if (called && (isAllocationFunction(*called, gutils->TLI) || + isDeallocationFunction(*called, gutils->TLI))) return; } @@ -1565,6 +1525,10 @@ auto in_arg = todiff->arg_begin(); auto pp_arg = gutils->oldFunc->arg_begin(); for (; pp_arg != gutils->oldFunc->arg_end();) { + if (_uncacheable_args.find(in_arg) == _uncacheable_args.end()) { + llvm::errs() << " todiff: " << *todiff << "\n"; + llvm::errs() << " inargs: " << *in_arg << "\n"; + } assert(_uncacheable_args.find(in_arg) != _uncacheable_args.end()); _uncacheable_argsPP[pp_arg] = _uncacheable_args.find(in_arg)->second; ++pp_arg;
diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index d2ab3c9..536ee2d 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp
@@ -2607,19 +2607,7 @@ if (isa<IntrinsicInst>(CI)) continue; if (!isConstantInstruction(CI)) { - Function *F = CI->getCalledFunction(); -#if LLVM_VERSION_MAJOR >= 11 - if (auto castinst = - dyn_cast<ConstantExpr>(CI->getCalledOperand())) -#else - if (auto castinst = dyn_cast<ConstantExpr>(CI->getCalledValue())) -#endif - { - if (castinst->isCast()) - if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) { - F = fn; - } - } + Function *F = getFunctionFromCall(CI); if (F && (isMemFreeLibMFunction(F->getName()) || F->getName() == "__fd_sincos_1")) { continue; @@ -4769,20 +4757,7 @@ } } } else if (auto CI = dyn_cast<CallInst>(&I)) { - Function *F = CI->getCalledFunction(); - -#if LLVM_VERSION_MAJOR >= 11 - if (auto castinst = dyn_cast<ConstantExpr>(CI->getCalledOperand())) -#else - if (auto castinst = dyn_cast<ConstantExpr>(CI->getCalledValue())) -#endif - { - if (castinst->isCast()) - if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) { - F = fn; - } - } - + Function *F = getFunctionFromCall(CI); if (F && isAllocationFunction(*F, TLI)) Available[CI] = CI; }
diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index bff5d1f..e1a6d1e 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h
@@ -511,18 +511,7 @@ if (!CI) continue; - Function *called = CI->getCalledFunction(); - -#if LLVM_VERSION_MAJOR >= 11 - if (auto castinst = dyn_cast<ConstantExpr>(CI->getCalledOperand())) -#else - if (auto castinst = dyn_cast<ConstantExpr>(CI->getCalledValue())) -#endif - if (castinst->isCast()) { - if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) { - called = fn; - } - } + Function *called = getFunctionFromCall(CI); if (!called) continue; if (isDeallocationFunction(*called, TLI)) {
diff --git a/enzyme/Enzyme/LibraryFuncs.h b/enzyme/Enzyme/LibraryFuncs.h index b7149a6..f02fb62 100644 --- a/enzyme/Enzyme/LibraryFuncs.h +++ b/enzyme/Enzyme/LibraryFuncs.h
@@ -369,19 +369,7 @@ maybeWriter->getParent()->getParent()); using namespace llvm; if (auto call = dyn_cast<CallInst>(maybeWriter)) { - Function *called = call->getCalledFunction(); -#if LLVM_VERSION_MAJOR >= 11 - if (auto castinst = dyn_cast<ConstantExpr>(call->getCalledOperand())) -#else - if (auto castinst = dyn_cast<ConstantExpr>(call->getCalledValue())) -#endif - { - if (castinst->isCast()) { - if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) { - called = fn; - } - } - } + Function *called = getFunctionFromCall(call); if (called && isCertainPrintMallocOrFree(called)) { return false; } @@ -414,19 +402,7 @@ } } if (auto call = dyn_cast<CallInst>(maybeReader)) { - Function *called = call->getCalledFunction(); -#if LLVM_VERSION_MAJOR >= 11 - if (auto castinst = dyn_cast<ConstantExpr>(call->getCalledOperand())) -#else - if (auto castinst = dyn_cast<ConstantExpr>(call->getCalledValue())) -#endif - { - if (castinst->isCast()) { - if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) { - called = fn; - } - } - } + Function *called = getFunctionFromCall(call); if (called && isCertainMallocOrFree(called)) { return false; } @@ -447,19 +423,7 @@ } } if (auto call = dyn_cast<InvokeInst>(maybeWriter)) { - Function *called = call->getCalledFunction(); -#if LLVM_VERSION_MAJOR >= 11 - if (auto castinst = dyn_cast<ConstantExpr>(call->getCalledOperand())) -#else - if (auto castinst = dyn_cast<ConstantExpr>(call->getCalledValue())) -#endif - { - if (castinst->isCast()) { - if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) { - called = fn; - } - } - } + Function *called = getFunctionFromCall(call); if (called && isCertainMallocOrFree(called)) { return false; } @@ -480,19 +444,7 @@ } } if (auto call = dyn_cast<InvokeInst>(maybeReader)) { - Function *called = call->getCalledFunction(); -#if LLVM_VERSION_MAJOR >= 11 - if (auto castinst = dyn_cast<ConstantExpr>(call->getCalledOperand())) -#else - if (auto castinst = dyn_cast<ConstantExpr>(call->getCalledValue())) -#endif - { - if (castinst->isCast()) { - if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) { - called = fn; - } - } - } + Function *called = getFunctionFromCall(call); if (called && isCertainMallocOrFree(called)) { return false; }
diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 0c5f30a..80f1d2e 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
@@ -2429,6 +2429,10 @@ } void TypeAnalyzer::visitMemTransferCommon(llvm::CallInst &MTI) { + if (MTI.getType()->isIntegerTy()) { + updateAnalysis(&MTI, TypeTree(BaseType::Integer).Only(-1), &MTI); + } + if (!(direction & UP)) return;
diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 02e50be..50bbce5 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h
@@ -821,7 +821,7 @@ virtual ~AssertingReplacingVH() {} }; -static inline llvm::Function *getFunctionFromCall(llvm::CallInst *op) { +template <typename T> static inline llvm::Function *getFunctionFromCall(T *op) { llvm::Function *called = nullptr; using namespace llvm; llvm::Value *callVal;
diff --git a/enzyme/test/Enzyme/ReverseMode/cbrt.ll b/enzyme/test/Enzyme/ReverseMode/cbrt.ll index 8c3513e..be185f6 100644 --- a/enzyme/test/Enzyme/ReverseMode/cbrt.ll +++ b/enzyme/test/Enzyme/ReverseMode/cbrt.ll
@@ -21,8 +21,8 @@ ; CHECK: define internal { double } @diffetester(double %x, double %differeturn) { ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = call fast double @cbrt(double %x) -; CHECK-NEXT: %1 = fmul fast double 3.000000e+00, %0 -; CHECK-NEXT: %2 = fmul fast double %differeturn, %x +; CHECK-NEXT: %1 = fmul fast double 3.000000e+00, %x +; CHECK-NEXT: %2 = fmul fast double %differeturn, %0 ; CHECK-NEXT: %3 = fdiv fast double %2, %1 ; CHECK-NEXT: %4 = insertvalue { double } undef, double %3, 0 ; CHECK-NEXT: ret { double } %4