| //===- Utils.cpp - Definition of miscellaneous utilities ------------------===// |
| // |
| // Enzyme Project |
| // |
| // Part of the Enzyme Project, under the Apache License v2.0 with LLVM |
| // Exceptions. See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| // If using this code in an academic setting, please cite the following: |
| // @incollection{enzymeNeurips, |
| // title = {Instead of Rewriting Foreign Code for Machine Learning, |
| // Automatically Synthesize Fast Gradients}, |
| // author = {Moses, William S. and Churavy, Valentin}, |
| // booktitle = {Advances in Neural Information Processing Systems 33}, |
| // year = {2020}, |
| // note = {To appear in}, |
| // } |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file defines miscellaneous utilities that are used as part of the |
| // AD process. |
| // |
| //===----------------------------------------------------------------------===// |
| #include "Utils.h" |
| #include "TypeAnalysis/TypeAnalysis.h" |
| |
| #if LLVM_VERSION_MAJOR >= 16 |
| #include "llvm/Analysis/ScalarEvolution.h" |
| #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" |
| #else |
| #include "SCEV/ScalarEvolution.h" |
| #include "SCEV/ScalarEvolutionExpander.h" |
| #endif |
| |
| #include "TypeAnalysis/TBAA.h" |
| #include "llvm/IR/BasicBlock.h" |
| #include "llvm/IR/DerivedTypes.h" |
| #include "llvm/IR/Function.h" |
| #include "llvm/IR/IRBuilder.h" |
| #include "llvm/IR/InlineAsm.h" |
| #include "llvm/IR/Module.h" |
| #include "llvm/IR/Type.h" |
| |
| #include "llvm-c/Core.h" |
| |
| #include "LibraryFuncs.h" |
| |
| using namespace llvm; |
| |
| extern "C" { |
| LLVMValueRef (*CustomErrorHandler)(const char *, LLVMValueRef, ErrorType, |
| const void *, LLVMValueRef, |
| LLVMBuilderRef) = nullptr; |
| LLVMValueRef (*CustomAllocator)(LLVMBuilderRef, LLVMTypeRef, |
| /*Count*/ LLVMValueRef, |
| /*Align*/ LLVMValueRef, uint8_t, |
| LLVMValueRef *) = nullptr; |
| void (*CustomZero)(LLVMBuilderRef, LLVMTypeRef, |
| /*Ptr*/ LLVMValueRef, uint8_t) = nullptr; |
| LLVMValueRef (*CustomDeallocator)(LLVMBuilderRef, LLVMValueRef) = nullptr; |
| void (*CustomRuntimeInactiveError)(LLVMBuilderRef, LLVMValueRef, |
| LLVMValueRef) = nullptr; |
| LLVMValueRef *(*EnzymePostCacheStore)(LLVMValueRef, LLVMBuilderRef, |
| uint64_t *size) = nullptr; |
| LLVMTypeRef (*EnzymeDefaultTapeType)(LLVMContextRef) = nullptr; |
| LLVMValueRef (*EnzymeUndefinedValueForType)(LLVMTypeRef, uint8_t) = nullptr; |
| |
| LLVMValueRef (*EnzymeSanitizeDerivatives)(LLVMValueRef, LLVMValueRef toset, |
| LLVMBuilderRef, |
| LLVMValueRef) = nullptr; |
| |
| extern llvm::cl::opt<bool> EnzymeZeroCache; |
| |
| // default to false because lacpy is slow |
| llvm::cl::opt<bool> |
| EnzymeLapackCopy("enzyme-lapack-copy", cl::init(false), cl::Hidden, |
| cl::desc("Use blas copy calls to cache matrices")); |
| llvm::cl::opt<bool> |
| EnzymeBlasCopy("enzyme-blas-copy", cl::init(true), cl::Hidden, |
| cl::desc("Use blas copy calls to cache vectors")); |
| llvm::cl::opt<bool> |
| EnzymeFastMath("enzyme-fast-math", cl::init(true), cl::Hidden, |
| cl::desc("Use fast math on derivative compuation")); |
| llvm::cl::opt<bool> |
| EnzymeStrongZero("enzyme-strong-zero", cl::init(false), cl::Hidden, |
| cl::desc("Use additional checks to ensure correct " |
| "behavior when handling functions with inf")); |
| } |
| |
| void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj, |
| bool isTape) { |
| if (CustomZero) { |
| CustomZero(wrap(&Builder), wrap(T), wrap(obj), isTape); |
| } else { |
| Builder.CreateStore(Constant::getNullValue(T), obj); |
| } |
| } |
| |
| llvm::SmallVector<llvm::Instruction *, 2> PostCacheStore(llvm::StoreInst *SI, |
| llvm::IRBuilder<> &B) { |
| SmallVector<llvm::Instruction *, 2> res; |
| if (EnzymePostCacheStore) { |
| uint64_t size = 0; |
| auto ptr = EnzymePostCacheStore(wrap(SI), wrap(&B), &size); |
| for (size_t i = 0; i < size; i++) { |
| res.push_back(cast<Instruction>(unwrap(ptr[i]))); |
| } |
| free(ptr); |
| } |
| return res; |
| } |
| |
| llvm::PointerType *getDefaultAnonymousTapeType(llvm::LLVMContext &C) { |
| if (EnzymeDefaultTapeType) |
| return cast<PointerType>(unwrap(EnzymeDefaultTapeType(wrap(&C)))); |
| return Type::getInt8PtrTy(C); |
| } |
| |
| Function *getOrInsertExponentialAllocator(Module &M, Function *newFunc, |
| bool ZeroInit, llvm::Type *RT) { |
| bool custom = true; |
| llvm::PointerType *allocType; |
| { |
| auto i64 = Type::getInt64Ty(newFunc->getContext()); |
| BasicBlock *BB = BasicBlock::Create(M.getContext(), "entry", newFunc); |
| IRBuilder<> B(BB); |
| auto P = B.CreatePHI(i64, 1); |
| CallInst *malloccall; |
| Instruction *SubZero = nullptr; |
| CreateAllocation(B, RT, P, "tapemem", &malloccall, &SubZero)->getType(); |
| if (auto F = getFunctionFromCall(malloccall)) { |
| custom = F->getName() != "malloc"; |
| } |
| allocType = cast<PointerType>(malloccall->getType()); |
| BB->eraseFromParent(); |
| } |
| |
| Type *types[] = {allocType, Type::getInt64Ty(M.getContext()), |
| Type::getInt64Ty(M.getContext())}; |
| std::string name = "__enzyme_exponentialallocation"; |
| if (ZeroInit) |
| name += "zero"; |
| if (custom) |
| name += ".custom@" + std::to_string((size_t)RT); |
| |
| FunctionType *FT = FunctionType::get(allocType, types, false); |
| Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee()); |
| |
| if (!F->empty()) |
| return F; |
| |
| F->setLinkage(Function::LinkageTypes::InternalLinkage); |
| F->addFnAttr(Attribute::AlwaysInline); |
| F->addFnAttr(Attribute::NoUnwind); |
| BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); |
| BasicBlock *grow = BasicBlock::Create(M.getContext(), "grow", F); |
| BasicBlock *ok = BasicBlock::Create(M.getContext(), "ok", F); |
| |
| IRBuilder<> B(entry); |
| |
| Argument *ptr = F->arg_begin(); |
| ptr->setName("ptr"); |
| Argument *size = ptr + 1; |
| size->setName("size"); |
| Argument *tsize = size + 1; |
| tsize->setName("tsize"); |
| |
| Value *hasOne = B.CreateICmpNE( |
| B.CreateAnd(size, ConstantInt::get(size->getType(), 1, false)), |
| ConstantInt::get(size->getType(), 0, false)); |
| auto popCnt = Intrinsic::getDeclaration(&M, Intrinsic::ctpop, {types[1]}); |
| |
| B.CreateCondBr( |
| B.CreateAnd(B.CreateICmpULT(B.CreateCall(popCnt, {size}), |
| ConstantInt::get(types[1], 3, false)), |
| hasOne), |
| grow, ok); |
| |
| B.SetInsertPoint(grow); |
| |
| auto lz = |
| B.CreateCall(Intrinsic::getDeclaration(&M, Intrinsic::ctlz, {types[1]}), |
| {size, ConstantInt::getTrue(M.getContext())}); |
| Value *next = |
| B.CreateShl(tsize, B.CreateSub(ConstantInt::get(types[1], 64, false), lz, |
| "", true, true)); |
| |
| Value *gVal; |
| |
| Value *prevSize = |
| B.CreateSelect(B.CreateICmpEQ(size, ConstantInt::get(size->getType(), 1)), |
| ConstantInt::get(next->getType(), 0), |
| B.CreateLShr(next, ConstantInt::get(next->getType(), 1))); |
| |
| if (!custom) { |
| auto reallocF = M.getOrInsertFunction("realloc", allocType, allocType, |
| Type::getInt64Ty(M.getContext())); |
| |
| Value *args[] = {B.CreatePointerCast(ptr, allocType), next}; |
| gVal = B.CreateCall(reallocF, args); |
| } else { |
| Value *tsize = ConstantInt::get( |
| next->getType(), |
| newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(RT) / 8); |
| auto elSize = B.CreateUDiv(next, tsize, "", /*isExact*/ true); |
| Instruction *SubZero = nullptr; |
| gVal = CreateAllocation(B, RT, elSize, "", nullptr, &SubZero); |
| |
| Type *bTy = |
| PointerType::get(Type::getInt8Ty(gVal->getContext()), |
| cast<PointerType>(gVal->getType())->getAddressSpace()); |
| gVal = B.CreatePointerCast(gVal, bTy); |
| auto pVal = B.CreatePointerCast(ptr, gVal->getType()); |
| |
| Value *margs[] = {gVal, pVal, prevSize, |
| ConstantInt::getFalse(M.getContext())}; |
| Type *tys[] = {margs[0]->getType(), margs[1]->getType(), |
| margs[2]->getType()}; |
| auto memsetF = Intrinsic::getDeclaration(&M, Intrinsic::memcpy, tys); |
| B.CreateCall(memsetF, margs); |
| if (SubZero) { |
| ZeroInit = false; |
| IRBuilder<> BB(SubZero); |
| Value *zeroSize = BB.CreateSub(next, prevSize); |
| Value *tmp = SubZero->getOperand(0); |
| Type *tmpT = tmp->getType(); |
| tmp = BB.CreatePointerCast(tmp, bTy); |
| tmp = BB.CreateInBoundsGEP(Type::getInt8Ty(tmp->getContext()), tmp, |
| prevSize); |
| tmp = BB.CreatePointerCast(tmp, tmpT); |
| SubZero->setOperand(0, tmp); |
| SubZero->setOperand(2, zeroSize); |
| } |
| } |
| |
| if (ZeroInit) { |
| Value *zeroSize = B.CreateSub(next, prevSize); |
| |
| Value *margs[] = {B.CreateInBoundsGEP(B.getInt8Ty(), gVal, prevSize), |
| B.getInt8(0), zeroSize, B.getFalse()}; |
| Type *tys[] = {margs[0]->getType(), margs[2]->getType()}; |
| auto memsetF = Intrinsic::getDeclaration(&M, Intrinsic::memset, tys); |
| B.CreateCall(memsetF, margs); |
| } |
| gVal = B.CreatePointerCast(gVal, ptr->getType()); |
| |
| B.CreateBr(ok); |
| B.SetInsertPoint(ok); |
| auto phi = B.CreatePHI(ptr->getType(), 2); |
| phi->addIncoming(gVal, grow); |
| phi->addIncoming(ptr, entry); |
| B.CreateRet(phi); |
| return F; |
| } |
| |
| llvm::Value *CreateReAllocation(llvm::IRBuilder<> &B, llvm::Value *prev, |
| llvm::Type *T, llvm::Value *OuterCount, |
| llvm::Value *InnerCount, |
| const llvm::Twine &Name, |
| llvm::CallInst **caller, bool ZeroMem) { |
| auto newFunc = B.GetInsertBlock()->getParent(); |
| |
| Value *tsize = ConstantInt::get( |
| InnerCount->getType(), |
| newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(T) / 8); |
| |
| Value *idxs[] = { |
| /*ptr*/ |
| prev, |
| /*incrementing value to increase when it goes past a power of two*/ |
| OuterCount, |
| /*buffer size (element x subloops)*/ |
| B.CreateMul(tsize, InnerCount, "", /*NUW*/ true, |
| /*NSW*/ true)}; |
| |
| auto realloccall = |
| B.CreateCall(getOrInsertExponentialAllocator(*newFunc->getParent(), |
| newFunc, ZeroMem, T), |
| idxs, Name); |
| if (caller) |
| *caller = realloccall; |
| return realloccall; |
| } |
| |
| Value *CreateAllocation(IRBuilder<> &Builder, llvm::Type *T, Value *Count, |
| const Twine &Name, CallInst **caller, |
| Instruction **ZeroMem, bool isDefault) { |
| Value *res; |
| auto &M = *Builder.GetInsertBlock()->getParent()->getParent(); |
| auto AlignI = M.getDataLayout().getTypeAllocSizeInBits(T) / 8; |
| auto Align = ConstantInt::get(Count->getType(), AlignI); |
| CallInst *malloccall = nullptr; |
| if (CustomAllocator) { |
| LLVMValueRef wzeromem = nullptr; |
| res = unwrap(CustomAllocator(wrap(&Builder), wrap(T), wrap(Count), |
| wrap(Align), isDefault, |
| ZeroMem ? &wzeromem : nullptr)); |
| if (auto I = dyn_cast<Instruction>(res)) |
| I->setName(Name); |
| |
| malloccall = dyn_cast<CallInst>(res); |
| if (malloccall == nullptr) { |
| malloccall = cast<CallInst>(cast<Instruction>(res)->getOperand(0)); |
| } |
| if (ZeroMem) { |
| *ZeroMem = cast_or_null<Instruction>(unwrap(wzeromem)); |
| ZeroMem = nullptr; |
| } |
| } else { |
| if (Builder.GetInsertPoint() == Builder.GetInsertBlock()->end()) { |
| res = CallInst::CreateMalloc(Builder.GetInsertBlock(), Count->getType(), |
| T, Align, Count, nullptr, Name); |
| Builder.SetInsertPoint(Builder.GetInsertBlock()); |
| } else { |
| res = CallInst::CreateMalloc(&*Builder.GetInsertPoint(), Count->getType(), |
| T, Align, Count, nullptr, Name); |
| } |
| if (!cast<Instruction>(res)->getParent()) |
| Builder.Insert(cast<Instruction>(res)); |
| |
| malloccall = dyn_cast<CallInst>(res); |
| if (malloccall == nullptr) { |
| malloccall = cast<CallInst>(cast<Instruction>(res)->getOperand(0)); |
| } |
| |
| // Assert computation of size of array doesn't wrap |
| if (auto BI = dyn_cast<BinaryOperator>(malloccall->getArgOperand(0))) { |
| if (BI->getOpcode() == BinaryOperator::Mul) { |
| if ((BI->getOperand(0) == Align && BI->getOperand(1) == Count) || |
| (BI->getOperand(1) == Align && BI->getOperand(0) == Count)) |
| BI->setHasNoSignedWrap(true); |
| BI->setHasNoUnsignedWrap(true); |
| } |
| } |
| |
| if (auto ci = dyn_cast<ConstantInt>(Count)) { |
| #if LLVM_VERSION_MAJOR >= 14 |
| malloccall->addDereferenceableRetAttr(ci->getLimitedValue() * AlignI); |
| #if !defined(FLANG) && !defined(ROCM) |
| AttrBuilder B(ci->getContext()); |
| #else |
| AttrBuilder B; |
| #endif |
| B.addDereferenceableOrNullAttr(ci->getLimitedValue() * AlignI); |
| malloccall->setAttributes(malloccall->getAttributes().addRetAttributes( |
| malloccall->getContext(), B)); |
| #else |
| malloccall->addDereferenceableAttr(llvm::AttributeList::ReturnIndex, |
| ci->getLimitedValue() * AlignI); |
| malloccall->addDereferenceableOrNullAttr(llvm::AttributeList::ReturnIndex, |
| ci->getLimitedValue() * AlignI); |
| #endif |
| // malloccall->removeAttribute(llvm::AttributeList::ReturnIndex, |
| // Attribute::DereferenceableOrNull); |
| } |
| #if LLVM_VERSION_MAJOR >= 14 |
| malloccall->addAttributeAtIndex(AttributeList::ReturnIndex, |
| Attribute::NoAlias); |
| malloccall->addAttributeAtIndex(AttributeList::ReturnIndex, |
| Attribute::NonNull); |
| #else |
| malloccall->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias); |
| malloccall->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); |
| #endif |
| } |
| if (caller) { |
| *caller = malloccall; |
| } |
| if (ZeroMem) { |
| auto PT = cast<PointerType>(malloccall->getType()); |
| Value *tozero = malloccall; |
| |
| bool needsCast = false; |
| #if LLVM_VERSION_MAJOR < 18 |
| #if LLVM_VERSION_MAJOR >= 15 |
| if (PT->getContext().supportsTypedPointers()) { |
| #endif |
| needsCast = !PT->getPointerElementType()->isIntegerTy(8); |
| #if LLVM_VERSION_MAJOR >= 15 |
| } |
| #endif |
| #endif |
| if (needsCast) |
| tozero = Builder.CreatePointerCast( |
| tozero, PointerType::get(Type::getInt8Ty(PT->getContext()), |
| PT->getAddressSpace())); |
| Value *args[] = { |
| tozero, ConstantInt::get(Type::getInt8Ty(malloccall->getContext()), 0), |
| Builder.CreateMul(Align, Count, "", true, true), |
| ConstantInt::getFalse(malloccall->getContext())}; |
| Type *tys[] = {args[0]->getType(), args[2]->getType()}; |
| |
| *ZeroMem = Builder.CreateCall( |
| Intrinsic::getDeclaration(&M, Intrinsic::memset, tys), args); |
| } |
| return res; |
| } |
| |
| CallInst *CreateDealloc(llvm::IRBuilder<> &Builder, llvm::Value *ToFree) { |
| CallInst *res = nullptr; |
| |
| if (CustomDeallocator) { |
| res = dyn_cast_or_null<CallInst>( |
| unwrap(CustomDeallocator(wrap(&Builder), wrap(ToFree)))); |
| } else { |
| |
| ToFree = Builder.CreatePointerCast( |
| ToFree, Type::getInt8PtrTy(ToFree->getContext())); |
| if (Builder.GetInsertPoint() == Builder.GetInsertBlock()->end()) { |
| res = cast<CallInst>( |
| CallInst::CreateFree(ToFree, Builder.GetInsertBlock())); |
| Builder.SetInsertPoint(Builder.GetInsertBlock()); |
| } else { |
| res = cast<CallInst>( |
| CallInst::CreateFree(ToFree, &*Builder.GetInsertPoint())); |
| } |
| if (!cast<Instruction>(res)->getParent()) |
| Builder.Insert(cast<Instruction>(res)); |
| #if LLVM_VERSION_MAJOR >= 14 |
| res->addAttributeAtIndex(AttributeList::FirstArgIndex, Attribute::NonNull); |
| #else |
| res->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull); |
| #endif |
| } |
| return res; |
| } |
| |
| EnzymeFailure::EnzymeFailure(const llvm::Twine &RemarkName, |
| const llvm::DiagnosticLocation &Loc, |
| const llvm::Instruction *CodeRegion) |
| : DiagnosticInfoUnsupported(*CodeRegion->getParent()->getParent(), |
| RemarkName, Loc) {} |
| |
| /// Convert a floating type to a string |
| static inline std::string tofltstr(Type *T) { |
| switch (T->getTypeID()) { |
| case Type::HalfTyID: |
| return "half"; |
| case Type::FloatTyID: |
| return "float"; |
| case Type::DoubleTyID: |
| return "double"; |
| case Type::X86_FP80TyID: |
| return "x87d"; |
| case Type::FP128TyID: |
| return "quad"; |
| case Type::PPC_FP128TyID: |
| return "ppcddouble"; |
| default: |
| llvm_unreachable("Invalid floating type"); |
| } |
| } |
| |
| Constant *getString(Module &M, StringRef Str) { |
| llvm::Constant *s = llvm::ConstantDataArray::getString(M.getContext(), Str); |
| auto *gv = new llvm::GlobalVariable( |
| M, s->getType(), true, llvm::GlobalValue::PrivateLinkage, s, ".str"); |
| gv->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); |
| Value *Idxs[2] = {ConstantInt::get(Type::getInt32Ty(M.getContext()), 0), |
| ConstantInt::get(Type::getInt32Ty(M.getContext()), 0)}; |
| return ConstantExpr::getInBoundsGetElementPtr(s->getType(), gv, Idxs); |
| } |
| |
| void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal, |
| llvm::Value *shadow, const char *Message, |
| llvm::DebugLoc &&loc, llvm::Instruction *orig) { |
| Module &M = *B.GetInsertBlock()->getParent()->getParent(); |
| std::string name = "__enzyme_runtimeinactiveerr"; |
| if (CustomRuntimeInactiveError) { |
| static int count = 0; |
| name += std::to_string(count); |
| count++; |
| } |
| FunctionType *FT = FunctionType::get(Type::getVoidTy(M.getContext()), |
| {Type::getInt8PtrTy(M.getContext()), |
| Type::getInt8PtrTy(M.getContext()), |
| Type::getInt8PtrTy(M.getContext())}, |
| false); |
| |
| Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee()); |
| |
| if (F->empty()) { |
| F->setLinkage(Function::LinkageTypes::InternalLinkage); |
| F->addFnAttr(Attribute::AlwaysInline); |
| F->addParamAttr(0, Attribute::NoCapture); |
| F->addParamAttr(1, Attribute::NoCapture); |
| |
| BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); |
| BasicBlock *error = BasicBlock::Create(M.getContext(), "error", F); |
| BasicBlock *end = BasicBlock::Create(M.getContext(), "end", F); |
| |
| auto prim = F->arg_begin(); |
| prim->setName("primal"); |
| auto shadow = prim + 1; |
| shadow->setName("shadow"); |
| auto msg = prim + 2; |
| msg->setName("msg"); |
| |
| IRBuilder<> EB(entry); |
| EB.CreateCondBr(EB.CreateICmpEQ(prim, shadow), error, end); |
| |
| EB.SetInsertPoint(error); |
| |
| if (CustomRuntimeInactiveError) { |
| CustomRuntimeInactiveError(wrap(&EB), wrap(msg), wrap(orig)); |
| } else { |
| FunctionType *FT = |
| FunctionType::get(Type::getInt32Ty(M.getContext()), |
| {Type::getInt8PtrTy(M.getContext())}, false); |
| |
| auto PutsF = M.getOrInsertFunction("puts", FT); |
| EB.CreateCall(PutsF, msg); |
| |
| FunctionType *FT2 = |
| FunctionType::get(Type::getVoidTy(M.getContext()), |
| {Type::getInt32Ty(M.getContext())}, false); |
| |
| auto ExitF = M.getOrInsertFunction("exit", FT2); |
| EB.CreateCall(ExitF, |
| ConstantInt::get(Type::getInt32Ty(M.getContext()), 1)); |
| } |
| EB.CreateUnreachable(); |
| |
| EB.SetInsertPoint(end); |
| EB.CreateRetVoid(); |
| } |
| |
| Value *args[] = { |
| B.CreatePointerCast(primal, Type::getInt8PtrTy(M.getContext())), |
| B.CreatePointerCast(shadow, Type::getInt8PtrTy(M.getContext())), |
| getString(M, Message)}; |
| auto call = B.CreateCall(F, args); |
| call->setDebugLoc(loc); |
| } |
| |
| /// Create function for type that is equivalent to memcpy but adds to |
| /// destination rather than a direct copy; dst, src, numelems |
| Function *getOrInsertDifferentialFloatMemcpy(Module &M, Type *elementType, |
| unsigned dstalign, |
| unsigned srcalign, |
| unsigned dstaddr, unsigned srcaddr, |
| unsigned bitwidth) { |
| assert(elementType->isFloatingPointTy()); |
| std::string name = "__enzyme_memcpy"; |
| if (bitwidth != 64) |
| name += std::to_string(bitwidth); |
| name += "add_" + tofltstr(elementType) + "da" + std::to_string(dstalign) + |
| "sa" + std::to_string(srcalign); |
| if (dstaddr) |
| name += "dadd" + std::to_string(dstaddr); |
| if (srcaddr) |
| name += "sadd" + std::to_string(srcaddr); |
| FunctionType *FT = |
| FunctionType::get(Type::getVoidTy(M.getContext()), |
| {PointerType::get(elementType, dstaddr), |
| PointerType::get(elementType, srcaddr), |
| IntegerType::get(M.getContext(), bitwidth)}, |
| false); |
| |
| Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee()); |
| |
| if (!F->empty()) |
| return F; |
| |
| F->setLinkage(Function::LinkageTypes::InternalLinkage); |
| #if LLVM_VERSION_MAJOR >= 16 |
| F->setOnlyAccessesArgMemory(); |
| #else |
| F->addFnAttr(Attribute::ArgMemOnly); |
| #endif |
| F->addFnAttr(Attribute::NoUnwind); |
| F->addFnAttr(Attribute::AlwaysInline); |
| F->addParamAttr(0, Attribute::NoCapture); |
| F->addParamAttr(1, Attribute::NoCapture); |
| |
| BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); |
| BasicBlock *body = BasicBlock::Create(M.getContext(), "for.body", F); |
| BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F); |
| |
| auto dst = F->arg_begin(); |
| dst->setName("dst"); |
| auto src = dst + 1; |
| src->setName("src"); |
| auto num = src + 1; |
| num->setName("num"); |
| |
| { |
| IRBuilder<> B(entry); |
| B.CreateCondBr(B.CreateICmpEQ(num, ConstantInt::get(num->getType(), 0)), |
| end, body); |
| } |
| |
| { |
| IRBuilder<> B(body); |
| B.setFastMathFlags(getFast()); |
| PHINode *idx = B.CreatePHI(num->getType(), 2, "idx"); |
| idx->addIncoming(ConstantInt::get(num->getType(), 0), entry); |
| |
| Value *dsti = B.CreateInBoundsGEP(elementType, dst, idx, "dst.i"); |
| LoadInst *dstl = B.CreateLoad(elementType, dsti, "dst.i.l"); |
| StoreInst *dsts = B.CreateStore(Constant::getNullValue(elementType), dsti); |
| if (dstalign) { |
| dstl->setAlignment(Align(dstalign)); |
| dsts->setAlignment(Align(dstalign)); |
| } |
| |
| Value *srci = B.CreateInBoundsGEP(elementType, src, idx, "src.i"); |
| LoadInst *srcl = B.CreateLoad(elementType, srci, "src.i.l"); |
| StoreInst *srcs = B.CreateStore(B.CreateFAdd(srcl, dstl), srci); |
| if (srcalign) { |
| srcl->setAlignment(Align(srcalign)); |
| srcs->setAlignment(Align(srcalign)); |
| } |
| |
| Value *next = |
| B.CreateNUWAdd(idx, ConstantInt::get(num->getType(), 1), "idx.next"); |
| idx->addIncoming(next, body); |
| B.CreateCondBr(B.CreateICmpEQ(num, next), end, body); |
| } |
| |
| { |
| IRBuilder<> B(end); |
| B.CreateRetVoid(); |
| } |
| return F; |
| } |
| |
| void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, |
| llvm::ArrayRef<llvm::Value *> args, |
| llvm::ArrayRef<llvm::OperandBundleDef> bundles) { |
| std::string copy_name = |
| (blas.prefix + blas.floatType + "copy" + blas.suffix).str(); |
| |
| SmallVector<Type *, 1> tys; |
| for (auto arg : args) |
| tys.push_back(arg->getType()); |
| |
| auto FT = FunctionType::get(Type::getVoidTy(M.getContext()), tys, false); |
| auto fn = M.getOrInsertFunction(copy_name, FT); |
| |
| B.CreateCall(fn, args, bundles); |
| } |
| |
| void callMemcpyStridedLapack(llvm::IRBuilder<> &B, llvm::Module &M, |
| BlasInfo blas, llvm::ArrayRef<llvm::Value *> args, |
| llvm::ArrayRef<llvm::OperandBundleDef> bundles) { |
| std::string copy_name = (blas.floatType + "lacpy" + blas.suffix).str(); |
| |
| SmallVector<Type *, 1> tys; |
| for (auto arg : args) |
| tys.push_back(arg->getType()); |
| |
| auto FT = FunctionType::get(Type::getVoidTy(M.getContext()), tys, false); |
| auto fn = M.getOrInsertFunction(copy_name, FT); |
| |
| B.CreateCall(fn, args, bundles); |
| } |
| |
| void callSPMVDiagUpdate(IRBuilder<> &B, Module &M, BlasInfo blas, |
| IntegerType *IT, Type *BlasCT, Type *BlasFPT, |
| Type *BlasPT, Type *BlasIT, Type *fpTy, |
| ArrayRef<Value *> args, |
| ArrayRef<OperandBundleDef> bundles, bool byRef, |
| bool julia_decl) { |
| // add spmv diag update call if not already present |
| std::string fnc_name = |
| ("__enzyme_spmv_diag" + blas.floatType + blas.suffix).str(); |
| |
| // spmvDiagHelper(uplo, n, alpha, x, incx, ya, incy, APa) |
| auto FDiagUpdateT = FunctionType::get( |
| B.getVoidTy(), |
| {BlasCT, BlasIT, BlasFPT, BlasPT, BlasIT, BlasPT, BlasIT, BlasPT}, false); |
| Function *F = |
| cast<Function>(M.getOrInsertFunction(fnc_name, FDiagUpdateT).getCallee()); |
| |
| if (!F->empty()) { |
| B.CreateCall(F, args, bundles); |
| return; |
| } |
| |
| // now add the implementation for the call |
| F->setLinkage(Function::LinkageTypes::InternalLinkage); |
| #if LLVM_VERSION_MAJOR >= 16 |
| F->setOnlyAccessesArgMemory(); |
| #else |
| F->addFnAttr(Attribute::ArgMemOnly); |
| #endif |
| F->addFnAttr(Attribute::NoUnwind); |
| F->addFnAttr(Attribute::AlwaysInline); |
| if (!julia_decl) { |
| F->addParamAttr(3, Attribute::NoCapture); |
| F->addParamAttr(5, Attribute::NoCapture); |
| F->addParamAttr(7, Attribute::NoCapture); |
| F->addParamAttr(3, Attribute::NoAlias); |
| F->addParamAttr(5, Attribute::NoAlias); |
| F->addParamAttr(7, Attribute::NoAlias); |
| F->addParamAttr(3, Attribute::ReadOnly); |
| F->addParamAttr(5, Attribute::ReadOnly); |
| if (byRef) { |
| F->addParamAttr(2, Attribute::NoCapture); |
| F->addParamAttr(2, Attribute::NoAlias); |
| F->addParamAttr(2, Attribute::ReadOnly); |
| } |
| } |
| |
| BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); |
| BasicBlock *init = BasicBlock::Create(M.getContext(), "init", F); |
| BasicBlock *uper_code = BasicBlock::Create(M.getContext(), "uper", F); |
| BasicBlock *lower_code = BasicBlock::Create(M.getContext(), "lower", F); |
| BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F); |
| |
| // spmvDiagHelper(uplo, n, alpha, x, incx, ya, incy, APa) |
| auto blasuplo = F->arg_begin(); |
| blasuplo->setName("blasuplo"); |
| auto blasn = blasuplo + 1; |
| blasn->setName("blasn"); |
| auto blasalpha = blasn + 1; |
| blasalpha->setName("blasalpha"); |
| auto blasx = blasalpha + 1; |
| blasx->setName("blasx"); |
| auto blasincx = blasx + 1; |
| blasincx->setName("blasincx"); |
| auto blasdy = blasx + 1; |
| blasdy->setName("blasdy"); |
| auto blasincy = blasdy + 1; |
| blasincy->setName("blasincy"); |
| auto blasdAP = blasincy + 1; |
| blasdAP->setName("blasdAP"); |
| |
| // TODO: consider cblas_layout |
| |
| // https://dl.acm.org/doi/pdf/10.1145/3382191 |
| // Following example is Fortran based, thus 1 indexed |
| // if(uplo == 'u' .or. uplo == 'U') then |
| // k = 0 |
| // do i = 1,n |
| // k = k+i |
| // APa(k) = APa(k) - alpha*x(1 + (i-1)*incx)*ya(1 + (i-1)*incy) |
| // end do |
| // else |
| // k = 1 |
| // do i = 1,n |
| // APa(k) = APa(k) - alpha*x(1 + (i-1)*incx)*ya(1 + (i-1)*incy) |
| // k = k+n-i+1 |
| // end do |
| // end if |
| { |
| IRBuilder<> B1(entry); |
| Value *n = load_if_ref(B1, IT, blasn, byRef); |
| Value *incx = load_if_ref(B1, IT, blasincx, byRef); |
| Value *incy = load_if_ref(B1, IT, blasincy, byRef); |
| Value *alpha = blasalpha; |
| if (byRef) { |
| auto VP = B1.CreatePointerCast( |
| blasalpha, |
| PointerType::get( |
| fpTy, |
| cast<PointerType>(blasalpha->getType())->getAddressSpace())); |
| alpha = B1.CreateLoad(fpTy, VP); |
| } |
| Value *is_u = is_uper(B1, blasuplo, byRef); |
| Value *k = B1.CreateSelect(is_u, ConstantInt::get(IT, 0), |
| ConstantInt::get(IT, 1), "k"); |
| B1.CreateCondBr(B1.CreateICmpEQ(n, ConstantInt::get(IT, 0)), end, init); |
| |
| IRBuilder<> B2(init); |
| Value *xfloat = B2.CreatePointerCast( |
| blasx, |
| PointerType::get( |
| fpTy, cast<PointerType>(blasx->getType())->getAddressSpace())); |
| Value *dyfloat = B2.CreatePointerCast( |
| blasdy, |
| PointerType::get( |
| fpTy, cast<PointerType>(blasdy->getType())->getAddressSpace())); |
| Value *dAPfloat = B2.CreatePointerCast( |
| blasdAP, |
| PointerType::get( |
| fpTy, cast<PointerType>(blasdAP->getType())->getAddressSpace())); |
| B2.CreateCondBr(is_u, uper_code, lower_code); |
| |
| IRBuilder<> B3(uper_code); |
| B3.setFastMathFlags(getFast()); |
| { |
| PHINode *iter = B3.CreatePHI(IT, 2, "iteration"); |
| PHINode *kval = B3.CreatePHI(IT, 2, "k"); |
| iter->addIncoming(ConstantInt::get(IT, 0), init); |
| kval->addIncoming(ConstantInt::get(IT, 0), init); |
| Value *iternext = |
| B3.CreateAdd(iter, ConstantInt::get(IT, 1), "iter.next"); |
| // 0, 2, 5, 9, 14, 20, 27, 35, 44, 54, ... are diag elements |
| Value *kvalnext = B3.CreateAdd(kval, iternext, "k.next"); |
| iter->addIncoming(iternext, uper_code); |
| kval->addIncoming(kvalnext, uper_code); |
| |
| Value *xidx = B3.CreateNUWMul(iter, incx, "x.idx"); |
| Value *yidx = B3.CreateNUWMul(iter, incy, "y.idx"); |
| Value *x = B3.CreateInBoundsGEP(fpTy, xfloat, xidx, "x.ptr"); |
| Value *y = B3.CreateInBoundsGEP(fpTy, dyfloat, yidx, "y.ptr"); |
| Value *xval = B3.CreateLoad(fpTy, x, "x.val"); |
| Value *yval = B3.CreateLoad(fpTy, y, "y.val"); |
| Value *xy = B3.CreateFMul(xval, yval, "xy"); |
| Value *xyalpha = B3.CreateFMul(xy, alpha, "xy.alpha"); |
| Value *kptr = B3.CreateInBoundsGEP(fpTy, dAPfloat, kval, "k.ptr"); |
| Value *kvalloaded = B3.CreateLoad(fpTy, kptr, "k.val"); |
| Value *kvalnew = B3.CreateFSub(kvalloaded, xyalpha, "k.val.new"); |
| B3.CreateStore(kvalnew, kptr); |
| |
| B3.CreateCondBr(B3.CreateICmpEQ(iternext, n), end, uper_code); |
| } |
| |
| IRBuilder<> B4(lower_code); |
| B4.setFastMathFlags(getFast()); |
| { |
| PHINode *iter = B4.CreatePHI(IT, 2, "iteration"); |
| PHINode *kval = B4.CreatePHI(IT, 2, "k"); |
| iter->addIncoming(ConstantInt::get(IT, 0), init); |
| kval->addIncoming(ConstantInt::get(IT, 0), init); |
| Value *iternext = |
| B4.CreateAdd(iter, ConstantInt::get(IT, 1), "iter.next"); |
| Value *ktmp = B4.CreateAdd(n, ConstantInt::get(IT, 1), "tmp.val"); |
| Value *ktmp2 = B4.CreateSub(ktmp, iternext, "tmp.val.other"); |
| Value *kvalnext = B4.CreateAdd(kval, ktmp2, "k.next"); |
| iter->addIncoming(iternext, lower_code); |
| kval->addIncoming(kvalnext, lower_code); |
| |
| Value *xidx = B4.CreateNUWMul(iter, incx, "x.idx"); |
| Value *yidx = B4.CreateNUWMul(iter, incy, "y.idx"); |
| Value *x = B4.CreateInBoundsGEP(fpTy, xfloat, xidx, "x.ptr"); |
| Value *y = B4.CreateInBoundsGEP(fpTy, dyfloat, yidx, "y.ptr"); |
| Value *xval = B4.CreateLoad(fpTy, x, "x.val"); |
| Value *yval = B4.CreateLoad(fpTy, y, "y.val"); |
| Value *xy = B4.CreateFMul(xval, yval, "xy"); |
| Value *xyalpha = B4.CreateFMul(xy, alpha, "xy.alpha"); |
| Value *kptr = B4.CreateInBoundsGEP(fpTy, dAPfloat, kval, "k.ptr"); |
| Value *kvalloaded = B4.CreateLoad(fpTy, kptr, "k.val"); |
| Value *kvalnew = B4.CreateFSub(kvalloaded, xyalpha, "k.val.new"); |
| B4.CreateStore(kvalnew, kptr); |
| |
| B4.CreateCondBr(B4.CreateICmpEQ(iternext, n), end, lower_code); |
| } |
| |
| IRBuilder<> B5(end); |
| B5.CreateRetVoid(); |
| } |
| B.CreateCall(F, args, bundles); |
| return; |
| } |
| |
| llvm::CallInst * |
| getorInsertInnerProd(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, |
| IntegerType *IT, Type *BlasPT, Type *BlasIT, Type *fpTy, |
| llvm::ArrayRef<llvm::Value *> args, |
| const llvm::ArrayRef<llvm::OperandBundleDef> bundles, |
| bool byRef, bool julia_decl) { |
| assert(fpTy->isFloatingPointTy()); |
| |
| // add inner_prod call if not already present |
| std::string prod_name = |
| ("__enzyme_inner_prod" + blas.floatType + blas.suffix).str(); |
| auto FInnerProdT = |
| FunctionType::get(fpTy, {BlasIT, BlasIT, BlasPT, BlasIT, BlasPT}, false); |
| Function *F = |
| cast<Function>(M.getOrInsertFunction(prod_name, FInnerProdT).getCallee()); |
| |
| if (!F->empty()) |
| return B.CreateCall(F, args, bundles); |
| |
| // add dot call if not already present |
| std::string dot_name = |
| (blas.prefix + blas.floatType + "dot" + blas.suffix).str(); |
| auto FDotT = |
| FunctionType::get(fpTy, {BlasIT, BlasPT, BlasIT, BlasPT, BlasIT}, false); |
| Function *FDot = |
| cast<Function>(M.getOrInsertFunction(dot_name, FDotT).getCallee()); |
| |
| // now add the implementation for the inner_prod call |
| F->setLinkage(Function::LinkageTypes::InternalLinkage); |
| #if LLVM_VERSION_MAJOR >= 16 |
| F->setOnlyAccessesArgMemory(); |
| F->setOnlyReadsMemory(); |
| #else |
| F->addFnAttr(Attribute::ArgMemOnly); |
| F->addFnAttr(Attribute::ReadOnly); |
| #endif |
| F->addFnAttr(Attribute::NoUnwind); |
| F->addFnAttr(Attribute::AlwaysInline); |
| if (!julia_decl) { |
| F->addParamAttr(2, Attribute::NoCapture); |
| F->addParamAttr(4, Attribute::NoCapture); |
| F->addParamAttr(2, Attribute::NoAlias); |
| F->addParamAttr(4, Attribute::NoAlias); |
| F->addParamAttr(2, Attribute::ReadOnly); |
| F->addParamAttr(4, Attribute::ReadOnly); |
| } |
| |
| BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); |
| BasicBlock *init = BasicBlock::Create(M.getContext(), "init.idx", F); |
| BasicBlock *fastPath = BasicBlock::Create(M.getContext(), "fast.path", F); |
| BasicBlock *body = BasicBlock::Create(M.getContext(), "for.body", F); |
| BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F); |
| |
| // This is the .td declaration which we need to match |
| // No need to support ld for the second matrix, as it will |
| // always be based on a matrix which we allocated (contiguous) |
| //(FrobInnerProd<> $m, $n, adj<"C">, $ldc, use<"AB">) |
| |
| auto blasm = F->arg_begin(); |
| blasm->setName("blasm"); |
| auto blasn = blasm + 1; |
| blasn->setName("blasn"); |
| auto matA = blasn + 1; |
| matA->setName("A"); |
| auto blaslda = matA + 1; |
| blaslda->setName("lda"); |
| auto matB = blaslda + 1; |
| matB->setName("B"); |
| |
| { |
| IRBuilder<> B1(entry); |
| Value *blasOne = to_blas_callconv(B1, ConstantInt::get(IT, 1), byRef, IT, |
| B1, "constant.one"); |
| Value *m = load_if_ref(B1, IT, blasm, byRef); |
| Value *n = load_if_ref(B1, IT, blasn, byRef); |
| Value *size = B1.CreateNUWMul(m, n, "mat.size"); |
| Value *blasSize = to_blas_callconv(B1, size, byRef, IT, B1, "mat.size"); |
| B1.CreateCondBr(B1.CreateICmpEQ(size, ConstantInt::get(IT, 0)), end, init); |
| |
| IRBuilder<> B2(init); |
| B2.setFastMathFlags(getFast()); |
| Value *lda = load_if_ref(B2, IT, blaslda, byRef); |
| Value *Afloat = B2.CreatePointerCast( |
| matA, PointerType::get( |
| fpTy, cast<PointerType>(matA->getType())->getAddressSpace())); |
| Value *Bfloat = B2.CreatePointerCast( |
| matB, PointerType::get( |
| fpTy, cast<PointerType>(matB->getType())->getAddressSpace())); |
| B2.CreateCondBr(B2.CreateICmpEQ(m, lda), fastPath, body); |
| |
| // our second matrix is always continuos, by construction. |
| // If our first matrix is continuous too (lda == m), then we can |
| // use a single dot call. |
| IRBuilder<> B3(fastPath); |
| B3.setFastMathFlags(getFast()); |
| Value *blasA = B3.CreatePointerCast(matA, BlasPT); |
| Value *blasB = B3.CreatePointerCast(matB, BlasPT); |
| Value *fastSum = B3.CreateCall( |
| FDot, {blasSize, blasA, blasOne, blasB, blasOne}, bundles); |
| B3.CreateBr(end); |
| |
| IRBuilder<> B4(body); |
| B4.setFastMathFlags(getFast()); |
| PHINode *Aidx = B4.CreatePHI(IT, 2, "Aidx"); |
| PHINode *Bidx = B4.CreatePHI(IT, 2, "Bidx"); |
| PHINode *iter = B4.CreatePHI(IT, 2, "iteration"); |
| PHINode *sum = B4.CreatePHI(fpTy, 2, "sum"); |
| Aidx->addIncoming(ConstantInt::get(IT, 0), init); |
| Bidx->addIncoming(ConstantInt::get(IT, 0), init); |
| iter->addIncoming(ConstantInt::get(IT, 0), init); |
| sum->addIncoming(ConstantFP::get(fpTy, 0.0), init); |
| |
| Value *Ai = B4.CreateInBoundsGEP(fpTy, Afloat, Aidx, "A.i"); |
| Value *Bi = B4.CreateInBoundsGEP(fpTy, Bfloat, Bidx, "B.i"); |
| Value *AiDot = B4.CreatePointerCast(Ai, BlasPT); |
| Value *BiDot = B4.CreatePointerCast(Bi, BlasPT); |
| Value *newDot = |
| B4.CreateCall(FDot, {blasm, AiDot, blasOne, BiDot, blasOne}, bundles); |
| |
| Value *Anext = B4.CreateNUWAdd(Aidx, lda, "Aidx.next"); |
| Value *Bnext = B4.CreateNUWAdd(Aidx, m, "Bidx.next"); |
| Value *iternext = B4.CreateAdd(iter, ConstantInt::get(IT, 1), "iter.next"); |
| Value *sumnext = B4.CreateFAdd(sum, newDot); |
| |
| iter->addIncoming(iternext, body); |
| Aidx->addIncoming(Anext, body); |
| Bidx->addIncoming(Bnext, body); |
| sum->addIncoming(sumnext, body); |
| |
| B4.CreateCondBr(B4.CreateICmpEQ(iter, n), end, body); |
| |
| IRBuilder<> B5(end); |
| PHINode *res = B5.CreatePHI(fpTy, 3, "res"); |
| res->addIncoming(ConstantFP::get(fpTy, 0.0), entry); |
| res->addIncoming(sum, body); |
| res->addIncoming(fastSum, fastPath); |
| B5.CreateRet(res); |
| } |
| |
| return B.CreateCall(F, args, bundles); |
| } |
| |
| Function *getOrInsertMemcpyStrided(Module &M, Type *elementType, PointerType *T, |
| Type *IT, unsigned dstalign, |
| unsigned srcalign) { |
| assert(elementType->isFloatingPointTy()); |
| std::string name = "__enzyme_memcpy_" + tofltstr(elementType) + "_" + |
| std::to_string(cast<IntegerType>(IT)->getBitWidth()) + |
| "_da" + std::to_string(dstalign) + "sa" + |
| std::to_string(srcalign) + "stride"; |
| FunctionType *FT = |
| FunctionType::get(Type::getVoidTy(M.getContext()), {T, T, IT, IT}, false); |
| |
| Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee()); |
| |
| if (!F->empty()) |
| return F; |
| |
| F->setLinkage(Function::LinkageTypes::InternalLinkage); |
| #if LLVM_VERSION_MAJOR >= 16 |
| F->setOnlyAccessesArgMemory(); |
| #else |
| F->addFnAttr(Attribute::ArgMemOnly); |
| #endif |
| F->addFnAttr(Attribute::NoUnwind); |
| F->addFnAttr(Attribute::AlwaysInline); |
| F->addParamAttr(0, Attribute::NoCapture); |
| F->addParamAttr(0, Attribute::NoAlias); |
| F->addParamAttr(1, Attribute::NoCapture); |
| F->addParamAttr(1, Attribute::NoAlias); |
| F->addParamAttr(0, Attribute::WriteOnly); |
| F->addParamAttr(1, Attribute::ReadOnly); |
| |
| BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); |
| BasicBlock *init = BasicBlock::Create(M.getContext(), "init.idx", F); |
| BasicBlock *body = BasicBlock::Create(M.getContext(), "for.body", F); |
| BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F); |
| |
| auto dst = F->arg_begin(); |
| dst->setName("dst"); |
| auto src = dst + 1; |
| src->setName("src"); |
| auto num = src + 1; |
| num->setName("num"); |
| auto stride = num + 1; |
| stride->setName("stride"); |
| |
| { |
| IRBuilder<> B(entry); |
| B.CreateCondBr(B.CreateICmpEQ(num, ConstantInt::get(num->getType(), 0)), |
| end, init); |
| } |
| |
| { |
| IRBuilder<> B2(init); |
| B2.setFastMathFlags(getFast()); |
| Value *a = B2.CreateNSWSub(ConstantInt::get(num->getType(), 1), num, "a"); |
| Value *negidx = B2.CreateNSWMul(a, stride, "negidx"); |
| // Value *negidx = |
| // B2.CreateNSWAdd(b, ConstantInt::get(num->getType(), 1), |
| // "negidx"); |
| Value *isneg = |
| B2.CreateICmpSLT(stride, ConstantInt::get(num->getType(), 0), "is.neg"); |
| Value *startidx = B2.CreateSelect( |
| isneg, negidx, ConstantInt::get(num->getType(), 0), "startidx"); |
| B2.CreateBr(body); |
| //} |
| |
| //{ |
| IRBuilder<> B(body); |
| B.setFastMathFlags(getFast()); |
| PHINode *idx = B.CreatePHI(num->getType(), 2, "idx"); |
| PHINode *sidx = B.CreatePHI(num->getType(), 2, "sidx"); |
| idx->addIncoming(ConstantInt::get(num->getType(), 0), init); |
| sidx->addIncoming(startidx, init); |
| |
| Value *dsti = B.CreateInBoundsGEP(elementType, dst, idx, "dst.i"); |
| Value *srci = B.CreateInBoundsGEP(elementType, src, sidx, "src.i"); |
| LoadInst *srcl = B.CreateLoad(elementType, srci, "src.i.l"); |
| StoreInst *dsts = B.CreateStore(srcl, dsti); |
| |
| if (dstalign) { |
| dsts->setAlignment(Align(dstalign)); |
| } |
| if (srcalign) { |
| srcl->setAlignment(Align(srcalign)); |
| } |
| |
| Value *next = |
| B.CreateNSWAdd(idx, ConstantInt::get(num->getType(), 1), "idx.next"); |
| Value *snext = B.CreateNSWAdd(sidx, stride, "sidx.next"); |
| idx->addIncoming(next, body); |
| sidx->addIncoming(snext, body); |
| B.CreateCondBr(B.CreateICmpEQ(num, next), end, body); |
| } |
| |
| { |
| IRBuilder<> B(end); |
| B.CreateRetVoid(); |
| } |
| |
| return F; |
| } |
| |
| Function *getOrInsertMemcpyMat(Module &Mod, Type *elementType, PointerType *PT, |
| IntegerType *IT, unsigned dstalign, |
| unsigned srcalign) { |
| assert(elementType->isFloatingPointTy()); |
| #if LLVM_VERSION_MAJOR < 18 |
| #if LLVM_VERSION_MAJOR >= 15 |
| if (Mod.getContext().supportsTypedPointers()) { |
| #endif |
| assert(PT->getPointerElementType() == elementType); |
| #if LLVM_VERSION_MAJOR >= 15 |
| } |
| #endif |
| #endif |
| std::string name = "__enzyme_memcpy_" + tofltstr(elementType) + "_mat_" + |
| std::to_string(cast<IntegerType>(IT)->getBitWidth()); |
| FunctionType *FT = FunctionType::get(Type::getVoidTy(Mod.getContext()), |
| {PT, PT, IT, IT, IT}, false); |
| |
| Function *F = cast<Function>(Mod.getOrInsertFunction(name, FT).getCallee()); |
| |
| if (!F->empty()) |
| return F; |
| |
| F->setLinkage(Function::LinkageTypes::InternalLinkage); |
| #if LLVM_VERSION_MAJOR >= 16 |
| F->setOnlyAccessesArgMemory(); |
| #else |
| F->addFnAttr(Attribute::ArgMemOnly); |
| #endif |
| F->addFnAttr(Attribute::NoUnwind); |
| F->addFnAttr(Attribute::AlwaysInline); |
| F->addParamAttr(0, Attribute::NoCapture); |
| F->addParamAttr(0, Attribute::NoAlias); |
| F->addParamAttr(1, Attribute::NoCapture); |
| F->addParamAttr(1, Attribute::NoAlias); |
| F->addParamAttr(0, Attribute::WriteOnly); |
| F->addParamAttr(1, Attribute::ReadOnly); |
| |
| BasicBlock *entry = BasicBlock::Create(F->getContext(), "entry", F); |
| BasicBlock *init = BasicBlock::Create(F->getContext(), "init.idx", F); |
| BasicBlock *body = BasicBlock::Create(F->getContext(), "for.body", F); |
| BasicBlock *initend = BasicBlock::Create(F->getContext(), "init.end", F); |
| BasicBlock *end = BasicBlock::Create(F->getContext(), "for.end", F); |
| |
| auto dst = F->arg_begin(); |
| dst->setName("dst"); |
| auto src = dst + 1; |
| src->setName("src"); |
| auto M = src + 1; |
| M->setName("M"); |
| auto N = M + 1; |
| N->setName("N"); |
| auto LDA = N + 1; |
| LDA->setName("LDA"); |
| |
| { |
| IRBuilder<> B(entry); |
| Value *l = B.CreateAdd(M, N, "mul", true, true); |
| // Don't copy a 0*0 matrix |
| B.CreateCondBr(B.CreateICmpEQ(l, ConstantInt::get(IT, 0)), end, init); |
| } |
| |
| PHINode *j; |
| { |
| IRBuilder<> B(init); |
| j = B.CreatePHI(IT, 2, "j"); |
| j->addIncoming(ConstantInt::get(IT, 0), entry); |
| B.CreateBr(body); |
| } |
| |
| { |
| IRBuilder<> B(body); |
| PHINode *i = B.CreatePHI(IT, 2, "i"); |
| i->addIncoming(ConstantInt::get(IT, 0), init); |
| |
| Value *dsti = B.CreateInBoundsGEP( |
| elementType, dst, |
| B.CreateAdd(i, B.CreateMul(j, M, "", true, true), "", true, true), |
| "dst.i"); |
| Value *srci = B.CreateInBoundsGEP( |
| elementType, src, |
| B.CreateAdd(i, B.CreateMul(j, LDA, "", true, true), "", true, true), |
| "dst.i"); |
| LoadInst *srcl = B.CreateLoad(elementType, srci, "src.i.l"); |
| |
| StoreInst *dsts = B.CreateStore(srcl, dsti); |
| |
| if (dstalign) { |
| dsts->setAlignment(Align(dstalign)); |
| } |
| if (srcalign) { |
| srcl->setAlignment(Align(srcalign)); |
| } |
| |
| Value *nexti = |
| B.CreateAdd(i, ConstantInt::get(IT, 1), "i.next", true, true); |
| i->addIncoming(nexti, body); |
| B.CreateCondBr(B.CreateICmpEQ(nexti, M), initend, body); |
| } |
| |
| { |
| IRBuilder<> B(initend); |
| Value *nextj = |
| B.CreateAdd(j, ConstantInt::get(IT, 1), "j.next", true, true); |
| j->addIncoming(nextj, initend); |
| B.CreateCondBr(B.CreateICmpEQ(nextj, N), end, init); |
| } |
| |
| { |
| IRBuilder<> B(end); |
| B.CreateRetVoid(); |
| } |
| |
| return F; |
| } |
| |
| // TODO implement differential memmove |
| Function * |
| getOrInsertDifferentialFloatMemmove(Module &M, Type *T, unsigned dstalign, |
| unsigned srcalign, unsigned dstaddr, |
| unsigned srcaddr, unsigned bitwidth) { |
| llvm::errs() << "warning: didn't implement memmove, using memcpy as fallback " |
| "which can result in errors\n"; |
| return getOrInsertDifferentialFloatMemcpy(M, T, dstalign, srcalign, dstaddr, |
| srcaddr, bitwidth); |
| } |
| |
| Function *getOrInsertCheckedFree(Module &M, CallInst *call, Type *Ty, |
| unsigned width) { |
| FunctionType *FreeTy = call->getFunctionType(); |
| Value *Free = call->getCalledOperand(); |
| AttributeList FreeAttributes = call->getAttributes(); |
| CallingConv::ID CallingConvention = call->getCallingConv(); |
| DebugLoc DebugLoc = call->getDebugLoc(); |
| |
| std::string name = "__enzyme_checked_free_" + std::to_string(width); |
| |
| SmallVector<Type *, 3> types; |
| types.push_back(Ty); |
| for (unsigned i = 0; i < width; i++) { |
| types.push_back(Ty); |
| } |
| |
| FunctionType *FT = |
| FunctionType::get(Type::getVoidTy(M.getContext()), types, false); |
| |
| Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee()); |
| |
| if (!F->empty()) |
| return F; |
| |
| F->setLinkage(Function::LinkageTypes::InternalLinkage); |
| #if LLVM_VERSION_MAJOR >= 16 |
| F->setOnlyAccessesArgMemory(); |
| #else |
| F->addFnAttr(Attribute::ArgMemOnly); |
| #endif |
| F->addFnAttr(Attribute::NoUnwind); |
| F->addFnAttr(Attribute::AlwaysInline); |
| |
| BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); |
| BasicBlock *free0 = BasicBlock::Create(M.getContext(), "free0", F); |
| BasicBlock *end = BasicBlock::Create(M.getContext(), "end", F); |
| |
| IRBuilder<> EntryBuilder(entry); |
| IRBuilder<> Free0Builder(free0); |
| IRBuilder<> EndBuilder(end); |
| |
| auto primal = F->arg_begin(); |
| Argument *first_shadow = F->arg_begin() + 1; |
| F->addParamAttr(0, Attribute::NoCapture); |
| F->addParamAttr(1, Attribute::NoCapture); |
| |
| Value *isNotEqual = EntryBuilder.CreateICmpNE(primal, first_shadow); |
| EntryBuilder.CreateCondBr(isNotEqual, free0, end); |
| |
| CallInst *CI = Free0Builder.CreateCall(FreeTy, Free, {first_shadow}); |
| CI->setAttributes(FreeAttributes); |
| CI->setCallingConv(CallingConvention); |
| CI->setDebugLoc(DebugLoc); |
| |
| if (width > 1) { |
| Value *checkResult = nullptr; |
| BasicBlock *free1 = BasicBlock::Create(M.getContext(), "free1", F); |
| IRBuilder<> Free1Builder(free1); |
| |
| for (unsigned i = 0; i < width; i++) { |
| F->addParamAttr(i + 1, Attribute::NoCapture); |
| Argument *shadow = F->arg_begin() + i + 1; |
| |
| if (i < width - 1) { |
| Argument *nextShadow = F->arg_begin() + i + 2; |
| Value *isNotEqual = Free0Builder.CreateICmpNE(shadow, nextShadow); |
| checkResult = checkResult |
| ? Free0Builder.CreateAnd(isNotEqual, checkResult) |
| : isNotEqual; |
| |
| CallInst *CI = Free1Builder.CreateCall(FreeTy, Free, {nextShadow}); |
| CI->setAttributes(FreeAttributes); |
| CI->setCallingConv(CallingConvention); |
| CI->setDebugLoc(DebugLoc); |
| } |
| } |
| Free0Builder.CreateCondBr(checkResult, free1, end); |
| Free1Builder.CreateBr(end); |
| } else { |
| Free0Builder.CreateBr(end); |
| } |
| |
| EndBuilder.CreateRetVoid(); |
| |
| return F; |
| } |
| |
| /// Create function to computer nearest power of two |
| llvm::Value *nextPowerOfTwo(llvm::IRBuilder<> &B, llvm::Value *V) { |
| assert(V->getType()->isIntegerTy()); |
| IntegerType *T = cast<IntegerType>(V->getType()); |
| V = B.CreateAdd(V, ConstantInt::get(T, -1)); |
| for (size_t i = 1; i < T->getBitWidth(); i *= 2) { |
| V = B.CreateOr(V, B.CreateLShr(V, ConstantInt::get(T, i))); |
| } |
| V = B.CreateAdd(V, ConstantInt::get(T, 1)); |
| return V; |
| } |
| |
| llvm::Function *getOrInsertDifferentialWaitallSave(llvm::Module &M, |
| ArrayRef<llvm::Type *> T, |
| PointerType *reqType) { |
| std::string name = "__enzyme_differential_waitall_save"; |
| FunctionType *FT = |
| FunctionType::get(PointerType::getUnqual(reqType), T, false); |
| Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee()); |
| |
| if (!F->empty()) |
| return F; |
| |
| F->setLinkage(Function::LinkageTypes::InternalLinkage); |
| F->addFnAttr(Attribute::NoUnwind); |
| F->addFnAttr(Attribute::AlwaysInline); |
| |
| BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); |
| |
| auto buff = F->arg_begin(); |
| buff->setName("count"); |
| Value *count = buff; |
| Value *req = buff + 1; |
| req->setName("req"); |
| Value *dreq = buff + 2; |
| dreq->setName("dreq"); |
| |
| IRBuilder<> B(entry); |
| count = B.CreateZExtOrTrunc(count, Type::getInt64Ty(entry->getContext())); |
| |
| auto ret = CreateAllocation(B, reqType, count); |
| |
| BasicBlock *loopBlock = BasicBlock::Create(M.getContext(), "loop", F); |
| BasicBlock *endBlock = BasicBlock::Create(M.getContext(), "end", F); |
| |
| B.CreateCondBr(B.CreateICmpEQ(count, ConstantInt::get(count->getType(), 0)), |
| endBlock, loopBlock); |
| |
| B.SetInsertPoint(loopBlock); |
| auto idx = B.CreatePHI(count->getType(), 2); |
| idx->addIncoming(ConstantInt::get(count->getType(), 0), entry); |
| auto inc = B.CreateAdd(idx, ConstantInt::get(count->getType(), 1)); |
| idx->addIncoming(inc, loopBlock); |
| |
| Type *reqT = reqType; // req->getType()->getPointerElementType(); |
| Value *idxs[] = {idx}; |
| Value *ireq = B.CreateInBoundsGEP(reqT, req, idxs); |
| Value *idreq = B.CreateInBoundsGEP(reqT, dreq, idxs); |
| Value *iout = B.CreateInBoundsGEP(reqType, ret, idxs); |
| Value *isNull = nullptr; |
| if (auto GV = M.getNamedValue("ompi_request_null")) { |
| Value *reql = |
| B.CreatePointerCast(ireq, PointerType::getUnqual(GV->getType())); |
| reql = B.CreateLoad(GV->getType(), reql); |
| isNull = B.CreateICmpEQ(reql, GV); |
| } |
| |
| idreq = B.CreatePointerCast(idreq, PointerType::getUnqual(reqType)); |
| Value *d_reqp = B.CreateLoad(reqType, idreq); |
| if (isNull) |
| d_reqp = B.CreateSelect(isNull, Constant::getNullValue(d_reqp->getType()), |
| d_reqp); |
| |
| B.CreateStore(d_reqp, iout); |
| |
| B.CreateCondBr(B.CreateICmpEQ(inc, count), endBlock, loopBlock); |
| |
| B.SetInsertPoint(endBlock); |
| B.CreateRet(ret); |
| return F; |
| } |
| |
| llvm::Function *getOrInsertDifferentialMPI_Wait(llvm::Module &M, |
| ArrayRef<llvm::Type *> T, |
| Type *reqType) { |
| llvm::SmallVector<llvm::Type *, 4> types(T.begin(), T.end()); |
| types.push_back(reqType); |
| std::string name = "__enzyme_differential_mpi_wait"; |
| FunctionType *FT = |
| FunctionType::get(Type::getVoidTy(M.getContext()), types, false); |
| Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee()); |
| |
| if (!F->empty()) |
| return F; |
| |
| F->setLinkage(Function::LinkageTypes::InternalLinkage); |
| F->addFnAttr(Attribute::NoUnwind); |
| F->addFnAttr(Attribute::AlwaysInline); |
| |
| BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); |
| BasicBlock *isend = BasicBlock::Create(M.getContext(), "invertISend", F); |
| BasicBlock *irecv = BasicBlock::Create(M.getContext(), "invertIRecv", F); |
| |
| #if 0 |
| /*0 */Type::getInt8PtrTy(call.getContext()) |
| /*1 */i64 |
| /*2 */Type::getInt8PtrTy(call.getContext()) |
| /*3 */i64 |
| /*4 */i64 |
| /*5 */Type::getInt8PtrTy(call.getContext()) |
| /*6 */Type::getInt8Ty(call.getContext()) |
| #endif |
| |
| auto buff = F->arg_begin(); |
| buff->setName("buf"); |
| Value *buf = buff; |
| Value *count = buff + 1; |
| count->setName("count"); |
| Value *datatype = buff + 2; |
| datatype->setName("datatype"); |
| Value *source = buff + 3; |
| source->setName("source"); |
| Value *tag = buff + 4; |
| tag->setName("tag"); |
| Value *comm = buff + 5; |
| comm->setName("comm"); |
| Value *fn = buff + 6; |
| fn->setName("fn"); |
| Value *d_req = buff + 7; |
| d_req->setName("d_req"); |
| |
| bool pmpi = true; |
| auto isendfn = M.getFunction("PMPI_Isend"); |
| if (!isendfn) { |
| isendfn = M.getFunction("MPI_Isend"); |
| pmpi = false; |
| } |
| assert(isendfn); |
| auto irecvfn = M.getFunction("PMPI_Irecv"); |
| if (!irecvfn) |
| irecvfn = M.getFunction("MPI_Irecv"); |
| if (!irecvfn) { |
| FunctionType *FuT = isendfn->getFunctionType(); |
| std::string name = pmpi ? "PMPI_Irecv" : "MPI_Irecv"; |
| irecvfn = cast<Function>(M.getOrInsertFunction(name, FuT).getCallee()); |
| } |
| assert(irecvfn); |
| |
| IRBuilder<> B(entry); |
| auto arg = isendfn->arg_begin(); |
| if (arg->getType()->isIntegerTy()) |
| buf = B.CreatePtrToInt(buf, arg->getType()); |
| arg++; |
| count = B.CreateZExtOrTrunc(count, arg->getType()); |
| arg++; |
| datatype = B.CreatePointerCast(datatype, arg->getType()); |
| arg++; |
| source = B.CreateZExtOrTrunc(source, arg->getType()); |
| arg++; |
| tag = B.CreateZExtOrTrunc(tag, arg->getType()); |
| arg++; |
| comm = B.CreatePointerCast(comm, arg->getType()); |
| arg++; |
| if (arg->getType()->isIntegerTy()) |
| d_req = B.CreatePtrToInt(d_req, arg->getType()); |
| Value *args[] = { |
| buf, count, datatype, source, tag, comm, d_req, |
| }; |
| |
| B.CreateCondBr(B.CreateICmpEQ(fn, ConstantInt::get(fn->getType(), |
| (int)MPI_CallType::ISEND)), |
| isend, irecv); |
| |
| { |
| B.SetInsertPoint(isend); |
| auto fcall = B.CreateCall(irecvfn, args); |
| fcall->setCallingConv(isendfn->getCallingConv()); |
| B.CreateRetVoid(); |
| } |
| |
| { |
| B.SetInsertPoint(irecv); |
| auto fcall = B.CreateCall(isendfn, args); |
| fcall->setCallingConv(isendfn->getCallingConv()); |
| B.CreateRetVoid(); |
| } |
| return F; |
| } |
| |
| llvm::Value *getOrInsertOpFloatSum(llvm::Module &M, llvm::Type *OpPtr, |
| llvm::Type *OpType, ConcreteType CT, |
| llvm::Type *intType, IRBuilder<> &B2) { |
| std::string name = "__enzyme_mpi_sum" + CT.str(); |
| assert(CT.isFloat()); |
| auto FlT = CT.isFloat(); |
| |
| if (auto Glob = M.getGlobalVariable(name)) { |
| return B2.CreateLoad(Glob->getValueType(), Glob); |
| } |
| |
| llvm::Type *types[] = {PointerType::getUnqual(FlT), |
| PointerType::getUnqual(FlT), |
| PointerType::getUnqual(intType), OpPtr}; |
| FunctionType *FuT = |
| FunctionType::get(Type::getVoidTy(M.getContext()), types, false); |
| Function *F = |
| cast<Function>(M.getOrInsertFunction(name + "_run", FuT).getCallee()); |
| |
| F->setLinkage(Function::LinkageTypes::InternalLinkage); |
| #if LLVM_VERSION_MAJOR >= 16 |
| F->setOnlyAccessesArgMemory(); |
| #else |
| F->addFnAttr(Attribute::ArgMemOnly); |
| #endif |
| F->addFnAttr(Attribute::NoUnwind); |
| F->addFnAttr(Attribute::AlwaysInline); |
| F->addParamAttr(0, Attribute::NoCapture); |
| F->addParamAttr(0, Attribute::ReadOnly); |
| F->addParamAttr(1, Attribute::NoCapture); |
| F->addParamAttr(2, Attribute::NoCapture); |
| F->addParamAttr(2, Attribute::ReadOnly); |
| F->addParamAttr(3, Attribute::NoCapture); |
| F->addParamAttr(3, Attribute::ReadNone); |
| |
| BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); |
| BasicBlock *body = BasicBlock::Create(M.getContext(), "for.body", F); |
| BasicBlock *end = BasicBlock::Create(M.getContext(), "for.end", F); |
| |
| auto src = F->arg_begin(); |
| src->setName("src"); |
| auto dst = src + 1; |
| dst->setName("dst"); |
| auto lenp = dst + 1; |
| lenp->setName("lenp"); |
| Value *len; |
| // TODO consider using datatype arg and asserting same size as assumed |
| // by type analysis |
| |
| { |
| IRBuilder<> B(entry); |
| len = B.CreateLoad(intType, lenp); |
| B.CreateCondBr(B.CreateICmpEQ(len, ConstantInt::get(len->getType(), 0)), |
| end, body); |
| } |
| |
| { |
| IRBuilder<> B(body); |
| B.setFastMathFlags(getFast()); |
| PHINode *idx = B.CreatePHI(len->getType(), 2, "idx"); |
| idx->addIncoming(ConstantInt::get(len->getType(), 0), entry); |
| |
| Value *dsti = B.CreateInBoundsGEP(FlT, dst, idx, "dst.i"); |
| LoadInst *dstl = B.CreateLoad(FlT, dsti, "dst.i.l"); |
| |
| Value *srci = B.CreateInBoundsGEP(FlT, src, idx, "src.i"); |
| LoadInst *srcl = B.CreateLoad(FlT, srci, "src.i.l"); |
| B.CreateStore(B.CreateFAdd(srcl, dstl), dsti); |
| |
| Value *next = |
| B.CreateNUWAdd(idx, ConstantInt::get(len->getType(), 1), "idx.next"); |
| idx->addIncoming(next, body); |
| B.CreateCondBr(B.CreateICmpEQ(len, next), end, body); |
| } |
| |
| { |
| IRBuilder<> B(end); |
| B.CreateRetVoid(); |
| } |
| |
| llvm::Type *rtypes[] = {Type::getInt8PtrTy(M.getContext()), intType, OpPtr}; |
| FunctionType *RFT = FunctionType::get(intType, rtypes, false); |
| |
| Constant *RF = M.getNamedValue("MPI_Op_create"); |
| if (!RF) { |
| RF = |
| cast<Function>(M.getOrInsertFunction("MPI_Op_create", RFT).getCallee()); |
| } else { |
| RF = ConstantExpr::getBitCast(RF, PointerType::getUnqual(RFT)); |
| } |
| |
| GlobalVariable *GV = |
| new GlobalVariable(M, OpType, false, GlobalVariable::InternalLinkage, |
| UndefValue::get(OpType), name); |
| |
| Type *i1Ty = Type::getInt1Ty(M.getContext()); |
| GlobalVariable *initD = new GlobalVariable( |
| M, i1Ty, false, GlobalVariable::InternalLinkage, |
| ConstantInt::getFalse(M.getContext()), name + "_initd"); |
| |
| // Finish initializing mpi sum |
| // https://www.mpich.org/static/docs/v3.2/www3/MPI_Op_create.html |
| FunctionType *IFT = FunctionType::get(Type::getVoidTy(M.getContext()), |
| ArrayRef<Type *>(), false); |
| Function *initializerFunction = cast<Function>( |
| M.getOrInsertFunction(name + "initializer", IFT).getCallee()); |
| |
| initializerFunction->setLinkage(Function::LinkageTypes::InternalLinkage); |
| initializerFunction->addFnAttr(Attribute::NoUnwind); |
| |
| { |
| BasicBlock *entry = |
| BasicBlock::Create(M.getContext(), "entry", initializerFunction); |
| BasicBlock *run = |
| BasicBlock::Create(M.getContext(), "run", initializerFunction); |
| BasicBlock *end = |
| BasicBlock::Create(M.getContext(), "end", initializerFunction); |
| IRBuilder<> B(entry); |
| |
| B.CreateCondBr(B.CreateLoad(initD->getValueType(), initD), end, run); |
| |
| B.SetInsertPoint(run); |
| Value *args[] = {ConstantExpr::getPointerCast(F, rtypes[0]), |
| ConstantInt::get(rtypes[1], 1, false), |
| ConstantExpr::getPointerCast(GV, rtypes[2])}; |
| B.CreateCall(RFT, RF, args); |
| B.CreateStore(ConstantInt::getTrue(M.getContext()), initD); |
| B.CreateBr(end); |
| B.SetInsertPoint(end); |
| B.CreateRetVoid(); |
| } |
| |
| B2.CreateCall(M.getFunction(name + "initializer")); |
| return B2.CreateLoad(GV->getValueType(), GV); |
| } |
| |
| void mayExecuteAfter(llvm::SmallVectorImpl<llvm::Instruction *> &results, |
| llvm::Instruction *inst, |
| const llvm::SmallPtrSetImpl<Instruction *> &stores, |
| const llvm::Loop *region) { |
| using namespace llvm; |
| std::map<BasicBlock *, SmallVector<Instruction *, 1>> maybeBlocks; |
| BasicBlock *instBlk = inst->getParent(); |
| for (auto store : stores) { |
| BasicBlock *storeBlk = store->getParent(); |
| if (instBlk == storeBlk) { |
| // if store doesn't come before, exit. |
| |
| if (store != inst) { |
| BasicBlock::const_iterator It = storeBlk->begin(); |
| for (; &*It != store && &*It != inst; ++It) |
| /*empty*/; |
| // if inst comes first (e.g. before store) in the |
| // block, return true |
| if (&*It == inst) { |
| results.push_back(store); |
| } |
| } |
| maybeBlocks[storeBlk].push_back(store); |
| } else { |
| maybeBlocks[storeBlk].push_back(store); |
| } |
| } |
| |
| if (maybeBlocks.size() == 0) |
| return; |
| |
| llvm::SmallVector<BasicBlock *, 2> todo; |
| for (auto B : successors(instBlk)) { |
| if (region && region->getHeader() == B) { |
| continue; |
| } |
| todo.push_back(B); |
| } |
| |
| SmallPtrSet<BasicBlock *, 2> seen; |
| while (todo.size()) { |
| auto cur = todo.back(); |
| todo.pop_back(); |
| if (seen.count(cur)) |
| continue; |
| seen.insert(cur); |
| auto found = maybeBlocks.find(cur); |
| if (found != maybeBlocks.end()) { |
| for (auto store : found->second) |
| results.push_back(store); |
| maybeBlocks.erase(found); |
| } |
| for (auto B : successors(cur)) { |
| if (region && region->getHeader() == B) { |
| continue; |
| } |
| todo.push_back(B); |
| } |
| } |
| } |
| |
| bool overwritesToMemoryReadByLoop( |
| llvm::ScalarEvolution &SE, llvm::LoopInfo &LI, llvm::DominatorTree &DT, |
| llvm::Instruction *maybeReader, const llvm::SCEV *LoadStart, |
| const llvm::SCEV *LoadEnd, llvm::Instruction *maybeWriter, |
| const llvm::SCEV *StoreStart, const llvm::SCEV *StoreEnd, |
| llvm::Loop *scope) { |
| // The store may either occur directly after the load in the current loop |
| // nest, or prior to the load in a subsequent iteration of the loop nest |
| // Generally: |
| // L0 -> scope -> L1 -> L2 -> L3 -> load_L4 -> load_L5 ... Load |
| // \-> store_L4 -> store_L5 ... Store |
| // We begin by finding the common ancestor of the two loops, which may |
| // be none. |
| Loop *anc = getAncestor(LI.getLoopFor(maybeReader->getParent()), |
| LI.getLoopFor(maybeWriter->getParent())); |
| |
| // The surrounding scope must contain the ancestor |
| if (scope) { |
| assert(anc); |
| assert(scope == anc || scope->contains(anc)); |
| } |
| |
| // Consider the case where the load and store don't share any common loops. |
| // That is to say, there's no loops in [scope, ancestor) we need to consider |
| // having a store in a later iteration overwrite the load of a previous |
| // iteration. |
| // |
| // An example of this overwriting would be a "left shift" |
| // for (int j = 1; j<N; j++) { |
| // load A[j] |
| // store A[j-1] |
| // } |
| // |
| // Ignoring such ancestors, if we compare the two regions to have no direct |
| // overlap we can return that it doesn't overwrite memory if the regions |
| // don't overlap at any level of region expansion. That is to say, we can |
| // expand the start or end, for any loop to be the worst case scenario |
| // given the loop bounds. |
| // |
| // However, now let us consider the case where there are surrounding loops. |
| // If the storing boundary is represented by an induction variable of one |
| // of these common loops, we must conseratively expand it all the way to the |
| // end. We will also mark the loops we may expand. If we encounter all |
| // intervening loops in this fashion, and it is proven safe in these cases, |
| // the region does not overlap. However, if we don't encounter all surrounding |
| // loops in our induction expansion, we may simply be repeating the write |
| // which we should also ensure we say the region may overlap (due to the |
| // repetition). |
| // |
| // Since we also have a Loop scope, we can ignore any common loops at the |
| // scope level or above |
| |
| /// We force all ranges for all loops in range ... [scope, anc], .... cur |
| /// to expand the number of iterations |
| |
| SmallPtrSet<const Loop *, 1> visitedAncestors; |
| auto skipLoop = [&](const Loop *L) { |
| assert(L); |
| if (scope && L->contains(scope)) |
| return false; |
| |
| if (anc && (anc == L || anc->contains(L))) { |
| visitedAncestors.insert(L); |
| return true; |
| } |
| return false; |
| }; |
| |
| // Check the boounds of an [... endprev][startnext ...] for potential |
| // overlaps. The boolean EndIsStore is true of the EndPev represents |
| // the store and should have its loops expanded, or if that should |
| // apply to StartNed. |
| auto hasOverlap = [&](const SCEV *EndPrev, const SCEV *StartNext, |
| bool EndIsStore) { |
| for (auto slim = StartNext; slim != SE.getCouldNotCompute();) { |
| bool sskip = false; |
| if (!EndIsStore) |
| if (auto startL = dyn_cast<SCEVAddRecExpr>(slim)) |
| if (skipLoop(startL->getLoop()) && |
| SE.isKnownNonPositive(startL->getStepRecurrence(SE))) { |
| sskip = true; |
| } |
| |
| if (!sskip) |
| for (auto elim = EndPrev; elim != SE.getCouldNotCompute();) { |
| { |
| |
| bool eskip = false; |
| if (EndIsStore) |
| if (auto endL = dyn_cast<SCEVAddRecExpr>(elim)) { |
| if (skipLoop(endL->getLoop()) && |
| SE.isKnownNonNegative(endL->getStepRecurrence(SE))) { |
| eskip = true; |
| } |
| } |
| |
| // Moreover because otherwise SE cannot "groupScevByComplexity" |
| // we need to ensure that if both slim/elim are AddRecv |
| // they must be in the same loop, or one loop must dominate |
| // the other. |
| if (!eskip) { |
| |
| if (auto endL = dyn_cast<SCEVAddRecExpr>(elim)) { |
| auto EH = endL->getLoop()->getHeader(); |
| if (auto startL = dyn_cast<SCEVAddRecExpr>(slim)) { |
| auto SH = startL->getLoop()->getHeader(); |
| if (EH != SH && !DT.dominates(EH, SH) && |
| !DT.dominates(SH, EH)) |
| eskip = true; |
| } |
| } |
| } |
| if (!eskip) { |
| auto sub = SE.getMinusSCEV(slim, elim); |
| if (sub != SE.getCouldNotCompute() && SE.isKnownNonNegative(sub)) |
| return false; |
| } |
| } |
| |
| if (auto endL = dyn_cast<SCEVAddRecExpr>(elim)) { |
| if (SE.isKnownNonPositive(endL->getStepRecurrence(SE))) { |
| elim = endL->getStart(); |
| continue; |
| } else if (SE.isKnownNonNegative(endL->getStepRecurrence(SE))) { |
| #if LLVM_VERSION_MAJOR >= 12 |
| auto ebd = SE.getSymbolicMaxBackedgeTakenCount(endL->getLoop()); |
| #else |
| auto ebd = SE.getBackedgeTakenCount(endL->getLoop()); |
| #endif |
| if (ebd == SE.getCouldNotCompute()) |
| break; |
| elim = endL->evaluateAtIteration(ebd, SE); |
| continue; |
| } |
| } |
| break; |
| } |
| |
| if (auto startL = dyn_cast<SCEVAddRecExpr>(slim)) { |
| if (SE.isKnownNonNegative(startL->getStepRecurrence(SE))) { |
| slim = startL->getStart(); |
| continue; |
| } else if (SE.isKnownNonPositive(startL->getStepRecurrence(SE))) { |
| #if LLVM_VERSION_MAJOR >= 12 |
| auto sbd = SE.getSymbolicMaxBackedgeTakenCount(startL->getLoop()); |
| #else |
| auto sbd = SE.getBackedgeTakenCount(startL->getLoop()); |
| #endif |
| if (sbd == SE.getCouldNotCompute()) |
| break; |
| slim = startL->evaluateAtIteration(sbd, SE); |
| continue; |
| } |
| } |
| break; |
| } |
| return true; |
| }; |
| |
| // There is no overwrite if either the stores all occur before the loads |
| // [S, S+Size][start load, L+Size] |
| visitedAncestors.clear(); |
| if (!hasOverlap(StoreEnd, LoadStart, /*EndIsStore*/ true)) { |
| // We must have seen all common loops as induction variables |
| // to be legal, lest we have a repetition of the store. |
| bool legal = true; |
| for (const Loop *L = anc; anc != scope; anc = anc->getParentLoop()) { |
| if (!visitedAncestors.count(L)) |
| legal = false; |
| } |
| if (legal) |
| return false; |
| } |
| |
| // There is no overwrite if either the loads all occur before the stores |
| // [start load, L+Size] [S, S+Size] |
| visitedAncestors.clear(); |
| if (!hasOverlap(LoadEnd, StoreStart, /*EndIsStore*/ false)) { |
| // We must have seen all common loops as induction variables |
| // to be legal, lest we have a repetition of the store. |
| bool legal = true; |
| for (const Loop *L = anc; anc != scope; anc = anc->getParentLoop()) { |
| if (!visitedAncestors.count(L)) |
| legal = false; |
| } |
| if (legal) |
| return false; |
| } |
| return true; |
| } |
| |
| bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, |
| ScalarEvolution &SE, llvm::LoopInfo &LI, |
| llvm::DominatorTree &DT, |
| llvm::Instruction *maybeReader, |
| llvm::Instruction *maybeWriter, |
| llvm::Loop *scope) { |
| using namespace llvm; |
| if (!writesToMemoryReadBy(AA, TLI, maybeReader, maybeWriter)) |
| return false; |
| const SCEV *LoadBegin = SE.getCouldNotCompute(); |
| const SCEV *LoadEnd = SE.getCouldNotCompute(); |
| |
| const SCEV *StoreBegin = SE.getCouldNotCompute(); |
| const SCEV *StoreEnd = SE.getCouldNotCompute(); |
| |
| if (auto LI = dyn_cast<LoadInst>(maybeReader)) { |
| LoadBegin = SE.getSCEV(LI->getPointerOperand()); |
| if (LoadBegin != SE.getCouldNotCompute()) { |
| auto &DL = maybeWriter->getModule()->getDataLayout(); |
| auto width = cast<IntegerType>(DL.getIndexType(LoadBegin->getType())) |
| ->getBitWidth(); |
| auto TS = SE.getConstant( |
| APInt(width, DL.getTypeStoreSize(LI->getType()).getFixedSize())); |
| LoadEnd = SE.getAddExpr(LoadBegin, TS); |
| } |
| } |
| if (auto SI = dyn_cast<StoreInst>(maybeWriter)) { |
| StoreBegin = SE.getSCEV(SI->getPointerOperand()); |
| if (StoreBegin != SE.getCouldNotCompute()) { |
| auto &DL = maybeWriter->getModule()->getDataLayout(); |
| auto width = cast<IntegerType>(DL.getIndexType(StoreBegin->getType())) |
| ->getBitWidth(); |
| auto TS = SE.getConstant( |
| APInt(width, DL.getTypeStoreSize(SI->getValueOperand()->getType()) |
| .getFixedSize())); |
| StoreEnd = SE.getAddExpr(StoreBegin, TS); |
| } |
| } |
| if (auto MS = dyn_cast<MemSetInst>(maybeWriter)) { |
| StoreBegin = SE.getSCEV(MS->getArgOperand(0)); |
| if (StoreBegin != SE.getCouldNotCompute()) { |
| if (auto Len = dyn_cast<ConstantInt>(MS->getArgOperand(2))) { |
| auto &DL = MS->getModule()->getDataLayout(); |
| auto width = cast<IntegerType>(DL.getIndexType(StoreBegin->getType())) |
| ->getBitWidth(); |
| auto TS = |
| SE.getConstant(APInt(width, Len->getValue().getLimitedValue())); |
| StoreEnd = SE.getAddExpr(StoreBegin, TS); |
| } |
| } |
| } |
| if (auto MS = dyn_cast<MemTransferInst>(maybeWriter)) { |
| StoreBegin = SE.getSCEV(MS->getArgOperand(0)); |
| if (StoreBegin != SE.getCouldNotCompute()) { |
| if (auto Len = dyn_cast<ConstantInt>(MS->getArgOperand(2))) { |
| auto &DL = MS->getModule()->getDataLayout(); |
| auto width = cast<IntegerType>(DL.getIndexType(StoreBegin->getType())) |
| ->getBitWidth(); |
| auto TS = |
| SE.getConstant(APInt(width, Len->getValue().getLimitedValue())); |
| StoreEnd = SE.getAddExpr(StoreBegin, TS); |
| } |
| } |
| } |
| if (auto MS = dyn_cast<MemTransferInst>(maybeReader)) { |
| LoadBegin = SE.getSCEV(MS->getArgOperand(1)); |
| if (LoadBegin != SE.getCouldNotCompute()) { |
| if (auto Len = dyn_cast<ConstantInt>(MS->getArgOperand(2))) { |
| auto &DL = MS->getModule()->getDataLayout(); |
| auto width = cast<IntegerType>(DL.getIndexType(LoadBegin->getType())) |
| ->getBitWidth(); |
| auto TS = |
| SE.getConstant(APInt(width, Len->getValue().getLimitedValue())); |
| LoadEnd = SE.getAddExpr(LoadBegin, TS); |
| } |
| } |
| } |
| |
| if (!overwritesToMemoryReadByLoop(SE, LI, DT, maybeReader, LoadBegin, LoadEnd, |
| maybeWriter, StoreBegin, StoreEnd, scope)) |
| return false; |
| |
| return true; |
| } |
| |
| /// Return whether maybeReader can read from memory written to by maybeWriter |
| bool writesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, |
| llvm::Instruction *maybeReader, |
| llvm::Instruction *maybeWriter) { |
| assert(maybeReader->getParent()->getParent() == |
| maybeWriter->getParent()->getParent()); |
| using namespace llvm; |
| if (isa<StoreInst>(maybeReader)) |
| return false; |
| if (auto call = dyn_cast<CallInst>(maybeWriter)) { |
| StringRef funcName = getFuncNameFromCall(call); |
| |
| if (isDebugFunction(call->getCalledFunction())) |
| return false; |
| |
| if (isCertainPrint(funcName) || isAllocationFunction(funcName, TLI) || |
| isDeallocationFunction(funcName, TLI)) { |
| return false; |
| } |
| |
| if (isMemFreeLibMFunction(funcName)) { |
| return false; |
| } |
| if (funcName == "jl_array_copy" || funcName == "ijl_array_copy") |
| return false; |
| |
| // Isend only writes to inaccessible mem only |
| if (funcName == "MPI_Send" || funcName == "PMPI_Send") { |
| return false; |
| } |
| // Wait only overwrites memory in the status and request. |
| if (funcName == "MPI_Wait" || funcName == "PMPI_Wait" || |
| funcName == "MPI_Waitall" || funcName == "PMPI_Waitall") { |
| #if LLVM_VERSION_MAJOR > 11 |
| auto loc = LocationSize::afterPointer(); |
| #else |
| auto loc = MemoryLocation::UnknownSize; |
| #endif |
| size_t off = (funcName == "MPI_Wait" || funcName == "PMPI_Wait") ? 0 : 1; |
| // No alias with status |
| if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(off + 1), |
| loc))) { |
| // No alias with request |
| if (!isRefSet(AA.getModRefInfo(maybeReader, |
| call->getArgOperand(off + 0), loc))) |
| return false; |
| auto R = parseTBAA(*maybeReader, maybeReader->getParent() |
| ->getParent() |
| ->getParent() |
| ->getDataLayout())[{-1}]; |
| // Could still conflict with the mpi_request unless a non pointer |
| // type. |
| if (R != BaseType::Unknown && R != BaseType::Anything && |
| R != BaseType::Pointer) |
| return false; |
| } |
| } |
| // Isend only writes to inaccessible mem and request. |
| if (funcName == "MPI_Isend" || funcName == "PMPI_Isend") { |
| auto R = parseTBAA(*maybeReader, maybeReader->getParent() |
| ->getParent() |
| ->getParent() |
| ->getDataLayout())[{-1}]; |
| // Could still conflict with the mpi_request, unless either |
| // synchronous, or a non pointer type. |
| if (R != BaseType::Unknown && R != BaseType::Anything && |
| R != BaseType::Pointer) |
| return false; |
| #if LLVM_VERSION_MAJOR > 11 |
| if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(6), |
| LocationSize::afterPointer()))) |
| return false; |
| #else |
| if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(6), |
| MemoryLocation::UnknownSize))) |
| return false; |
| #endif |
| return false; |
| } |
| if (funcName == "MPI_Irecv" || funcName == "PMPI_Irecv" || |
| funcName == "MPI_Recv" || funcName == "PMPI_Recv") { |
| ConcreteType type(BaseType::Unknown); |
| if (Constant *C = dyn_cast<Constant>(call->getArgOperand(2))) { |
| while (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) { |
| C = CE->getOperand(0); |
| } |
| if (auto GV = dyn_cast<GlobalVariable>(C)) { |
| if (GV->getName() == "ompi_mpi_double") { |
| type = ConcreteType(Type::getDoubleTy(C->getContext())); |
| } else if (GV->getName() == "ompi_mpi_float") { |
| type = ConcreteType(Type::getFloatTy(C->getContext())); |
| } |
| } |
| } |
| if (type.isKnown()) { |
| auto R = parseTBAA(*maybeReader, maybeReader->getParent() |
| ->getParent() |
| ->getParent() |
| ->getDataLayout())[{-1}]; |
| if (R.isKnown() && type != R) { |
| // Could still conflict with the mpi_request, unless either |
| // synchronous, or a non pointer type. |
| if (funcName == "MPI_Recv" || funcName == "PMPI_Recv" || |
| (R != BaseType::Anything && R != BaseType::Pointer)) |
| return false; |
| #if LLVM_VERSION_MAJOR > 11 |
| if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(6), |
| LocationSize::afterPointer()))) |
| return false; |
| #else |
| if (!isRefSet(AA.getModRefInfo(maybeReader, call->getArgOperand(6), |
| MemoryLocation::UnknownSize))) |
| return false; |
| #endif |
| } |
| } |
| } |
| if (auto II = dyn_cast<IntrinsicInst>(call)) { |
| if (II->getIntrinsicID() == Intrinsic::stacksave) |
| return false; |
| if (II->getIntrinsicID() == Intrinsic::stackrestore) |
| return false; |
| if (II->getIntrinsicID() == Intrinsic::trap) |
| return false; |
| #if LLVM_VERSION_MAJOR >= 13 |
| if (II->getIntrinsicID() == Intrinsic::experimental_noalias_scope_decl) |
| return false; |
| #endif |
| } |
| |
| if (auto iasm = dyn_cast<InlineAsm>(call->getCalledOperand())) { |
| if (StringRef(iasm->getAsmString()).contains("exit")) |
| return false; |
| } |
| } |
| if (auto call = dyn_cast<CallInst>(maybeReader)) { |
| StringRef funcName = getFuncNameFromCall(call); |
| |
| if (isDebugFunction(call->getCalledFunction())) |
| return false; |
| |
| if (isAllocationFunction(funcName, TLI) || |
| isDeallocationFunction(funcName, TLI)) { |
| return false; |
| } |
| |
| if (isMemFreeLibMFunction(funcName)) { |
| return false; |
| } |
| |
| if (auto II = dyn_cast<IntrinsicInst>(call)) { |
| if (II->getIntrinsicID() == Intrinsic::stacksave) |
| return false; |
| if (II->getIntrinsicID() == Intrinsic::stackrestore) |
| return false; |
| if (II->getIntrinsicID() == Intrinsic::trap) |
| return false; |
| #if LLVM_VERSION_MAJOR >= 13 |
| if (II->getIntrinsicID() == Intrinsic::experimental_noalias_scope_decl) |
| return false; |
| #endif |
| } |
| } |
| if (auto call = dyn_cast<InvokeInst>(maybeWriter)) { |
| StringRef funcName = getFuncNameFromCall(call); |
| |
| if (isDebugFunction(call->getCalledFunction())) |
| return false; |
| |
| if (isAllocationFunction(funcName, TLI) || |
| isDeallocationFunction(funcName, TLI)) { |
| return false; |
| } |
| |
| if (isMemFreeLibMFunction(funcName)) { |
| return false; |
| } |
| if (funcName == "jl_array_copy" || funcName == "ijl_array_copy") |
| return false; |
| |
| if (auto iasm = dyn_cast<InlineAsm>(call->getCalledOperand())) { |
| if (StringRef(iasm->getAsmString()).contains("exit")) |
| return false; |
| } |
| } |
| if (auto call = dyn_cast<InvokeInst>(maybeReader)) { |
| StringRef funcName = getFuncNameFromCall(call); |
| |
| if (isDebugFunction(call->getCalledFunction())) |
| return false; |
| |
| if (isAllocationFunction(funcName, TLI) || |
| isDeallocationFunction(funcName, TLI)) { |
| return false; |
| } |
| |
| if (isMemFreeLibMFunction(funcName)) { |
| return false; |
| } |
| } |
| assert(maybeWriter->mayWriteToMemory()); |
| assert(maybeReader->mayReadFromMemory()); |
| |
| if (auto li = dyn_cast<LoadInst>(maybeReader)) { |
| return isModSet(AA.getModRefInfo(maybeWriter, MemoryLocation::get(li))); |
| } |
| if (auto rmw = dyn_cast<AtomicRMWInst>(maybeReader)) { |
| return isModSet(AA.getModRefInfo(maybeWriter, MemoryLocation::get(rmw))); |
| } |
| if (auto xch = dyn_cast<AtomicCmpXchgInst>(maybeReader)) { |
| return isModSet(AA.getModRefInfo(maybeWriter, MemoryLocation::get(xch))); |
| } |
| if (auto mti = dyn_cast<MemTransferInst>(maybeReader)) { |
| return isModSet( |
| AA.getModRefInfo(maybeWriter, MemoryLocation::getForSource(mti))); |
| } |
| |
| if (auto si = dyn_cast<StoreInst>(maybeWriter)) { |
| return isRefSet(AA.getModRefInfo(maybeReader, MemoryLocation::get(si))); |
| } |
| if (auto rmw = dyn_cast<AtomicRMWInst>(maybeWriter)) { |
| return isRefSet(AA.getModRefInfo(maybeReader, MemoryLocation::get(rmw))); |
| } |
| if (auto xch = dyn_cast<AtomicCmpXchgInst>(maybeWriter)) { |
| return isRefSet(AA.getModRefInfo(maybeReader, MemoryLocation::get(xch))); |
| } |
| if (auto mti = dyn_cast<MemIntrinsic>(maybeWriter)) { |
| return isRefSet( |
| AA.getModRefInfo(maybeReader, MemoryLocation::getForDest(mti))); |
| } |
| |
| if (auto cb = dyn_cast<CallInst>(maybeReader)) { |
| return isModOrRefSet(AA.getModRefInfo(maybeWriter, cb)); |
| } |
| if (auto cb = dyn_cast<InvokeInst>(maybeReader)) { |
| return isModOrRefSet(AA.getModRefInfo(maybeWriter, cb)); |
| } |
| llvm::errs() << " maybeReader: " << *maybeReader |
| << " maybeWriter: " << *maybeWriter << "\n"; |
| llvm_unreachable("unknown inst2"); |
| } |
| |
| Function *GetFunctionFromValue(Value *fn, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, llvm::Value **lastVal) { |
| while (!isa<Function>(fn)) { |
| if (auto ci = dyn_cast<CastInst>(fn)) { |
| fn = ci->getOperand(0); |
| continue; |
| } |
| if (auto ci = dyn_cast<ConstantExpr>(fn)) { |
| if (ci->isCast()) { |
| fn = ci->getOperand(0); |
| continue; |
| } |
| } |
| if (auto ci = dyn_cast<BlockAddress>(fn)) { |
| fn = ci->getFunction(); |
| continue; |
| } |
| if (auto *GA = dyn_cast<GlobalAlias>(fn)) { |
| fn = GA->getAliasee(); |
| continue; |
| } |
| if (auto *Call = dyn_cast<CallInst>(fn)) { |
| if (auto F = Call->getCalledFunction()) { |
| SmallPtrSet<Value *, 1> ret; |
| for (auto &BB : *F) { |
| if (auto RI = dyn_cast<ReturnInst>(BB.getTerminator())) { |
| ret.insert(RI->getReturnValue()); |
| } |
| } |
| if (ret.size() == 1) { |
| auto val = *ret.begin(); |
| if (isa<Constant>(val)) { |
| fn = val; |
| continue; |
| } |
| if (auto arg = dyn_cast<Argument>(val)) { |
| fn = Call->getArgOperand(arg->getArgNo()); |
| continue; |
| } |
| } |
| } |
| } |
| if (auto *Call = dyn_cast<InvokeInst>(fn)) { |
| if (auto F = Call->getCalledFunction()) { |
| SmallPtrSet<Value *, 1> ret; |
| for (auto &BB : *F) { |
| if (auto RI = dyn_cast<ReturnInst>(BB.getTerminator())) { |
| ret.insert(RI->getReturnValue()); |
| } |
| } |
| if (ret.size() == 1) { |
| auto val = *ret.begin(); |
| if (isa<Constant>(val)) { |
| fn = val; |
| continue; |
| } |
| if (auto arg = dyn_cast<Argument>(val)) { |
| fn = Call->getArgOperand(arg->getArgNo()); |
| continue; |
| } |
| } |
| } |
| } |
| if (auto LI = dyn_cast<LoadInst>(fn)) { |
| auto obj = getBaseObject(LI->getPointerOperand()); |
| if (isa<AllocaInst>(obj)) { |
| std::set<std::pair<Instruction *, Value *>> done; |
| SmallVector<std::pair<Instruction *, Value *>, 1> todo; |
| Value *stored = nullptr; |
| bool legal = true; |
| for (auto U : obj->users()) { |
| if (auto I = dyn_cast<Instruction>(U)) |
| todo.push_back(std::make_pair(I, obj)); |
| else { |
| legal = false; |
| break; |
| } |
| } |
| while (legal && todo.size()) { |
| auto tup = todo.pop_back_val(); |
| if (done.count(tup)) |
| continue; |
| done.insert(tup); |
| auto cur = tup.first; |
| auto prev = tup.second; |
| if (auto SI = dyn_cast<StoreInst>(cur)) |
| if (SI->getPointerOperand() == prev) { |
| if (stored == SI->getValueOperand()) |
| continue; |
| else if (stored == nullptr) { |
| stored = SI->getValueOperand(); |
| continue; |
| } else { |
| legal = false; |
| break; |
| } |
| } |
| |
| if (isPointerArithmeticInst(cur, /*includephi*/ true)) { |
| for (auto U : cur->users()) { |
| if (auto I = dyn_cast<Instruction>(U)) |
| todo.push_back(std::make_pair(I, cur)); |
| else { |
| legal = false; |
| break; |
| } |
| } |
| continue; |
| } |
| |
| if (isa<LoadInst>(cur)) |
| continue; |
| |
| if (cur->getType()->isVoidTy()) { |
| if (!cur->mayWriteToMemory()) |
| continue; |
| if (!writesToMemoryReadBy(AA, TLI, LI, cur)) |
| continue; |
| } |
| |
| legal = false; |
| break; |
| } |
| |
| if (legal && stored) { |
| fn = stored; |
| continue; |
| } |
| } |
| } |
| break; |
| } |
| |
| if (lastVal) *lastVal = fn; |
| return dyn_cast<Function>(fn); |
| } |
| |
| #if LLVM_VERSION_MAJOR >= 16 |
| std::optional<BlasInfo> extractBLAS(llvm::StringRef in) |
| #else |
| llvm::Optional<BlasInfo> extractBLAS(llvm::StringRef in) |
| #endif |
| { |
| llvm::Twine floatType[] = {"s", "d"}; // c, z |
| llvm::Twine extractable[] = {"dot", "scal", "axpy", "gemv", "gemm", "spmv"}; |
| llvm::Twine prefixes[] = {"" /*Fortran*/, "cblas_", "cublas_"}; |
| llvm::Twine suffixes[] = {"", "_", "64_", "_64_"}; |
| for (auto t : floatType) { |
| for (auto f : extractable) { |
| for (auto p : prefixes) { |
| for (auto s : suffixes) { |
| if (in == (p + t + f + s).str()) { |
| return BlasInfo{ |
| t.getSingleStringRef(), |
| p.getSingleStringRef(), |
| s.getSingleStringRef(), |
| f.getSingleStringRef(), |
| }; |
| } |
| } |
| } |
| } |
| } |
| return {}; |
| } |
| |
| llvm::Constant *getUndefinedValueForType(llvm::Type *T, bool forceZero) { |
| if (EnzymeUndefinedValueForType) |
| return cast<Constant>( |
| unwrap(EnzymeUndefinedValueForType(wrap(T), forceZero))); |
| else if (EnzymeZeroCache || forceZero) |
| return Constant::getNullValue(T); |
| else |
| return UndefValue::get(T); |
| } |
| |
| llvm::Value *SanitizeDerivatives(llvm::Value *val, llvm::Value *toset, |
| llvm::IRBuilder<> &BuilderM, |
| llvm::Value *mask) { |
| if (EnzymeSanitizeDerivatives) |
| return unwrap(EnzymeSanitizeDerivatives(wrap(val), wrap(toset), |
| wrap(&BuilderM), wrap(mask))); |
| return toset; |
| } |
| |
| llvm::FastMathFlags getFast() { |
| llvm::FastMathFlags f; |
| if (EnzymeFastMath) |
| f.set(); |
| return f; |
| } |
| |
| void addValueToCache(llvm::Value *arg, bool cache_arg, llvm::Type *ty, |
| llvm::SmallVectorImpl<llvm::Value *> &cacheValues, |
| llvm::IRBuilder<> &BuilderZ, const Twine &name) { |
| if (!arg->getType()->isPointerTy()) { |
| assert(arg->getType() == ty); |
| cacheValues.push_back(arg); |
| return; |
| } |
| if (!cache_arg) |
| return; |
| #if LLVM_VERSION_MAJOR < 18 |
| auto PT = cast<PointerType>(arg->getType()); |
| #if LLVM_VERSION_MAJOR <= 14 |
| if (PT->getElementType() != ty) |
| arg = BuilderZ.CreatePointerCast( |
| arg, PointerType::get(ty, PT->getAddressSpace()), "pcld." + name); |
| #else |
| auto PT2 = PointerType::get(ty, PT->getAddressSpace()); |
| if (!PT->isOpaqueOrPointeeTypeMatches(PT2)) |
| arg = BuilderZ.CreatePointerCast( |
| arg, PointerType::get(ty, PT->getAddressSpace()), "pcld." + name); |
| #endif |
| #endif |
| arg = BuilderZ.CreateLoad(ty, arg, "avld." + name); |
| cacheValues.push_back(arg); |
| } |
| |
| // julia_decl null means not julia decl, otherwise it is the integer type needed |
| // to cast to |
| llvm::Value *to_blas_callconv(IRBuilder<> &B, llvm::Value *V, bool byRef, |
| IntegerType *julia_decl, |
| IRBuilder<> &entryBuilder, |
| llvm::Twine const &name) { |
| if (!byRef) |
| return V; |
| |
| Value *allocV = |
| entryBuilder.CreateAlloca(V->getType(), nullptr, "byref." + name); |
| B.CreateStore(V, allocV); |
| |
| if (julia_decl) |
| allocV = B.CreatePointerCast(allocV, Type::getInt8PtrTy(V->getContext()), |
| "intcast." + name); |
| |
| return allocV; |
| } |
| llvm::Value *to_blas_fp_callconv(IRBuilder<> &B, llvm::Value *V, bool byRef, |
| Type *fpTy, IRBuilder<> &entryBuilder, |
| llvm::Twine const &name) { |
| if (!byRef) |
| return V; |
| |
| Value *allocV = |
| entryBuilder.CreateAlloca(V->getType(), nullptr, "byref." + name); |
| B.CreateStore(V, allocV); |
| |
| if (fpTy) |
| allocV = B.CreatePointerCast(allocV, fpTy, "fpcast." + name); |
| |
| return allocV; |
| } |
| |
| llvm::Value *select_vec_dims(IRBuilder<> &B, llvm::Value *trans, |
| llvm::Value *dim1, llvm::Value *dim2, bool byRef) { |
| Value *width = B.CreateSelect(is_normal(B, trans, byRef), dim1, dim2); |
| |
| return width; |
| } |
| |
| Value *is_uper(IRBuilder<> &B, Value *trans, bool byRef) { |
| IntegerType *charTy; |
| if (byRef) { |
| // can't inspect opaque ptr, so assume 8 (Julia) |
| charTy = IntegerType::get(trans->getContext(), 8); |
| trans = B.CreateLoad(charTy, trans, "loaded.trans"); |
| } else { |
| // we can inspect scalars |
| unsigned int len = trans->getType()->getScalarSizeInBits(); |
| charTy = IntegerType::get(trans->getContext(), len); |
| } |
| |
| Value *trueVal = ConstantInt::getTrue(trans->getContext()); |
| |
| Value *isUper = |
| B.CreateOr(B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'u')), |
| B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'U'))); |
| return isUper; |
| } |
| |
| llvm::Value *is_normal(IRBuilder<> &B, llvm::Value *trans, bool byRef) { |
| IntegerType *charTy; |
| if (byRef) { |
| // can't inspect opaque ptr, so assume 8 (Julia) |
| charTy = IntegerType::get(trans->getContext(), 8); |
| trans = B.CreateLoad(charTy, trans, "loaded.trans"); |
| } else { |
| // we can inspect scalars |
| unsigned int len = trans->getType()->getScalarSizeInBits(); |
| charTy = IntegerType::get(trans->getContext(), len); |
| } |
| |
| Value *trueVal = ConstantInt::getTrue(trans->getContext()); |
| |
| Value *isNormal = |
| B.CreateOr(B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'n')), |
| B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'N'))); |
| return isNormal; |
| } |
| |
| // Ok. Here we are. |
| // netlib declares trans args as something out of |
| // N,n,T,t,C,c, represented as 8 bit chars. |
| // However, if we ask openBlas c ABI, |
| // it is one of the following 32 bit integers values: |
| // enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113}; |
| llvm::Value *transpose(IRBuilder<> &B, llvm::Value *V) { |
| llvm::Type *T = V->getType(); |
| Value *out; |
| if (T->isIntegerTy(8)) { |
| out = B.CreateSelect( |
| B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'T')), |
| ConstantInt::get(V->getType(), 'N'), |
| B.CreateSelect( |
| B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 't')), |
| ConstantInt::get(V->getType(), 'n'), |
| B.CreateSelect( |
| B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'N')), |
| ConstantInt::get(V->getType(), 'T'), |
| B.CreateSelect( |
| B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'n')), |
| ConstantInt::get(V->getType(), 't'), |
| ConstantInt::get(V->getType(), 0))))); |
| } else if (T->isIntegerTy(32)) { |
| out = B.CreateSelect( |
| B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 111)), |
| ConstantInt::get(V->getType(), 112), |
| B.CreateSelect(B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 112)), |
| ConstantInt::get(V->getType(), 111), |
| ConstantInt::get(V->getType(), 0))); |
| } else { |
| std::string s; |
| llvm::raw_string_ostream ss(s); |
| ss << "cannot handle unknown trans blas value\n" << V; |
| if (CustomErrorHandler) { |
| CustomErrorHandler(ss.str().c_str(), nullptr, ErrorType::NoDerivative, |
| nullptr, nullptr, nullptr); |
| } else { |
| EmitFailure("unknown trans blas value", nullptr, nullptr, ss.str()); |
| } |
| } |
| return out; |
| } |
| |
| // Implement the following logic to get the width of a matrix |
| // if (cache_A) { |
| // ld_A = (arg_transa == 'N') ? arg_m : arg_k; |
| // } else { |
| // ld_A = arg_lda; |
| // } |
| llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B, llvm::Value *trans, |
| llvm::Value *arg_ld, llvm::Value *dim1, |
| llvm::Value *dim2, bool cacheMat, |
| bool byRef) { |
| if (!cacheMat) |
| return arg_ld; |
| |
| Value *width = B.CreateSelect(is_normal(B, trans, byRef), dim1, dim2); |
| |
| return width; |
| } |
| |
| llvm::Value *transpose(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef, |
| llvm::IntegerType *julia_decl, |
| llvm::IRBuilder<> &entryBuilder, |
| const llvm::Twine &name) { |
| |
| if (byRef) { |
| auto charType = IntegerType::get(V->getContext(), 8); |
| V = B.CreateLoad(charType, V, "ld." + name); |
| } |
| |
| V = transpose(B, V); |
| |
| return to_blas_callconv(B, V, byRef, julia_decl, entryBuilder, |
| "transpose." + name); |
| } |
| |
| llvm::Value *load_if_ref(llvm::IRBuilder<> &B, llvm::IntegerType *intType, |
| llvm::Value *V, bool byRef) { |
| if (!byRef) |
| return V; |
| |
| auto VP = B.CreatePointerCast( |
| V, PointerType::get(intType, |
| cast<PointerType>(V->getType())->getAddressSpace())); |
| return B.CreateLoad(intType, VP); |
| } |
| |
| llvm::Value *get_blas_row(llvm::IRBuilder<> &B, llvm::Value *trans, |
| llvm::Value *row, llvm::Value *col, bool byRef) { |
| |
| if (byRef) { |
| auto charType = IntegerType::get(trans->getContext(), 8); |
| trans = B.CreateLoad(charType, trans, "ld.row.trans"); |
| } |
| |
| return B.CreateSelect( |
| B.CreateOr( |
| B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')), |
| B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n'))), |
| row, col); |
| } |
| |
| // return how many Special pointers are in T (count > 0), |
| // and if there is anything else in T (all == false) |
| CountTrackedPointers::CountTrackedPointers(Type *T) { |
| if (isa<PointerType>(T)) { |
| if (isSpecialPtr(T)) { |
| count++; |
| if (T->getPointerAddressSpace() != AddressSpace::Tracked) |
| derived = true; |
| } |
| } else if (isa<StructType>(T) || isa<ArrayType>(T) || isa<VectorType>(T)) { |
| for (Type *ElT : T->subtypes()) { |
| auto sub = CountTrackedPointers(ElT); |
| count += sub.count; |
| all &= sub.all; |
| derived |= sub.derived; |
| } |
| if (isa<ArrayType>(T)) |
| count *= cast<ArrayType>(T)->getNumElements(); |
| else if (isa<VectorType>(T)) { |
| #if LLVM_VERSION_MAJOR >= 12 |
| count *= cast<VectorType>(T)->getElementCount().getKnownMinValue(); |
| #else |
| count *= cast<VectorType>(T)->getNumElements(); |
| #endif |
| } |
| } |
| if (count == 0) |
| all = false; |
| } |