| //===- CApi.cpp - Enzyme API exported to C for external use -----------===// |
| // |
| // Enzyme Project |
| // |
| // Part of the Enzyme Project, under the Apache License v2.0 with LLVM |
| // Exceptions. See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| // If using this code in an academic setting, please cite the following: |
| // @incollection{enzymeNeurips, |
| // title = {Instead of Rewriting Foreign Code for Machine Learning, |
| // Automatically Synthesize Fast Gradients}, |
| // author = {Moses, William S. and Churavy, Valentin}, |
| // booktitle = {Advances in Neural Information Processing Systems 33}, |
| // year = {2020}, |
| // note = {To appear in}, |
| // } |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file defines various utility functions of Enzyme for access via C |
| // |
| //===----------------------------------------------------------------------===// |
| #include "CApi.h" |
| #if LLVM_VERSION_MAJOR >= 16 |
| #define private public |
| #include "llvm/Analysis/ScalarEvolution.h" |
| #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" |
| #undef private |
| #else |
| #include "SCEV/ScalarEvolution.h" |
| #include "SCEV/ScalarEvolutionExpander.h" |
| #endif |
| |
| #include "DiffeGradientUtils.h" |
| #include "DifferentialUseAnalysis.h" |
| #include "EnzymeLogic.h" |
| #include "GradientUtils.h" |
| #include "LibraryFuncs.h" |
| #if LLVM_VERSION_MAJOR >= 16 |
| #include "llvm/Analysis/TargetLibraryInfo.h" |
| #else |
| #include "SCEV/TargetLibraryInfo.h" |
| #endif |
| #include "TraceInterface.h" |
| |
| // #include "llvm/ADT/Triple.h" |
| #include "llvm/Analysis/CallGraph.h" |
| #include "llvm/Analysis/GlobalsModRef.h" |
| #include "llvm/IR/DIBuilder.h" |
| #include "llvm/IR/MDBuilder.h" |
| #include "llvm/Transforms/Utils/Cloning.h" |
| |
| #include "llvm/IR/LegacyPassManager.h" |
| #include "llvm/Transforms/IPO/Attributor.h" |
| |
| #define addAttribute addAttributeAtIndex |
| #define removeAttribute removeAttributeAtIndex |
| #define getAttribute getAttributeAtIndex |
| #define hasAttribute hasAttributeAtIndex |
| |
| using namespace llvm; |
| |
| TargetLibraryInfo eunwrap(LLVMTargetLibraryInfoRef P) { |
| return TargetLibraryInfo(*reinterpret_cast<TargetLibraryInfoImpl *>(P)); |
| } |
| |
| EnzymeLogic &eunwrap(EnzymeLogicRef LR) { return *(EnzymeLogic *)LR; } |
| |
| TraceInterface *eunwrap(EnzymeTraceInterfaceRef Ref) { |
| return (TraceInterface *)Ref; |
| } |
| |
| TypeAnalysis &eunwrap(EnzymeTypeAnalysisRef TAR) { |
| return *(TypeAnalysis *)TAR; |
| } |
| AugmentedReturn *eunwrap(EnzymeAugmentedReturnPtr ARP) { |
| return (AugmentedReturn *)ARP; |
| } |
| EnzymeAugmentedReturnPtr ewrap(const AugmentedReturn &AR) { |
| return (EnzymeAugmentedReturnPtr)(&AR); |
| } |
| |
| ConcreteType eunwrap(CConcreteType CDT, llvm::LLVMContext &ctx) { |
| switch (CDT) { |
| case DT_Anything: |
| return BaseType::Anything; |
| case DT_Integer: |
| return BaseType::Integer; |
| case DT_Pointer: |
| return BaseType::Pointer; |
| case DT_Half: |
| return ConcreteType(llvm::Type::getHalfTy(ctx)); |
| case DT_Float: |
| return ConcreteType(llvm::Type::getFloatTy(ctx)); |
| case DT_Double: |
| return ConcreteType(llvm::Type::getDoubleTy(ctx)); |
| case DT_X86_FP80: |
| return ConcreteType(llvm::Type::getX86_FP80Ty(ctx)); |
| case DT_BFloat16: |
| return ConcreteType(llvm::Type::getBFloatTy(ctx)); |
| case DT_FP128: |
| return ConcreteType(llvm::Type::getFP128Ty(ctx)); |
| case DT_Unknown: |
| return BaseType::Unknown; |
| } |
| llvm_unreachable("Unknown concrete type to unwrap"); |
| } |
| |
| std::vector<int> eunwrap(IntList IL) { |
| std::vector<int> v; |
| for (size_t i = 0; i < IL.size; i++) { |
| v.push_back((int)IL.data[i]); |
| } |
| return v; |
| } |
| std::set<int64_t> eunwrap64(IntList IL) { |
| std::set<int64_t> v; |
| for (size_t i = 0; i < IL.size; i++) { |
| v.insert((int64_t)IL.data[i]); |
| } |
| return v; |
| } |
| TypeTree eunwrap(CTypeTreeRef CTT) { return *(TypeTree *)CTT; } |
| |
| CConcreteType ewrap(const ConcreteType &CT) { |
| if (auto flt = CT.isFloat()) { |
| if (flt->isHalfTy()) |
| return DT_Half; |
| if (flt->isFloatTy()) |
| return DT_Float; |
| if (flt->isDoubleTy()) |
| return DT_Double; |
| if (flt->isX86_FP80Ty()) |
| return DT_X86_FP80; |
| if (flt->isBFloatTy()) |
| return DT_BFloat16; |
| if (flt->isFP128Ty()) |
| return DT_FP128; |
| } else { |
| switch (CT.SubTypeEnum) { |
| case BaseType::Integer: |
| return DT_Integer; |
| case BaseType::Pointer: |
| return DT_Pointer; |
| case BaseType::Anything: |
| return DT_Anything; |
| case BaseType::Unknown: |
| return DT_Unknown; |
| case BaseType::Float: |
| llvm_unreachable("Illegal conversion of concretetype"); |
| } |
| } |
| llvm_unreachable("Illegal conversion of concretetype"); |
| } |
| |
| IntList ewrap(const std::vector<int> &offsets) { |
| IntList IL; |
| IL.size = offsets.size(); |
| IL.data = new int64_t[IL.size]; |
| for (size_t i = 0; i < offsets.size(); i++) { |
| IL.data[i] = offsets[i]; |
| } |
| return IL; |
| } |
| |
| CTypeTreeRef ewrap(const TypeTree &TT) { |
| return (CTypeTreeRef)(new TypeTree(TT)); |
| } |
| |
| FnTypeInfo eunwrap(CFnTypeInfo CTI, llvm::Function *F) { |
| FnTypeInfo FTI(F); |
| // auto &ctx = F->getContext(); |
| FTI.Return = eunwrap(CTI.Return); |
| |
| size_t argnum = 0; |
| for (auto &arg : F->args()) { |
| FTI.Arguments[&arg] = eunwrap(CTI.Arguments[argnum]); |
| FTI.KnownValues[&arg] = eunwrap64(CTI.KnownValues[argnum]); |
| argnum++; |
| } |
| return FTI; |
| } |
| |
| extern "C" { |
| |
| void EnzymeSetCLBool(void *ptr, uint8_t val) { |
| auto cl = (llvm::cl::opt<bool> *)ptr; |
| cl->setValue((bool)val); |
| } |
| |
| uint8_t EnzymeGetCLBool(void *ptr) { |
| auto cl = (llvm::cl::opt<bool> *)ptr; |
| return (uint8_t)(bool)cl->getValue(); |
| } |
| |
| void EnzymeSetCLInteger(void *ptr, int64_t val) { |
| auto cl = (llvm::cl::opt<int> *)ptr; |
| cl->setValue((int)val); |
| } |
| |
| int64_t EnzymeGetCLInteger(void *ptr) { |
| auto cl = (llvm::cl::opt<int> *)ptr; |
| return (int64_t)cl->getValue(); |
| } |
| |
| void EnzymeSetCLString(void *ptr, const char *val) { |
| if (auto *clopt = static_cast<cl::opt<std::string> *>(ptr)) |
| clopt->setValue(val); |
| } |
| |
| EnzymeLogicRef CreateEnzymeLogic(uint8_t PostOpt) { |
| return (EnzymeLogicRef)(new EnzymeLogic((bool)PostOpt)); |
| } |
| |
| void EnzymeLogicSetExternalContext(EnzymeLogicRef Ref, void *ExternalContext) { |
| eunwrap(Ref).ExternalContext = ExternalContext; |
| } |
| |
| void *EnzymeLogicGetExternalContext(EnzymeLogicRef Ref) { |
| return eunwrap(Ref).ExternalContext; |
| } |
| |
| EnzymeTraceInterfaceRef FindEnzymeStaticTraceInterface(LLVMModuleRef M) { |
| return (EnzymeTraceInterfaceRef)(new StaticTraceInterface(unwrap(M))); |
| } |
| |
| EnzymeTraceInterfaceRef CreateEnzymeStaticTraceInterface( |
| LLVMContextRef C, LLVMValueRef getTraceFunction, |
| LLVMValueRef getChoiceFunction, LLVMValueRef insertCallFunction, |
| LLVMValueRef insertChoiceFunction, LLVMValueRef insertArgumentFunction, |
| LLVMValueRef insertReturnFunction, LLVMValueRef insertFunctionFunction, |
| LLVMValueRef insertChoiceGradientFunction, |
| LLVMValueRef insertArgumentGradientFunction, LLVMValueRef newTraceFunction, |
| LLVMValueRef freeTraceFunction, LLVMValueRef hasCallFunction, |
| LLVMValueRef hasChoiceFunction) { |
| return (EnzymeTraceInterfaceRef)(new StaticTraceInterface( |
| *unwrap(C), cast<Function>(unwrap(getTraceFunction)), |
| cast<Function>(unwrap(getChoiceFunction)), |
| cast<Function>(unwrap(insertCallFunction)), |
| cast<Function>(unwrap(insertChoiceFunction)), |
| cast<Function>(unwrap(insertArgumentFunction)), |
| cast<Function>(unwrap(insertReturnFunction)), |
| cast<Function>(unwrap(insertFunctionFunction)), |
| cast<Function>(unwrap(insertChoiceGradientFunction)), |
| cast<Function>(unwrap(insertArgumentGradientFunction)), |
| cast<Function>(unwrap(newTraceFunction)), |
| cast<Function>(unwrap(freeTraceFunction)), |
| cast<Function>(unwrap(hasCallFunction)), |
| cast<Function>(unwrap(hasChoiceFunction)))); |
| }; |
| |
| EnzymeTraceInterfaceRef |
| CreateEnzymeDynamicTraceInterface(LLVMValueRef interface, LLVMValueRef F) { |
| return (EnzymeTraceInterfaceRef)(new DynamicTraceInterface( |
| unwrap(interface), cast<Function>(unwrap(F)))); |
| } |
| |
| void ClearEnzymeLogic(EnzymeLogicRef Ref) { eunwrap(Ref).clear(); } |
| |
| void EnzymeLogicErasePreprocessedFunctions(EnzymeLogicRef Ref) { |
| auto &Logic = eunwrap(Ref); |
| for (const auto &pair : Logic.PPC.cache) |
| pair.second->eraseFromParent(); |
| } |
| |
| void FreeEnzymeLogic(EnzymeLogicRef Ref) { delete (EnzymeLogic *)Ref; } |
| |
| void FreeTraceInterface(EnzymeTraceInterfaceRef Ref) { |
| delete (TraceInterface *)Ref; |
| } |
| |
| EnzymeTypeAnalysisRef CreateTypeAnalysis(EnzymeLogicRef Log, |
| char **customRuleNames, |
| CustomRuleType *customRules, |
| size_t numRules) { |
| TypeAnalysis *TA = new TypeAnalysis(((EnzymeLogic *)Log)->PPC.FAM); |
| for (size_t i = 0; i < numRules; i++) { |
| CustomRuleType rule = customRules[i]; |
| TA->CustomRules[customRuleNames[i]] = |
| [=](int direction, TypeTree &returnTree, ArrayRef<TypeTree> argTrees, |
| ArrayRef<std::set<int64_t>> knownValues, CallBase *call, |
| TypeAnalyzer *TA) -> uint8_t { |
| CTypeTreeRef creturnTree = (CTypeTreeRef)(&returnTree); |
| CTypeTreeRef *cargs = new CTypeTreeRef[argTrees.size()]; |
| IntList *kvs = new IntList[argTrees.size()]; |
| for (size_t i = 0; i < argTrees.size(); ++i) { |
| cargs[i] = (CTypeTreeRef)(&(argTrees[i])); |
| kvs[i].size = knownValues[i].size(); |
| kvs[i].data = new int64_t[kvs[i].size]; |
| size_t j = 0; |
| for (auto val : knownValues[i]) { |
| kvs[i].data[j] = val; |
| j++; |
| } |
| } |
| uint8_t result = rule(direction, creturnTree, cargs, kvs, argTrees.size(), |
| wrap(call), TA); |
| delete[] cargs; |
| for (size_t i = 0; i < argTrees.size(); ++i) { |
| delete[] kvs[i].data; |
| } |
| delete[] kvs; |
| return result; |
| }; |
| } |
| return (EnzymeTypeAnalysisRef)TA; |
| } |
| |
| void ClearTypeAnalysis(EnzymeTypeAnalysisRef TAR) { eunwrap(TAR).clear(); } |
| |
| void FreeTypeAnalysis(EnzymeTypeAnalysisRef TAR) { |
| TypeAnalysis *TA = (TypeAnalysis *)TAR; |
| delete TA; |
| } |
| |
| void *EnzymeAnalyzeTypes(EnzymeTypeAnalysisRef TAR, CFnTypeInfo CTI, |
| LLVMValueRef F) { |
| FnTypeInfo FTI(eunwrap(CTI, cast<Function>(unwrap(F)))); |
| return (void *)((TypeAnalysis *)TAR)->analyzeFunction(FTI).analyzer; |
| } |
| |
| void *EnzymeGradientUtilsTypeAnalyzer(GradientUtils *G) { |
| return (void *)&G->TR.analyzer; |
| } |
| |
| void EnzymeGradientUtilsErase(GradientUtils *G, LLVMValueRef I) { |
| return G->erase(cast<Instruction>(unwrap(I))); |
| } |
| void EnzymeGradientUtilsEraseWithPlaceholder(GradientUtils *G, LLVMValueRef I, |
| LLVMValueRef orig, uint8_t erase) { |
| return G->eraseWithPlaceholder(cast<Instruction>(unwrap(I)), |
| cast<Instruction>(unwrap(orig)), |
| "_replacementABI", erase != 0); |
| } |
| |
| void EnzymeGradientUtilsReplaceAWithB(GradientUtils *G, LLVMValueRef A, |
| LLVMValueRef B) { |
| return G->replaceAWithB(unwrap(A), unwrap(B)); |
| } |
| |
| void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle, |
| CustomShadowFree FHandle) { |
| shadowHandlers[Name] = [=](IRBuilder<> &B, CallInst *CI, |
| ArrayRef<Value *> Args, |
| GradientUtils *gutils) -> llvm::Value * { |
| SmallVector<LLVMValueRef, 3> refs; |
| for (auto a : Args) |
| refs.push_back(wrap(a)); |
| return unwrap( |
| AHandle(wrap(&B), wrap(CI), Args.size(), refs.data(), gutils)); |
| }; |
| if (FHandle) |
| shadowErasers[Name] = [=](IRBuilder<> &B, |
| Value *ToFree) -> llvm::CallInst * { |
| return cast_or_null<CallInst>(unwrap(FHandle(wrap(&B), wrap(ToFree)))); |
| }; |
| } |
| |
| void EnzymeRegisterCallHandler(const char *Name, |
| CustomAugmentedFunctionForward FwdHandle, |
| CustomFunctionReverse RevHandle) { |
| auto &pair = customCallHandlers[Name]; |
| pair.first = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils, |
| Value *&normalReturn, Value *&shadowReturn, |
| Value *&tape) -> bool { |
| LLVMValueRef normalR = wrap(normalReturn); |
| LLVMValueRef shadowR = wrap(shadowReturn); |
| LLVMValueRef tapeR = wrap(tape); |
| uint8_t noMod = |
| FwdHandle(wrap(&B), wrap(CI), &gutils, &normalR, &shadowR, &tapeR); |
| normalReturn = unwrap(normalR); |
| shadowReturn = unwrap(shadowR); |
| tape = unwrap(tapeR); |
| return noMod != 0; |
| }; |
| pair.second = [=](IRBuilder<> &B, CallInst *CI, DiffeGradientUtils &gutils, |
| Value *tape) { |
| RevHandle(wrap(&B), wrap(CI), &gutils, wrap(tape)); |
| }; |
| } |
| |
| void EnzymeRegisterFwdCallHandler(char *Name, CustomFunctionForward FwdHandle) { |
| auto &pair = customFwdCallHandlers[Name]; |
| pair = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils, |
| Value *&normalReturn, Value *&shadowReturn) -> bool { |
| LLVMValueRef normalR = wrap(normalReturn); |
| LLVMValueRef shadowR = wrap(shadowReturn); |
| uint8_t noMod = FwdHandle(wrap(&B), wrap(CI), &gutils, &normalR, &shadowR); |
| normalReturn = unwrap(normalR); |
| shadowReturn = unwrap(shadowR); |
| return noMod != 0; |
| }; |
| } |
| |
| void EnzymeRegisterDiffUseCallHandler(char *Name, |
| CustomFunctionDiffUse Handle) { |
| auto &pair = customDiffUseHandlers[Name]; |
| pair = [=](const CallInst *CI, const GradientUtils *gutils, const Value *arg, |
| bool isshadow, DerivativeMode mode, bool &useDefault) -> bool { |
| uint8_t useDefaultC = 0; |
| uint8_t noMod = Handle(wrap(CI), gutils, wrap(arg), isshadow, |
| (CDerivativeMode)(mode), &useDefaultC); |
| useDefault = useDefaultC != 0; |
| return noMod != 0; |
| }; |
| } |
| |
| uint8_t EnzymeGradientUtilsGetRuntimeActivity(GradientUtils *gutils) { |
| return gutils->runtimeActivity; |
| } |
| |
| void *EnzymeGradientUtilsGetExternalContext(GradientUtils *gutils) { |
| return gutils->Logic.ExternalContext; |
| } |
| |
| uint8_t EnzymeGradientUtilsGetStrongZero(GradientUtils *gutils) { |
| return gutils->strongZero; |
| } |
| |
| uint64_t EnzymeGradientUtilsGetWidth(GradientUtils *gutils) { |
| return gutils->getWidth(); |
| } |
| |
| LLVMTypeRef EnzymeGradientUtilsGetShadowType(GradientUtils *gutils, |
| LLVMTypeRef T) { |
| return wrap(gutils->getShadowType(unwrap(T))); |
| } |
| |
| LLVMTypeRef EnzymeGetShadowType(uint64_t width, LLVMTypeRef T) { |
| return wrap(GradientUtils::getShadowType(unwrap(T), width)); |
| } |
| |
| LLVMValueRef EnzymeGradientUtilsNewFromOriginal(GradientUtils *gutils, |
| LLVMValueRef val) { |
| return wrap(gutils->getNewFromOriginal(unwrap(val))); |
| } |
| |
| CDerivativeMode EnzymeGradientUtilsGetMode(GradientUtils *gutils) { |
| return (CDerivativeMode)gutils->mode; |
| } |
| |
| CDIFFE_TYPE |
| EnzymeGradientUtilsGetDiffeType(GradientUtils *G, LLVMValueRef oval, |
| uint8_t foreignFunction) { |
| return (CDIFFE_TYPE)(G->getDiffeType(unwrap(oval), foreignFunction != 0)); |
| } |
| |
| CDIFFE_TYPE |
| EnzymeGradientUtilsGetReturnDiffeType(GradientUtils *G, LLVMValueRef oval, |
| uint8_t *needsPrimal, |
| uint8_t *needsShadow, |
| CDerivativeMode mode) { |
| bool needsPrimalB; |
| bool needsShadowB; |
| auto res = (CDIFFE_TYPE)(G->getReturnDiffeType( |
| unwrap(oval), &needsPrimalB, &needsShadowB, (DerivativeMode)mode)); |
| if (needsPrimal) |
| *needsPrimal = needsPrimalB; |
| if (needsShadow) |
| *needsShadow = needsShadowB; |
| return res; |
| } |
| |
| void EnzymeGradientUtilsSetDebugLocFromOriginal(GradientUtils *gutils, |
| LLVMValueRef val, |
| LLVMValueRef orig) { |
| return cast<Instruction>(unwrap(val)) |
| ->setDebugLoc(gutils->getNewFromOriginal( |
| cast<Instruction>(unwrap(orig))->getDebugLoc())); |
| } |
| |
| LLVMValueRef EnzymeInsertValue(LLVMBuilderRef B, LLVMValueRef val, |
| LLVMValueRef val2, unsigned *sz, int64_t length, |
| const char *name) { |
| return wrap(unwrap(B)->CreateInsertValue( |
| unwrap(val), unwrap(val2), ArrayRef<unsigned>(sz, sz + length), name)); |
| } |
| |
| LLVMValueRef EnzymeGradientUtilsLookup(GradientUtils *gutils, LLVMValueRef val, |
| LLVMBuilderRef B) { |
| return wrap(gutils->lookupM(unwrap(val), *unwrap(B))); |
| } |
| |
| LLVMValueRef EnzymeGradientUtilsInvertPointer(GradientUtils *gutils, |
| LLVMValueRef val, |
| LLVMBuilderRef B) { |
| return wrap(gutils->invertPointerM(unwrap(val), *unwrap(B))); |
| } |
| |
| LLVMValueRef EnzymeGradientUtilsDiffe(DiffeGradientUtils *gutils, |
| LLVMValueRef val, LLVMBuilderRef B) { |
| return wrap(gutils->diffe(unwrap(val), *unwrap(B))); |
| } |
| |
| void EnzymeGradientUtilsAddToDiffe(DiffeGradientUtils *gutils, LLVMValueRef val, |
| LLVMValueRef diffe, LLVMBuilderRef B, |
| LLVMTypeRef T) { |
| gutils->addToDiffe(unwrap(val), unwrap(diffe), *unwrap(B), unwrap(T)); |
| } |
| |
| void EnzymeGradientUtilsAddToInvertedPointerDiffe( |
| DiffeGradientUtils *gutils, LLVMValueRef orig, LLVMValueRef origVal, |
| LLVMTypeRef addingType, unsigned start, unsigned size, LLVMValueRef origptr, |
| LLVMValueRef dif, LLVMBuilderRef BuilderM, unsigned align, |
| LLVMValueRef mask) { |
| MaybeAlign align2; |
| if (align) |
| align2 = MaybeAlign(align); |
| auto inst = cast_or_null<Instruction>(unwrap(orig)); |
| gutils->addToInvertedPtrDiffe(inst, unwrap(origVal), unwrap(addingType), |
| start, size, unwrap(origptr), unwrap(dif), |
| *unwrap(BuilderM), align2, unwrap(mask)); |
| } |
| |
| void EnzymeGradientUtilsAddToInvertedPointerDiffeTT( |
| DiffeGradientUtils *gutils, LLVMValueRef orig, LLVMValueRef origVal, |
| CTypeTreeRef vd, unsigned LoadSize, LLVMValueRef origptr, |
| LLVMValueRef prediff, LLVMBuilderRef BuilderM, unsigned align, |
| LLVMValueRef premask) { |
| MaybeAlign align2; |
| if (align) |
| align2 = MaybeAlign(align); |
| auto inst = cast_or_null<Instruction>(unwrap(orig)); |
| gutils->addToInvertedPtrDiffe(inst, unwrap(origVal), *(TypeTree *)vd, |
| LoadSize, unwrap(origptr), unwrap(prediff), |
| *unwrap(BuilderM), align2, unwrap(premask)); |
| } |
| |
| void EnzymeGradientUtilsSetDiffe(DiffeGradientUtils *gutils, LLVMValueRef val, |
| LLVMValueRef diffe, LLVMBuilderRef B) { |
| gutils->setDiffe(unwrap(val), unwrap(diffe), *unwrap(B)); |
| } |
| |
| uint8_t EnzymeGradientUtilsIsConstantValue(GradientUtils *gutils, |
| LLVMValueRef val) { |
| return gutils->isConstantValue(unwrap(val)); |
| } |
| |
| uint8_t EnzymeGradientUtilsIsConstantInstruction(GradientUtils *gutils, |
| LLVMValueRef val) { |
| return gutils->isConstantInstruction(cast<Instruction>(unwrap(val))); |
| } |
| |
| LLVMBasicBlockRef EnzymeGradientUtilsAllocationBlock(GradientUtils *gutils) { |
| return wrap(gutils->inversionAllocs); |
| } |
| |
| uint8_t EnzymeGradientUtilsGetUncacheableArgs(GradientUtils *gutils, |
| LLVMValueRef orig, uint8_t *data, |
| uint64_t size) { |
| if (gutils->mode == DerivativeMode::ForwardMode || |
| gutils->mode == DerivativeMode::ForwardModeError) |
| return 0; |
| |
| if (!gutils->overwritten_args_map_ptr) |
| return 0; |
| |
| CallInst *call = cast<CallInst>(unwrap(orig)); |
| |
| assert(gutils->overwritten_args_map_ptr); |
| auto found = gutils->overwritten_args_map_ptr->find(call); |
| if (found == gutils->overwritten_args_map_ptr->end()) { |
| llvm::errs() << " oldFunc " << *gutils->oldFunc << "\n"; |
| for (auto &pair : *gutils->overwritten_args_map_ptr) { |
| llvm::errs() << " + " << *pair.first << "\n"; |
| } |
| llvm::errs() << " could not find call orig in overwritten_args_map_ptr " |
| << *call << "\n"; |
| } |
| assert(found != gutils->overwritten_args_map_ptr->end()); |
| |
| const std::vector<bool> &overwritten_args = found->second.second; |
| |
| if (size != overwritten_args.size()) { |
| llvm::errs() << " orig: " << *call << "\n"; |
| llvm::errs() << " size: " << size |
| << " overwritten_args.size(): " << overwritten_args.size() |
| << "\n"; |
| } |
| assert(size == overwritten_args.size()); |
| for (uint64_t i = 0; i < size; i++) { |
| data[i] = overwritten_args[i]; |
| } |
| return 1; |
| } |
| |
| CTypeTreeRef EnzymeGradientUtilsAllocAndGetTypeTree(GradientUtils *gutils, |
| LLVMValueRef val) { |
| auto v = unwrap(val); |
| TypeTree TT = gutils->TR.query(v); |
| TypeTree *pTT = new TypeTree(TT); |
| return (CTypeTreeRef)pTT; |
| } |
| |
| void EnzymeGradientUtilsDumpTypeResults(GradientUtils *gutils) { |
| gutils->TR.dump(); |
| } |
| |
| void EnzymeGradientUtilsSubTransferHelper( |
| GradientUtils *gutils, CDerivativeMode mode, LLVMTypeRef secretty, |
| uint64_t intrinsic, uint64_t dstAlign, uint64_t srcAlign, uint64_t offset, |
| uint8_t dstConstant, LLVMValueRef shadow_dst, uint8_t srcConstant, |
| LLVMValueRef shadow_src, LLVMValueRef length, LLVMValueRef isVolatile, |
| LLVMValueRef MTI, uint8_t allowForward, uint8_t shadowsLookedUp) { |
| auto orig = unwrap(MTI); |
| assert(orig); |
| SubTransferHelper(gutils, (DerivativeMode)mode, unwrap(secretty), |
| (Intrinsic::ID)intrinsic, (unsigned)dstAlign, |
| (unsigned)srcAlign, (unsigned)offset, (bool)dstConstant, |
| unwrap(shadow_dst), (bool)srcConstant, unwrap(shadow_src), |
| unwrap(length), unwrap(isVolatile), cast<CallInst>(orig), |
| (bool)allowForward, (bool)shadowsLookedUp); |
| } |
| |
| LLVMBasicBlockRef EnzymeGradientUtilsAddReverseBlock(GradientUtils *gutils, |
| LLVMBasicBlockRef block, |
| const char *name, |
| uint8_t forkCache, |
| uint8_t push) { |
| return wrap(gutils->addReverseBlock(cast<BasicBlock>(unwrap(block)), name, |
| forkCache, push)); |
| } |
| |
| void EnzymeGradientUtilsSetReverseBlock(GradientUtils *gutils, |
| LLVMBasicBlockRef block) { |
| auto endBlock = cast<BasicBlock>(unwrap(block)); |
| auto found = gutils->reverseBlockToPrimal.find(endBlock); |
| assert(found != gutils->reverseBlockToPrimal.end()); |
| auto &vec = gutils->reverseBlocks[found->second]; |
| assert(vec.size()); |
| vec.push_back(endBlock); |
| } |
| |
| LLVMValueRef EnzymeCreateForwardDiff( |
| EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, |
| LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, |
| size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue, |
| CDerivativeMode mode, uint8_t freeMemory, uint8_t runtimeActivity, |
| uint8_t strongZero, unsigned width, LLVMTypeRef additionalArg, |
| CFnTypeInfo typeInfo, uint8_t subsequent_calls_may_write, |
| uint8_t *_overwritten_args, size_t overwritten_args_size, |
| EnzymeAugmentedReturnPtr augmented) { |
| SmallVector<DIFFE_TYPE, 4> nconstant_args((DIFFE_TYPE *)constant_args, |
| (DIFFE_TYPE *)constant_args + |
| constant_args_size); |
| std::vector<bool> overwritten_args; |
| assert(overwritten_args_size == cast<Function>(unwrap(todiff))->arg_size()); |
| for (uint64_t i = 0; i < overwritten_args_size; i++) { |
| overwritten_args.push_back(_overwritten_args[i]); |
| } |
| return wrap(eunwrap(Logic).CreateForwardDiff( |
| RequestContext(cast_or_null<Instruction>(unwrap(request_req)), |
| unwrap(request_ip)), |
| cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args, |
| eunwrap(TA), returnValue, (DerivativeMode)mode, freeMemory, |
| runtimeActivity, strongZero, width, unwrap(additionalArg), |
| eunwrap(typeInfo, cast<Function>(unwrap(todiff))), |
| subsequent_calls_may_write, overwritten_args, eunwrap(augmented))); |
| } |
| LLVMValueRef EnzymeCreatePrimalAndGradient( |
| EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, |
| LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, |
| size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue, |
| uint8_t dretUsed, CDerivativeMode mode, uint8_t runtimeActivity, |
| uint8_t strongZero, unsigned width, uint8_t freeMemory, |
| LLVMTypeRef additionalArg, uint8_t forceAnonymousTape, CFnTypeInfo typeInfo, |
| uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args, |
| size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented, |
| uint8_t AtomicAdd) { |
| std::vector<DIFFE_TYPE> nconstant_args((DIFFE_TYPE *)constant_args, |
| (DIFFE_TYPE *)constant_args + |
| constant_args_size); |
| std::vector<bool> overwritten_args; |
| assert(overwritten_args_size == cast<Function>(unwrap(todiff))->arg_size()); |
| for (uint64_t i = 0; i < overwritten_args_size; i++) { |
| overwritten_args.push_back(_overwritten_args[i]); |
| } |
| return wrap(eunwrap(Logic).CreatePrimalAndGradient( |
| RequestContext(cast_or_null<Instruction>(unwrap(request_req)), |
| unwrap(request_ip)), |
| (ReverseCacheKey){ |
| .todiff = cast<Function>(unwrap(todiff)), |
| .retType = (DIFFE_TYPE)retType, |
| .constant_args = nconstant_args, |
| .subsequent_calls_may_write = (bool)subsequent_calls_may_write, |
| .overwritten_args = overwritten_args, |
| .returnUsed = (bool)returnValue, |
| .shadowReturnUsed = (bool)dretUsed, |
| .mode = (DerivativeMode)mode, |
| .width = width, |
| .freeMemory = (bool)freeMemory, |
| .AtomicAdd = (bool)AtomicAdd, |
| .additionalType = unwrap(additionalArg), |
| .forceAnonymousTape = (bool)forceAnonymousTape, |
| .typeInfo = eunwrap(typeInfo, cast<Function>(unwrap(todiff))), |
| .runtimeActivity = (bool)runtimeActivity, |
| .strongZero = (bool)strongZero}, |
| eunwrap(TA), eunwrap(augmented))); |
| } |
| EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal( |
| EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, |
| LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, |
| size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnUsed, |
| uint8_t shadowReturnUsed, CFnTypeInfo typeInfo, |
| uint8_t subsequent_calls_may_write, uint8_t *_overwritten_args, |
| size_t overwritten_args_size, uint8_t forceAnonymousTape, |
| uint8_t runtimeActivity, uint8_t strongZero, unsigned width, |
| uint8_t AtomicAdd) { |
| |
| SmallVector<DIFFE_TYPE, 4> nconstant_args((DIFFE_TYPE *)constant_args, |
| (DIFFE_TYPE *)constant_args + |
| constant_args_size); |
| std::vector<bool> overwritten_args; |
| assert(overwritten_args_size == cast<Function>(unwrap(todiff))->arg_size()); |
| for (uint64_t i = 0; i < overwritten_args_size; i++) { |
| overwritten_args.push_back(_overwritten_args[i]); |
| } |
| return ewrap(eunwrap(Logic).CreateAugmentedPrimal( |
| RequestContext(cast_or_null<Instruction>(unwrap(request_req)), |
| unwrap(request_ip)), |
| cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args, |
| eunwrap(TA), returnUsed, shadowReturnUsed, |
| eunwrap(typeInfo, cast<Function>(unwrap(todiff))), |
| subsequent_calls_may_write, overwritten_args, forceAnonymousTape, |
| runtimeActivity, strongZero, width, AtomicAdd)); |
| } |
| |
| LLVMValueRef EnzymeCreateBatch(EnzymeLogicRef Logic, LLVMValueRef request_req, |
| LLVMBuilderRef request_ip, LLVMValueRef tobatch, |
| unsigned width, CBATCH_TYPE *arg_types, |
| size_t arg_types_size, CBATCH_TYPE retType) { |
| |
| return wrap(eunwrap(Logic).CreateBatch( |
| RequestContext(cast_or_null<Instruction>(unwrap(request_req)), |
| unwrap(request_ip)), |
| cast<Function>(unwrap(tobatch)), width, |
| ArrayRef<BATCH_TYPE>((BATCH_TYPE *)arg_types, |
| (BATCH_TYPE *)arg_types + arg_types_size), |
| (BATCH_TYPE)retType)); |
| } |
| |
| LLVMValueRef EnzymeCreateTrace( |
| EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, |
| LLVMValueRef totrace, LLVMValueRef *sample_functions, |
| size_t sample_functions_size, LLVMValueRef *observe_functions, |
| size_t observe_functions_size, const char *active_random_variables[], |
| size_t active_random_variables_size, CProbProgMode mode, uint8_t autodiff, |
| EnzymeTraceInterfaceRef interface) { |
| |
| SmallPtrSet<Function *, 4> SampleFunctions; |
| for (size_t i = 0; i < sample_functions_size; i++) { |
| SampleFunctions.insert(cast<Function>(unwrap(sample_functions[i]))); |
| } |
| |
| SmallPtrSet<Function *, 4> ObserveFunctions; |
| for (size_t i = 0; i < observe_functions_size; i++) { |
| ObserveFunctions.insert(cast<Function>(unwrap(observe_functions[i]))); |
| } |
| |
| StringSet<> ActiveRandomVariables; |
| for (size_t i = 0; i < active_random_variables_size; i++) { |
| ActiveRandomVariables.insert(active_random_variables[i]); |
| } |
| |
| return wrap(eunwrap(Logic).CreateTrace( |
| RequestContext(cast_or_null<Instruction>(unwrap(request_req)), |
| unwrap(request_ip)), |
| cast<Function>(unwrap(totrace)), SampleFunctions, ObserveFunctions, |
| ActiveRandomVariables, (ProbProgMode)mode, (bool)autodiff, |
| eunwrap(interface))); |
| } |
| |
| LLVMValueRef |
| EnzymeExtractFunctionFromAugmentation(EnzymeAugmentedReturnPtr ret) { |
| auto AR = (AugmentedReturn *)ret; |
| return wrap(AR->fn); |
| } |
| |
| LLVMTypeRef |
| EnzymeExtractUnderlyingTapeTypeFromAugmentation(EnzymeAugmentedReturnPtr ret) { |
| auto AR = (AugmentedReturn *)ret; |
| return wrap(AR->tapeType); |
| } |
| |
| LLVMTypeRef |
| EnzymeExtractTapeTypeFromAugmentation(EnzymeAugmentedReturnPtr ret) { |
| auto AR = (AugmentedReturn *)ret; |
| auto found = AR->returns.find(AugmentedStruct::Tape); |
| if (found == AR->returns.end()) { |
| return wrap((Type *)nullptr); |
| } |
| if (found->second == -1) { |
| return wrap(AR->fn->getReturnType()); |
| } |
| return wrap( |
| cast<StructType>(AR->fn->getReturnType())->getTypeAtIndex(found->second)); |
| } |
| void EnzymeExtractReturnInfo(EnzymeAugmentedReturnPtr ret, int64_t *data, |
| uint8_t *existed, size_t len) { |
| assert(len == 3); |
| auto AR = (AugmentedReturn *)ret; |
| AugmentedStruct todo[] = {AugmentedStruct::Tape, AugmentedStruct::Return, |
| AugmentedStruct::DifferentialReturn}; |
| for (size_t i = 0; i < len; i++) { |
| auto found = AR->returns.find(todo[i]); |
| if (found != AR->returns.end()) { |
| existed[i] = true; |
| data[i] = (int64_t)found->second; |
| } else { |
| existed[i] = false; |
| } |
| } |
| } |
| |
| static MDNode *extractMDNode(MetadataAsValue *MAV) { |
| Metadata *MD = MAV->getMetadata(); |
| assert((isa<MDNode>(MD) || isa<ConstantAsMetadata>(MD)) && |
| "Expected a metadata node or a canonicalized constant"); |
| |
| if (MDNode *N = dyn_cast<MDNode>(MD)) |
| return N; |
| |
| return MDNode::get(MAV->getContext(), MD); |
| } |
| |
| CTypeTreeRef EnzymeTypeTreeFromMD(LLVMValueRef Val) { |
| TypeTree *Ret = new TypeTree(); |
| MDNode *N = Val ? extractMDNode(unwrap<MetadataAsValue>(Val)) : nullptr; |
| Ret->insertFromMD(N); |
| return (CTypeTreeRef)N; |
| } |
| |
| LLVMValueRef EnzymeTypeTreeToMD(CTypeTreeRef CTR, LLVMContextRef ctx) { |
| auto MD = ((TypeTree *)CTR)->toMD(*unwrap(ctx)); |
| return wrap(MetadataAsValue::get(MD->getContext(), MD)); |
| } |
| |
| CTypeTreeRef EnzymeNewTypeTree() { return (CTypeTreeRef)(new TypeTree()); } |
| CTypeTreeRef EnzymeNewTypeTreeCT(CConcreteType CT, LLVMContextRef ctx) { |
| return (CTypeTreeRef)(new TypeTree(eunwrap(CT, *unwrap(ctx)))); |
| } |
| CTypeTreeRef EnzymeNewTypeTreeTR(CTypeTreeRef CTR) { |
| return (CTypeTreeRef)(new TypeTree(*(TypeTree *)(CTR))); |
| } |
| void EnzymeFreeTypeTree(CTypeTreeRef CTT) { delete (TypeTree *)CTT; } |
| uint8_t EnzymeSetTypeTree(CTypeTreeRef dst, CTypeTreeRef src) { |
| return *(TypeTree *)dst = *(TypeTree *)src; |
| } |
| uint8_t EnzymeMergeTypeTree(CTypeTreeRef dst, CTypeTreeRef src) { |
| return ((TypeTree *)dst)->orIn(*(TypeTree *)src, /*PointerIntSame*/ false); |
| } |
| uint8_t EnzymeCheckedMergeTypeTree(CTypeTreeRef dst, CTypeTreeRef src, |
| uint8_t *legalP) { |
| bool legal = true; |
| bool res = |
| ((TypeTree *)dst) |
| ->checkedOrIn(*(TypeTree *)src, /*PointerIntSame*/ false, legal); |
| *legalP = legal; |
| return res; |
| } |
| |
| void EnzymeTypeTreeOnlyEq(CTypeTreeRef CTT, int64_t x) { |
| // TODO only inst |
| *(TypeTree *)CTT = ((TypeTree *)CTT)->Only(x, nullptr); |
| } |
| void EnzymeTypeTreeData0Eq(CTypeTreeRef CTT) { |
| *(TypeTree *)CTT = ((TypeTree *)CTT)->Data0(); |
| } |
| |
| void EnzymeTypeTreeLookupEq(CTypeTreeRef CTT, int64_t size, const char *dl) { |
| *(TypeTree *)CTT = ((TypeTree *)CTT)->Lookup(size, DataLayout(dl)); |
| } |
| void EnzymeTypeTreeCanonicalizeInPlace(CTypeTreeRef CTT, int64_t size, |
| const char *dl) { |
| ((TypeTree *)CTT)->CanonicalizeInPlace(size, DataLayout(dl)); |
| } |
| |
| CConcreteType EnzymeTypeTreeInner0(CTypeTreeRef CTT) { |
| return ewrap(((TypeTree *)CTT)->Inner0()); |
| } |
| |
| void EnzymeTypeTreeShiftIndiciesEq(CTypeTreeRef CTT, const char *datalayout, |
| int64_t offset, int64_t maxSize, |
| uint64_t addOffset) { |
| DataLayout DL(datalayout); |
| *(TypeTree *)CTT = |
| ((TypeTree *)CTT)->ShiftIndices(DL, offset, maxSize, addOffset); |
| } |
| void EnzymeTypeTreeInsertEq(CTypeTreeRef CTT, const int64_t *indices, |
| size_t len, CConcreteType ct, LLVMContextRef ctx) { |
| std::vector<int> seq; |
| for (size_t i = 0; i < len; i++) { |
| seq.push_back(indices[i]); |
| } |
| ((TypeTree *)CTT)->insert(seq, eunwrap(ct, *unwrap(ctx))); |
| } |
| const char *EnzymeTypeTreeToString(CTypeTreeRef src) { |
| std::string tmp = ((TypeTree *)src)->str(); |
| char *cstr = new char[tmp.length() + 1]; |
| std::strcpy(cstr, tmp.c_str()); |
| |
| return cstr; |
| } |
| |
| // TODO deprecated |
| void EnzymeTypeTreeToStringFree(const char *cstr) { delete[] cstr; } |
| |
| const char *EnzymeTypeAnalyzerToString(void *src) { |
| auto TA = (TypeAnalyzer *)src; |
| std::string str; |
| raw_string_ostream ss(str); |
| TA->dump(ss); |
| ss.str(); |
| char *cstr = new char[str.length() + 1]; |
| std::strcpy(cstr, str.c_str()); |
| return cstr; |
| } |
| |
| const char *EnzymeGradientUtilsInvertedPointersToString(GradientUtils *gutils, |
| void *src) { |
| std::string str; |
| raw_string_ostream ss(str); |
| for (auto z : gutils->invertedPointers) { |
| ss << "available inversion for " << *z.first << " of " << *z.second << "\n"; |
| } |
| ss.str(); |
| char *cstr = new char[str.length() + 1]; |
| std::strcpy(cstr, str.c_str()); |
| return cstr; |
| } |
| |
| LLVMValueRef EnzymeGradientUtilsCallWithInvertedBundles( |
| GradientUtils *gutils, LLVMValueRef func, LLVMTypeRef funcTy, |
| LLVMValueRef *args_vr, uint64_t args_size, LLVMValueRef orig_vr, |
| CValueType *valTys, uint64_t valTys_size, LLVMBuilderRef B, |
| uint8_t lookup) { |
| auto orig = cast<CallInst>(unwrap(orig_vr)); |
| |
| ArrayRef<ValueType> ar((ValueType *)valTys, valTys_size); |
| |
| IRBuilder<> &BR = *unwrap(B); |
| |
| auto Defs = gutils->getInvertedBundles(orig, ar, BR, lookup != 0); |
| |
| SmallVector<Value *, 1> args; |
| for (size_t i = 0; i < args_size; i++) { |
| args.push_back(unwrap(args_vr[i])); |
| } |
| |
| auto callval = unwrap(func); |
| |
| auto res = |
| BR.CreateCall(cast<FunctionType>(unwrap(funcTy)), callval, args, Defs); |
| return wrap(res); |
| } |
| |
| void EnzymeStringFree(const char *cstr) { delete[] cstr; } |
| |
| void EnzymeMoveBefore(LLVMValueRef inst1, LLVMValueRef inst2, |
| LLVMBuilderRef B) { |
| Instruction *I1 = cast<Instruction>(unwrap(inst1)); |
| Instruction *I2 = cast<Instruction>(unwrap(inst2)); |
| if (I1 != I2) { |
| if (B != nullptr) { |
| IRBuilder<> &BR = *unwrap(B); |
| if (I1->getIterator() == BR.GetInsertPoint()) { |
| if (I2->getNextNode() == nullptr) |
| BR.SetInsertPoint(I1->getParent()); |
| else |
| BR.SetInsertPoint(I1->getNextNode()); |
| } |
| } |
| I1->moveBefore(I2); |
| } |
| } |
| |
| void EnzymeSetStringMD(LLVMValueRef Inst, const char *Kind, LLVMValueRef Val) { |
| MDNode *N = Val ? extractMDNode(unwrap<MetadataAsValue>(Val)) : nullptr; |
| Value *V = unwrap(Inst); |
| if (auto I = dyn_cast<Instruction>(V)) |
| I->setMetadata(Kind, N); |
| else |
| cast<GlobalVariable>(V)->setMetadata(Kind, N); |
| } |
| |
| LLVMValueRef EnzymeGetStringMD(LLVMValueRef Inst, const char *Kind) { |
| auto *I = unwrap<Instruction>(Inst); |
| assert(I && "Expected instruction"); |
| if (auto *MD = I->getMetadata(Kind)) |
| return wrap(MetadataAsValue::get(I->getContext(), MD)); |
| return nullptr; |
| } |
| |
| void EnzymeSetMustCache(LLVMValueRef inst1) { |
| Instruction *I1 = cast<Instruction>(unwrap(inst1)); |
| I1->setMetadata("enzyme_mustcache", MDNode::get(I1->getContext(), {})); |
| } |
| |
| uint8_t EnzymeHasFromStack(LLVMValueRef inst1) { |
| Instruction *I1 = cast<Instruction>(unwrap(inst1)); |
| return hasMetadata(I1, "enzyme_fromstack") != 0; |
| } |
| |
| void EnzymeCloneFunctionDISubprogramInto(LLVMValueRef NF, LLVMValueRef F) { |
| auto &OldFunc = *cast<Function>(unwrap(F)); |
| auto &NewFunc = *cast<Function>(unwrap(NF)); |
| auto OldSP = OldFunc.getSubprogram(); |
| if (!OldSP) |
| return; |
| DIBuilder DIB(*OldFunc.getParent(), /*AllowUnresolved=*/false, |
| OldSP->getUnit()); |
| auto SPType = DIB.createSubroutineType(DIB.getOrCreateTypeArray({})); |
| DISubprogram::DISPFlags SPFlags = DISubprogram::SPFlagDefinition | |
| DISubprogram::SPFlagOptimized | |
| DISubprogram::SPFlagLocalToUnit; |
| auto NewSP = DIB.createFunction( |
| OldSP->getUnit(), NewFunc.getName(), NewFunc.getName(), OldSP->getFile(), |
| /*LineNo=*/0, SPType, /*ScopeLine=*/0, DINode::FlagZero, SPFlags); |
| NewFunc.setSubprogram(NewSP); |
| DIB.finalizeSubprogram(NewSP); |
| return; |
| } |
| |
| void EnzymeReplaceFunctionImplementation(LLVMModuleRef M) { |
| ReplaceFunctionImplementation(*unwrap(M)); |
| } |
| |
| void EnzymeDetectReadonlyOrThrow(LLVMModuleRef M) { |
| DetectReadonlyOrThrow(*unwrap(M)); |
| } |
| |
| void EnzymeDumpModuleRef(LLVMModuleRef M) { |
| llvm::errs() << *unwrap(M) << "\n"; |
| } |
| |
| static bool runAttributorOnFunctions(InformationCache &InfoCache, |
| SetVector<Function *> &Functions, |
| AnalysisGetter &AG, |
| CallGraphUpdater &CGUpdater, |
| bool DeleteFns, bool IsModulePass) { |
| if (Functions.empty()) |
| return false; |
| |
| // Create an Attributor and initially empty information cache that is filled |
| // while we identify default attribute opportunities. |
| AttributorConfig AC(CGUpdater); |
| AC.RewriteSignatures = false; |
| AC.IsModulePass = IsModulePass; |
| AC.DeleteFns = DeleteFns; |
| Attributor A(Functions, InfoCache, AC); |
| |
| for (Function *F : Functions) { |
| // Populate the Attributor with abstract attribute opportunities in the |
| // function and the information cache with IR information. |
| A.identifyDefaultAbstractAttributes(*F); |
| } |
| |
| ChangeStatus Changed = A.run(); |
| |
| return Changed == ChangeStatus::CHANGED; |
| } |
| |
| extern "C" void RunAttributorOnModule(LLVMModuleRef M0) { |
| auto &M = *unwrap(M0); |
| AnalysisGetter AG; |
| SetVector<Function *> Functions; |
| for (Function &F : M) |
| Functions.insert(&F); |
| |
| CallGraphUpdater CGUpdater; |
| BumpPtrAllocator Allocator; |
| InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ nullptr); |
| runAttributorOnFunctions(InfoCache, Functions, AG, CGUpdater, |
| /* DeleteFns*/ true, |
| /* IsModulePass */ true); |
| } |
| |
| struct MyAttributorLegacyPass : public ModulePass { |
| static char ID; |
| |
| MyAttributorLegacyPass() : ModulePass(ID) {} |
| |
| bool runOnModule(Module &M) override { |
| if (skipModule(M)) |
| return false; |
| |
| AnalysisGetter AG; |
| SetVector<Function *> Functions; |
| for (Function &F : M) |
| Functions.insert(&F); |
| |
| CallGraphUpdater CGUpdater; |
| BumpPtrAllocator Allocator; |
| InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ nullptr); |
| return runAttributorOnFunctions(InfoCache, Functions, AG, CGUpdater, |
| /* DeleteFns*/ true, |
| /* IsModulePass */ true); |
| } |
| |
| void getAnalysisUsage(AnalysisUsage &AU) const override { |
| // FIXME: Think about passes we will preserve and add them here. |
| AU.addRequired<TargetLibraryInfoWrapperPass>(); |
| } |
| }; |
| extern "C++" char MyAttributorLegacyPass::ID = 0; |
| void EnzymeAddAttributorLegacyPass(LLVMPassManagerRef PM) { |
| unwrap(PM)->add(new MyAttributorLegacyPass()); |
| } |
| |
| LLVMMetadataRef EnzymeMakeNonConstTBAA(LLVMMetadataRef MD) { |
| auto M = cast<MDNode>(unwrap(MD)); |
| if (M->getNumOperands() != 4) |
| return MD; |
| auto CAM = dyn_cast<ConstantAsMetadata>(M->getOperand(3)); |
| if (!CAM) |
| return MD; |
| if (!CAM->getValue()->isOneValue()) |
| return MD; |
| SmallVector<Metadata *, 4> MDs; |
| for (auto &M : M->operands()) |
| MDs.push_back(M); |
| MDs[3] = |
| ConstantAsMetadata::get(ConstantInt::get(CAM->getValue()->getType(), 0)); |
| return wrap(MDNode::get(M->getContext(), MDs)); |
| } |
| void EnzymeCopyMetadata(LLVMValueRef inst1, LLVMValueRef inst2) { |
| cast<Instruction>(unwrap(inst1)) |
| ->copyMetadata(*cast<Instruction>(unwrap(inst2))); |
| } |
| LLVMMetadataRef EnzymeAnonymousAliasScopeDomain(const char *str, |
| LLVMContextRef ctx) { |
| MDBuilder MDB(*unwrap(ctx)); |
| MDNode *scope = MDB.createAnonymousAliasScopeDomain(str); |
| return wrap(scope); |
| } |
| LLVMMetadataRef EnzymeAnonymousAliasScope(LLVMMetadataRef domain, |
| const char *str) { |
| auto dom = cast<MDNode>(unwrap(domain)); |
| MDBuilder MDB(dom->getContext()); |
| MDNode *scope = MDB.createAnonymousAliasScope(dom, str); |
| return wrap(scope); |
| } |
| uint8_t EnzymeLowerSparsification(LLVMValueRef F, uint8_t replaceAll) { |
| return LowerSparsification(cast<Function>(unwrap(F)), replaceAll != 0); |
| } |
| |
| void EnzymeAttributeKnownFunctions(LLVMValueRef FC) { |
| attributeKnownFunctions(*cast<Function>(unwrap(FC))); |
| } |
| |
| void EnzymeSetCalledFunction(LLVMValueRef C_CI, LLVMValueRef C_F, |
| uint64_t *argrem, uint64_t num_argrem) { |
| auto CI = cast<CallInst>(unwrap(C_CI)); |
| auto F = cast<Function>(unwrap(C_F)); |
| auto Attrs = CI->getAttributes(); |
| AttributeList NewAttrs; |
| |
| if (CI->getType() == F->getReturnType()) { |
| for (auto attr : Attrs.getAttributes(AttributeList::ReturnIndex)) |
| NewAttrs = NewAttrs.addAttribute(F->getContext(), |
| AttributeList::ReturnIndex, attr); |
| } |
| for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex)) |
| NewAttrs = NewAttrs.addAttribute(F->getContext(), |
| AttributeList::FunctionIndex, attr); |
| |
| size_t argremsz = 0; |
| size_t nexti = 0; |
| SmallVector<Value *, 1> vals; |
| for (size_t i = 0, end = CI->arg_size(); i < end; i++) { |
| if (argremsz < num_argrem) { |
| if (i == argrem[argremsz]) { |
| argremsz++; |
| continue; |
| } |
| } |
| for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i)) |
| NewAttrs = NewAttrs.addAttribute( |
| F->getContext(), AttributeList::FirstArgIndex + nexti, attr); |
| vals.push_back(CI->getArgOperand(i)); |
| nexti++; |
| } |
| assert(argremsz == num_argrem); |
| |
| IRBuilder<> B(CI); |
| SmallVector<OperandBundleDef, 1> Bundles; |
| for (unsigned I = 0, E = CI->getNumOperandBundles(); I != E; ++I) |
| Bundles.emplace_back(CI->getOperandBundleAt(I)); |
| auto NC = B.CreateCall(F, vals, Bundles); |
| NC->setAttributes(NewAttrs); |
| NC->copyMetadata(*CI); |
| |
| if (CI->getType() == F->getReturnType()) |
| CI->replaceAllUsesWith(NC); |
| |
| if (!NC->getType()->isVoidTy()) |
| NC->takeName(CI); |
| NC->setCallingConv(CI->getCallingConv()); |
| CI->eraseFromParent(); |
| } |
| |
| // clones a function to now miss the return or args |
| LLVMValueRef EnzymeCloneFunctionWithoutReturnOrArgs(LLVMValueRef FC, |
| uint8_t keepReturnU, |
| uint64_t *argrem, |
| uint64_t num_argrem) { |
| auto F = cast<Function>(unwrap(FC)); |
| auto FT = F->getFunctionType(); |
| bool keepReturn = keepReturnU != 0; |
| |
| size_t argremsz = 0; |
| size_t nexti = 0; |
| SmallVector<Type *, 1> types; |
| auto Attrs = F->getAttributes(); |
| AttributeList NewAttrs; |
| for (size_t i = 0, end = FT->getNumParams(); i < end; i++) { |
| if (argremsz < num_argrem) { |
| if (i == argrem[argremsz]) { |
| argremsz++; |
| continue; |
| } |
| } |
| for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i)) |
| NewAttrs = NewAttrs.addAttribute( |
| F->getContext(), AttributeList::FirstArgIndex + nexti, attr); |
| types.push_back(F->getFunctionType()->getParamType(i)); |
| nexti++; |
| } |
| if (keepReturn) { |
| for (auto attr : Attrs.getAttributes(AttributeList::ReturnIndex)) |
| NewAttrs = NewAttrs.addAttribute(F->getContext(), |
| AttributeList::ReturnIndex, attr); |
| } |
| for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex)) |
| NewAttrs = NewAttrs.addAttribute(F->getContext(), |
| AttributeList::FunctionIndex, attr); |
| |
| FunctionType *FTy = FunctionType::get( |
| keepReturn ? F->getReturnType() : Type::getVoidTy(F->getContext()), types, |
| FT->isVarArg()); |
| |
| // Create the new function |
| Function *NewF = Function::Create(FTy, F->getLinkage(), F->getAddressSpace(), |
| F->getName(), F->getParent()); |
| |
| ValueToValueMapTy VMap; |
| // Loop over the arguments, copying the names of the mapped arguments over... |
| nexti = 0; |
| argremsz = 0; |
| Function::arg_iterator DestI = NewF->arg_begin(); |
| for (const Argument &I : F->args()) { |
| if (argremsz < num_argrem) { |
| if (I.getArgNo() == argrem[argremsz]) { |
| VMap[&I] = UndefValue::get(I.getType()); |
| argremsz++; |
| continue; |
| } |
| } |
| DestI->setName(I.getName()); // Copy the name over... |
| VMap[&I] = &*DestI++; // Add mapping to VMap |
| } |
| |
| SmallVector<ReturnInst *, 8> Returns; // Ignore returns cloned. |
| CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, |
| Returns, "", nullptr); |
| |
| if (!keepReturn) { |
| for (auto &B : *NewF) { |
| if (auto RI = dyn_cast<ReturnInst>(B.getTerminator())) { |
| IRBuilder<> B(RI); |
| auto NRI = B.CreateRetVoid(); |
| NRI->copyMetadata(*RI); |
| RI->eraseFromParent(); |
| } |
| } |
| } |
| NewF->setAttributes(NewAttrs); |
| if (!keepReturn) |
| for (auto &Arg : NewF->args()) |
| Arg.removeAttr(Attribute::Returned); |
| SmallVector<std::pair<unsigned, MDNode *>, 1> MD; |
| F->getAllMetadata(MD); |
| for (auto pair : MD) |
| if (pair.first != LLVMContext::MD_dbg) |
| NewF->addMetadata(pair.first, *pair.second); |
| NewF->takeName(F); |
| NewF->setCallingConv(F->getCallingConv()); |
| if (!keepReturn) |
| NewF->addFnAttr("enzyme_retremove", ""); |
| |
| if (num_argrem) { |
| SmallVector<uint64_t, 1> previdx; |
| if (Attrs.hasAttribute(AttributeList::FunctionIndex, "enzyme_parmremove")) { |
| auto attr = |
| Attrs.getAttribute(AttributeList::FunctionIndex, "enzyme_parmremove"); |
| auto prevstr = attr.getValueAsString(); |
| SmallVector<StringRef, 1> sub; |
| prevstr.split(sub, ","); |
| for (auto s : sub) { |
| uint64_t ival; |
| bool b = s.getAsInteger(10, ival); |
| (void)b; |
| assert(!b); |
| previdx.push_back(ival); |
| } |
| } |
| SmallVector<uint64_t, 1> nextidx; |
| for (size_t i = 0; i < num_argrem; i++) { |
| auto val = argrem[i]; |
| nextidx.push_back(val); |
| } |
| |
| size_t prevcnt = 0; |
| size_t nextcnt = 0; |
| SmallVector<uint64_t, 1> out; |
| while (prevcnt < previdx.size() && nextcnt < nextidx.size()) { |
| if (previdx[prevcnt] <= nextidx[nextcnt] + prevcnt) { |
| out.push_back(previdx[prevcnt]); |
| prevcnt++; |
| } else { |
| out.push_back(nextidx[nextcnt] + prevcnt); |
| nextcnt++; |
| } |
| } |
| while (prevcnt < previdx.size()) { |
| out.push_back(previdx[prevcnt]); |
| prevcnt++; |
| } |
| while (nextcnt < nextidx.size()) { |
| out.push_back(nextidx[nextcnt] + prevcnt); |
| nextcnt++; |
| } |
| |
| std::string remstr; |
| for (auto arg : out) { |
| if (remstr.size()) |
| remstr += ","; |
| remstr += std::to_string(arg); |
| } |
| |
| NewF->addFnAttr("enzyme_parmremove", remstr); |
| } |
| return wrap(NewF); |
| } |
| LLVMTypeRef EnzymeAllocaType(LLVMValueRef V) { |
| return wrap(cast<AllocaInst>(unwrap(V))->getAllocatedType()); |
| } |
| LLVMValueRef EnzymeComputeByteOffsetOfGEP(LLVMBuilderRef B_r, LLVMValueRef V_r, |
| LLVMTypeRef T_r) { |
| IRBuilder<> &B = *unwrap(B_r); |
| auto T = cast<IntegerType>(unwrap(T_r)); |
| auto width = T->getBitWidth(); |
| auto uw = unwrap(V_r); |
| GEPOperator *gep = isa<GetElementPtrInst>(uw) |
| ? cast<GEPOperator>(cast<GetElementPtrInst>(uw)) |
| : cast<GEPOperator>(cast<ConstantExpr>(uw)); |
| auto &DL = B.GetInsertBlock()->getParent()->getParent()->getDataLayout(); |
| |
| #if LLVM_VERSION_MAJOR >= 20 |
| SmallMapVector<Value *, APInt, 4> VariableOffsets; |
| #else |
| MapVector<Value *, APInt> VariableOffsets; |
| #endif |
| APInt Offset(width, 0); |
| bool success = collectOffset(gep, DL, width, VariableOffsets, Offset); |
| (void)success; |
| assert(success); |
| Value *start = ConstantInt::get(T, Offset); |
| for (auto &pair : VariableOffsets) |
| start = B.CreateAdd( |
| start, B.CreateMul(pair.first, ConstantInt::get(T, pair.second))); |
| return wrap(start); |
| } |
| } |
| |
| static size_t num_rooting(llvm::Type *T, llvm::Function *F) { |
| CountTrackedPointers tracked(T); |
| if (tracked.derived) { |
| llvm::errs() << *F << "\n"; |
| llvm::errs() << "Invalid Derived Type: " << *T << "\n"; |
| } |
| assert(!tracked.derived); |
| if (tracked.count != 0 && !tracked.all) |
| return tracked.count; |
| return 0; |
| } |
| |
| extern "C" { |
| |
| void EnzymeFixupBatchedJuliaCallingConvention(LLVMValueRef F_C) { |
| auto F = cast<Function>(unwrap(F_C)); |
| if (F->empty()) |
| return; |
| auto RT = F->getReturnType(); |
| auto FT = F->getFunctionType(); |
| auto Attrs = F->getAttributes(); |
| |
| AttributeList NewAttrs; |
| SmallVector<Type *, 1> types; |
| SmallSet<size_t, 1> changed; |
| for (auto pair : llvm::enumerate(FT->params())) { |
| auto T = pair.value(); |
| auto i = pair.index(); |
| bool sretv = false; |
| for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i)) { |
| if (attr.isStringAttribute() && |
| attr.getKindAsString() == "enzyme_sret_v") { |
| sretv = true; |
| } else { |
| NewAttrs = NewAttrs.addAttribute( |
| F->getContext(), AttributeList::FirstArgIndex + types.size(), attr); |
| } |
| } |
| if (auto AT = dyn_cast<ArrayType>(T)) { |
| if (auto PT = dyn_cast<PointerType>(AT->getElementType())) { |
| auto AS = PT->getAddressSpace(); |
| if (AS == 11 || AS == 12 || AS == 13 || sretv) { |
| for (unsigned i = 0; i < AT->getNumElements(); i++) { |
| if (sretv) { |
| NewAttrs = NewAttrs.addAttribute( |
| F->getContext(), AttributeList::FirstArgIndex + types.size(), |
| Attribute::get(F->getContext(), "enzyme_sret")); |
| } |
| types.push_back(PT); |
| } |
| changed.insert(i); |
| continue; |
| } |
| } |
| } |
| types.push_back(T); |
| } |
| if (changed.size() == 0) |
| return; |
| |
| for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex)) |
| NewAttrs = NewAttrs.addAttribute(F->getContext(), |
| AttributeList::FunctionIndex, attr); |
| |
| for (auto attr : Attrs.getAttributes(AttributeList::ReturnIndex)) |
| NewAttrs = NewAttrs.addAttribute(F->getContext(), |
| AttributeList::ReturnIndex, attr); |
| |
| FunctionType *FTy = |
| FunctionType::get(FT->getReturnType(), types, FT->isVarArg()); |
| |
| // Create the new function |
| Function *NewF = Function::Create(FTy, F->getLinkage(), F->getAddressSpace(), |
| F->getName(), F->getParent()); |
| |
| ValueToValueMapTy VMap; |
| // Loop over the arguments, copying the names of the mapped arguments over... |
| Function::arg_iterator DestI = NewF->arg_begin(); |
| |
| // To handle the deleted args, it needs to be replaced by a non-arg operand. |
| // This map contains the temporary phi nodes corresponding |
| SmallVector<Instruction *, 1> toInsert; |
| for (Argument &I : F->args()) { |
| auto T = I.getType(); |
| if (auto AT = dyn_cast<ArrayType>(T)) { |
| if (changed.count(I.getArgNo())) { |
| Value *V = UndefValue::get(T); |
| for (unsigned i = 0; i < AT->getNumElements(); i++) { |
| DestI->setName(I.getName() + "." + |
| std::to_string(i)); // Copy the name over... |
| unsigned idx[1] = {i}; |
| auto IV = InsertValueInst::Create(V, (llvm::Value *)&*DestI++, idx); |
| toInsert.push_back(IV); |
| V = IV; |
| } |
| VMap[&I] = V; |
| continue; |
| } |
| } |
| DestI->setName(I.getName()); // Copy the name over... |
| VMap[&I] = &*DestI++; // Add mapping to VMap |
| } |
| |
| SmallVector<ReturnInst *, 8> Returns; // Ignore returns cloned. |
| CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, |
| Returns, "", nullptr); |
| |
| { |
| IRBuilder<> EB(&*NewF->getEntryBlock().begin()); |
| for (auto I : toInsert) |
| EB.Insert(I); |
| } |
| |
| SmallVector<CallInst *, 1> callers; |
| for (auto U : F->users()) { |
| auto CI = dyn_cast<CallInst>(U); |
| assert(CI); |
| assert(CI->getCalledFunction() == F); |
| callers.push_back(CI); |
| } |
| |
| for (auto CI : callers) { |
| auto Attrs = CI->getAttributes(); |
| AttributeList NewAttrs; |
| IRBuilder<> B(CI); |
| |
| for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex)) |
| NewAttrs = NewAttrs.addAttribute(F->getContext(), |
| AttributeList::FunctionIndex, attr); |
| |
| for (auto attr : Attrs.getAttributes(AttributeList::ReturnIndex)) |
| NewAttrs = NewAttrs.addAttribute(F->getContext(), |
| AttributeList::ReturnIndex, attr); |
| |
| SmallVector<Value *, 1> vals; |
| for (size_t j = 0, end = CI->arg_size(); j < end; j++) { |
| |
| auto T = CI->getArgOperand(j)->getType(); |
| if (auto AT = dyn_cast<ArrayType>(T)) { |
| if (isa<PointerType>(AT->getElementType())) { |
| if (changed.count(j)) { |
| bool sretv = false; |
| for (auto attr : |
| Attrs.getAttributes(AttributeList::FirstArgIndex + j)) { |
| if (attr.isStringAttribute() && |
| attr.getKindAsString() == "enzyme_sret_v") { |
| sretv = true; |
| } |
| } |
| for (unsigned i = 0; i < AT->getNumElements(); i++) { |
| if (sretv) |
| NewAttrs = NewAttrs.addAttribute( |
| F->getContext(), AttributeList::FirstArgIndex + vals.size(), |
| Attribute::get(F->getContext(), "enzyme_sret")); |
| vals.push_back( |
| GradientUtils::extractMeta(B, CI->getArgOperand(j), i)); |
| } |
| continue; |
| } |
| } |
| } |
| |
| for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + j)) { |
| if (attr.isStringAttribute() && |
| attr.getKindAsString() == "enzyme_sret_v") { |
| NewAttrs = NewAttrs.addAttribute( |
| F->getContext(), AttributeList::FirstArgIndex + vals.size(), |
| Attribute::get(F->getContext(), "enzyme_sret")); |
| } else { |
| NewAttrs = NewAttrs.addAttribute( |
| F->getContext(), AttributeList::FirstArgIndex + vals.size(), |
| attr); |
| } |
| } |
| |
| vals.push_back(CI->getArgOperand(j)); |
| } |
| |
| SmallVector<OperandBundleDef, 1> Bundles; |
| for (unsigned I = 0, E = CI->getNumOperandBundles(); I != E; ++I) |
| Bundles.emplace_back(CI->getOperandBundleAt(I)); |
| auto NC = B.CreateCall(NewF, vals, Bundles); |
| NC->setAttributes(NewAttrs); |
| |
| SmallVector<std::pair<unsigned, MDNode *>, 4> TheMDs; |
| CI->getAllMetadataOtherThanDebugLoc(TheMDs); |
| SmallVector<unsigned, 1> toCopy; |
| for (auto pair : TheMDs) |
| toCopy.push_back(pair.first); |
| if (!toCopy.empty()) |
| NC->copyMetadata(*CI, toCopy); |
| NC->setDebugLoc(CI->getDebugLoc()); |
| |
| if (!RT->isVoidTy()) { |
| NC->takeName(CI); |
| CI->replaceAllUsesWith(NC); |
| } |
| |
| NC->setCallingConv(CI->getCallingConv()); |
| CI->eraseFromParent(); |
| } |
| NewF->setAttributes(NewAttrs); |
| SmallVector<std::pair<unsigned, MDNode *>, 1> MD; |
| F->getAllMetadata(MD); |
| for (auto pair : MD) |
| if (pair.first != LLVMContext::MD_dbg) |
| NewF->addMetadata(pair.first, *pair.second); |
| NewF->takeName(F); |
| NewF->setCallingConv(F->getCallingConv()); |
| F->eraseFromParent(); |
| } |
| |
| void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) { |
| auto F = cast<Function>(unwrap(F_C)); |
| if (F->empty()) |
| return; |
| auto RT = F->getReturnType(); |
| std::set<size_t> srets; |
| std::set<size_t> enzyme_srets; |
| std::set<size_t> enzyme_srets_v; |
| std::set<size_t> rroots; |
| std::set<size_t> rroots_v; |
| |
| auto FT = F->getFunctionType(); |
| auto Attrs = F->getAttributes(); |
| for (size_t i = 0, end = FT->getNumParams(); i < end; i++) { |
| if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i, |
| Attribute::StructRet)) |
| srets.insert(i); |
| if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i, "enzyme_sret")) |
| enzyme_srets.insert(i); |
| if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i, "enzyme_sret_v")) |
| enzyme_srets_v.insert(i); |
| if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i, |
| "enzymejl_returnRoots")) |
| rroots.insert(i); |
| if (Attrs.hasAttribute(AttributeList::FirstArgIndex + i, |
| "enzymejl_returnRoots_v")) |
| rroots_v.insert(i); |
| } |
| // Regular julia function, needing no intervention |
| if (srets.size() == 1) { |
| assert(*srets.begin() == 0); |
| assert(enzyme_srets.size() == 0); |
| assert(enzyme_srets_v.size() == 0); |
| assert(rroots_v.size() == 0); |
| if (rroots.size()) { |
| assert(rroots.size() == 1); |
| assert(*rroots.begin() == 1); |
| } |
| return; |
| } |
| // No sret/rooting, no intervention needed. |
| if (srets.size() == 0 && enzyme_srets.size() == 0 && |
| enzyme_srets_v.size() == 0 && rroots.size() == 0 && |
| rroots_v.size() == 0) { |
| return; |
| } |
| |
| assert(srets.size() == 0); |
| |
| SmallVector<Type *, 1> Types; |
| if (!RT->isVoidTy()) { |
| Types.push_back(RT); |
| } |
| |
| for (auto idx : enzyme_srets) { |
| llvm::Type *T = nullptr; |
| #if LLVM_VERSION_MAJOR >= 17 |
| (void)idx; |
| llvm_unreachable("Unhandled"); |
| // T = F->getParamAttribute(idx, Attribute::AttrKind::ElementType) |
| // .getValueAsType(); |
| #else |
| T = FT->getParamType(idx)->getPointerElementType(); |
| #endif |
| Types.push_back(T); |
| } |
| for (auto idx : enzyme_srets_v) { |
| llvm::Type *T = nullptr; |
| auto AT = cast<ArrayType>(FT->getParamType(idx)); |
| #if LLVM_VERSION_MAJOR >= 17 |
| llvm_unreachable("Unhandled"); |
| // T = F->getParamAttribute(idx, Attribute::AttrKind::ElementType) |
| // .getValueAsType(); |
| #else |
| T = AT->getElementType()->getPointerElementType(); |
| #endif |
| for (size_t i = 0; i < AT->getNumElements(); i++) |
| Types.push_back(T); |
| } |
| |
| StructType *ST = |
| Types.size() <= 1 ? nullptr : StructType::get(F->getContext(), Types); |
| Type *sretTy = nullptr; |
| if (Types.size()) |
| sretTy = Types.size() == 1 ? Types[0] : ST; |
| size_t numRooting = sretTy ? num_rooting(sretTy, F) : 0; |
| |
| auto T_jlvalue = StructType::get(F->getContext(), {}); |
| auto T_prjlvalue = PointerType::get(T_jlvalue, AddressSpace::Tracked); |
| ArrayType *roots_AT = |
| numRooting ? ArrayType::get(T_prjlvalue, numRooting) : nullptr; |
| |
| AttributeList NewAttrs; |
| SmallVector<Type *, 1> types; |
| size_t nexti = 0; |
| if (sretTy) { |
| types.push_back(PointerType::getUnqual(sretTy)); |
| NewAttrs = NewAttrs.addAttribute( |
| F->getContext(), AttributeList::FirstArgIndex + nexti, |
| Attribute::get(F->getContext(), Attribute::StructRet, sretTy)); |
| NewAttrs = NewAttrs.addAttribute(F->getContext(), |
| AttributeList::FirstArgIndex + nexti, |
| Attribute::NoAlias); |
| nexti++; |
| } |
| if (roots_AT) { |
| NewAttrs = NewAttrs.addAttribute(F->getContext(), |
| AttributeList::FirstArgIndex + nexti, |
| "enzymejl_returnRoots"); |
| NewAttrs = NewAttrs.addAttribute(F->getContext(), |
| AttributeList::FirstArgIndex + nexti, |
| Attribute::NoAlias); |
| NewAttrs = NewAttrs.addAttribute(F->getContext(), |
| AttributeList::FirstArgIndex + nexti, |
| Attribute::WriteOnly); |
| types.push_back(PointerType::getUnqual(roots_AT)); |
| nexti++; |
| } |
| for (size_t i = 0, end = FT->getNumParams(); i < end; i++) { |
| if (enzyme_srets.count(i) || enzyme_srets_v.count(i) || rroots.count(i) || |
| rroots_v.count(i)) |
| continue; |
| |
| for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i)) |
| NewAttrs = NewAttrs.addAttribute( |
| F->getContext(), AttributeList::FirstArgIndex + nexti, attr); |
| types.push_back(F->getFunctionType()->getParamType(i)); |
| nexti++; |
| } |
| for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex)) |
| NewAttrs = NewAttrs.addAttribute(F->getContext(), |
| AttributeList::FunctionIndex, attr); |
| |
| FunctionType *FTy = FunctionType::get(Type::getVoidTy(F->getContext()), types, |
| FT->isVarArg()); |
| |
| // Create the new function |
| auto &M = *F->getParent(); |
| Function *NewF = Function::Create(FTy, F->getLinkage(), F->getAddressSpace(), |
| F->getName(), &M); |
| |
| ValueToValueMapTy VMap; |
| // Loop over the arguments, copying the names of the mapped arguments over... |
| Function::arg_iterator DestI = NewF->arg_begin(); |
| Argument *sret = nullptr; |
| if (sretTy) { |
| sret = &*DestI; |
| DestI++; |
| } |
| Argument *roots = nullptr; |
| if (roots_AT) { |
| roots = &*DestI; |
| DestI++; |
| } |
| // To handle the deleted args, it needs to be replaced by a non-arg operand. |
| // This map contains the temporary phi nodes corresponding |
| // |
| |
| std::map<size_t, PHINode *> delArgMap; |
| for (Argument &I : F->args()) { |
| auto i = I.getArgNo(); |
| if (enzyme_srets.count(i) || enzyme_srets_v.count(i) || rroots.count(i) || |
| rroots_v.count(i)) { |
| VMap[&I] = delArgMap[i] = PHINode::Create(I.getType(), 0); |
| continue; |
| } |
| assert(DestI != NewF->arg_end()); |
| DestI->setName(I.getName()); // Copy the name over... |
| VMap[&I] = &*DestI++; // Add mapping to VMap |
| } |
| |
| SmallVector<ReturnInst *, 8> Returns; // Ignore returns cloned. |
| CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, |
| Returns, "", nullptr); |
| |
| SmallVector<CallInst *, 1> callers; |
| for (auto U : F->users()) { |
| auto CI = dyn_cast<CallInst>(U); |
| assert(CI); |
| assert(CI->getCalledFunction() == F); |
| callers.push_back(CI); |
| } |
| |
| size_t curOffset = 0; |
| |
| std::function<size_t(IRBuilder<> &, Value *, size_t)> recur = |
| [&](IRBuilder<> &B, Value *V, size_t offset) -> size_t { |
| auto T = V->getType(); |
| if (CountTrackedPointers(T).count == 0) |
| return offset; |
| if (roots_AT == nullptr) |
| return offset; |
| if (isa<PointerType>(T)) { |
| if (isSpecialPtr(T)) { |
| if (!roots_AT) { |
| llvm::errs() << *V << " \n"; |
| llvm::errs() << *cast<Instruction>(V)->getParent()->getParent() |
| << " \n"; |
| } |
| assert(roots_AT); |
| assert(roots); |
| auto gep = B.CreateConstInBoundsGEP2_32(roots_AT, roots, 0, offset); |
| if (T != T_prjlvalue) |
| V = B.CreatePointerCast(V, T_prjlvalue); |
| B.CreateStore(V, gep); |
| offset++; |
| } |
| return offset; |
| } else if (auto ST = dyn_cast<StructType>(T)) { |
| for (size_t i = 0; i < ST->getNumElements(); i++) { |
| offset = recur(B, GradientUtils::extractMeta(B, V, i), offset); |
| } |
| return offset; |
| } else if (auto AT = dyn_cast<ArrayType>(T)) { |
| for (size_t i = 0; i < AT->getNumElements(); i++) { |
| offset = recur(B, GradientUtils::extractMeta(B, V, i), offset); |
| } |
| return offset; |
| } else if (auto VT = dyn_cast<VectorType>(T)) { |
| size_t count = VT->getElementCount().getKnownMinValue(); |
| for (size_t i = 0; i < count; i++) { |
| offset = recur(B, B.CreateExtractElement(V, i), offset); |
| } |
| return offset; |
| } |
| return offset; |
| }; |
| |
| size_t sretCount = 0; |
| if (!RT->isVoidTy()) { |
| for (auto &RT : Returns) { |
| IRBuilder<> B(RT); |
| Value *gep = ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, 0) : sret; |
| Value *rval = RT->getReturnValue(); |
| B.CreateStore(rval, gep); |
| recur(B, rval, 0); |
| auto NR = B.CreateRetVoid(); |
| RT->eraseFromParent(); |
| RT = NR; |
| } |
| if (roots_AT) |
| curOffset = CountTrackedPointers(RT).count; |
| sretCount++; |
| } |
| |
| for (auto i : enzyme_srets) { |
| auto arg = delArgMap[i]; |
| assert(arg); |
| SmallVector<Instruction *, 1> uses; |
| SmallVector<unsigned, 1> op; |
| for (auto &U : arg->uses()) { |
| auto I = cast<Instruction>(U.getUser()); |
| uses.push_back(I); |
| op.push_back(U.getOperandNo()); |
| } |
| IRBuilder<> EB(&NewF->getEntryBlock().front()); |
| auto gep = |
| ST ? EB.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount) : sret; |
| for (size_t i = 0; i < uses.size(); i++) { |
| uses[i]->setOperand(op[i], gep); |
| } |
| for (auto &RT : Returns) { |
| IRBuilder<> B(RT); |
| auto val = B.CreateLoad(Types[sretCount], gep); |
| recur(B, val, curOffset); |
| } |
| if (roots_AT) |
| curOffset += CountTrackedPointers(Types[sretCount]).count; |
| sretCount++; |
| delete arg; |
| } |
| for (auto i : enzyme_srets_v) { |
| auto AT = cast<ArrayType>(FT->getParamType(i)); |
| auto arg = delArgMap[i]; |
| assert(arg); |
| SmallVector<Instruction *, 1> uses; |
| SmallVector<unsigned, 1> op; |
| for (auto &U : arg->uses()) { |
| auto I = cast<Instruction>(U.getUser()); |
| uses.push_back(I); |
| op.push_back(U.getOperandNo()); |
| } |
| IRBuilder<> EB(&NewF->getEntryBlock().front()); |
| Value *val = UndefValue::get(AT); |
| for (size_t j = 0; j < AT->getNumElements(); j++) { |
| auto gep = |
| ST ? EB.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount + j) : sret; |
| val = EB.CreateInsertValue(val, gep, j); |
| } |
| for (size_t i = 0; i < uses.size(); i++) { |
| uses[i]->setOperand(op[i], val); |
| } |
| for (auto &RT : Returns) { |
| IRBuilder<> B(RT); |
| for (size_t j = 0; j < AT->getNumElements(); j++) { |
| Value *em = GradientUtils::extractMeta(B, val, j); |
| em = B.CreateLoad(Types[sretCount + j], em); |
| recur(B, em, curOffset); |
| } |
| } |
| if (roots_AT) |
| curOffset += |
| CountTrackedPointers(Types[sretCount]).count * AT->getNumElements(); |
| sretCount += AT->getNumElements(); |
| delete arg; |
| } |
| |
| for (auto i : rroots) { |
| auto arg = delArgMap[i]; |
| assert(arg); |
| llvm::Type *T = nullptr; |
| #if LLVM_VERSION_MAJOR >= 17 |
| llvm_unreachable("Unhandled"); |
| // T = F->getParamAttribute(i, Attribute::AttrKind::ElementType) |
| // .getValueAsType(); |
| #else |
| T = FT->getParamType(i)->getPointerElementType(); |
| #endif |
| IRBuilder<> EB(&NewF->getEntryBlock().front()); |
| auto AL = EB.CreateAlloca(T, 0, "stack_roots"); |
| arg->replaceAllUsesWith(AL); |
| delete arg; |
| } |
| for (auto i : rroots_v) { |
| auto arg = delArgMap[i]; |
| assert(arg); |
| auto AT = cast<ArrayType>(FT->getParamType(i)); |
| llvm::Type *T = nullptr; |
| #if LLVM_VERSION_MAJOR >= 17 |
| llvm_unreachable("Unhandled"); |
| // T = F->getParamAttribute(i, Attribute::AttrKind::ElementType) |
| // .getValueAsType(); |
| #else |
| T = AT->getElementType()->getPointerElementType(); |
| #endif |
| IRBuilder<> EB(&NewF->getEntryBlock().front()); |
| Value *val = UndefValue::get(AT); |
| for (size_t j = 0; j < AT->getNumElements(); j++) { |
| auto AL = EB.CreateAlloca(T, 0, "stack_roots_v"); |
| val = EB.CreateInsertValue(val, AL, j); |
| } |
| arg->replaceAllUsesWith(val); |
| delete arg; |
| } |
| assert(curOffset == numRooting); |
| assert(sretCount == Types.size()); |
| |
| for (auto CI : callers) { |
| auto Attrs = CI->getAttributes(); |
| AttributeList NewAttrs; |
| IRBuilder<> B(CI); |
| IRBuilder<> EB(&CI->getParent()->getParent()->getEntryBlock().front()); |
| SmallVector<Value *, 1> vals; |
| size_t nexti = 0; |
| Value *sret = nullptr; |
| if (sretTy) { |
| sret = EB.CreateAlloca(sretTy, 0, "stack_sret"); |
| vals.push_back(sret); |
| NewAttrs = NewAttrs.addAttribute( |
| F->getContext(), AttributeList::FirstArgIndex + nexti, |
| Attribute::get(F->getContext(), Attribute::StructRet, sretTy)); |
| nexti++; |
| } |
| AllocaInst *roots = nullptr; |
| if (roots_AT) { |
| roots = EB.CreateAlloca(roots_AT, 0, "stack_roots_AT"); |
| vals.push_back(roots); |
| NewAttrs = NewAttrs.addAttribute( |
| |
| F->getContext(), AttributeList::FirstArgIndex + nexti, |
| "enzymejl_returnRoots"); |
| nexti++; |
| } |
| |
| for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex)) |
| NewAttrs = NewAttrs.addAttribute(F->getContext(), |
| AttributeList::FunctionIndex, attr); |
| |
| SmallVector<Value *, 1> sret_vals; |
| SmallVector<Value *, 1> sretv_vals; |
| for (size_t i = 0, end = CI->arg_size(); i < end; i++) { |
| if (rroots.count(i) || rroots_v.count(i)) { |
| continue; |
| } |
| if (enzyme_srets.count(i)) { |
| sret_vals.push_back(CI->getArgOperand(i)); |
| continue; |
| } |
| if (enzyme_srets_v.count(i)) { |
| sretv_vals.push_back(CI->getArgOperand(i)); |
| continue; |
| } |
| |
| for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i)) |
| NewAttrs = NewAttrs.addAttribute( |
| F->getContext(), AttributeList::FirstArgIndex + nexti, attr); |
| vals.push_back(CI->getArgOperand(i)); |
| nexti++; |
| } |
| |
| sretCount = 0; |
| if (!RT->isVoidTy()) { |
| sretCount++; |
| } |
| |
| std::function<void(Type *, Value *, Value *, ArrayRef<int>, int, Type *, |
| bool)> |
| copyNonJLValue = [&](Type *curType, Value *out, Value *in, |
| ArrayRef<int> inds, int outPrefix, Type *ptrTy, |
| bool shouldZero) { |
| if (auto PT = dyn_cast<PointerType>(curType)) { |
| if (PT->getAddressSpace() == 10) { |
| if (shouldZero) { |
| SmallVector<Value *, 1> outinds; |
| auto c0 = ConstantInt::get(B.getInt64Ty(), 0); |
| outinds.push_back(c0); |
| if (outPrefix >= 0) |
| outinds.push_back( |
| ConstantInt::get(B.getInt32Ty(), outPrefix)); |
| for (auto v : inds) { |
| outinds.push_back(ConstantInt::get(B.getInt32Ty(), v)); |
| } |
| if (outinds.size() > 1) |
| out = B.CreateInBoundsGEP(sretTy, out, outinds); |
| B.CreateStore(getUndefinedValueForType(M, PT), out); |
| } |
| return; |
| } |
| } |
| |
| if (auto AT = dyn_cast<ArrayType>(curType)) { |
| for (size_t i = 0; i < AT->getNumElements(); i++) { |
| SmallVector<int, 1> next(inds.begin(), inds.end()); |
| next.push_back(i); |
| copyNonJLValue(AT->getElementType(), out, in, next, outPrefix, |
| ptrTy, shouldZero); |
| } |
| return; |
| } |
| if (auto ST = dyn_cast<StructType>(curType)) { |
| for (size_t i = 0; i < ST->getNumElements(); i++) { |
| SmallVector<int, 1> next(inds.begin(), inds.end()); |
| next.push_back(i); |
| copyNonJLValue(ST->getElementType(i), out, in, next, outPrefix, |
| ptrTy, shouldZero); |
| } |
| return; |
| } |
| |
| SmallVector<Value *, 1> ininds; |
| SmallVector<Value *, 1> outinds; |
| auto c0 = ConstantInt::get(B.getInt64Ty(), 0); |
| ininds.push_back(c0); |
| outinds.push_back(c0); |
| if (outPrefix >= 0) |
| outinds.push_back(ConstantInt::get(B.getInt32Ty(), outPrefix)); |
| for (auto v : inds) { |
| ininds.push_back(ConstantInt::get(B.getInt32Ty(), v)); |
| outinds.push_back(ConstantInt::get(B.getInt32Ty(), v)); |
| } |
| |
| if (outinds.size() > 1) |
| out = B.CreateInBoundsGEP(sretTy, out, outinds); |
| if (ininds.size() > 1) |
| in = B.CreateInBoundsGEP(ptrTy, in, ininds); |
| |
| auto ld = B.CreateLoad(curType, in); |
| B.CreateStore(ld, out); |
| }; |
| |
| for (Value *ptr : sret_vals) { |
| copyNonJLValue(Types[sretCount], sret, ptr, {}, ST ? sretCount : -1, |
| Types[sretCount], true); |
| sretCount++; |
| } |
| for (Value *ptr_v : sretv_vals) { |
| auto AT = cast<ArrayType>(ptr_v->getType()); |
| for (size_t j = 0; j < AT->getNumElements(); j++) { |
| auto ptr = GradientUtils::extractMeta(B, ptr_v, j); |
| copyNonJLValue(Types[sretCount], sret, ptr, {}, |
| ST ? (sretCount + j) : -1, Types[sretCount], true); |
| } |
| sretCount += AT->getNumElements(); |
| } |
| |
| SmallVector<OperandBundleDef, 1> Bundles; |
| for (unsigned I = 0, E = CI->getNumOperandBundles(); I != E; ++I) |
| Bundles.emplace_back(CI->getOperandBundleAt(I)); |
| auto NC = B.CreateCall(NewF, vals, Bundles); |
| NC->setAttributes(NewAttrs); |
| |
| SmallVector<std::pair<unsigned, MDNode *>, 4> TheMDs; |
| CI->getAllMetadataOtherThanDebugLoc(TheMDs); |
| SmallVector<unsigned, 1> toCopy; |
| for (auto pair : TheMDs) |
| if (pair.first != LLVMContext::MD_range) { |
| toCopy.push_back(pair.first); |
| } |
| if (!toCopy.empty()) |
| NC->copyMetadata(*CI, toCopy); |
| NC->setDebugLoc(CI->getDebugLoc()); |
| |
| sretCount = 0; |
| if (!RT->isVoidTy()) { |
| auto gep = ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, 0) : sret; |
| auto ld = B.CreateLoad(RT, gep); |
| if (auto MD = CI->getMetadata(LLVMContext::MD_range)) |
| ld->setMetadata(LLVMContext::MD_range, MD); |
| ld->takeName(CI); |
| CI->replaceAllUsesWith(ld); |
| sretCount++; |
| } |
| |
| for (auto ptr : sret_vals) { |
| if (!isa<UndefValue>(ptr) && !isa<PoisonValue>(ptr)) { |
| auto gep = |
| ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount) : sret; |
| auto ld = B.CreateLoad(Types[sretCount], gep); |
| auto SI = B.CreateStore(ld, ptr); |
| PostCacheStore(SI, B); |
| } |
| sretCount++; |
| } |
| for (auto ptr_v : sretv_vals) { |
| auto AT = cast<ArrayType>(ptr_v->getType()); |
| for (size_t j = 0; j < AT->getNumElements(); j++) { |
| auto gep = ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount + j) |
| : sret; |
| auto ptr = GradientUtils::extractMeta(B, ptr_v, j); |
| if (!isa<UndefValue>(ptr) && !isa<PoisonValue>(ptr)) { |
| auto ld = B.CreateLoad(Types[sretCount], gep); |
| auto SI = B.CreateStore(ld, ptr); |
| PostCacheStore(SI, B); |
| } |
| } |
| sretCount += AT->getNumElements(); |
| } |
| |
| NC->setCallingConv(CI->getCallingConv()); |
| CI->eraseFromParent(); |
| } |
| NewF->setAttributes(NewAttrs); |
| SmallVector<std::pair<unsigned, MDNode *>, 1> MD; |
| F->getAllMetadata(MD); |
| for (auto pair : MD) |
| if (pair.first != LLVMContext::MD_dbg) |
| NewF->addMetadata(pair.first, *pair.second); |
| NewF->takeName(F); |
| NewF->setCallingConv(F->getCallingConv()); |
| F->eraseFromParent(); |
| } |
| |
| LLVMValueRef EnzymeBuildExtractValue(LLVMBuilderRef B, LLVMValueRef AggVal, |
| unsigned *Index, unsigned Size, |
| const char *Name) { |
| return wrap(unwrap(B)->CreateExtractValue( |
| unwrap(AggVal), ArrayRef<unsigned>(Index, Size), Name)); |
| } |
| |
| LLVMValueRef EnzymeBuildInsertValue(LLVMBuilderRef B, LLVMValueRef AggVal, |
| LLVMValueRef EltVal, unsigned *Index, |
| unsigned Size, const char *Name) { |
| return wrap(unwrap(B)->CreateInsertValue( |
| unwrap(AggVal), unwrap(EltVal), ArrayRef<unsigned>(Index, Size), Name)); |
| } |
| } |