| //===- GradientUtils.cpp - Helper class and utilities for AD ---------===// |
| // |
| // Enzyme Project |
| // |
| // Part of the Enzyme Project, under the Apache License v2.0 with LLVM |
| // Exceptions. See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| // If using this code in an academic setting, please cite the following: |
| // @incollection{enzymeNeurips, |
| // title = {Instead of Rewriting Foreign Code for Machine Learning, |
| // Automatically Synthesize Fast Gradients}, |
| // author = {Moses, William S. and Churavy, Valentin}, |
| // booktitle = {Advances in Neural Information Processing Systems 33}, |
| // year = {2020}, |
| // note = {To appear in}, |
| // } |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file define two helper classes GradientUtils and subclass |
| // DiffeGradientUtils. These classes contain utilities for managing the cache, |
| // recomputing statements, and in the case of DiffeGradientUtils, managing |
| // adjoint values and shadow pointers. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include <algorithm> |
| |
| #include <llvm/Config/llvm-config.h> |
| |
| #include "DifferentialUseAnalysis.h" |
| #include "EnzymeLogic.h" |
| #include "FunctionUtils.h" |
| #include "GradientUtils.h" |
| #include "LibraryFuncs.h" |
| #include "TypeAnalysis/TBAA.h" |
| |
| #include "llvm/IR/GlobalValue.h" |
| |
| #include "llvm/IR/Constants.h" |
| |
| #include "llvm/Analysis/ValueTracking.h" |
| #include "llvm/IR/InstrTypes.h" |
| #include "llvm/Support/AMDGPUMetadata.h" |
| #include "llvm/Transforms/Utils/SimplifyIndVar.h" |
| |
| std::map<std::string, |
| std::function<llvm::Value *(IRBuilder<> &, CallInst *, |
| ArrayRef<Value *>, GradientUtils *)>> |
| shadowHandlers; |
| std::map<std::string, std::function<llvm::CallInst *(IRBuilder<> &, Value *)>> |
| shadowErasers; |
| |
| std::map< |
| std::string, |
| std::pair<std::function<void(IRBuilder<> &, CallInst *, GradientUtils &, |
| Value *&, Value *&, Value *&)>, |
| std::function<void(IRBuilder<> &, CallInst *, |
| DiffeGradientUtils &, Value *)>>> |
| customCallHandlers; |
| |
| std::map<std::string, std::function<void(IRBuilder<> &, CallInst *, |
| GradientUtils &, Value *&, Value *&)>> |
| customFwdCallHandlers; |
| |
| extern "C" { |
| llvm::cl::opt<bool> |
| EnzymeNewCache("enzyme-new-cache", cl::init(true), cl::Hidden, |
| cl::desc("Use new cache decision algorithm")); |
| |
| llvm::cl::opt<bool> EnzymeMinCutCache("enzyme-mincut-cache", cl::init(true), |
| cl::Hidden, |
| cl::desc("Use Enzyme Mincut algorithm")); |
| |
| llvm::cl::opt<bool> EnzymeLoopInvariantCache( |
| "enzyme-loop-invariant-cache", cl::init(true), cl::Hidden, |
| cl::desc("Attempt to hoist cache outside of loop")); |
| |
| llvm::cl::opt<bool> EnzymeInactiveDynamic( |
| "enzyme-inactive-dynamic", cl::init(true), cl::Hidden, |
| cl::desc("Force wholy inactive dynamic loops to have 0 iter reverse pass")); |
| |
| llvm::cl::opt<bool> |
| EnzymeRuntimeActivityCheck("enzyme-runtime-activity", cl::init(false), |
| cl::Hidden, |
| cl::desc("Perform runtime activity checks")); |
| |
| llvm::cl::opt<bool> |
| EnzymeSharedForward("enzyme-shared-forward", cl::init(false), cl::Hidden, |
| cl::desc("Forward Shared Memory from definitions")); |
| |
| llvm::cl::opt<bool> |
| EnzymeRegisterReduce("enzyme-register-reduce", cl::init(false), cl::Hidden, |
| cl::desc("Reduce the amount of register reduce")); |
| llvm::cl::opt<bool> |
| EnzymeSpeculatePHIs("enzyme-speculate-phis", cl::init(false), cl::Hidden, |
| cl::desc("Speculatively execute phi computations")); |
| llvm::cl::opt<bool> EnzymeFreeInternalAllocations( |
| "enzyme-free-internal-allocations", cl::init(true), cl::Hidden, |
| cl::desc("Always free internal allocations (disable if allocation needs " |
| "access outside)")); |
| |
| llvm::cl::opt<bool> |
| EnzymeRematerialize("enzyme-rematerialize", cl::init(true), cl::Hidden, |
| cl::desc("Rematerialize allocations/shadows in the " |
| "reverse rather than caching")); |
| |
| llvm::cl::opt<bool> |
| EnzymeVectorSplitPhi("enzyme-vector-split-phi", cl::init(true), cl::Hidden, |
| cl::desc("Split phis according to vector size")); |
| |
| llvm::cl::opt<bool> |
| EnzymePrintDiffUse("enzyme-print-diffuse", cl::init(false), cl::Hidden, |
| cl::desc("Print differential use analysis")); |
| } |
| |
| SmallVector<unsigned int, 9> MD_ToCopy = { |
| LLVMContext::MD_dbg, |
| LLVMContext::MD_tbaa, |
| LLVMContext::MD_tbaa_struct, |
| LLVMContext::MD_range, |
| LLVMContext::MD_nonnull, |
| LLVMContext::MD_dereferenceable, |
| LLVMContext::MD_dereferenceable_or_null}; |
| |
| Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM, |
| const ValueToValueMapTy &available, |
| UnwrapMode unwrapMode, BasicBlock *scope, |
| bool permitCache) { |
| assert(val); |
| assert(val->getName() != "<badref>"); |
| assert(val->getType()); |
| |
| for (auto pair : available) { |
| assert(pair.first); |
| assert(pair.first->getType()); |
| if (pair.second) { |
| assert(pair.second->getType()); |
| assert(pair.first->getType() == pair.second->getType()); |
| } |
| } |
| |
| if (isa<LoadInst>(val) && |
| cast<LoadInst>(val)->getMetadata("enzyme_mustcache")) { |
| return val; |
| } |
| |
| if (available.count(val)) { |
| auto avail = available.lookup(val); |
| assert(avail->getType()); |
| if (avail->getType() != val->getType()) { |
| llvm::errs() << "val: " << *val << "\n"; |
| llvm::errs() << "available[val]: " << *available.lookup(val) << "\n"; |
| } |
| assert(available.lookup(val)->getType() == val->getType()); |
| return available.lookup(val); |
| } |
| |
| if (auto inst = dyn_cast<Instruction>(val)) { |
| if (inversionAllocs && inst->getParent() == inversionAllocs) { |
| return val; |
| } |
| // if (inst->getParent() == &newFunc->getEntryBlock()) { |
| // return inst; |
| //} |
| if (inst->getParent()->getParent() == newFunc && |
| isOriginalBlock(*BuilderM.GetInsertBlock())) { |
| if (BuilderM.GetInsertBlock()->size() && |
| BuilderM.GetInsertPoint() != BuilderM.GetInsertBlock()->end()) { |
| if (DT.dominates(inst, &*BuilderM.GetInsertPoint())) { |
| // llvm::errs() << "allowed " << *inst << "from domination\n"; |
| assert(inst->getType() == val->getType()); |
| return inst; |
| } |
| } else { |
| if (DT.dominates(inst, &*BuilderM.GetInsertPoint())) { |
| // llvm::errs() << "allowed " << *inst << "from block domination\n"; |
| assert(inst->getType() == val->getType()); |
| return inst; |
| } |
| } |
| } |
| assert(!TapesToPreventRecomputation.count(inst)); |
| } |
| |
| std::pair<Value *, BasicBlock *> idx = std::make_pair(val, scope); |
| // assert(!val->getName().startswith("$tapeload")); |
| if (permitCache) { |
| auto found0 = unwrap_cache.find(BuilderM.GetInsertBlock()); |
| if (found0 != unwrap_cache.end()) { |
| auto found1 = found0->second.find(idx.first); |
| if (found1 != found0->second.end()) { |
| auto found2 = found1->second.find(idx.second); |
| if (found2 != found1->second.end()) { |
| |
| auto cachedValue = found2->second; |
| if (cachedValue == nullptr) { |
| found1->second.erase(idx.second); |
| if (found1->second.size() == 0) { |
| found0->second.erase(idx.first); |
| } |
| } else { |
| if (cachedValue->getType() != val->getType()) { |
| llvm::errs() << "newFunc: " << *newFunc << "\n"; |
| llvm::errs() << "val: " << *val << "\n"; |
| llvm::errs() << "unwrap_cache[cidx]: " << *cachedValue << "\n"; |
| } |
| assert(cachedValue->getType() == val->getType()); |
| return cachedValue; |
| } |
| } |
| } |
| } |
| } |
| |
| if (this->mode == DerivativeMode::ReverseModeGradient || |
| this->mode == DerivativeMode::ForwardModeSplit || |
| this->mode == DerivativeMode::ReverseModeCombined) |
| if (auto inst = dyn_cast<Instruction>(val)) { |
| if (inst->getParent()->getParent() == newFunc) { |
| if (unwrapMode == UnwrapMode::LegalFullUnwrap && |
| this->mode != DerivativeMode::ReverseModeCombined) { |
| // TODO this isOriginal is a bottleneck, the new mapping of |
| // knownRecompute should be precomputed and maintained to lookup |
| // instead |
| Instruction *orig = isOriginal(inst); |
| // If a given value has been chosen to be cached, do not compute the |
| // operands to unwrap it, instead simply emit a placeholder to be |
| // replaced by the cache load later. This placeholder should only be |
| // returned when the original value would be recomputed (e.g. this |
| // function would not return null). Since this case assumes everything |
| // can be recomputed, simply return the placeholder. |
| if (orig && knownRecomputeHeuristic.find(orig) != |
| knownRecomputeHeuristic.end()) { |
| if (!knownRecomputeHeuristic[orig]) { |
| assert(inst->getParent()->getParent() == newFunc); |
| auto placeholder = BuilderM.CreatePHI( |
| val->getType(), 0, val->getName() + "_krcLFUreplacement"); |
| unwrappedLoads[placeholder] = inst; |
| SmallVector<Metadata *, 1> avail; |
| for (auto pair : available) |
| if (pair.second) |
| avail.push_back(MDNode::get( |
| placeholder->getContext(), |
| {ValueAsMetadata::get(const_cast<Value *>(pair.first)), |
| ValueAsMetadata::get(pair.second)})); |
| placeholder->setMetadata( |
| "enzyme_available", |
| MDNode::get(placeholder->getContext(), avail)); |
| if (!permitCache) |
| return placeholder; |
| return unwrap_cache[BuilderM.GetInsertBlock()][idx.first] |
| [idx.second] = placeholder; |
| } |
| } |
| } else if (unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup) { |
| // TODO this isOriginal is a bottleneck, the new mapping of |
| // knownRecompute should be precomputed and maintained to lookup |
| // instead |
| Instruction *orig = isOriginal(inst); |
| // If a given value has been chosen to be cached, do not compute the |
| // operands to unwrap it, instead simply emit a placeholder to be |
| // replaced by the cache load later. This placeholder should only be |
| // returned when the original value would be recomputed (e.g. this |
| // function would not return null). See note below about the condition |
| // as applied to this case. |
| if (orig && knownRecomputeHeuristic.find(orig) != |
| knownRecomputeHeuristic.end()) { |
| if (!knownRecomputeHeuristic[orig]) { |
| if (mode == DerivativeMode::ReverseModeCombined) { |
| // Don't unnecessarily cache a value if the caching |
| // heuristic says we should preserve this precise (and not |
| // an lcssa wrapped) value |
| if (!isOriginalBlock(*BuilderM.GetInsertBlock())) { |
| Value *nval = inst; |
| if (scope) |
| nval = fixLCSSA(inst, scope); |
| if (nval == inst) |
| goto endCheck; |
| } |
| } else { |
| // Note that this logic (original load must dominate or |
| // alternatively be in the reverse block) is only valid iff when |
| // applicable (here if in split mode), an uncacheable load |
| // cannot be hoisted outside of a loop to be used as a loop |
| // limit. This optimization is currently done in the combined |
| // mode (e.g. if a load isn't modified between a prior insertion |
| // point and the actual load, it is legal to recompute). |
| if (!isOriginalBlock(*BuilderM.GetInsertBlock()) || |
| DT.dominates(inst, &*BuilderM.GetInsertPoint())) { |
| assert(inst->getParent()->getParent() == newFunc); |
| auto placeholder = BuilderM.CreatePHI( |
| val->getType(), 0, |
| val->getName() + "_krcAFUWLreplacement"); |
| unwrappedLoads[placeholder] = inst; |
| SmallVector<Metadata *, 1> avail; |
| for (auto pair : available) |
| if (pair.second) |
| avail.push_back( |
| MDNode::get(placeholder->getContext(), |
| {ValueAsMetadata::get( |
| const_cast<Value *>(pair.first)), |
| ValueAsMetadata::get(pair.second)})); |
| placeholder->setMetadata( |
| "enzyme_available", |
| MDNode::get(placeholder->getContext(), avail)); |
| if (!permitCache) |
| return placeholder; |
| return unwrap_cache[BuilderM.GetInsertBlock()][idx.first] |
| [idx.second] = placeholder; |
| } |
| } |
| } |
| } |
| } else if (unwrapMode != UnwrapMode::LegalFullUnwrapNoTapeReplace && |
| mode != DerivativeMode::ReverseModeCombined) { |
| // TODO this isOriginal is a bottleneck, the new mapping of |
| // knownRecompute should be precomputed and maintained to lookup |
| // instead |
| |
| // If a given value has been chosen to be cached, do not compute the |
| // operands to unwrap it if it is not legal to do so. This prevents |
| // the creation of unused versions of the instruction's operand, which |
| // may be assumed to never be used and thus cause an error when they |
| // are inadvertantly cached. |
| Value *orig = isOriginal(val); |
| if (orig && knownRecomputeHeuristic.find(orig) != |
| knownRecomputeHeuristic.end()) { |
| if (!knownRecomputeHeuristic[orig]) { |
| if (!legalRecompute(orig, available, &BuilderM)) |
| return nullptr; |
| |
| assert(isa<LoadInst>(orig) == isa<LoadInst>(val)); |
| } |
| } |
| } |
| } |
| } |
| |
| #define getOpFullest(Builder, vtmp, frominst, lookupInst, check) \ |
| ({ \ |
| Value *v = vtmp; \ |
| BasicBlock *origParent = frominst; \ |
| Value *___res; \ |
| if (unwrapMode == UnwrapMode::LegalFullUnwrap || \ |
| unwrapMode == UnwrapMode::LegalFullUnwrapNoTapeReplace || \ |
| unwrapMode == UnwrapMode::AttemptFullUnwrap || \ |
| unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup) { \ |
| if (v == val) \ |
| ___res = nullptr; \ |
| else \ |
| ___res = unwrapM(v, Builder, available, unwrapMode, origParent, \ |
| permitCache); \ |
| if (!___res && unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup) { \ |
| bool noLookup = false; \ |
| auto found = available.find(v); \ |
| if (found != available.end() && !found->second) \ |
| noLookup = true; \ |
| if (auto opinst = dyn_cast<Instruction>(v)) \ |
| if (isOriginalBlock(*Builder.GetInsertBlock())) { \ |
| if (!DT.dominates(opinst, &*Builder.GetInsertPoint())) \ |
| noLookup = true; \ |
| } \ |
| origParent = lookupInst; \ |
| llvm::errs() << " v: " << *v << "\n"; \ |
| llvm::errs() << " lookupInst: " << *lookupInst << "\n"; \ |
| if (BasicBlock *forwardBlock = origParent) \ |
| if (auto opinst = dyn_cast<Instruction>(v)) { \ |
| if (!isOriginalBlock(*forwardBlock)) { \ |
| forwardBlock = originalForReverseBlock(*forwardBlock); \ |
| } \ |
| llvm::errs() << " fwd: " << *forwardBlock << "\n"; \ |
| if (isPotentialLastLoopValue(opinst, forwardBlock, LI)) { \ |
| v = fixLCSSA(opinst, forwardBlock); \ |
| origParent = nullptr; \ |
| llvm::errs() << " last: " << *v << "\n";\ |
| } \ |
| } \ |
| if (!noLookup) \ |
| ___res = lookupM(v, Builder, available, v != val, origParent); \ |
| } \ |
| if (___res) \ |
| assert(___res->getType() == v->getType() && "uw"); \ |
| } else { \ |
| origParent = lookupInst; \ |
| if (BasicBlock *forwardBlock = origParent) \ |
| if (auto opinst = dyn_cast<Instruction>(v)) { \ |
| if (!isOriginalBlock(*forwardBlock)) { \ |
| forwardBlock = originalForReverseBlock(*forwardBlock); \ |
| } \ |
| if (isPotentialLastLoopValue(opinst, forwardBlock, LI)) { \ |
| v = fixLCSSA(opinst, forwardBlock); \ |
| origParent = nullptr; \ |
| } \ |
| } \ |
| assert(unwrapMode == UnwrapMode::AttemptSingleUnwrap); \ |
| auto found = available.find(v); \ |
| assert(found == available.end() || found->second); \ |
| ___res = lookupM(v, Builder, available, v != val, origParent); \ |
| if (___res && ___res->getType() != v->getType()) { \ |
| llvm::errs() << *newFunc << "\n"; \ |
| llvm::errs() << " v = " << *v << " res = " << *___res << "\n"; \ |
| } \ |
| if (___res) \ |
| assert(___res->getType() == v->getType() && "lu"); \ |
| } \ |
| ___res; \ |
| }) |
| #define getOpFull(Builder, vtmp, frominst) \ |
| ({\ |
| BasicBlock *parent = scope; \ |
| if (parent == nullptr) \ |
| if (auto originst = dyn_cast<Instruction>(val)) \ |
| parent = originst->getParent(); \ |
| getOpFullest(Builder, vtmp, frominst, parent, true);\ |
| }) |
| #define getOpUnchecked(vtmp) \ |
| ({ \ |
| BasicBlock *parent = scope; \ |
| getOpFullest(BuilderM, vtmp, parent, parent, false); \ |
| }) |
| #define getOp(vtmp) \ |
| ({ \ |
| BasicBlock *parent = scope; \ |
| if (parent == nullptr) \ |
| if (auto originst = dyn_cast<Instruction>(val)) \ |
| parent = originst->getParent(); \ |
| getOpFullest(BuilderM, vtmp, parent, parent, true); \ |
| }) |
| |
| if (isa<Argument>(val) || isa<Constant>(val)) { |
| return val; |
| #if LLVM_VERSION_MAJOR >= 10 |
| } else if (auto op = dyn_cast<FreezeInst>(val)) { |
| auto op0 = getOp(op->getOperand(0)); |
| if (op0 == nullptr) |
| goto endCheck; |
| auto toreturn = BuilderM.CreateFreeze(op0, op->getName() + "_unwrap"); |
| if (auto newi = dyn_cast<Instruction>(toreturn)) { |
| newi->copyIRFlags(op); |
| unwrappedLoads[newi] = val; |
| } |
| if (permitCache) |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; |
| assert(val->getType() == toreturn->getType()); |
| return toreturn; |
| #endif |
| } else if (auto op = dyn_cast<CastInst>(val)) { |
| auto op0 = getOp(op->getOperand(0)); |
| if (op0 == nullptr) |
| goto endCheck; |
| auto toreturn = BuilderM.CreateCast(op->getOpcode(), op0, op->getDestTy(), |
| op->getName() + "_unwrap"); |
| if (auto newi = dyn_cast<Instruction>(toreturn)) { |
| newi->copyIRFlags(op); |
| unwrappedLoads[newi] = val; |
| if (newi->getParent()->getParent() != op->getParent()->getParent()) |
| newi->setDebugLoc(nullptr); |
| } |
| if (permitCache) |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; |
| assert(val->getType() == toreturn->getType()); |
| return toreturn; |
| } else if (auto op = dyn_cast<ExtractValueInst>(val)) { |
| auto op0 = getOp(op->getAggregateOperand()); |
| if (op0 == nullptr) |
| goto endCheck; |
| auto toreturn = BuilderM.CreateExtractValue(op0, op->getIndices(), |
| op->getName() + "_unwrap"); |
| if (permitCache) |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; |
| if (auto newi = dyn_cast<Instruction>(toreturn)) { |
| newi->copyIRFlags(op); |
| unwrappedLoads[newi] = val; |
| if (newi->getParent()->getParent() != op->getParent()->getParent()) |
| newi->setDebugLoc(nullptr); |
| } |
| assert(val->getType() == toreturn->getType()); |
| return toreturn; |
| } else if (auto op = dyn_cast<InsertValueInst>(val)) { |
| // Unwrapped Aggregate, Indices, parent |
| SmallVector<std::tuple<Value *, ArrayRef<unsigned>, InsertValueInst *>, 1> |
| insertElements; |
| |
| Value *agg = op; |
| while (auto op1 = dyn_cast<InsertValueInst>(agg)) { |
| if (Value *orig = isOriginal(op1)) { |
| if (knownRecomputeHeuristic.count(orig)) { |
| if (!knownRecomputeHeuristic[orig]) { |
| break; |
| } |
| } |
| } |
| Value *valOp = op1->getInsertedValueOperand(); |
| valOp = getOp(valOp); |
| if (valOp == nullptr) |
| goto endCheck; |
| insertElements.push_back({valOp, op1->getIndices(), op1}); |
| agg = op1->getAggregateOperand(); |
| } |
| |
| Value *toreturn = getOp(agg); |
| if (toreturn == nullptr) |
| goto endCheck; |
| for (auto &&[valOp, idcs, parent] : reverse(insertElements)) { |
| toreturn = BuilderM.CreateInsertValue(toreturn, valOp, idcs, |
| parent->getName() + "_unwrap"); |
| |
| if (permitCache) |
| unwrap_cache[BuilderM.GetInsertBlock()][parent][idx.second] = toreturn; |
| if (auto newi = dyn_cast<Instruction>(toreturn)) { |
| newi->copyIRFlags(parent); |
| unwrappedLoads[newi] = val; |
| if (newi->getParent()->getParent() != parent->getParent()->getParent()) |
| newi->setDebugLoc(nullptr); |
| } |
| } |
| |
| assert(val->getType() == toreturn->getType()); |
| return toreturn; |
| } else if (auto op = dyn_cast<ExtractElementInst>(val)) { |
| auto op0 = getOp(op->getOperand(0)); |
| if (op0 == nullptr) |
| goto endCheck; |
| auto op1 = getOp(op->getOperand(1)); |
| if (op1 == nullptr) |
| goto endCheck; |
| auto toreturn = |
| BuilderM.CreateExtractElement(op0, op1, op->getName() + "_unwrap"); |
| if (permitCache) |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; |
| if (auto newi = dyn_cast<Instruction>(toreturn)) { |
| newi->copyIRFlags(op); |
| unwrappedLoads[newi] = val; |
| if (newi->getParent()->getParent() != op->getParent()->getParent()) |
| newi->setDebugLoc(nullptr); |
| } |
| assert(val->getType() == toreturn->getType()); |
| return toreturn; |
| } else if (auto op = dyn_cast<InsertElementInst>(val)) { |
| auto op0 = getOp(op->getOperand(0)); |
| if (op0 == nullptr) |
| goto endCheck; |
| auto op1 = getOp(op->getOperand(1)); |
| if (op1 == nullptr) |
| goto endCheck; |
| auto op2 = getOp(op->getOperand(2)); |
| if (op2 == nullptr) |
| goto endCheck; |
| auto toreturn = |
| BuilderM.CreateInsertElement(op0, op1, op2, op->getName() + "_unwrap"); |
| if (permitCache) |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; |
| if (auto newi = dyn_cast<Instruction>(toreturn)) { |
| newi->copyIRFlags(op); |
| unwrappedLoads[newi] = val; |
| if (newi->getParent()->getParent() != op->getParent()->getParent()) |
| newi->setDebugLoc(nullptr); |
| } |
| assert(val->getType() == toreturn->getType()); |
| return toreturn; |
| } else if (auto op = dyn_cast<ShuffleVectorInst>(val)) { |
| auto op0 = getOp(op->getOperand(0)); |
| if (op0 == nullptr) |
| goto endCheck; |
| auto op1 = getOp(op->getOperand(1)); |
| if (op1 == nullptr) |
| goto endCheck; |
| #if LLVM_VERSION_MAJOR >= 11 |
| auto toreturn = BuilderM.CreateShuffleVector( |
| op0, op1, op->getShuffleMaskForBitcode(), op->getName() + "'_unwrap"); |
| #else |
| auto toreturn = BuilderM.CreateShuffleVector(op0, op1, op->getOperand(2), |
| op->getName() + "'_unwrap"); |
| #endif |
| if (permitCache) |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; |
| if (auto newi = dyn_cast<Instruction>(toreturn)) { |
| newi->copyIRFlags(op); |
| unwrappedLoads[newi] = val; |
| if (newi->getParent()->getParent() != op->getParent()->getParent()) |
| newi->setDebugLoc(nullptr); |
| } |
| assert(val->getType() == toreturn->getType()); |
| return toreturn; |
| } else if (auto op = dyn_cast<BinaryOperator>(val)) { |
| auto op0 = getOp(op->getOperand(0)); |
| if (op0 == nullptr) |
| goto endCheck; |
| auto op1 = getOp(op->getOperand(1)); |
| if (op1 == nullptr) |
| goto endCheck; |
| if (op0->getType() != op1->getType()) { |
| llvm::errs() << " op: " << *op << " op0: " << *op0 << " op1: " << *op1 |
| << " p0: " << *op->getOperand(0) |
| << " p1: " << *op->getOperand(1) << "\n"; |
| } |
| assert(op0->getType() == op1->getType()); |
| auto toreturn = BuilderM.CreateBinOp(op->getOpcode(), op0, op1, |
| op->getName() + "_unwrap"); |
| if (auto newi = dyn_cast<Instruction>(toreturn)) { |
| newi->copyIRFlags(op); |
| unwrappedLoads[newi] = val; |
| if (newi->getParent()->getParent() != op->getParent()->getParent()) |
| newi->setDebugLoc(nullptr); |
| } |
| if (permitCache) |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; |
| assert(val->getType() == toreturn->getType()); |
| return toreturn; |
| } else if (auto op = dyn_cast<ICmpInst>(val)) { |
| auto op0 = getOp(op->getOperand(0)); |
| if (op0 == nullptr) |
| goto endCheck; |
| auto op1 = getOp(op->getOperand(1)); |
| if (op1 == nullptr) |
| goto endCheck; |
| auto toreturn = BuilderM.CreateICmp(op->getPredicate(), op0, op1, |
| op->getName() + "_unwrap"); |
| if (auto newi = dyn_cast<Instruction>(toreturn)) { |
| newi->copyIRFlags(op); |
| unwrappedLoads[newi] = val; |
| if (newi->getParent()->getParent() != op->getParent()->getParent()) |
| newi->setDebugLoc(nullptr); |
| } |
| if (permitCache) |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; |
| assert(val->getType() == toreturn->getType()); |
| return toreturn; |
| } else if (auto op = dyn_cast<FCmpInst>(val)) { |
| auto op0 = getOp(op->getOperand(0)); |
| if (op0 == nullptr) |
| goto endCheck; |
| auto op1 = getOp(op->getOperand(1)); |
| if (op1 == nullptr) |
| goto endCheck; |
| auto toreturn = BuilderM.CreateFCmp(op->getPredicate(), op0, op1, |
| op->getName() + "_unwrap"); |
| if (auto newi = dyn_cast<Instruction>(toreturn)) { |
| newi->copyIRFlags(op); |
| unwrappedLoads[newi] = val; |
| if (newi->getParent()->getParent() != op->getParent()->getParent()) |
| newi->setDebugLoc(nullptr); |
| } |
| if (permitCache) |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; |
| assert(val->getType() == toreturn->getType()); |
| return toreturn; |
| #if LLVM_VERSION_MAJOR >= 9 |
| } else if (isa<FPMathOperator>(val) && |
| cast<FPMathOperator>(val)->getOpcode() == Instruction::FNeg) { |
| auto op = cast<FPMathOperator>(val); |
| auto op0 = getOp(op->getOperand(0)); |
| if (op0 == nullptr) |
| goto endCheck; |
| auto toreturn = BuilderM.CreateFNeg(op0, op->getName() + "_unwrap"); |
| if (auto newi = dyn_cast<Instruction>(toreturn)) { |
| newi->copyIRFlags(op); |
| unwrappedLoads[newi] = val; |
| if (newi->getParent()->getParent() != |
| cast<Instruction>(val)->getParent()->getParent()) |
| newi->setDebugLoc(nullptr); |
| } |
| if (permitCache) |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; |
| assert(val->getType() == toreturn->getType()); |
| return toreturn; |
| #endif |
| } else if (auto op = dyn_cast<SelectInst>(val)) { |
| auto op0 = getOp(op->getOperand(0)); |
| if (op0 == nullptr) |
| goto endCheck; |
| auto op1 = getOp(op->getOperand(1)); |
| if (op1 == nullptr) |
| goto endCheck; |
| auto op2 = getOp(op->getOperand(2)); |
| if (op2 == nullptr) |
| goto endCheck; |
| auto toreturn = |
| BuilderM.CreateSelect(op0, op1, op2, op->getName() + "_unwrap"); |
| if (auto newi = dyn_cast<Instruction>(toreturn)) { |
| newi->copyIRFlags(op); |
| unwrappedLoads[newi] = val; |
| if (newi->getParent()->getParent() != op->getParent()->getParent()) |
| newi->setDebugLoc(nullptr); |
| } |
| if (permitCache) |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; |
| assert(val->getType() == toreturn->getType()); |
| return toreturn; |
| } else if (auto inst = dyn_cast<GetElementPtrInst>(val)) { |
| auto ptr = getOp(inst->getPointerOperand()); |
| if (ptr == nullptr) |
| goto endCheck; |
| SmallVector<Value *, 4> ind; |
| // llvm::errs() << "inst: " << *inst << "\n"; |
| for (unsigned i = 0; i < inst->getNumIndices(); ++i) { |
| Value *a = inst->getOperand(1 + i); |
| auto op = getOp(a); |
| if (op == nullptr) |
| goto endCheck; |
| ind.push_back(op); |
| } |
| #if LLVM_VERSION_MAJOR > 7 |
| auto toreturn = BuilderM.CreateGEP(inst->getSourceElementType(), ptr, ind, |
| inst->getName() + "_unwrap"); |
| #else |
| auto toreturn = BuilderM.CreateGEP(ptr, ind, inst->getName() + "_unwrap"); |
| #endif |
| if (isa<GetElementPtrInst>(toreturn)) |
| cast<GetElementPtrInst>(toreturn)->setIsInBounds(inst->isInBounds()); |
| if (auto newi = dyn_cast<Instruction>(toreturn)) { |
| newi->copyIRFlags(inst); |
| unwrappedLoads[newi] = val; |
| if (newi->getParent()->getParent() != inst->getParent()->getParent()) |
| newi->setDebugLoc(nullptr); |
| } |
| if (permitCache) |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; |
| assert(val->getType() == toreturn->getType()); |
| return toreturn; |
| } else if (auto load = dyn_cast<LoadInst>(val)) { |
| if (load->getMetadata("enzyme_noneedunwrap")) |
| return load; |
| |
| bool legalMove = unwrapMode == UnwrapMode::LegalFullUnwrap || |
| unwrapMode == UnwrapMode::LegalFullUnwrapNoTapeReplace; |
| if (!legalMove) { |
| BasicBlock *parent = nullptr; |
| if (isOriginalBlock(*BuilderM.GetInsertBlock())) |
| parent = BuilderM.GetInsertBlock(); |
| if (!parent || |
| LI.getLoopFor(parent) == LI.getLoopFor(load->getParent()) || |
| DT.dominates(load, parent)) { |
| legalMove = legalRecompute(load, available, &BuilderM); |
| } else { |
| legalMove = |
| legalRecompute(load, available, &BuilderM, /*reverse*/ false, |
| /*legalRecomputeCache*/ false); |
| } |
| } |
| if (!legalMove) { |
| auto &warnMap = UnwrappedWarnings[load]; |
| if (!warnMap.count(BuilderM.GetInsertBlock())) { |
| EmitWarning("UncacheableUnwrap", *load, "Load cannot be unwrapped ", |
| *load, " in ", BuilderM.GetInsertBlock()->getName(), " - ", |
| BuilderM.GetInsertBlock()->getParent()->getName(), " mode ", |
| unwrapMode); |
| warnMap.insert(BuilderM.GetInsertBlock()); |
| } |
| goto endCheck; |
| } |
| |
| Value *pidx = getOp(load->getOperand(0)); |
| |
| if (pidx == nullptr) { |
| goto endCheck; |
| } |
| |
| if (pidx->getType() != load->getOperand(0)->getType()) { |
| llvm::errs() << "load: " << *load << "\n"; |
| llvm::errs() << "load->getOperand(0): " << *load->getOperand(0) << "\n"; |
| llvm::errs() << "idx: " << *pidx << " unwrapping: " << *val |
| << " mode=" << unwrapMode << "\n"; |
| } |
| assert(pidx->getType() == load->getOperand(0)->getType()); |
| |
| #if LLVM_VERSION_MAJOR > 7 |
| auto toreturn = |
| BuilderM.CreateLoad(load->getType(), pidx, load->getName() + "_unwrap"); |
| #else |
| auto toreturn = BuilderM.CreateLoad(pidx, load->getName() + "_unwrap"); |
| #endif |
| llvm::SmallVector<unsigned int, 9> ToCopy2(MD_ToCopy); |
| ToCopy2.push_back(LLVMContext::MD_noalias); |
| ToCopy2.push_back(LLVMContext::MD_alias_scope); |
| toreturn->copyMetadata(*load, ToCopy2); |
| toreturn->copyIRFlags(load); |
| unwrappedLoads[toreturn] = load; |
| if (toreturn->getParent()->getParent() != load->getParent()->getParent()) |
| toreturn->setDebugLoc(nullptr); |
| else |
| toreturn->setDebugLoc(getNewFromOriginal(load->getDebugLoc())); |
| #if LLVM_VERSION_MAJOR >= 10 |
| toreturn->setAlignment(load->getAlign()); |
| #else |
| toreturn->setAlignment(load->getAlignment()); |
| #endif |
| toreturn->setVolatile(load->isVolatile()); |
| toreturn->setOrdering(load->getOrdering()); |
| toreturn->setSyncScopeID(load->getSyncScopeID()); |
| if (toreturn->getParent()->getParent() != load->getParent()->getParent()) |
| toreturn->setDebugLoc(nullptr); |
| else |
| toreturn->setDebugLoc(getNewFromOriginal(load->getDebugLoc())); |
| toreturn->setMetadata(LLVMContext::MD_tbaa, |
| load->getMetadata(LLVMContext::MD_tbaa)); |
| toreturn->setMetadata(LLVMContext::MD_invariant_group, |
| load->getMetadata(LLVMContext::MD_invariant_group)); |
| // TODO adding to cache only legal if no alias of any future writes |
| if (permitCache) |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; |
| assert(val->getType() == toreturn->getType()); |
| return toreturn; |
| } else if (auto op = dyn_cast<CallInst>(val)) { |
| |
| bool legalMove = unwrapMode == UnwrapMode::LegalFullUnwrap || |
| unwrapMode == UnwrapMode::LegalFullUnwrapNoTapeReplace; |
| if (!legalMove) { |
| legalMove = legalRecompute(op, available, &BuilderM); |
| } |
| if (!legalMove) |
| goto endCheck; |
| |
| SmallVector<Value *, 4> args; |
| #if LLVM_VERSION_MAJOR >= 14 |
| for (unsigned i = 0; i < op->arg_size(); ++i) |
| #else |
| for (unsigned i = 0; i < op->getNumArgOperands(); ++i) |
| #endif |
| { |
| args.emplace_back(getOp(op->getArgOperand(i))); |
| if (args[i] == nullptr) |
| goto endCheck; |
| } |
| |
| #if LLVM_VERSION_MAJOR >= 11 |
| Value *fn = getOp(op->getCalledOperand()); |
| #else |
| Value *fn = getOp(op->getCalledValue()); |
| #endif |
| if (fn == nullptr) |
| goto endCheck; |
| |
| auto toreturn = |
| cast<CallInst>(BuilderM.CreateCall(op->getFunctionType(), fn, args)); |
| toreturn->copyIRFlags(op); |
| toreturn->setAttributes(op->getAttributes()); |
| toreturn->setCallingConv(op->getCallingConv()); |
| toreturn->setTailCallKind(op->getTailCallKind()); |
| if (toreturn->getParent()->getParent() == op->getParent()->getParent()) |
| toreturn->setDebugLoc(getNewFromOriginal(op->getDebugLoc())); |
| else |
| toreturn->setDebugLoc(nullptr); |
| if (permitCache) |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn; |
| unwrappedLoads[toreturn] = val; |
| return toreturn; |
| } else if (auto phi = dyn_cast<PHINode>(val)) { |
| if (phi->getNumIncomingValues() == 0) { |
| // This is a placeholder shadow for a load, rather than falling |
| // back to the uncached variant, use the proper procedure for |
| // an inverted load |
| if (auto dli = dyn_cast_or_null<LoadInst>(hasUninverted(phi))) { |
| // Almost identical code to unwrap load (replacing use of shadow |
| // where appropriate) |
| if (dli->getMetadata("enzyme_noneedunwrap")) |
| return dli; |
| |
| bool legalMove = unwrapMode == UnwrapMode::LegalFullUnwrap || |
| unwrapMode == UnwrapMode::LegalFullUnwrapNoTapeReplace; |
| if (!legalMove) { |
| // TODO actually consider whether this is legal to move to the new |
| // location, rather than recomputable anywhere |
| legalMove = legalRecompute(dli, available, &BuilderM); |
| } |
| if (!legalMove) { |
| auto &warnMap = UnwrappedWarnings[phi]; |
| if (!warnMap.count(BuilderM.GetInsertBlock())) { |
| EmitWarning("UncacheableUnwrap", *dli, |
| "Differential Load cannot be unwrapped ", *dli, " in ", |
| BuilderM.GetInsertBlock()->getName(), " mode ", |
| unwrapMode); |
| warnMap.insert(BuilderM.GetInsertBlock()); |
| } |
| return nullptr; |
| } |
| |
| Value *pidx = nullptr; |
| |
| if (isOriginalBlock(*BuilderM.GetInsertBlock())) { |
| pidx = invertPointerM(dli->getOperand(0), BuilderM); |
| } else { |
| pidx = lookupM(invertPointerM(dli->getOperand(0), BuilderM), BuilderM, |
| available); |
| } |
| |
| if (pidx == nullptr) |
| goto endCheck; |
| |
| if (pidx->getType() != getShadowType(dli->getOperand(0)->getType())) { |
| llvm::errs() << "dli: " << *dli << "\n"; |
| llvm::errs() << "dli->getOperand(0): " << *dli->getOperand(0) << "\n"; |
| llvm::errs() << "pidx: " << *pidx << "\n"; |
| } |
| assert(pidx->getType() == getShadowType(dli->getOperand(0)->getType())); |
| |
| Value *toreturn = applyChainRule( |
| dli->getType(), BuilderM, |
| [&](Value *pidx) { |
| #if LLVM_VERSION_MAJOR > 7 |
| auto toreturn = BuilderM.CreateLoad(dli->getType(), pidx, |
| phi->getName() + "_unwrap"); |
| #else |
| auto toreturn = |
| BuilderM.CreateLoad(pidx, phi->getName() + "_unwrap"); |
| #endif |
| if (auto newi = dyn_cast<Instruction>(toreturn)) { |
| newi->copyIRFlags(dli); |
| unwrappedLoads[toreturn] = dli; |
| } |
| #if LLVM_VERSION_MAJOR >= 10 |
| toreturn->setAlignment(dli->getAlign()); |
| #else |
| toreturn->setAlignment(dli->getAlignment()); |
| #endif |
| toreturn->setVolatile(dli->isVolatile()); |
| toreturn->setOrdering(dli->getOrdering()); |
| toreturn->setSyncScopeID(dli->getSyncScopeID()); |
| llvm::SmallVector<unsigned int, 9> ToCopy2(MD_ToCopy); |
| ToCopy2.push_back(LLVMContext::MD_noalias); |
| toreturn->copyMetadata(*dli, ToCopy2); |
| toreturn->setDebugLoc(getNewFromOriginal(dli->getDebugLoc())); |
| return toreturn; |
| }, |
| pidx); |
| |
| // TODO adding to cache only legal if no alias of any future writes |
| if (permitCache) |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = |
| toreturn; |
| assert(val->getType() == toreturn->getType()); |
| return toreturn; |
| } |
| goto endCheck; |
| } |
| assert(phi->getNumIncomingValues() != 0); |
| |
| // If requesting loop bound and are requesting the total size. |
| // Rather than generating a new lcssa variable, use the existing loop exact |
| // bound var |
| BasicBlock *ivctx = scope; |
| if (!ivctx) |
| ivctx = BuilderM.GetInsertBlock(); |
| if (newFunc == ivctx->getParent() && !isOriginalBlock(*ivctx)) { |
| ivctx = originalForReverseBlock(*ivctx); |
| } |
| if ((ivctx == phi->getParent() || DT.dominates(phi, ivctx)) && |
| (!isOriginalBlock(*BuilderM.GetInsertBlock()) || |
| DT.dominates(phi, &*BuilderM.GetInsertPoint()))) { |
| LoopContext lc; |
| bool loopVar = false; |
| if (getContext(phi->getParent(), lc) && lc.var == phi) { |
| loopVar = true; |
| } else { |
| Value *V = nullptr; |
| bool legal = true; |
| for (auto &val : phi->incoming_values()) { |
| if (isa<UndefValue>(val)) |
| continue; |
| if (V == nullptr) |
| V = val; |
| else if (V != val) { |
| legal = false; |
| break; |
| } |
| } |
| if (legal) { |
| if (auto I = dyn_cast_or_null<PHINode>(V)) { |
| if (getContext(I->getParent(), lc) && lc.var == I) { |
| loopVar = true; |
| } |
| } |
| } |
| } |
| if (loopVar) { |
| if (!lc.dynamic) { |
| Value *lim = getOp(lc.trueLimit); |
| if (lim) { |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = |
| lim; |
| return lim; |
| } |
| } else if (unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup && |
| reverseBlocks.size() > 0) { |
| // Must be in a reverse pass fashion for a lookup to index bound to be |
| // legal |
| assert(/*ReverseLimit*/ reverseBlocks.size() > 0); |
| LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, |
| lc.preheader); |
| Value *lim = lookupValueFromCache( |
| /*forwardPass*/ false, BuilderM, lctx, |
| getDynamicLoopLimit(LI.getLoopFor(lc.header)), |
| /*isi1*/ false, available); |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = lim; |
| return lim; |
| } |
| } |
| } |
| |
| auto parent = phi->getParent(); |
| |
| // Don't attempt to unroll a loop induction variable in other |
| // circumstances |
| auto &LLI = Logic.PPC.FAM.getResult<LoopAnalysis>(*parent->getParent()); |
| std::set<BasicBlock *> prevIteration; |
| if (LLI.isLoopHeader(parent)) { |
| if (phi->getNumIncomingValues() != 2) { |
| assert(unwrapMode != UnwrapMode::LegalFullUnwrap); |
| goto endCheck; |
| } |
| auto L = LLI.getLoopFor(parent); |
| for (auto PH : predecessors(parent)) { |
| if (L->contains(PH)) |
| prevIteration.insert(PH); |
| } |
| if (prevIteration.size() && !legalRecompute(phi, available, &BuilderM)) { |
| assert(unwrapMode != UnwrapMode::LegalFullUnwrap); |
| goto endCheck; |
| } |
| } |
| for (auto &val : phi->incoming_values()) { |
| if (isPotentialLastLoopValue(val, parent, LLI)) { |
| assert(unwrapMode != UnwrapMode::LegalFullUnwrap); |
| goto endCheck; |
| } |
| } |
| |
| if (phi->getNumIncomingValues() == 1) { |
| assert(phi->getIncomingValue(0) != phi); |
| auto toreturn = getOpUnchecked(phi->getIncomingValue(0)); |
| if (toreturn == nullptr || toreturn == phi) { |
| assert(unwrapMode != UnwrapMode::LegalFullUnwrap); |
| goto endCheck; |
| } |
| assert(val->getType() == toreturn->getType()); |
| return toreturn; |
| } |
| |
| std::set<BasicBlock *> targetToPreds; |
| // Map of function edges to list of values possible |
| std::map<std::pair</*pred*/ BasicBlock *, /*successor*/ BasicBlock *>, |
| std::set<BasicBlock *>> |
| done; |
| { |
| std::deque<std::tuple< |
| std::pair</*pred*/ BasicBlock *, /*successor*/ BasicBlock *>, |
| BasicBlock *>> |
| Q; // newblock, target |
| |
| for (unsigned i = 0; i < phi->getNumIncomingValues(); ++i) { |
| Q.push_back( |
| std::make_pair(std::make_pair(phi->getIncomingBlock(i), parent), |
| phi->getIncomingBlock(i))); |
| targetToPreds.insert(phi->getIncomingBlock(i)); |
| } |
| |
| for (std::tuple< |
| std::pair</*pred*/ BasicBlock *, /*successor*/ BasicBlock *>, |
| BasicBlock *> |
| trace; |
| Q.size() > 0;) { |
| trace = Q.front(); |
| Q.pop_front(); |
| auto edge = std::get<0>(trace); |
| auto block = edge.first; |
| auto target = std::get<1>(trace); |
| |
| if (done[edge].count(target)) |
| continue; |
| done[edge].insert(target); |
| |
| if (DT.dominates(block, phi->getParent())) |
| continue; |
| |
| Loop *blockLoop = LI.getLoopFor(block); |
| |
| for (BasicBlock *Pred : predecessors(block)) { |
| // Don't go up the backedge as we can use the last value if desired |
| // via lcssa |
| if (blockLoop && blockLoop->getHeader() == block && |
| blockLoop == LI.getLoopFor(Pred)) |
| continue; |
| |
| Q.push_back( |
| std::tuple<std::pair<BasicBlock *, BasicBlock *>, BasicBlock *>( |
| std::make_pair(Pred, block), target)); |
| } |
| } |
| } |
| |
| std::set<BasicBlock *> blocks; |
| for (auto pair : done) { |
| const auto &edge = pair.first; |
| blocks.insert(edge.first); |
| } |
| |
| BasicBlock *oldB = BuilderM.GetInsertBlock(); |
| if (BuilderM.GetInsertPoint() != oldB->end()) { |
| assert(unwrapMode != UnwrapMode::LegalFullUnwrap); |
| goto endCheck; |
| } |
| |
| BasicBlock *fwd = oldB; |
| bool inReverseBlocks = false; |
| if (!isOriginalBlock(*fwd)) { |
| auto found = reverseBlockToPrimal.find(oldB); |
| if (found == reverseBlockToPrimal.end()) { |
| assert(unwrapMode != UnwrapMode::LegalFullUnwrap); |
| goto endCheck; |
| } |
| fwd = found->second; |
| inReverseBlocks = |
| std::find(reverseBlocks[fwd].begin(), reverseBlocks[fwd].end(), |
| oldB) != reverseBlocks[fwd].end(); |
| } |
| |
| auto eraseBlocks = [&](ArrayRef<BasicBlock *> blocks, BasicBlock *bret) { |
| SmallVector<BasicBlock *, 2> revtopo; |
| { |
| SmallPtrSet<BasicBlock *, 2> seen; |
| std::function<void(BasicBlock *)> dfs = [&](BasicBlock *B) { |
| if (seen.count(B)) |
| return; |
| seen.insert(B); |
| if (B->getTerminator()) |
| for (auto S : successors(B)) |
| if (!seen.count(S)) |
| dfs(S); |
| revtopo.push_back(B); |
| }; |
| for (auto B : blocks) |
| dfs(B); |
| if (!seen.count(bret)) |
| revtopo.insert(revtopo.begin(), bret); |
| } |
| |
| SmallVector<Instruction *, 4> toErase; |
| for (auto B : revtopo) { |
| if (B == bret) |
| continue; |
| for (auto &I : llvm::reverse(*B)) { |
| toErase.push_back(&I); |
| } |
| unwrap_cache.erase(B); |
| lookup_cache.erase(B); |
| if (reverseBlocks.size() > 0) { |
| auto tfwd = reverseBlockToPrimal[B]; |
| assert(tfwd); |
| auto rfound = reverseBlocks.find(tfwd); |
| assert(rfound != reverseBlocks.end()); |
| auto &tlst = rfound->second; |
| auto found = std::find(tlst.begin(), tlst.end(), B); |
| if (found != tlst.end()) |
| tlst.erase(found); |
| reverseBlockToPrimal.erase(B); |
| } |
| } |
| for (auto I : toErase) { |
| erase(I); |
| } |
| for (auto B : revtopo) |
| B->eraseFromParent(); |
| }; |
| |
| if (targetToPreds.size() == 3) { |
| for (auto block : blocks) { |
| if (!DT.dominates(block, phi->getParent())) |
| continue; |
| std::set<BasicBlock *> foundtargets; |
| std::set<BasicBlock *> uniqueTargets; |
| for (BasicBlock *succ : successors(block)) { |
| auto edge = std::make_pair(block, succ); |
| for (BasicBlock *target : done[edge]) { |
| if (foundtargets.find(target) != foundtargets.end()) { |
| goto rnextpair; |
| } |
| foundtargets.insert(target); |
| if (done[edge].size() == 1) |
| uniqueTargets.insert(target); |
| } |
| } |
| if (foundtargets.size() != 3) |
| goto rnextpair; |
| if (uniqueTargets.size() != 1) |
| goto rnextpair; |
| |
| { |
| BasicBlock *subblock = nullptr; |
| for (auto block2 : blocks) { |
| { |
| // The second split block must not have a parent with an edge |
| // to a block other than to itself, which can reach any of its |
| // two targets. |
| // TODO verify this |
| for (auto P : predecessors(block2)) { |
| for (auto S : successors(P)) { |
| if (S == block2) |
| continue; |
| auto edge = std::make_pair(P, S); |
| if (done.find(edge) != done.end()) { |
| for (auto target : done[edge]) { |
| if (foundtargets.find(target) != foundtargets.end() && |
| uniqueTargets.find(target) == uniqueTargets.end()) |
| goto nextblock; |
| } |
| } |
| } |
| } |
| std::set<BasicBlock *> seen2; |
| for (BasicBlock *succ : successors(block2)) { |
| auto edge = std::make_pair(block2, succ); |
| if (done[edge].size() != 1) { |
| // llvm::errs() << " -- failed from noonesize\n"; |
| goto nextblock; |
| } |
| for (BasicBlock *target : done[edge]) { |
| if (seen2.find(target) != seen2.end()) { |
| // llvm::errs() << " -- failed from not uniqueTargets\n"; |
| goto nextblock; |
| } |
| seen2.insert(target); |
| if (foundtargets.find(target) == foundtargets.end()) { |
| // llvm::errs() << " -- failed from not unknown target\n"; |
| goto nextblock; |
| } |
| if (uniqueTargets.find(target) != uniqueTargets.end()) { |
| // llvm::errs() << " -- failed from not same target\n"; |
| goto nextblock; |
| } |
| } |
| } |
| if (seen2.size() != 2) { |
| // llvm::errs() << " -- failed from not 2 seen\n"; |
| goto nextblock; |
| } |
| subblock = block2; |
| break; |
| } |
| nextblock:; |
| } |
| |
| if (subblock == nullptr) |
| goto rnextpair; |
| |
| { |
| auto bi1 = cast<BranchInst>(block->getTerminator()); |
| |
| auto cond1 = getOp(bi1->getCondition()); |
| if (cond1 == nullptr) { |
| assert(unwrapMode != UnwrapMode::LegalFullUnwrap); |
| goto endCheck; |
| } |
| auto bi2 = cast<BranchInst>(subblock->getTerminator()); |
| auto cond2 = getOp(bi2->getCondition()); |
| if (cond2 == nullptr) { |
| assert(unwrapMode != UnwrapMode::LegalFullUnwrap); |
| goto endCheck; |
| } |
| |
| SmallVector<BasicBlock *, 3> predBlocks = {bi2->getSuccessor(0), |
| bi2->getSuccessor(1)}; |
| for (int i = 0; i < 2; i++) { |
| auto edge = std::make_pair(block, bi1->getSuccessor(i)); |
| if (done[edge].size() == 1) { |
| predBlocks.push_back(bi1->getSuccessor(i)); |
| } |
| } |
| |
| SmallVector<Value *, 2> vals; |
| |
| SmallVector<BasicBlock *, 2> blocks; |
| SmallVector<BasicBlock *, 2> endingBlocks; |
| |
| BasicBlock *last = oldB; |
| |
| BasicBlock *bret = BasicBlock::Create( |
| val->getContext(), oldB->getName() + "_phimerge", newFunc); |
| |
| for (size_t i = 0; i < predBlocks.size(); i++) { |
| BasicBlock *valparent = (i < 2) ? subblock : block; |
| assert(done.find(std::make_pair(valparent, predBlocks[i])) != |
| done.end()); |
| assert(done[std::make_pair(valparent, predBlocks[i])].size() == |
| 1); |
| blocks.push_back(BasicBlock::Create( |
| val->getContext(), oldB->getName() + "_phirc", newFunc)); |
| blocks[i]->moveAfter(last); |
| last = blocks[i]; |
| if (inReverseBlocks) |
| reverseBlocks[fwd].push_back(blocks[i]); |
| reverseBlockToPrimal[blocks[i]] = fwd; |
| IRBuilder<> B(blocks[i]); |
| |
| for (auto pair : unwrap_cache[oldB]) |
| unwrap_cache[blocks[i]].insert(pair); |
| for (auto pair : lookup_cache[oldB]) |
| lookup_cache[blocks[i]].insert(pair); |
| auto PB = *done[std::make_pair(valparent, predBlocks[i])].begin(); |
| |
| if (auto inst = dyn_cast<Instruction>( |
| phi->getIncomingValueForBlock(PB))) { |
| // Recompute the phi computation with the conditional if: |
| // 1) the instruction may read from memory AND does not |
| // dominate the current insertion point (thereby |
| // potentially making such recomputation without the |
| // condition illegal) |
| // 2) the value is a call or load and option is set to not |
| // speculatively recompute values within a phi |
| // OR |
| // 3) the value comes from a previous iteration. |
| BasicBlock *nextScope = PB; |
| // if (inst->getParent() == nextScope) nextScope = |
| // phi->getParent(); |
| if (prevIteration.count(PB)) { |
| assert(0 && "tri block prev iteration unhandled"); |
| } else if (!DT.dominates(inst->getParent(), phi->getParent()) || |
| (!EnzymeSpeculatePHIs && |
| (isa<CallInst>(inst) || isa<LoadInst>(inst)))) |
| vals.push_back(getOpFull(B, inst, nextScope)); |
| else |
| vals.push_back(getOpFull(BuilderM, inst, nextScope)); |
| } else |
| vals.push_back( |
| getOpFull(BuilderM, phi->getIncomingValueForBlock(PB), PB)); |
| |
| if (!vals[i]) { |
| eraseBlocks(blocks, bret); |
| assert(unwrapMode != UnwrapMode::LegalFullUnwrap); |
| goto endCheck; |
| } |
| assert(val->getType() == vals[i]->getType()); |
| B.CreateBr(bret); |
| endingBlocks.push_back(B.GetInsertBlock()); |
| } |
| |
| bret->moveAfter(last); |
| |
| BasicBlock *bsplit = BasicBlock::Create( |
| val->getContext(), oldB->getName() + "_phisplt", newFunc); |
| bsplit->moveAfter(oldB); |
| if (inReverseBlocks) |
| reverseBlocks[fwd].push_back(bsplit); |
| reverseBlockToPrimal[bsplit] = fwd; |
| BuilderM.CreateCondBr( |
| cond1, |
| (done[std::make_pair(block, bi1->getSuccessor(0))].size() == 1) |
| ? blocks[2] |
| : bsplit, |
| (done[std::make_pair(block, bi1->getSuccessor(1))].size() == 1) |
| ? blocks[2] |
| : bsplit); |
| |
| BuilderM.SetInsertPoint(bsplit); |
| BuilderM.CreateCondBr(cond2, blocks[0], blocks[1]); |
| |
| BuilderM.SetInsertPoint(bret); |
| if (inReverseBlocks) |
| reverseBlocks[fwd].push_back(bret); |
| reverseBlockToPrimal[bret] = fwd; |
| auto toret = BuilderM.CreatePHI(val->getType(), vals.size()); |
| for (size_t i = 0; i < vals.size(); i++) |
| toret->addIncoming(vals[i], endingBlocks[i]); |
| assert(val->getType() == toret->getType()); |
| if (permitCache) { |
| unwrap_cache[bret][idx.first][idx.second] = toret; |
| } |
| unwrappedLoads[toret] = val; |
| for (auto pair : unwrap_cache[oldB]) |
| unwrap_cache[bret].insert(pair); |
| for (auto pair : lookup_cache[oldB]) |
| lookup_cache[bret].insert(pair); |
| return toret; |
| } |
| } |
| rnextpair:; |
| } |
| } |
| |
| Instruction *equivalentTerminator = nullptr; |
| |
| if (prevIteration.size() == 1) { |
| if (phi->getNumIncomingValues() == 2) { |
| |
| ValueToValueMapTy prevAvailable; |
| for (const auto &pair : available) |
| prevAvailable.insert(pair); |
| LoopContext ctx; |
| getContext(parent, ctx); |
| Value *prevIdx; |
| if (prevAvailable.count(ctx.var)) |
| prevIdx = prevAvailable[ctx.var]; |
| else { |
| if (!isOriginalBlock(*BuilderM.GetInsertBlock())) { |
| // If we are using the phi in the reverse pass of a block inside the |
| // loop itself the previous index variable (aka the previous inc) is |
| // equivalent to the current load of antivaralloc |
| if (LI.getLoopFor(ctx.header)->contains(fwd)) { |
| #if LLVM_VERSION_MAJOR > 7 |
| prevIdx = |
| BuilderM.CreateLoad(ctx.var->getType(), ctx.antivaralloc); |
| #else |
| prevIdx = BuilderM.CreateLoad(ctx.antivaralloc); |
| #endif |
| } else { |
| // However, if we are using the phi of the reverse pass of a block |
| // outside the loop we must be in the reverse pass of a block |
| // after the loop. In which case, the previous index variable (aka |
| // previous inc) is the total loop iteration count-1, aka the |
| // trueLimit. |
| Value *lim = nullptr; |
| if (ctx.dynamic) { |
| // Must be in a reverse pass fashion for a lookup to index bound |
| // to be legal |
| assert(/*ReverseLimit*/ reverseBlocks.size() > 0); |
| LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, |
| ctx.preheader); |
| lim = lookupValueFromCache( |
| /*forwardPass*/ false, BuilderM, lctx, |
| getDynamicLoopLimit(LI.getLoopFor(ctx.header)), |
| /*isi1*/ false, /*available*/ prevAvailable); |
| } else { |
| lim = lookupM(ctx.trueLimit, BuilderM, prevAvailable); |
| } |
| prevIdx = lim; |
| } |
| } else { |
| prevIdx = ctx.var; |
| } |
| } |
| // Prevent recursive unroll. |
| prevAvailable[phi] = nullptr; |
| SmallVector<Value *, 2> vals; |
| |
| SmallVector<BasicBlock *, 2> blocks; |
| SmallVector<BasicBlock *, 2> endingBlocks; |
| BasicBlock *last = oldB; |
| |
| BasicBlock *bret = BasicBlock::Create( |
| val->getContext(), oldB->getName() + "_phimerge", newFunc); |
| |
| SmallVector<BasicBlock *, 2> preds(predecessors(phi->getParent())); |
| |
| for (auto tup : llvm::enumerate(preds)) { |
| auto i = tup.index(); |
| BasicBlock *PB = tup.value(); |
| blocks.push_back(BasicBlock::Create( |
| val->getContext(), oldB->getName() + "_phirc", newFunc)); |
| blocks[i]->moveAfter(last); |
| last = blocks[i]; |
| if (reverseBlocks.size() > 0) { |
| if (inReverseBlocks) |
| reverseBlocks[fwd].push_back(blocks[i]); |
| reverseBlockToPrimal[blocks[i]] = fwd; |
| } |
| IRBuilder<> B(blocks[i]); |
| |
| if (!prevIteration.count(PB)) { |
| for (auto pair : unwrap_cache[oldB]) |
| unwrap_cache[blocks[i]].insert(pair); |
| for (auto pair : lookup_cache[oldB]) |
| lookup_cache[blocks[i]].insert(pair); |
| } |
| |
| if (auto inst = |
| dyn_cast<Instruction>(phi->getIncomingValueForBlock(PB))) { |
| // Recompute the phi computation with the conditional if: |
| // 1) the instruction may read from memory AND does not dominate |
| // the current insertion point (thereby potentially making such |
| // recomputation without the condition illegal) |
| // 2) the value is a call or load and option is set to not |
| // speculatively recompute values within a phi |
| // OR |
| // 3) the value comes from a previous iteration. |
| BasicBlock *nextScope = PB; |
| // if (inst->getParent() == nextScope) nextScope = phi->getParent(); |
| if (prevIteration.count(PB)) { |
| prevAvailable[ctx.incvar] = prevIdx; |
| prevAvailable[ctx.var] = |
| B.CreateSub(prevIdx, ConstantInt::get(prevIdx->getType(), 1), |
| "", /*NUW*/ true, /*NSW*/ false); |
| Value *___res; |
| if (unwrapMode == UnwrapMode::LegalFullUnwrap || |
| unwrapMode == UnwrapMode::LegalFullUnwrapNoTapeReplace || |
| unwrapMode == UnwrapMode::AttemptFullUnwrap || |
| unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup) { |
| ___res = unwrapM(inst, B, prevAvailable, unwrapMode, nextScope, |
| /*permitCache*/ false); |
| if (!___res && |
| unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup) { |
| bool noLookup = false; |
| if (isOriginalBlock(*B.GetInsertBlock())) { |
| if (!DT.dominates(inst, &*B.GetInsertPoint())) |
| noLookup = true; |
| } |
| if (!noLookup) { |
| BasicBlock *nS2 = nextScope; |
| Value *v = inst; |
| if (BasicBlock *forwardBlock = nextScope) |
| if (auto opinst = dyn_cast<Instruction>(v)) { |
| if (!isOriginalBlock(*forwardBlock)) { |
| forwardBlock = originalForReverseBlock(*forwardBlock); |
| } |
| if (isPotentialLastLoopValue(opinst, forwardBlock, |
| LI)) { |
| v = fixLCSSA(opinst, forwardBlock); |
| nS2 = nullptr; |
| } |
| } |
| ___res = lookupM(v, B, prevAvailable, v != val, nS2); |
| } |
| } |
| if (___res) |
| assert(___res->getType() == inst->getType() && "uw"); |
| } else { |
| BasicBlock *nS2 = nextScope; |
| Value *v = inst; |
| if (BasicBlock *forwardBlock = nextScope) |
| if (auto opinst = dyn_cast<Instruction>(v)) { |
| if (!isOriginalBlock(*forwardBlock)) { |
| forwardBlock = originalForReverseBlock(*forwardBlock); |
| } |
| if (isPotentialLastLoopValue(opinst, forwardBlock, LI)) { |
| v = fixLCSSA(opinst, forwardBlock); |
| nS2 = nullptr; |
| } |
| } |
| ___res = lookupM(v, B, prevAvailable, v != val, nS2); |
| if (___res && ___res->getType() != v->getType()) { |
| llvm::errs() << *newFunc << "\n"; |
| llvm::errs() << " v = " << *v << " res = " << *___res << "\n"; |
| } |
| if (___res) |
| assert(___res->getType() == inst->getType() && "lu"); |
| } |
| vals.push_back(___res); |
| } else if (!DT.dominates(inst->getParent(), phi->getParent()) || |
| (!EnzymeSpeculatePHIs && |
| (isa<CallInst>(inst) || isa<LoadInst>(inst)))) |
| vals.push_back(getOpFull(B, inst, nextScope)); |
| else |
| vals.push_back(getOpFull(BuilderM, inst, nextScope)); |
| } else |
| vals.push_back(phi->getIncomingValueForBlock(PB)); |
| |
| if (!vals[i]) { |
| eraseBlocks(blocks, bret); |
| assert(unwrapMode != UnwrapMode::LegalFullUnwrap); |
| goto endCheck; |
| } |
| assert(val->getType() == vals[i]->getType()); |
| B.CreateBr(bret); |
| endingBlocks.push_back(B.GetInsertBlock()); |
| } |
| |
| // Coming from a previous iteration is equivalent to the current |
| // iteration at zero. |
| Value *cond; |
| if (prevIteration.count(preds[0])) |
| cond = BuilderM.CreateICmpNE(prevIdx, |
| ConstantInt::get(prevIdx->getType(), 0)); |
| else |
| cond = BuilderM.CreateICmpEQ(prevIdx, |
| ConstantInt::get(prevIdx->getType(), 0)); |
| |
| if (blocks[0]->size() == 1 && blocks[1]->size() == 1) { |
| if (auto B1 = dyn_cast<BranchInst>(blocks[0]->getTerminator())) |
| if (auto B2 = dyn_cast<BranchInst>(blocks[1]->getTerminator())) |
| if (B1->isUnconditional() && B2->isUnconditional() && |
| B1->getSuccessor(0) == bret && B2->getSuccessor(0) == bret) { |
| eraseBlocks(blocks, bret); |
| Value *toret = BuilderM.CreateSelect( |
| cond, vals[0], vals[1], phi->getName() + "_unwrap"); |
| if (permitCache) { |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first] |
| [idx.second] = toret; |
| } |
| if (auto instRet = dyn_cast<Instruction>(toret)) { |
| unwrappedLoads[instRet] = val; |
| } |
| return toret; |
| } |
| } |
| |
| bret->moveAfter(last); |
| BuilderM.CreateCondBr(cond, blocks[0], blocks[1]); |
| |
| BuilderM.SetInsertPoint(bret); |
| if (inReverseBlocks) |
| reverseBlocks[fwd].push_back(bret); |
| reverseBlockToPrimal[bret] = fwd; |
| auto toret = BuilderM.CreatePHI(val->getType(), vals.size()); |
| for (size_t i = 0; i < vals.size(); i++) |
| toret->addIncoming(vals[i], endingBlocks[i]); |
| assert(val->getType() == toret->getType()); |
| if (permitCache) { |
| unwrap_cache[bret][idx.first][idx.second] = toret; |
| } |
| for (auto pair : unwrap_cache[oldB]) |
| unwrap_cache[bret].insert(pair); |
| for (auto pair : lookup_cache[oldB]) |
| lookup_cache[bret].insert(pair); |
| unwrappedLoads[toret] = val; |
| return toret; |
| } |
| } |
| if (prevIteration.size() != 0) { |
| llvm::errs() << "prev iteration: " << *phi << "\n"; |
| assert(unwrapMode != UnwrapMode::LegalFullUnwrap); |
| goto endCheck; |
| } |
| |
| for (auto block : blocks) { |
| if (!DT.dominates(block, phi->getParent())) |
| continue; |
| std::set<BasicBlock *> foundtargets; |
| for (BasicBlock *succ : successors(block)) { |
| auto edge = std::make_pair(block, succ); |
| if (done[edge].size() != 1) { |
| goto nextpair; |
| } |
| BasicBlock *target = *done[edge].begin(); |
| if (foundtargets.find(target) != foundtargets.end()) { |
| goto nextpair; |
| } |
| foundtargets.insert(target); |
| } |
| if (foundtargets.size() != targetToPreds.size()) { |
| goto nextpair; |
| } |
| |
| if (DT.dominates(block, parent)) { |
| equivalentTerminator = block->getTerminator(); |
| goto fast; |
| } |
| nextpair:; |
| } |
| goto endCheck; |
| |
| fast:; |
| assert(equivalentTerminator); |
| |
| if (isa<BranchInst>(equivalentTerminator) || |
| isa<SwitchInst>(equivalentTerminator)) { |
| BasicBlock *oldB = BuilderM.GetInsertBlock(); |
| |
| SmallVector<BasicBlock *, 2> predBlocks; |
| Value *cond = nullptr; |
| if (auto branch = dyn_cast<BranchInst>(equivalentTerminator)) { |
| cond = branch->getCondition(); |
| predBlocks.push_back(branch->getSuccessor(0)); |
| predBlocks.push_back(branch->getSuccessor(1)); |
| } else { |
| auto SI = cast<SwitchInst>(equivalentTerminator); |
| cond = SI->getCondition(); |
| predBlocks.push_back(SI->getDefaultDest()); |
| for (auto scase : SI->cases()) { |
| predBlocks.push_back(scase.getCaseSuccessor()); |
| } |
| } |
| cond = getOp(cond); |
| if (!cond) { |
| assert(unwrapMode != UnwrapMode::LegalFullUnwrap); |
| goto endCheck; |
| } |
| |
| SmallVector<Value *, 2> vals; |
| |
| SmallVector<BasicBlock *, 2> blocks; |
| SmallVector<BasicBlock *, 2> endingBlocks; |
| |
| BasicBlock *last = oldB; |
| |
| assert(prevIteration.size() == 0); |
| |
| BasicBlock *bret = BasicBlock::Create( |
| val->getContext(), oldB->getName() + "_phimerge", newFunc); |
| |
| for (size_t i = 0; i < predBlocks.size(); i++) { |
| assert(done.find(std::make_pair(equivalentTerminator->getParent(), |
| predBlocks[i])) != done.end()); |
| assert(done[std::make_pair(equivalentTerminator->getParent(), |
| predBlocks[i])] |
| .size() == 1); |
| BasicBlock *PB = *done[std::make_pair(equivalentTerminator->getParent(), |
| predBlocks[i])] |
| .begin(); |
| blocks.push_back(BasicBlock::Create( |
| val->getContext(), oldB->getName() + "_phirc", newFunc)); |
| blocks[i]->moveAfter(last); |
| last = blocks[i]; |
| if (reverseBlocks.size() > 0) { |
| if (inReverseBlocks) |
| reverseBlocks[fwd].push_back(blocks[i]); |
| reverseBlockToPrimal[blocks[i]] = fwd; |
| } |
| IRBuilder<> B(blocks[i]); |
| |
| for (auto pair : unwrap_cache[oldB]) |
| unwrap_cache[blocks[i]].insert(pair); |
| for (auto pair : lookup_cache[oldB]) |
| lookup_cache[blocks[i]].insert(pair); |
| |
| if (auto inst = |
| dyn_cast<Instruction>(phi->getIncomingValueForBlock(PB))) { |
| // Recompute the phi computation with the conditional if: |
| // 1) the instruction may reat from memory AND does not dominate |
| // the current insertion point (thereby potentially making such |
| // recomputation without the condition illegal) |
| // 2) the value is a call or load and option is set to not |
| // speculatively recompute values within a phi |
| // OR |
| // 3) the value comes from a previous iteration. |
| BasicBlock *nextScope = PB; |
| // if (inst->getParent() == nextScope) nextScope = phi->getParent(); |
| if (!DT.dominates(inst->getParent(), phi->getParent()) || |
| (!EnzymeSpeculatePHIs && |
| (isa<CallInst>(inst) || isa<LoadInst>(inst)))) |
| vals.push_back(getOpFull(B, inst, nextScope)); |
| else |
| vals.push_back(getOpFull(BuilderM, inst, nextScope)); |
| } else |
| vals.push_back(phi->getIncomingValueForBlock(PB)); |
| |
| if (!vals[i]) { |
| eraseBlocks(blocks, bret); |
| assert(unwrapMode != UnwrapMode::LegalFullUnwrap); |
| goto endCheck; |
| } |
| assert(val->getType() == vals[i]->getType()); |
| B.CreateBr(bret); |
| endingBlocks.push_back(B.GetInsertBlock()); |
| } |
| |
| // Fast path to not make a split block if no additional instructions |
| // were made in the two blocks |
| if (isa<BranchInst>(equivalentTerminator) && blocks[0]->size() == 1 && |
| blocks[1]->size() == 1) { |
| if (auto B1 = dyn_cast<BranchInst>(blocks[0]->getTerminator())) |
| if (auto B2 = dyn_cast<BranchInst>(blocks[1]->getTerminator())) |
| if (B1->isUnconditional() && B2->isUnconditional() && |
| B1->getSuccessor(0) == bret && B2->getSuccessor(0) == bret) { |
| eraseBlocks(blocks, bret); |
| Value *toret = BuilderM.CreateSelect(cond, vals[0], vals[1], |
| phi->getName() + "_unwrap"); |
| if (permitCache) { |
| unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = |
| toret; |
| } |
| if (auto instRet = dyn_cast<Instruction>(toret)) { |
| unwrappedLoads[instRet] = val; |
| } |
| return toret; |
| } |
| } |
| |
| if (BuilderM.GetInsertPoint() != oldB->end()) { |
| eraseBlocks(blocks, bret); |
| assert(unwrapMode != UnwrapMode::LegalFullUnwrap); |
| goto endCheck; |
| } |
| |
| bret->moveAfter(last); |
| if (isa<BranchInst>(equivalentTerminator)) { |
| BuilderM.CreateCondBr(cond, blocks[0], blocks[1]); |
| } else { |
| auto SI = cast<SwitchInst>(equivalentTerminator); |
| auto NSI = BuilderM.CreateSwitch(cond, blocks[0], SI->getNumCases()); |
| size_t idx = 1; |
| for (auto scase : SI->cases()) { |
| NSI->addCase(scase.getCaseValue(), blocks[idx]); |
| idx++; |
| } |
| } |
| BuilderM.SetInsertPoint(bret); |
| if (inReverseBlocks) |
| reverseBlocks[fwd].push_back(bret); |
| reverseBlockToPrimal[bret] = fwd; |
| auto toret = BuilderM.CreatePHI(val->getType(), vals.size()); |
| for (size_t i = 0; i < vals.size(); i++) |
| toret->addIncoming(vals[i], endingBlocks[i]); |
| assert(val->getType() == toret->getType()); |
| if (permitCache) { |
| unwrap_cache[bret][idx.first][idx.second] = toret; |
| } |
| for (auto pair : unwrap_cache[oldB]) |
| unwrap_cache[bret].insert(pair); |
| for (auto pair : lookup_cache[oldB]) |
| lookup_cache[bret].insert(pair); |
| unwrappedLoads[toret] = val; |
| return toret; |
| } |
| assert(unwrapMode != UnwrapMode::LegalFullUnwrap); |
| goto endCheck; |
| } |
| |
| endCheck: |
| assert(val); |
| if (unwrapMode == UnwrapMode::LegalFullUnwrap || |
| unwrapMode == UnwrapMode::LegalFullUnwrapNoTapeReplace || |
| unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup) { |
| assert(val->getName() != "<badref>"); |
| Value *nval = val; |
| if (auto opinst = dyn_cast<Instruction>(nval)) |
| if (isOriginalBlock(*BuilderM.GetInsertBlock())) { |
| if (!DT.dominates(opinst, &*BuilderM.GetInsertPoint())) { |
| if (unwrapMode != UnwrapMode::AttemptFullUnwrapWithLookup) { |
| llvm::errs() << " oldF: " << *oldFunc << "\n"; |
| llvm::errs() << " opParen: " << *opinst->getParent()->getParent() |
| << "\n"; |
| llvm::errs() << " newF: " << *newFunc << "\n"; |
| llvm::errs() << " - blk: " << *BuilderM.GetInsertBlock(); |
| llvm::errs() << " opInst: " << *opinst << " mode=" << unwrapMode |
| << "\n"; |
| } |
| assert(unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup); |
| return nullptr; |
| } |
| } |
| BasicBlock *nS2 = scope; |
| if (BasicBlock *forwardBlock = scope) |
| if (auto opinst = dyn_cast<Instruction>(nval)) { |
| if (!isOriginalBlock(*forwardBlock)) { |
| forwardBlock = originalForReverseBlock(*forwardBlock); |
| } |
| if (isPotentialLastLoopValue(opinst, forwardBlock, LI)) { |
| nval = fixLCSSA(opinst, forwardBlock); |
| nS2 = nullptr; |
| } |
| } |
| auto toreturn = lookupM(nval, BuilderM, available, |
| /*tryLegalRecomputeCheck*/ false, nS2); |
| assert(val->getType() == toreturn->getType()); |
| return toreturn; |
| } |
| |
| if (auto inst = dyn_cast<Instruction>(val)) { |
| if (isOriginalBlock(*BuilderM.GetInsertBlock())) { |
| if (BuilderM.GetInsertBlock()->size() && |
| BuilderM.GetInsertPoint() != BuilderM.GetInsertBlock()->end()) { |
| if (DT.dominates(inst, &*BuilderM.GetInsertPoint())) { |
| assert(inst->getType() == val->getType()); |
| return inst; |
| } |
| } else { |
| if (DT.dominates(inst, BuilderM.GetInsertBlock())) { |
| assert(inst->getType() == val->getType()); |
| return inst; |
| } |
| } |
| } |
| assert(val->getName() != "<badref>"); |
| auto &warnMap = UnwrappedWarnings[inst]; |
| if (!warnMap.count(BuilderM.GetInsertBlock())) { |
| EmitWarning("NoUnwrap", *inst, "Cannot unwrap ", *val, " in ", |
| BuilderM.GetInsertBlock()->getName()); |
| warnMap.insert(BuilderM.GetInsertBlock()); |
| } |
| } |
| return nullptr; |
| } |
| |
| Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc, |
| int idx, bool ignoreType, bool replace) { |
| assert(malloc); |
| assert(BuilderQ.GetInsertBlock()->getParent() == newFunc); |
| assert(isOriginalBlock(*BuilderQ.GetInsertBlock())); |
| if (mode == DerivativeMode::ReverseModeCombined) { |
| assert(!tape); |
| return malloc; |
| } |
| |
| if (auto CI = dyn_cast<CallInst>(malloc)) { |
| if (auto F = CI->getCalledFunction()) { |
| assert(F->getName() != "omp_get_thread_num"); |
| } |
| } |
| |
| if (malloc->getType()->isTokenTy()) { |
| llvm::errs() << " oldFunc: " << *oldFunc << "\n"; |
| llvm::errs() << " newFunc: " << *newFunc << "\n"; |
| llvm::errs() << " malloc: " << *malloc << "\n"; |
| } |
| assert(!malloc->getType()->isTokenTy()); |
| |
| if (tape) { |
| if (idx >= 0 && !tape->getType()->isStructTy()) { |
| llvm::errs() << "cacheForReverse incorrect tape type: " << *tape |
| << " idx: " << idx << "\n"; |
| } |
| assert(idx < 0 || tape->getType()->isStructTy()); |
| if (idx >= 0 && |
| (unsigned)idx >= cast<StructType>(tape->getType())->getNumElements()) { |
| llvm::errs() << "oldFunc: " << *oldFunc << "\n"; |
| llvm::errs() << "newFunc: " << *newFunc << "\n"; |
| if (malloc) |
| llvm::errs() << "malloc: " << *malloc << "\n"; |
| llvm::errs() << "tape: " << *tape << "\n"; |
| llvm::errs() << "idx: " << idx << "\n"; |
| } |
| assert(idx < 0 || |
| (unsigned)idx < cast<StructType>(tape->getType())->getNumElements()); |
| Value *ret = |
| (idx < 0) ? tape : BuilderQ.CreateExtractValue(tape, {(unsigned)idx}); |
| |
| if (ret->getType()->isEmptyTy()) { |
| if (auto inst = dyn_cast_or_null<Instruction>(malloc)) { |
| if (!ignoreType) { |
| if (inst->getType() != ret->getType()) { |
| llvm::errs() << "oldFunc: " << *oldFunc << "\n"; |
| llvm::errs() << "newFunc: " << *newFunc << "\n"; |
| llvm::errs() << "inst==malloc: " << *inst << "\n"; |
| llvm::errs() << "ret: " << *ret << "\n"; |
| } |
| assert(inst->getType() == ret->getType()); |
| if (replace) |
| inst->replaceAllUsesWith(UndefValue::get(ret->getType())); |
| } |
| if (replace) |
| erase(inst); |
| } |
| Type *retType = ret->getType(); |
| if (replace) |
| if (auto ri = dyn_cast<Instruction>(ret)) |
| erase(ri); |
| return UndefValue::get(retType); |
| } |
| |
| LimitContext ctx(/*ReverseLimit*/ reverseBlocks.size() > 0, |
| BuilderQ.GetInsertBlock()); |
| if (auto inst = dyn_cast<Instruction>(malloc)) |
| ctx = LimitContext(/*ReverseLimit*/ reverseBlocks.size() > 0, |
| inst->getParent()); |
| if (auto found = findInMap(scopeMap, malloc)) { |
| ctx = found->second; |
| } |
| assert(isOriginalBlock(*ctx.Block)); |
| |
| bool inLoop; |
| if (ctx.ForceSingleIteration) { |
| inLoop = true; |
| ctx.ForceSingleIteration = false; |
| } else { |
| LoopContext lc; |
| inLoop = getContext(ctx.Block, lc); |
| } |
| |
| if (!inLoop) { |
| ret->setName(malloc->getName() + "_fromtape"); |
| if (omp) { |
| Value *tid = ompThreadId(); |
| #if LLVM_VERSION_MAJOR > 7 |
| Value *tPtr = BuilderQ.CreateInBoundsGEP(malloc->getType(), ret, |
| ArrayRef<Value *>(tid)); |
| #else |
| Value *tPtr = BuilderQ.CreateInBoundsGEP(ret, ArrayRef<Value *>(tid)); |
| #endif |
| ret = BuilderQ.CreateLoad(malloc->getType(), tPtr); |
| } |
| } else { |
| if (idx >= 0) |
| erase(cast<Instruction>(ret)); |
| IRBuilder<> entryBuilder(inversionAllocs); |
| entryBuilder.setFastMathFlags(getFast()); |
| ret = (idx < 0) ? tape |
| : entryBuilder.CreateExtractValue(tape, {(unsigned)idx}); |
| |
| Type *innerType = ret->getType(); |
| for (size_t i = 0, |
| limit = getSubLimits( |
| /*inForwardPass*/ true, nullptr, |
| LimitContext( |
| /*ReverseLimit*/ reverseBlocks.size() > 0, |
| BuilderQ.GetInsertBlock())) |
| .size(); |
| i < limit; ++i) { |
| if (!isa<PointerType>(innerType)) { |
| llvm::errs() << "mod: " |
| << *BuilderQ.GetInsertBlock()->getParent()->getParent() |
| << "\n"; |
| llvm::errs() << "fn: " << *BuilderQ.GetInsertBlock()->getParent() |
| << "\n"; |
| llvm::errs() << "bq insertblock: " << *BuilderQ.GetInsertBlock() |
| << "\n"; |
| llvm::errs() << "ret: " << *ret << " type: " << *ret->getType() |
| << "\n"; |
| llvm::errs() << "innerType: " << *innerType << "\n"; |
| if (malloc) |
| llvm::errs() << " malloc: " << *malloc << " i=" << i |
| << " / lim = " << limit << "\n"; |
| } |
| assert(isa<PointerType>(innerType)); |
| innerType = innerType->getPointerElementType(); |
| } |
| |
| assert(malloc); |
| if (!ignoreType) { |
| if (EfficientBoolCache && malloc->getType()->isIntegerTy() && |
| cast<IntegerType>(malloc->getType())->getBitWidth() == 1 && |
| innerType != ret->getType()) { |
| assert(innerType == Type::getInt8Ty(malloc->getContext())); |
| } else { |
| if (innerType != malloc->getType()) { |
| llvm::errs() << *oldFunc << "\n"; |
| llvm::errs() << *newFunc << "\n"; |
| llvm::errs() << "innerType: " << *innerType << "\n"; |
| llvm::errs() << "malloc->getType(): " << *malloc->getType() << "\n"; |
| llvm::errs() << "ret: " << *ret << " - " << *ret->getType() << "\n"; |
| llvm::errs() << "malloc: " << *malloc << "\n"; |
| assert(0 && "illegal loop cache type"); |
| llvm_unreachable("illegal loop cache type"); |
| } |
| } |
| } |
| |
| LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, |
| BuilderQ.GetInsertBlock()); |
| AllocaInst *cache = |
| createCacheForScope(lctx, innerType, "mdyncache_fromtape", |
| ((DiffeGradientUtils *)this)->FreeMemory, false); |
| assert(malloc); |
| bool isi1 = !ignoreType && malloc->getType()->isIntegerTy() && |
| cast<IntegerType>(malloc->getType())->getBitWidth() == 1; |
| assert(isa<PointerType>(cache->getType())); |
| assert(cache->getType()->getPointerElementType() == ret->getType()); |
| entryBuilder.CreateStore(ret, cache); |
| |
| auto v = lookupValueFromCache(/*forwardPass*/ true, BuilderQ, lctx, cache, |
| isi1, /*available*/ ValueToValueMapTy()); |
| if (!ignoreType && malloc) { |
| assert(v->getType() == malloc->getType()); |
| } |
| insert_or_assign(scopeMap, v, |
| std::make_pair(AssertingVH<AllocaInst>(cache), ctx)); |
| ret = cast<Instruction>(v); |
| } |
| |
| if (malloc && !isa<UndefValue>(malloc)) { |
| if (!ignoreType) { |
| if (malloc->getType() != ret->getType()) { |
| llvm::errs() << *oldFunc << "\n"; |
| llvm::errs() << *newFunc << "\n"; |
| llvm::errs() << *malloc << "\n"; |
| llvm::errs() << *ret << "\n"; |
| } |
| assert(malloc->getType() == ret->getType()); |
| } |
| |
| if (replace) { |
| auto found = newToOriginalFn.find(malloc); |
| if (found != newToOriginalFn.end()) { |
| Value *orig = found->second; |
| originalToNewFn[orig] = ret; |
| newToOriginalFn.erase(malloc); |
| newToOriginalFn[ret] = orig; |
| } |
| } |
| |
| if (auto found = findInMap(scopeMap, malloc)) { |
| // There already exists an alloaction for this, we should fully remove |
| // it |
| if (!inLoop) { |
| |
| // Remove stores into |
| SmallVector<Instruction *, 3> stores( |
| scopeInstructions[found->first].begin(), |
| scopeInstructions[found->first].end()); |
| scopeInstructions.erase(found->first); |
| for (int i = stores.size() - 1; i >= 0; i--) { |
| erase(stores[i]); |
| } |
| |
| SmallVector<User *, 4> users; |
| for (auto u : found->first->users()) { |
| users.push_back(u); |
| } |
| for (auto u : users) { |
| if (auto li = dyn_cast<LoadInst>(u)) { |
| IRBuilder<> lb(li); |
| if (replace) { |
| |
| Value *replacewith = |
| (idx < 0) ? tape |
| : lb.CreateExtractValue(tape, {(unsigned)idx}); |
| if (!inLoop && omp) { |
| Value *tid = ompThreadId(); |
| #if LLVM_VERSION_MAJOR > 7 |
| Value *tPtr = lb.CreateInBoundsGEP( |
| replacewith->getType()->getPointerElementType(), |
| replacewith, ArrayRef<Value *>(tid)); |
| #else |
| Value *tPtr = |
| lb.CreateInBoundsGEP(replacewith, ArrayRef<Value *>(tid)); |
| #endif |
| replacewith = lb.CreateLoad( |
| replacewith->getType()->getPointerElementType(), tPtr); |
| } |
| if (li->getType() != replacewith->getType()) { |
| llvm::errs() << " oldFunc: " << *oldFunc << "\n"; |
| llvm::errs() << " newFunc: " << *newFunc << "\n"; |
| llvm::errs() << " malloc: " << *malloc << "\n"; |
| llvm::errs() << " li: " << *li << "\n"; |
| llvm::errs() << " u: " << *u << "\n"; |
| llvm::errs() << " replacewith: " << *replacewith |
| << " idx=" << idx << " - tape=" << *tape << "\n"; |
| } |
| assert(li->getType() == replacewith->getType()); |
| li->replaceAllUsesWith(replacewith); |
| } else { |
| auto phi = |
| lb.CreatePHI(li->getType(), 0, li->getName() + "_cfrphi"); |
| unwrappedLoads[phi] = malloc; |
| li->replaceAllUsesWith(phi); |
| } |
| erase(li); |
| } else { |
| llvm::errs() << "newFunc: " << *newFunc << "\n"; |
| llvm::errs() << "malloc: " << *malloc << "\n"; |
| llvm::errs() << "scopeMap[malloc]: " << *found->first << "\n"; |
| llvm::errs() << "u: " << *u << "\n"; |
| assert(0 && "illegal use for out of loop scopeMap1"); |
| } |
| } |
| |
| { |
| AllocaInst *preerase = found->first; |
| scopeMap.erase(malloc); |
| erase(preerase); |
| } |
| } else { |
| // Remove allocations for scopealloc since it is already allocated |
| // by the augmented forward pass |
| // Remove stores into |
| SmallVector<Instruction *, 3> stores( |
| scopeInstructions[found->first].begin(), |
| scopeInstructions[found->first].end()); |
| scopeInstructions.erase(found->first); |
| scopeAllocs.erase(found->first); |
| for (int i = stores.size() - 1; i >= 0; i--) { |
| erase(stores[i]); |
| } |
| |
| // Remove frees |
| SmallVector<CallInst *, 3> tofree(scopeFrees[found->first].begin(), |
| scopeFrees[found->first].end()); |
| scopeFrees.erase(found->first); |
| for (auto freeinst : tofree) { |
| // This deque contains a list of operations |
| // we can erasing upon erasing the free (and so on). |
| // Since multiple operations can have the same operand, |
| // this deque can contain the same value multiple times. |
| // To remedy this we use a tracking value handle which will |
| // be set to null when erased. |
| std::deque<WeakTrackingVH> ops = {freeinst->getArgOperand(0)}; |
| erase(freeinst); |
| |
| while (ops.size()) { |
| auto z = dyn_cast_or_null<Instruction>(ops[0]); |
| ops.pop_front(); |
| if (z && z->getNumUses() == 0 && !z->isUsedByMetadata()) { |
| for (unsigned i = 0; i < z->getNumOperands(); ++i) { |
| ops.push_back(z->getOperand(i)); |
| } |
| erase(z); |
| } |
| } |
| } |
| |
| // uses of the alloc |
| SmallVector<User *, 4> users; |
| for (auto u : found->first->users()) { |
| users.push_back(u); |
| } |
| for (auto u : users) { |
| if (auto li = dyn_cast<LoadInst>(u)) { |
| // even with replace off, this can be replaced |
| // as since we're in a loop this load is a load of cache |
| // not of the final value (thereby overwriting the new |
| // inst |
| IRBuilder<> lb(li); |
| auto replacewith = |
| (idx < 0) ? tape |
| : lb.CreateExtractValue(tape, {(unsigned)idx}); |
| li->replaceAllUsesWith(replacewith); |
| erase(li); |
| } else { |
| llvm::errs() << "newFunc: " << *newFunc << "\n"; |
| llvm::errs() << "malloc: " << *malloc << "\n"; |
| llvm::errs() << "scopeMap[malloc]: " << *found->first << "\n"; |
| llvm::errs() << "u: " << *u << "\n"; |
| assert(0 && "illegal use for out of loop scopeMap2"); |
| } |
| } |
| |
| AllocaInst *preerase = found->first; |
| scopeMap.erase(malloc); |
| if (replace) |
| erase(preerase); |
| } |
| } |
| if (!ignoreType && replace) |
| cast<Instruction>(malloc)->replaceAllUsesWith(ret); |
| ret->takeName(malloc); |
| if (replace) { |
| auto malloci = cast<Instruction>(malloc); |
| if (malloci == &*BuilderQ.GetInsertPoint()) { |
| BuilderQ.SetInsertPoint(malloci->getNextNode()); |
| } |
| erase(malloci); |
| } |
| } |
| return ret; |
| } else { |
| assert(malloc); |
| assert(!ignoreType); |
| |
| assert(idx >= 0 && (unsigned)idx == addedTapeVals.size()); |
| |
| if (isa<UndefValue>(malloc)) { |
| addedTapeVals.push_back(malloc); |
| return malloc; |
| } |
| |
| LimitContext ctx(/*ReverseLimit*/ reverseBlocks.size() > 0, |
| BuilderQ.GetInsertBlock()); |
| if (auto inst = dyn_cast<Instruction>(malloc)) |
| ctx = LimitContext(/*ReverseLimit*/ reverseBlocks.size() > 0, |
| inst->getParent()); |
| if (auto found = findInMap(scopeMap, malloc)) { |
| ctx = found->second; |
| } |
| |
| bool inLoop; |
| |
| if (ctx.ForceSingleIteration) { |
| inLoop = true; |
| ctx.ForceSingleIteration = false; |
| } else { |
| LoopContext lc; |
| inLoop = getContext(ctx.Block, lc); |
| } |
| |
| if (!inLoop) { |
| Value *toStoreInTape = malloc; |
| if (omp) { |
| Value *numThreads = ompNumThreads(); |
| Value *tid = ompThreadId(); |
| IRBuilder<> entryBuilder(inversionAllocs); |
| |
| auto firstallocation = |
| CreateAllocation(entryBuilder, malloc->getType(), numThreads, |
| malloc->getName() + "_malloccache"); |
| #if LLVM_VERSION_MAJOR > 7 |
| Value *tPtr = entryBuilder.CreateInBoundsGEP( |
| firstallocation->getType()->getPointerElementType(), |
| firstallocation, ArrayRef<Value *>(tid)); |
| #else |
| Value *tPtr = entryBuilder.CreateInBoundsGEP(firstallocation, |
| ArrayRef<Value *>(tid)); |
| #endif |
| if (auto inst = dyn_cast<Instruction>(malloc)) { |
| entryBuilder.SetInsertPoint(inst->getNextNode()); |
| } |
| entryBuilder.CreateStore(malloc, tPtr); |
| toStoreInTape = firstallocation; |
| } |
| addedTapeVals.push_back(toStoreInTape); |
| return malloc; |
| } |
| |
| ensureLookupCached( |
| cast<Instruction>(malloc), |
| /*shouldFree=*/reverseBlocks.size() > 0, |
| /*scope*/ nullptr, |
| cast<Instruction>(malloc)->getMetadata(LLVMContext::MD_tbaa)); |
| auto found2 = scopeMap.find(malloc); |
| assert(found2 != scopeMap.end()); |
| assert(found2->second.first); |
| |
| Value *toadd; |
| toadd = scopeAllocs[found2->second.first][0]; |
| for (auto u : toadd->users()) { |
| if (auto ci = dyn_cast<CastInst>(u)) { |
| toadd = ci; |
| break; |
| } |
| } |
| |
| // llvm::errs() << " malloc: " << *malloc << "\n"; |
| // llvm::errs() << " toadd: " << *toadd << "\n"; |
| Type *innerType = toadd->getType(); |
| for (size_t |
| i = 0, |
| limit = getSubLimits( |
| /*inForwardPass*/ true, nullptr, |
| LimitContext(/*ReverseLimit*/ reverseBlocks.size() > 0, |
| BuilderQ.GetInsertBlock())) |
| .size(); |
| i < limit; ++i) { |
| innerType = innerType->getPointerElementType(); |
| } |
| assert(!ignoreType); |
| if (EfficientBoolCache && malloc->getType()->isIntegerTy() && |
| toadd->getType() != innerType && |
| cast<IntegerType>(malloc->getType())->getBitWidth() == 1) { |
| assert(innerType == Type::getInt8Ty(toadd->getContext())); |
| } else { |
| if (innerType != malloc->getType()) { |
| llvm::errs() << "oldFunc:" << *oldFunc << "\n"; |
| llvm::errs() << "newFunc: " << *newFunc << "\n"; |
| llvm::errs() << " toadd: " << *toadd << "\n"; |
| llvm::errs() << "innerType: " << *innerType << "\n"; |
| llvm::errs() << "malloc: " << *malloc << "\n"; |
| } |
| assert(innerType == malloc->getType()); |
| } |
| addedTapeVals.push_back(toadd); |
| return malloc; |
| } |
| llvm::errs() |
| << "Fell through on cacheForReverse. This should never happen.\n"; |
| assert(false); |
| } |
| |
| /// Given an edge from BB to branchingBlock get the corresponding block to |
| /// branch to in the reverse pass |
| BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB, |
| BasicBlock *branchingBlock) { |
| assert(BB); |
| // BB should be a forward pass block, assert that |
| if (reverseBlocks.find(BB) == reverseBlocks.end()) { |
| llvm::errs() << *oldFunc << "\n"; |
| llvm::errs() << *newFunc << "\n"; |
| llvm::errs() << "BB: " << *BB << "\n"; |
| llvm::errs() << "branchingBlock: " << *branchingBlock << "\n"; |
| } |
| assert(reverseBlocks.find(BB) != reverseBlocks.end()); |
| assert(reverseBlocks.find(branchingBlock) != reverseBlocks.end()); |
| LoopContext lc; |
| bool inLoop = getContext(BB, lc); |
| |
| LoopContext branchingContext; |
| bool inLoopContext = getContext(branchingBlock, branchingContext); |
| |
| if (!inLoop) |
| return reverseBlocks[BB].front(); |
| |
| auto tup = std::make_tuple(BB, branchingBlock); |
| if (newBlocksForLoop_cache.find(tup) != newBlocksForLoop_cache.end()) |
| return newBlocksForLoop_cache[tup]; |
| |
| if (inLoop) { |
| // If we're reversing a latch edge. |
| bool incEntering = inLoopContext && branchingBlock == lc.header && |
| lc.header == branchingContext.header; |
| |
| auto L = LI.getLoopFor(BB); |
| auto latches = getLatches(L, lc.exitBlocks); |
| // If we're reverseing a loop exit. |
| bool exitEntering = |
| std::find(latches.begin(), latches.end(), BB) != latches.end() && |
| std::find(lc.exitBlocks.begin(), lc.exitBlocks.end(), branchingBlock) != |
| lc.exitBlocks.end(); |
| |
| // If we're re-entering a loop, prepare a loop-level forward pass to |
| // rematerialize any loop-scope rematerialization. |
| if (incEntering || exitEntering) { |
| SmallPtrSet<Instruction *, 1> loopRematerializations; |
| SmallPtrSet<Instruction *, 1> loopReallocations; |
| SmallPtrSet<Instruction *, 1> loopShadowReallocations; |
| SmallPtrSet<Instruction *, 1> loopShadowZeroInits; |
| SmallPtrSet<Instruction *, 1> loopShadowRematerializations; |
| Loop *origLI = nullptr; |
| for (auto pair : rematerializableAllocations) { |
| if (pair.second.LI && |
| getNewFromOriginal(pair.second.LI->getHeader()) == L->getHeader()) { |
| bool rematerialized = false; |
| std::map<UsageKey, bool> Seen; |
| for (auto pair : knownRecomputeHeuristic) |
| if (!pair.second) |
| Seen[UsageKey(pair.first, ValueType::Primal)] = false; |
| |
| if (is_value_needed_in_reverse<ValueType::Primal>( |
| this, pair.first, mode, Seen, notForAnalysis)) { |
| rematerialized = true; |
| } |
| if (rematerialized) { |
| if (auto inst = dyn_cast<Instruction>(pair.first)) |
| if (pair.second.LI->contains(inst->getParent())) { |
| loopReallocations.insert(inst); |
| } |
| for (auto I : pair.second.stores) |
| loopRematerializations.insert(I); |
| origLI = pair.second.LI; |
| } |
| } |
| } |
| for (auto pair : backwardsOnlyShadows) { |
| if (pair.second.LI && |
| getNewFromOriginal(pair.second.LI->getHeader()) == L->getHeader()) { |
| if (auto inst = dyn_cast<Instruction>(pair.first)) { |
| bool restoreStores = false; |
| if (pair.second.LI->contains(inst->getParent())) { |
| // TODO later make it so primalInitialize can be restored |
| // rather than cached from primal |
| if (!pair.second.primalInitialize) { |
| loopShadowReallocations.insert(inst); |
| restoreStores = true; |
| } |
| } else { |
| // if (pair.second.primalInitialize) { |
| // loopShadowZeroInits.insert(inst); |
| //} |
| restoreStores = true; |
| } |
| if (restoreStores) { |
| for (auto I : pair.second.stores) { |
| loopShadowRematerializations.insert(I); |
| } |
| } |
| origLI = pair.second.LI; |
| } |
| } |
| } |
| BasicBlock *resumeblock = reverseBlocks[BB].front(); |
| if (loopRematerializations.size() != 0 || loopReallocations.size() != 0 || |
| loopShadowRematerializations.size() != 0 || |
| loopShadowReallocations.size() != 0 || |
| loopShadowZeroInits.size() != 0) { |
| auto found = rematerializedLoops_cache.find(L); |
| if (found != rematerializedLoops_cache.end()) { |
| resumeblock = found->second; |
| } else { |
| BasicBlock *enterB = BasicBlock::Create( |
| BB->getContext(), "remat_enter", BB->getParent()); |
| rematerializedLoops_cache[L] = enterB; |
| std::map<BasicBlock *, BasicBlock *> origToNewForward; |
| for (auto B : origLI->getBlocks()) { |
| BasicBlock *newB = BasicBlock::Create( |
| B->getContext(), |
| "remat_" + lc.header->getName() + "_" + B->getName(), |
| BB->getParent()); |
| origToNewForward[B] = newB; |
| reverseBlockToPrimal[newB] = getNewFromOriginal(B); |
| if (B == origLI->getHeader()) { |
| IRBuilder<> NB(newB); |
| for (auto inst : loopShadowZeroInits) { |
| auto anti = lookupM(invertPointerM(inst, NB), NB); |
| StringRef funcName; |
| SmallVector<Value *, 8> args; |
| if (auto orig = dyn_cast<CallInst>(inst)) { |
| #if LLVM_VERSION_MAJOR >= 14 |
| for (auto &arg : orig->args()) |
| #else |
| for (auto &arg : orig->arg_operands()) |
| #endif |
| { |
| args.push_back(lookupM(getNewFromOriginal(arg), NB)); |
| } |
| funcName = getFuncNameFromCall(orig); |
| } else if (auto AI = dyn_cast<AllocaInst>(inst)) { |
| funcName = "malloc"; |
| Value *sz = |
| lookupM(getNewFromOriginal(AI->getArraySize()), NB); |
| |
| auto ci = ConstantInt::get( |
| sz->getType(), |
| B->getParent() |
| ->getParent() |
| ->getDataLayout() |
| .getTypeAllocSizeInBits(AI->getAllocatedType()) / |
| 8); |
| sz = NB.CreateMul(sz, ci); |
| args.push_back(sz); |
| } |
| assert(funcName.size()); |
| |
| applyChainRule( |
| NB, |
| [&](Value *anti) { |
| zeroKnownAllocation(NB, anti, args, funcName, TLI, |
| dyn_cast<CallInst>(inst)); |
| }, |
| anti); |
| } |
| } |
| } |
| |
| ValueToValueMapTy available; |
| |
| { |
| IRBuilder<> NB(enterB); |
| NB.CreateBr(origToNewForward[origLI->getHeader()]); |
| } |
| |
| std::function<void(Loop *, bool)> handleLoop = [&](Loop *OL, |
| bool subLoop) { |
| if (subLoop) { |
| auto Header = OL->getHeader(); |
| IRBuilder<> NB(origToNewForward[Header]); |
| LoopContext flc; |
| getContext(getNewFromOriginal(Header), flc); |
| |
| auto iv = NB.CreatePHI(flc.var->getType(), 2, "fiv"); |
| auto inc = NB.CreateAdd(iv, ConstantInt::get(iv->getType(), 1)); |
| |
| for (auto PH : predecessors(Header)) { |
| if (notForAnalysis.count(PH)) |
| continue; |
| |
| if (OL->contains(PH)) |
| iv->addIncoming(inc, origToNewForward[PH]); |
| else |
| iv->addIncoming(ConstantInt::get(iv->getType(), 0), |
| origToNewForward[PH]); |
| } |
| available[flc.var] = iv; |
| available[flc.incvar] = inc; |
| } |
| for (auto SL : OL->getSubLoops()) |
| handleLoop(SL, /*subLoop*/ true); |
| }; |
| handleLoop(origLI, /*subLoop*/ false); |
| |
| for (auto B : origLI->getBlocks()) { |
| auto newB = origToNewForward[B]; |
| IRBuilder<> NB(newB); |
| |
| // TODO fill available with relevant IV's surrounding and |
| // IV's of inner loop phi's |
| |
| for (auto &I : *B) { |
| // Only handle store, memset, and julia.write_barrier |
| if (loopRematerializations.count(&I)) { |
| if (auto SI = dyn_cast<StoreInst>(&I)) { |
| auto ts = NB.CreateStore( |
| lookupM(getNewFromOriginal(SI->getValueOperand()), NB, |
| available), |
| lookupM(getNewFromOriginal(SI->getPointerOperand()), NB, |
| available)); |
| llvm::SmallVector<unsigned int, 9> ToCopy2(MD_ToCopy); |
| ToCopy2.push_back(LLVMContext::MD_noalias); |
| ToCopy2.push_back(LLVMContext::MD_alias_scope); |
| ts->copyMetadata(*SI, ToCopy2); |
| #if LLVM_VERSION_MAJOR >= 10 |
| ts->setAlignment(SI->getAlign()); |
| #else |
| ts->setAlignment(SI->getAlignment()); |
| #endif |
| ts->setVolatile(SI->isVolatile()); |
| ts->setOrdering(SI->getOrdering()); |
| ts->setSyncScopeID(SI->getSyncScopeID()); |
| ts->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); |
| } else if (auto CI = dyn_cast<CallInst>(&I)) { |
| StringRef funcName = getFuncNameFromCall(CI); |
| if (funcName == "julia.write_barrier" || |
| isa<MemSetInst>(&I) || isa<MemTransferInst>(&I)) { |
| |
| // TODO |
| SmallVector<Value *, 2> args; |
| #if LLVM_VERSION_MAJOR >= 14 |
| for (auto &arg : CI->args()) |
| #else |
| for (auto &arg : CI->arg_operands()) |
| #endif |
| args.push_back( |
| lookupM(getNewFromOriginal(arg), NB, available)); |
| |
| SmallVector<ValueType, 2> BundleTypes(args.size(), |
| ValueType::Primal); |
| |
| auto Defs = getInvertedBundles(CI, BundleTypes, NB, |
| /*lookup*/ true, available); |
| #if LLVM_VERSION_MAJOR >= 11 |
| auto cal = |
| NB.CreateCall(CI->getFunctionType(), |
| CI->getCalledOperand(), args, Defs); |
| #else |
| auto cal = NB.CreateCall(CI->getCalledValue(), args, Defs); |
| #endif |
| cal->setAttributes(CI->getAttributes()); |
| cal->setCallingConv(CI->getCallingConv()); |
| cal->setTailCallKind(CI->getTailCallKind()); |
| cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); |
| } else { |
| assert(isDeallocationFunction(funcName, TLI)); |
| continue; |
| } |
| } else { |
| assert(0 && "unhandlable loop rematerialization instruction"); |
| } |
| } else if (loopReallocations.count(&I)) { |
| LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, |
| &newFunc->getEntryBlock()); |
| |
| auto inst = getNewFromOriginal((Value *)&I); |
| |
| auto found = scopeMap.find(inst); |
| if (found == scopeMap.end()) { |
| AllocaInst *cache = |
| createCacheForScope(lctx, inst->getType(), |
| inst->getName(), /*shouldFree*/ true); |
| assert(cache); |
| found = insert_or_assign( |
| scopeMap, inst, |
| std::pair<AssertingVH<AllocaInst>, LimitContext>(cache, |
| lctx)); |
| } |
| auto cache = found->second.first; |
| if (auto MD = hasMetadata(&I, "enzyme_fromstack")) { |
| auto replacement = NB.CreateAlloca( |
| Type::getInt8Ty(I.getContext()), |
| lookupM(getNewFromOriginal(I.getOperand(0)), NB, |
| available)); |
| auto Alignment = cast<ConstantInt>(cast<ConstantAsMetadata>( |
| MD->getOperand(0)) |
| ->getValue()) |
| ->getLimitedValue(); |
| #if LLVM_VERSION_MAJOR >= 10 |
| replacement->setAlignment(Align(Alignment)); |
| #else |
| replacement->setAlignment(Alignment); |
| #endif |
| replacement->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); |
| storeInstructionInCache(lctx, NB, replacement, cache); |
| } else if (auto CI = dyn_cast<CallInst>(&I)) { |
| SmallVector<Value *, 2> args; |
| #if LLVM_VERSION_MAJOR >= 14 |
| for (auto &arg : CI->args()) |
| #else |
| for (auto &arg : CI->arg_operands()) |
| #endif |
| args.push_back( |
| lookupM(getNewFromOriginal(arg), NB, available)); |
| |
| SmallVector<ValueType, 2> BundleTypes(args.size(), |
| ValueType::Primal); |
| |
| auto Defs = getInvertedBundles(CI, BundleTypes, NB, |
| /*lookup*/ true, available); |
| auto cal = NB.CreateCall(CI->getCalledFunction(), args, Defs); |
| llvm::SmallVector<unsigned int, 9> ToCopy2(MD_ToCopy); |
| ToCopy2.push_back(LLVMContext::MD_noalias); |
| ToCopy2.push_back(LLVMContext::MD_alias_scope); |
| cal->copyMetadata(*CI, ToCopy2); |
| cal->setName("remat_" + CI->getName()); |
| cal->setAttributes(CI->getAttributes()); |
| cal->setCallingConv(CI->getCallingConv()); |
| cal->setTailCallKind(CI->getTailCallKind()); |
| cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); |
| storeInstructionInCache(lctx, NB, cal, cache); |
| } else { |
| llvm::errs() << " realloc: " << I << "\n"; |
| llvm_unreachable("Unknown loop reallocation"); |
| } |
| } |
| if (loopShadowRematerializations.count(&I)) { |
| if (auto SI = dyn_cast<StoreInst>(&I)) { |
| Value *orig_ptr = SI->getPointerOperand(); |
| Value *orig_val = SI->getValueOperand(); |
| Type *valType = orig_val->getType(); |
| assert(!isConstantValue(orig_ptr)); |
| |
| auto &DL = newFunc->getParent()->getDataLayout(); |
| |
| bool constantval = isConstantValue(orig_val) || |
| parseTBAA(I, DL).Inner0().isIntegral(); |
| |
| // TODO allow recognition of other types that could contain |
| // pointers [e.g. {void*, void*} or <2 x i64> ] |
| auto storeSize = DL.getTypeSizeInBits(valType) / 8; |
| |
| //! Storing a floating point value |
| Type *FT = nullptr; |
| if (valType->isFPOrFPVectorTy()) { |
| FT = valType->getScalarType(); |
| } else if (!valType->isPointerTy()) { |
| if (looseTypeAnalysis) { |
| auto fp = TR.firstPointer(storeSize, orig_ptr, &I, |
| /*errifnotfound*/ false, |
| /*pointerIntSame*/ true); |
| if (fp.isKnown()) { |
| FT = fp.isFloat(); |
| } else if (isa<ConstantInt>(orig_val) || |
| valType->isIntOrIntVectorTy()) { |
| llvm::errs() |
| << "assuming type as integral for store: " << I |
| << "\n"; |
| FT = nullptr; |
| } else { |
| TR.firstPointer(storeSize, orig_ptr, &I, |
| /*errifnotfound*/ true, |
| /*pointerIntSame*/ true); |
| llvm::errs() |
| << "cannot deduce type of store " << I << "\n"; |
| assert(0 && "cannot deduce"); |
| } |
| } else { |
| FT = TR.firstPointer(storeSize, orig_ptr, &I, |
| /*errifnotfound*/ true, |
| /*pointerIntSame*/ true) |
| .isFloat(); |
| } |
| } |
| if (!FT) { |
| Value *valueop = nullptr; |
| if (constantval) { |
| Value *val = |
| lookupM(getNewFromOriginal(orig_val), NB, available); |
| valueop = val; |
| if (getWidth() > 1) { |
| Value *array = |
| UndefValue::get(getShadowType(val->getType())); |
| for (unsigned i = 0; i < getWidth(); ++i) { |
| array = NB.CreateInsertValue(array, val, {i}); |
| } |
| valueop = array; |
| } |
| } else { |
| valueop = |
| lookupM(invertPointerM(orig_val, NB), NB, available); |
| } |
| SmallVector<Metadata *, 1> prevNoAlias; |
| if (auto prev = SI->getMetadata(LLVMContext::MD_noalias)) { |
| for (auto &M : cast<MDNode>(prev)->operands()) { |
| prevNoAlias.push_back(M); |
| } |
| } |
| #if LLVM_VERSION_MAJOR >= 10 |
| auto align = SI->getAlign(); |
| #else |
| auto align = SI->getAlignment(); |
| #endif |
| setPtrDiffe(SI, orig_ptr, valueop, NB, align, |
| SI->isVolatile(), SI->getOrdering(), |
| SI->getSyncScopeID(), |
| /*mask*/ nullptr, prevNoAlias); |
| } |
| // TODO shadow memtransfer |
| } else if (auto MS = dyn_cast<MemSetInst>(&I)) { |
| if (!isConstantValue(MS->getArgOperand(0))) { |
| Value *args[4] = { |
| lookupM(invertPointerM(MS->getArgOperand(0), NB), NB, |
| available), |
| lookupM(getNewFromOriginal(MS->getArgOperand(1)), NB, |
| available), |
| lookupM(getNewFromOriginal(MS->getArgOperand(2)), NB, |
| available), |
| lookupM(getNewFromOriginal(MS->getArgOperand(3)), NB, |
| available)}; |
| |
| ValueType BundleTypes[4] = { |
| ValueType::Shadow, ValueType::Primal, ValueType::Primal, |
| ValueType::Primal}; |
| auto Defs = getInvertedBundles(MS, BundleTypes, NB, |
| /*lookup*/ true, available); |
| auto cal = |
| NB.CreateCall(MS->getCalledFunction(), args, Defs); |
| llvm::SmallVector<unsigned int, 9> ToCopy2(MD_ToCopy); |
| ToCopy2.push_back(LLVMContext::MD_noalias); |
| ToCopy2.push_back(LLVMContext::MD_alias_scope); |
| cal->copyMetadata(*MS, ToCopy2); |
| cal->setAttributes(MS->getAttributes()); |
| cal->setCallingConv(MS->getCallingConv()); |
| cal->setTailCallKind(MS->getTailCallKind()); |
| cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); |
| } |
| } else if (auto CI = dyn_cast<CallInst>(&I)) { |
| StringRef funcName = getFuncNameFromCall(CI); |
| if (funcName == "julia.write_barrier") { |
| |
| // TODO |
| SmallVector<Value *, 2> args; |
| #if LLVM_VERSION_MAJOR >= 14 |
| for (auto &arg : CI->args()) |
| #else |
| for (auto &arg : CI->arg_operands()) |
| #endif |
| if (!isConstantValue(arg)) |
| args.push_back( |
| lookupM(invertPointerM(arg, NB), NB, available)); |
| |
| if (args.size()) { |
| SmallVector<ValueType, 2> BundleTypes(args.size(), |
| ValueType::Primal); |
| |
| auto Defs = |
| getInvertedBundles(CI, BundleTypes, NB, |
| /*lookup*/ true, available); |
| #if LLVM_VERSION_MAJOR >= 11 |
| auto cal = |
| NB.CreateCall(CI->getFunctionType(), |
| CI->getCalledOperand(), args, Defs); |
| #else |
| auto cal = |
| NB.CreateCall(CI->getCalledValue(), args, Defs); |
| #endif |
| cal->setAttributes(CI->getAttributes()); |
| cal->setCallingConv(CI->getCallingConv()); |
| cal->setTailCallKind(CI->getTailCallKind()); |
| cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); |
| } |
| } else { |
| assert(isDeallocationFunction(funcName, TLI)); |
| continue; |
| } |
| } else { |
| assert( |
| 0 && |
| "unhandlable loop shadow rematerialization instruction"); |
| } |
| } else if (loopShadowReallocations.count(&I)) { |
| |
| LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, |
| &newFunc->getEntryBlock()); |
| auto ipfound = invertedPointers.find(&I); |
| PHINode *placeholder = cast<PHINode>(&*ipfound->second); |
| |
| auto found = scopeMap.find(placeholder); |
| if (found == scopeMap.end()) { |
| AllocaInst *cache = createCacheForScope( |
| lctx, placeholder->getType(), placeholder->getName(), |
| /*shouldFree*/ true); |
| assert(cache); |
| found = insert_or_assign( |
| scopeMap, (Value *&)placeholder, |
| std::pair<AssertingVH<AllocaInst>, LimitContext>(cache, |
| lctx)); |
| } |
| auto cache = found->second.first; |
| Value *anti = nullptr; |
| |
| if (auto orig = dyn_cast<CallInst>(&I)) { |
| StringRef funcName = getFuncNameFromCall(orig); |
| assert(funcName.size()); |
| |
| auto dbgLoc = getNewFromOriginal(orig)->getDebugLoc(); |
| |
| SmallVector<Value *, 8> args; |
| #if LLVM_VERSION_MAJOR >= 14 |
| for (auto &arg : orig->args()) |
| #else |
| for (auto &arg : orig->arg_operands()) |
| #endif |
| { |
| args.push_back(lookupM(getNewFromOriginal(arg), NB)); |
| } |
| |
| placeholder->setName(""); |
| if (shadowHandlers.find(funcName.str()) != |
| shadowHandlers.end()) { |
| |
| anti = shadowHandlers[funcName.str()](NB, orig, args, this); |
| } else { |
| auto rule = [&]() { |
| #if LLVM_VERSION_MAJOR >= 11 |
| Value *anti = NB.CreateCall( |
| orig->getFunctionType(), orig->getCalledOperand(), |
| args, orig->getName() + "'mi"); |
| #else |
| Value *anti = NB.CreateCall(orig->getCalledValue(), args, |
| orig->getName() + "'mi"); |
| #endif |
| cast<CallInst>(anti)->setAttributes( |
| orig->getAttributes()); |
| cast<CallInst>(anti)->setCallingConv( |
| orig->getCallingConv()); |
| cast<CallInst>(anti)->setTailCallKind( |
| orig->getTailCallKind()); |
| cast<CallInst>(anti)->setDebugLoc( |
| getNewFromOriginal(I.getDebugLoc())); |
| |
| #if LLVM_VERSION_MAJOR >= 14 |
| cast<CallInst>(anti)->addAttributeAtIndex( |
| AttributeList::ReturnIndex, Attribute::NoAlias); |
| cast<CallInst>(anti)->addAttributeAtIndex( |
| AttributeList::ReturnIndex, Attribute::NonNull); |
| #else |
| cast<CallInst>(anti)->addAttribute( |
| AttributeList::ReturnIndex, Attribute::NoAlias); |
| cast<CallInst>(anti)->addAttribute( |
| AttributeList::ReturnIndex, Attribute::NonNull); |
| #endif |
| return anti; |
| }; |
| |
| anti = applyChainRule(orig->getType(), NB, rule); |
| |
| if (auto MD = hasMetadata(orig, "enzyme_fromstack")) { |
| auto rule = [&](Value *anti) { |
| AllocaInst *replacement = NB.CreateAlloca( |
| Type::getInt8Ty(orig->getContext()), args[0]); |
| replacement->takeName(anti); |
| auto Alignment = |
| cast<ConstantInt>( |
| cast<ConstantAsMetadata>(MD->getOperand(0)) |
| ->getValue()) |
| ->getLimitedValue(); |
| #if LLVM_VERSION_MAJOR >= 10 |
| replacement->setAlignment(Align(Alignment)); |
| #else |
| replacement->setAlignment(Alignment); |
| #endif |
| replacement->setDebugLoc( |
| getNewFromOriginal(I.getDebugLoc())); |
| return replacement; |
| }; |
| |
| Value *replacement = applyChainRule( |
| Type::getInt8Ty(orig->getContext()), NB, rule, anti); |
| |
| replaceAWithB(cast<Instruction>(anti), replacement); |
| erase(cast<Instruction>(anti)); |
| anti = replacement; |
| } |
| |
| applyChainRule( |
| NB, |
| [&](Value *anti) { |
| zeroKnownAllocation(NB, anti, args, funcName, TLI, |
| orig); |
| }, |
| anti); |
| } |
| } else { |
| llvm_unreachable("Unknown shadow rematerialization value"); |
| } |
| assert(anti); |
| storeInstructionInCache(lctx, NB, anti, cache); |
| } |
| } |
| |
| llvm::SmallPtrSet<llvm::BasicBlock *, 8> origExitBlocks; |
| getExitBlocks(origLI, origExitBlocks); |
| // Remap a branch to the header to enter the incremented |
| // reverse of that block. |
| auto remap = [&](BasicBlock *rB) { |
| // Remap of an exit branch is to go to the reverse |
| // exiting block. |
| if (origExitBlocks.count(rB)) { |
| return reverseBlocks[getNewFromOriginal(B)].front(); |
| } |
| // Reverse of an incrementing branch is go to the |
| // reverse of the branching block. |
| if (rB == origLI->getHeader()) |
| return reverseBlocks[getNewFromOriginal(B)].front(); |
| return origToNewForward[rB]; |
| }; |
| |
| // TODO clone terminator |
| auto TI = B->getTerminator(); |
| assert(TI); |
| if (notForAnalysis.count(B)) { |
| NB.CreateUnreachable(); |
| } else if (auto BI = dyn_cast<BranchInst>(TI)) { |
| if (BI->isUnconditional()) |
| NB.CreateBr(remap(BI->getSuccessor(0))); |
| else |
| NB.CreateCondBr(lookupM(getNewFromOriginal(BI->getCondition()), |
| NB, available), |
| remap(BI->getSuccessor(0)), |
| remap(BI->getSuccessor(1))); |
| } else if (auto SI = dyn_cast<SwitchInst>(TI)) { |
| auto NSI = NB.CreateSwitch( |
| lookupM(getNewFromOriginal(BI->getCondition()), NB, |
| available), |
| remap(SI->getDefaultDest())); |
| for (auto cas : SI->cases()) { |
| NSI->addCase(cas.getCaseValue(), remap(cas.getCaseSuccessor())); |
| } |
| } else { |
| assert(isa<UnreachableInst>(TI)); |
| NB.CreateUnreachable(); |
| } |
| // Fixup phi nodes that may have their predecessors now changed by |
| // the phi unwrapping |
| if (!notForAnalysis.count(B) && |
| NB.GetInsertBlock() != origToNewForward[B]) { |
| for (auto S0 : successors(B)) { |
| if (!origToNewForward.count(S0)) |
| continue; |
| auto S = origToNewForward[S0]; |
| assert(S); |
| for (auto I = S->begin(), E = S->end(); I != E; ++I) { |
| PHINode *orig = dyn_cast<PHINode>(&*I); |
| if (orig == nullptr) |
| break; |
| for (unsigned Op = 0, NumOps = orig->getNumOperands(); |
| Op != NumOps; ++Op) |
| if (orig->getIncomingBlock(Op) == origToNewForward[B]) |
| orig->setIncomingBlock(Op, NB.GetInsertBlock()); |
| } |
| } |
| } |
| } |
| resumeblock = enterB; |
| } |
| } |
| |
| if (incEntering) { |
| BasicBlock *incB = BasicBlock::Create( |
| BB->getContext(), |
| "inc" + reverseBlocks[lc.header].front()->getName(), |
| BB->getParent()); |
| incB->moveAfter(reverseBlocks[lc.header].back()); |
| |
| IRBuilder<> tbuild(incB); |
| |
| #if LLVM_VERSION_MAJOR > 7 |
| Value *av = tbuild.CreateLoad(lc.var->getType(), lc.antivaralloc); |
| #else |
| Value *av = tbuild.CreateLoad(lc.antivaralloc); |
| #endif |
| Value *sub = |
| tbuild.CreateAdd(av, ConstantInt::get(av->getType(), -1), "", |
| /*NUW*/ false, /*NSW*/ true); |
| tbuild.CreateStore(sub, lc.antivaralloc); |
| tbuild.CreateBr(resumeblock); |
| return newBlocksForLoop_cache[tup] = incB; |
| } else { |
| assert(exitEntering); |
| BasicBlock *incB = BasicBlock::Create( |
| BB->getContext(), |
| "merge" + reverseBlocks[lc.header].front()->getName() + "_" + |
| branchingBlock->getName(), |
| BB->getParent()); |
| incB->moveAfter(reverseBlocks[branchingBlock].back()); |
| |
| IRBuilder<> tbuild(reverseBlocks[branchingBlock].back()); |
| |
| Value *lim = nullptr; |
| if (lc.dynamic && assumeDynamicLoopOfSizeOne(L)) { |
| lim = ConstantInt::get(lc.var->getType(), 0); |
| } else if (lc.dynamic) { |
| // Must be in a reverse pass fashion for a lookup to index bound to be |
| // legal |
| assert(/*ReverseLimit*/ reverseBlocks.size() > 0); |
| LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, |
| lc.preheader); |
| lim = lookupValueFromCache( |
| /*forwardPass*/ false, tbuild, lctx, |
| getDynamicLoopLimit(LI.getLoopFor(lc.header)), |
| /*isi1*/ false, /*available*/ ValueToValueMapTy()); |
| } else { |
| lim = lookupM(lc.trueLimit, tbuild); |
| } |
| |
| tbuild.SetInsertPoint(incB); |
| tbuild.CreateStore(lim, lc.antivaralloc); |
| tbuild.CreateBr(resumeblock); |
| |
| return newBlocksForLoop_cache[tup] = incB; |
| } |
| } |
| } |
| |
| return newBlocksForLoop_cache[tup] = reverseBlocks[BB].front(); |
| } |
| |
| void GradientUtils::forceContexts() { |
| for (auto BB : originalBlocks) { |
| LoopContext lc; |
| getContext(BB, lc); |
| } |
| } |
| |
| bool GradientUtils::legalRecompute(const Value *val, |
| const ValueToValueMapTy &available, |
| IRBuilder<> *BuilderM, bool reverse, |
| bool legalRecomputeCache) const { |
| { |
| auto found = available.find(val); |
| if (found != available.end()) { |
| if (found->second) |
| return true; |
| else { |
| return false; |
| } |
| } |
| } |
| |
| if (auto phi = dyn_cast<PHINode>(val)) { |
| if (auto uiv = hasUninverted(val)) { |
| if (auto dli = dyn_cast_or_null<LoadInst>(uiv)) { |
| return legalRecompute( |
| dli, available, BuilderM, |
| reverse); // TODO ADD && !TR.intType(getOriginal(dli), |
| // /*mustfind*/false).isPossibleFloat(); |
| } |
| if (phi->getNumIncomingValues() == 0) { |
| return false; |
| } |
| } |
| |
| if (phi->getNumIncomingValues() == 0) { |
| llvm::errs() << *oldFunc << "\n"; |
| llvm::errs() << *newFunc << "\n"; |
| llvm::errs() << *phi << "\n"; |
| } |
| assert(phi->getNumIncomingValues() != 0); |
| auto parent = phi->getParent(); |
| struct { |
| Function *func; |
| const LoopInfo &FLI; |
| } options[2] = {{newFunc, LI}, {oldFunc, OrigLI}}; |
| for (const auto &tup : options) { |
| if (parent->getParent() == tup.func) { |
| for (auto &val : phi->incoming_values()) { |
| if (isPotentialLastLoopValue(val, parent, tup.FLI)) { |
| return false; |
| } |
| } |
| if (tup.FLI.isLoopHeader(parent)) { |
| // Currently can only recompute header |
| // with two incoming values |
| if (phi->getNumIncomingValues() != 2) |
| return false; |
| auto L = tup.FLI.getLoopFor(parent); |
| |
| // Only recomputable if non recursive. |
| SmallPtrSet<Instruction *, 2> seen; |
| SmallVector<Instruction *, 1> todo; |
| for (auto PH : predecessors(parent)) { |
| // Prior iterations must be recomputable without |
| // this value. |
| if (L->contains(PH)) { |
| if (auto I = |
| dyn_cast<Instruction>(phi->getIncomingValueForBlock(PH))) |
| if (L->contains(I->getParent())) |
| todo.push_back(I); |
| } |
| } |
| |
| while (todo.size()) { |
| auto cur = todo.back(); |
| todo.pop_back(); |
| if (seen.count(cur)) |
| continue; |
| seen.insert(cur); |
| if (cur == phi) |
| return false; |
| for (auto &op : cur->operands()) { |
| if (auto I = dyn_cast<Instruction>(op)) { |
| if (L->contains(I->getParent())) |
| todo.push_back(I); |
| } |
| } |
| } |
| } |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| if (isa<Instruction>(val) && |
| cast<Instruction>(val)->getMetadata("enzyme_mustcache")) { |
| return false; |
| } |
| |
| // If this is a load from cache already, dont force a cache of this |
| if (legalRecomputeCache && isa<LoadInst>(val) && |
| CacheLookups.count(cast<LoadInst>(val))) { |
| return true; |
| } |
| |
| // TODO consider callinst here |
| |
| if (auto li = dyn_cast<Instruction>(val)) { |
| |
| const IntrinsicInst *II; |
| if (isa<LoadInst>(li) || |
| ((II = dyn_cast<IntrinsicInst>(li)) && |
| (II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_i || |
| II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_p || |
| II->getIntrinsicID() == Intrinsic::nvvm_ldu_global_f || |
| II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_i || |
| II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_p || |
| II->getIntrinsicID() == Intrinsic::nvvm_ldg_global_f || |
| II->getIntrinsicID() == Intrinsic::masked_load))) { |
| // If this is an already unwrapped value, legal to recompute again. |
| if (unwrappedLoads.find(li) != unwrappedLoads.end()) |
| return legalRecompute(unwrappedLoads.find(li)->second, available, |
| BuilderM, reverse); |
| |
| const Instruction *orig = nullptr; |
| if (li->getParent()->getParent() == oldFunc) { |
| orig = li; |
| } else if (li->getParent()->getParent() == newFunc) { |
| orig = isOriginal(li); |
| // todo consider when we pass non original queries |
| if (orig && !isa<LoadInst>(orig)) { |
| return legalRecompute(orig, available, BuilderM, reverse, |
| legalRecomputeCache); |
| } |
| } else { |
| llvm::errs() << " newFunc: " << *newFunc << "\n"; |
| llvm::errs() << " parent: " << *li->getParent()->getParent() << "\n"; |
| llvm::errs() << " li: " << *li << "\n"; |
| assert(0 && "illegal load legalRecopmute query"); |
| } |
| |
| if (orig) { |
| assert(can_modref_map); |
| auto found = can_modref_map->find(const_cast<Instruction *>(orig)); |
| if (found == can_modref_map->end()) { |
| llvm::errs() << *newFunc << "\n"; |
| llvm::errs() << *oldFunc << "\n"; |
| llvm::errs() << "can_modref_map:\n"; |
| for (auto &pair : *can_modref_map) { |
| llvm::errs() << " + " << *pair.first << ": " << pair.second |
| << " of func " |
| << pair.first->getParent()->getParent()->getName() |
| << "\n"; |
| } |
| llvm::errs() << "couldn't find in can_modref_map: " << *li << " - " |
| << *orig << " in fn: " |
| << orig->getParent()->getParent()->getName(); |
| } |
| assert(found != can_modref_map->end()); |
| if (!found->second) |
| return true; |
| // if insertion block of this function: |
| BasicBlock *fwdBlockIfReverse = nullptr; |
| if (BuilderM) { |
| fwdBlockIfReverse = BuilderM->GetInsertBlock(); |
| if (!reverse) { |
| auto found = reverseBlockToPrimal.find(BuilderM->GetInsertBlock()); |
| if (found != reverseBlockToPrimal.end()) { |
| fwdBlockIfReverse = found->second; |
| reverse = true; |
| } |
| } |
| if (fwdBlockIfReverse->getParent() != oldFunc) |
| fwdBlockIfReverse = |
| cast_or_null<BasicBlock>(isOriginal(fwdBlockIfReverse)); |
| } |
| if (mode == DerivativeMode::ReverseModeCombined && fwdBlockIfReverse) { |
| if (reverse) { |
| bool failed = false; |
| allFollowersOf( |
| const_cast<Instruction *>(orig), [&](Instruction *I) -> bool { |
| if (I->mayWriteToMemory() && |
| writesToMemoryReadBy( |
| OrigAA, TLI, |
| /*maybeReader*/ const_cast<Instruction *>(orig), |
| /*maybeWriter*/ I)) { |
| failed = true; |
| EmitWarning( |
| "UncacheableLoad", *orig, "Load must be recomputed ", |
| *orig, " in reverse_", |
| BuilderM->GetInsertBlock()->getName(), " due to ", *I); |
| return /*earlyBreak*/ true; |
| } |
| return /*earlyBreak*/ false; |
| }); |
| if (!failed) |
| return true; |
| } else { |
| Instruction *origStart = &*BuilderM->GetInsertPoint(); |
| do { |
| if (Instruction *og = isOriginal(origStart)) { |
| origStart = og; |
| break; |
| } |
| origStart = origStart->getNextNode(); |
| } while (true); |
| if (OrigDT.dominates(origStart, const_cast<Instruction *>(orig))) { |
| bool failed = false; |
| |
| allInstructionsBetween( |
| const_cast<GradientUtils *>(this)->LI, origStart, |
| const_cast<Instruction *>(orig), [&](Instruction *I) -> bool { |
| if (I->mayWriteToMemory() && |
| writesToMemoryReadBy( |
| OrigAA, TLI, |
| /*maybeReader*/ const_cast<Instruction *>(orig), |
| /*maybeWriter*/ I)) { |
| failed = true; |
| EmitWarning("UncacheableLoad", *orig, |
| "Load must be recomputed ", *orig, " in ", |
| BuilderM->GetInsertBlock()->getName(), |
| " due to ", *I); |
| return /*earlyBreak*/ true; |
| } |
| return /*earlyBreak*/ false; |
| }); |
| if (!failed) |
| return true; |
| } |
| } |
| } |
| return false; |
| } else { |
| if (auto dli = dyn_cast_or_null<LoadInst>(hasUninverted(li))) { |
| return legalRecompute(dli, available, BuilderM, reverse); |
| } |
| |
| // TODO mark all the explicitly legal nodes (caches, etc) |
| return true; |
| llvm::errs() << *li << " orig: " << orig |
| << " parent: " << li->getParent()->getParent()->getName() |
| << "\n"; |
| llvm_unreachable("unknown load to redo!"); |
| } |
| } |
| } |
| |
| if (auto ci = dyn_cast<CallInst>(val)) { |
| auto n = getFuncNameFromCall(const_cast<CallInst *>(ci)); |
| auto called = ci->getCalledFunction(); |
| Intrinsic::ID ID = Intrinsic::not_intrinsic; |
| |
| if (ci->hasFnAttr("enzyme_shouldrecompute") || |
| (called && called->hasFnAttribute("enzyme_shouldrecompute")) || |
| isMemFreeLibMFunction(n, &ID) || n == "lgamma_r" || n == "lgammaf_r" || |
| n == "lgammal_r" || n == "__lgamma_r_finite" || |
| n == "__lgammaf_r_finite" || n == "__lgammal_r_finite" || n == "tanh" || |
| n == "tanhf" || n == "__pow_finite" || n == "__fd_sincos_1" || |
| n == "julia.pointer_from_objref" || n.startswith("enzyme_wrapmpi$$") || |
| n == "omp_get_thread_num" || n == "omp_get_max_threads") { |
| return true; |
| } |
| } |
| |
| if (auto inst = dyn_cast<Instruction>(val)) { |
| if (inst->mayReadOrWriteMemory()) { |
| return false; |
| } |
| } |
| |
| return true; |
| } |
| |
| //! Given the option to recompute a value or re-use an old one, return true if |
| //! it is faster to recompute this value from scratch |
| bool GradientUtils::shouldRecompute(const Value *val, |
| const ValueToValueMapTy &available, |
| IRBuilder<> *BuilderM) { |
| if (available.count(val)) |
| return true; |
| // TODO: remake such that this returns whether a load to a cache is more |
| // expensive than redoing the computation. |
| |
| // If this is a load from cache already, just reload this |
| if (isa<LoadInst>(val) && |
| cast<LoadInst>(val)->getMetadata("enzyme_fromcache")) |
| return true; |
| |
| if (!isa<Instruction>(val)) |
| return true; |
| |
| const Instruction *inst = cast<Instruction>(val); |
| |
| if (TapesToPreventRecomputation.count(inst)) |
| return false; |
| |
| if (knownRecomputeHeuristic.find(inst) != knownRecomputeHeuristic.end()) { |
| return knownRecomputeHeuristic[inst]; |
| } |
| if (auto OrigInst = isOriginal(inst)) { |
| if (knownRecomputeHeuristic.find(OrigInst) != |
| knownRecomputeHeuristic.end()) { |
| return knownRecomputeHeuristic[OrigInst]; |
| } |
| } |
| |
| if (isa<CastInst>(val) || isa<GetElementPtrInst>(val)) |
| return true; |
| |
| if (EnzymeNewCache && !EnzymeMinCutCache) { |
| // if this has operands that need to be loaded and haven't already been |
| // loaded |
| // TODO, just cache this |
| for (auto &op : inst->operands()) { |
| if (!legalRecompute(op, available, BuilderM)) { |
| |
| // If this is a load from cache already, dont force a cache of this |
| if (isa<LoadInst>(op) && CacheLookups.count(cast<LoadInst>(op))) |
| continue; |
| |
| // If a previously cached this operand, don't let it trigger the |
| // heuristic for caching this value instead. |
| if (scopeMap.find(op) != scopeMap.end()) |
| continue; |
| |
| // If the actually uncacheable operand is in a different loop scope |
| // don't cache this value instead as it may require more memory |
| LoopContext lc1; |
| LoopContext lc2; |
| bool inLoop1 = |
| getContext(const_cast<Instruction *>(inst)->getParent(), lc1); |
| bool inLoop2 = getContext(cast<Instruction>(op)->getParent(), lc2); |
| if (inLoop1 != inLoop2 || (inLoop1 && (lc1.header != lc2.header))) { |
| continue; |
| } |
| |
| // If a placeholder phi for inversion (and we know from above not |
| // recomputable) |
| if (!isa<PHINode>(op) && |
| dyn_cast_or_null<LoadInst>(hasUninverted(op))) { |
| goto forceCache; |
| } |
| |
| // Even if cannot recompute (say a phi node), don't force a reload if it |
| // is possible to just use this instruction from forward pass without |
| // issue |
| if (auto i2 = dyn_cast<Instruction>(op)) { |
| if (!i2->mayReadOrWriteMemory()) { |
| LoopContext lc; |
| bool inLoop = const_cast<GradientUtils *>(this)->getContext( |
| i2->getParent(), lc); |
| if (!inLoop) { |
| // TODO upgrade this to be all returns that this could enter from |
| BasicBlock *orig = isOriginal(i2->getParent()); |
| assert(orig); |
| bool legal = BlocksDominatingAllReturns.count(orig); |
| if (legal) { |
| continue; |
| } |
| } |
| } |
| } |
| forceCache:; |
| EmitWarning("ChosenCache", *inst, "Choosing to cache use ", *inst, |
| " due to ", *op); |
| return false; |
| } |
| } |
| } |
| |
| if (auto op = dyn_cast<IntrinsicInst>(val)) { |
| if (!op->mayReadOrWriteMemory()) |
| return true; |
| switch (op->getIntrinsicID()) { |
| case Intrinsic::sin: |
| case Intrinsic::cos: |
| case Intrinsic::exp: |
| case Intrinsic::log: |
| case Intrinsic::nvvm_ldu_global_i: |
| case Intrinsic::nvvm_ldu_global_p: |
| case Intrinsic::nvvm_ldu_global_f: |
| case Intrinsic::nvvm_ldg_global_i: |
| case Intrinsic::nvvm_ldg_global_p: |
| case Intrinsic::nvvm_ldg_global_f: |
| return true; |
| default: |
| return false; |
| } |
| } |
| |
| if (auto ci = dyn_cast<CallInst>(val)) { |
| auto called = ci->getCalledFunction(); |
| auto n = getFuncNameFromCall(const_cast<CallInst *>(ci)); |
| Intrinsic::ID ID = Intrinsic::not_intrinsic; |
| if ((called && called->hasFnAttribute("enzyme_shouldrecompute")) || |
| isMemFreeLibMFunction(n, &ID) || n == "lgamma_r" || n == "lgammaf_r" || |
| n == "lgammal_r" || n == "__lgamma_r_finite" || |
| n == "__lgammaf_r_finite" || n == "__lgammal_r_finite" || n == "tanh" || |
| n == "tanhf" || n == "__pow_finite" || n == "__fd_sincos_1" || |
| n == "julia.pointer_from_objref" || n.startswith("enzyme_wrapmpi$$") || |
| n == "omp_get_thread_num" || n == "omp_get_max_threads") { |
| return true; |
| } |
| } |
| |
| // cache a call, assuming its longer to run that |
| if (isa<CallInst>(val)) { |
| llvm::errs() << " caching call: " << *val << "\n"; |
| // cast<CallInst>(val)->getCalledFunction()->dump(); |
| return false; |
| } |
| |
| return true; |
| } |
| |
| GradientUtils *GradientUtils::CreateFromClone( |
| EnzymeLogic &Logic, unsigned width, Function *todiff, |
| TargetLibraryInfo &TLI, TypeAnalysis &TA, FnTypeInfo &oldTypeInfo, |
| DIFFE_TYPE retType, ArrayRef<DIFFE_TYPE> constant_args, bool returnUsed, |
| bool shadowReturnUsed, std::map<AugmentedStruct, int> &returnMapping, |
| bool omp) { |
| assert(!todiff->empty()); |
| Function *oldFunc = todiff; |
| |
| // Since this is forward pass this should always return the tape (at index 0) |
| returnMapping[AugmentedStruct::Tape] = 0; |
| |
| int returnCount = 0; |
| |
| if (returnUsed) { |
| assert(!todiff->getReturnType()->isEmptyTy()); |
| assert(!todiff->getReturnType()->isVoidTy()); |
| returnMapping[AugmentedStruct::Return] = returnCount + 1; |
| ++returnCount; |
| } |
| |
| // We don't need to differentially return something that we know is not a |
| // pointer (or somehow needed for shadow analysis) |
| if (shadowReturnUsed) { |
| assert(retType == DIFFE_TYPE::DUP_ARG || retType == DIFFE_TYPE::DUP_NONEED); |
| assert(!todiff->getReturnType()->isEmptyTy()); |
| assert(!todiff->getReturnType()->isVoidTy()); |
| returnMapping[AugmentedStruct::DifferentialReturn] = returnCount + 1; |
| ++returnCount; |
| } |
| |
| ReturnType returnValue; |
| if (returnCount == 0) |
| returnValue = ReturnType::Tape; |
| else if (returnCount == 1) |
| returnValue = ReturnType::TapeAndReturn; |
| else if (returnCount == 2) |
| returnValue = ReturnType::TapeAndTwoReturns; |
| else |
| llvm_unreachable("illegal number of elements in augmented return struct"); |
| |
| ValueToValueMapTy invertedPointers; |
| SmallPtrSet<Instruction *, 4> constants; |
| SmallPtrSet<Instruction *, 20> nonconstant; |
| SmallPtrSet<Value *, 2> returnvals; |
| ValueToValueMapTy originalToNew; |
| |
| SmallPtrSet<Value *, 4> constant_values; |
| SmallPtrSet<Value *, 4> nonconstant_values; |
| |
| std::string prefix = "fakeaugmented"; |
| if (width > 1) |
| prefix += std::to_string(width); |
| prefix += "_"; |
| prefix += todiff->getName().str(); |
| |
| auto newFunc = Logic.PPC.CloneFunctionWithReturns( |
| DerivativeMode::ReverseModePrimal, /* width */ width, oldFunc, |
| invertedPointers, constant_args, constant_values, nonconstant_values, |
| returnvals, |
| /*returnValue*/ returnValue, retType, prefix, &originalToNew, |
| /*diffeReturnArg*/ false, /*additionalArg*/ nullptr); |
| |
| // Convert uncacheable args from the input function to the preprocessed |
| // function |
| |
| FnTypeInfo typeInfo(oldFunc); |
| { |
| auto toarg = todiff->arg_begin(); |
| auto olarg = oldFunc->arg_begin(); |
| for (; toarg != todiff->arg_end(); ++toarg, ++olarg) { |
| |
| { |
| auto fd = oldTypeInfo.Arguments.find(toarg); |
| assert(fd != oldTypeInfo.Arguments.end()); |
| typeInfo.Arguments.insert( |
| std::pair<Argument *, TypeTree>(olarg, fd->second)); |
| } |
| |
| { |
| auto cfd = oldTypeInfo.KnownValues.find(toarg); |
| assert(cfd != oldTypeInfo.KnownValues.end()); |
| typeInfo.KnownValues.insert( |
| std::pair<Argument *, std::set<int64_t>>(olarg, cfd->second)); |
| } |
| } |
| typeInfo.Return = oldTypeInfo.Return; |
| } |
| |
| TypeResults TR = TA.analyzeFunction(typeInfo); |
| assert(TR.getFunction() == oldFunc); |
| |
| auto res = new GradientUtils( |
| Logic, newFunc, oldFunc, TLI, TA, TR, invertedPointers, constant_values, |
| nonconstant_values, retType, constant_args, originalToNew, |
| DerivativeMode::ReverseModePrimal, /* width */ width, omp); |
| return res; |
| } |
| |
| DIFFE_TYPE GradientUtils::getReturnDiffeType(llvm::CallInst *orig, |
| bool *primalReturnUsedP, |
| bool *shadowReturnUsedP) { |
| bool shadowReturnUsed = false; |
| |
| DIFFE_TYPE subretType; |
| if (isConstantValue(orig)) { |
| subretType = DIFFE_TYPE::CONSTANT; |
| } else { |
| if (mode == DerivativeMode::ForwardMode || |
| mode == DerivativeMode::ForwardModeSplit) { |
| subretType = DIFFE_TYPE::DUP_ARG; |
| shadowReturnUsed = true; |
| } else { |
| if (!orig->getType()->isFPOrFPVectorTy() && |
| TR.query(orig).Inner0().isPossiblePointer()) { |
| if (is_value_needed_in_reverse<ValueType::Shadow>( |
| this, orig, DerivativeMode::ReverseModePrimal, |
| notForAnalysis)) { |
| subretType = DIFFE_TYPE::DUP_ARG; |
| shadowReturnUsed = true; |
| } else |
| subretType = DIFFE_TYPE::CONSTANT; |
| } else { |
| subretType = DIFFE_TYPE::OUT_DIFF; |
| } |
| } |
| } |
| |
| if (primalReturnUsedP) { |
| bool subretused = |
| unnecessaryValuesP->find(orig) == unnecessaryValuesP->end(); |
| if (knownRecomputeHeuristic.find(orig) != knownRecomputeHeuristic.end()) { |
| if (!knownRecomputeHeuristic[orig]) { |
| subretused = true; |
| } |
| } |
| *primalReturnUsedP = subretused; |
| } |
| |
| if (shadowReturnUsedP) |
| *shadowReturnUsedP = shadowReturnUsed; |
| return subretType; |
| } |
| |
| DIFFE_TYPE GradientUtils::getDiffeType(Value *v, bool foreignFunction) { |
| if (isConstantValue(v) && !foreignFunction) { |
| return DIFFE_TYPE::CONSTANT; |
| } |
| |
| auto argType = v->getType(); |
| |
| if (!argType->isFPOrFPVectorTy() && |
| (TR.query(v).Inner0().isPossiblePointer() || foreignFunction)) { |
| if (argType->isPointerTy()) { |
| #if LLVM_VERSION_MAJOR >= 12 |
| auto at = getUnderlyingObject(v, 100); |
| #else |
| auto at = |
| GetUnderlyingObject(v, oldFunc->getParent()->getDataLayout(), 100); |
| #endif |
| if (auto arg = dyn_cast<Argument>(at)) { |
| if (ArgDiffeTypes[arg->getArgNo()] == DIFFE_TYPE::DUP_NONEED) { |
| return DIFFE_TYPE::DUP_NONEED; |
| } |
| } |
| } |
| return DIFFE_TYPE::DUP_ARG; |
| } else { |
| if (foreignFunction) |
| assert(!argType->isIntOrIntVectorTy()); |
| if (mode == DerivativeMode::ForwardMode || |
| mode == DerivativeMode::ForwardModeSplit) |
| return DIFFE_TYPE::DUP_ARG; |
| else |
| return DIFFE_TYPE::OUT_DIFF; |
| } |
| } |
| |
| DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( |
| EnzymeLogic &Logic, DerivativeMode mode, unsigned width, Function *todiff, |
| TargetLibraryInfo &TLI, TypeAnalysis &TA, FnTypeInfo &oldTypeInfo, |
| DIFFE_TYPE retType, bool diffeReturnArg, ArrayRef<DIFFE_TYPE> constant_args, |
| ReturnType returnValue, Type *additionalArg, bool omp) { |
| assert(!todiff->empty()); |
| Function *oldFunc = todiff; |
| assert(mode == DerivativeMode::ReverseModeGradient || |
| mode == DerivativeMode::ReverseModeCombined || |
| mode == DerivativeMode::ForwardMode || |
| mode == DerivativeMode::ForwardModeSplit); |
| ValueToValueMapTy invertedPointers; |
| SmallPtrSet<Instruction *, 4> constants; |
| SmallPtrSet<Instruction *, 20> nonconstant; |
| SmallPtrSet<Value *, 2> returnvals; |
| ValueToValueMapTy originalToNew; |
| |
| SmallPtrSet<Value *, 4> constant_values; |
| SmallPtrSet<Value *, 4> nonconstant_values; |
| |
| std::string prefix; |
| |
| switch (mode) { |
| case DerivativeMode::ForwardMode: |
| case DerivativeMode::ForwardModeSplit: |
| prefix = "fwddiffe"; |
| break; |
| case DerivativeMode::ReverseModeCombined: |
| case DerivativeMode::ReverseModeGradient: |
| prefix = "diffe"; |
| break; |
| case DerivativeMode::ReverseModePrimal: |
| llvm_unreachable("invalid DerivativeMode: ReverseModePrimal\n"); |
| } |
| |
| if (width > 1) |
| prefix += std::to_string(width); |
| |
| auto newFunc = Logic.PPC.CloneFunctionWithReturns( |
| mode, width, oldFunc, invertedPointers, constant_args, constant_values, |
| nonconstant_values, returnvals, returnValue, retType, |
| prefix + oldFunc->getName(), &originalToNew, |
| /*diffeReturnArg*/ diffeReturnArg, additionalArg); |
| |
| // Convert uncacheable args from the input function to the preprocessed |
| // function |
| |
| FnTypeInfo typeInfo(oldFunc); |
| { |
| auto toarg = todiff->arg_begin(); |
| auto olarg = oldFunc->arg_begin(); |
| for (; toarg != todiff->arg_end(); ++toarg, ++olarg) { |
| |
| { |
| auto fd = oldTypeInfo.Arguments.find(toarg); |
| assert(fd != oldTypeInfo.Arguments.end()); |
| typeInfo.Arguments.insert( |
| std::pair<Argument *, TypeTree>(olarg, fd->second)); |
| } |
| |
| { |
| auto cfd = oldTypeInfo.KnownValues.find(toarg); |
| assert(cfd != oldTypeInfo.KnownValues.end()); |
| typeInfo.KnownValues.insert( |
| std::pair<Argument *, std::set<int64_t>>(olarg, cfd->second)); |
| } |
| } |
| typeInfo.Return = oldTypeInfo.Return; |
| } |
| |
| TypeResults TR = TA.analyzeFunction(typeInfo); |
| assert(TR.getFunction() == oldFunc); |
| |
| auto res = new DiffeGradientUtils(Logic, newFunc, oldFunc, TLI, TA, TR, |
| invertedPointers, constant_values, |
| nonconstant_values, retType, constant_args, |
| originalToNew, mode, width, omp); |
| |
| return res; |
| } |
| |
| Constant *GradientUtils::GetOrCreateShadowConstant( |
| EnzymeLogic &Logic, TargetLibraryInfo &TLI, TypeAnalysis &TA, |
| Constant *oval, DerivativeMode mode, unsigned width, bool AtomicAdd) { |
| if (isa<ConstantPointerNull>(oval)) { |
| return oval; |
| } else if (isa<UndefValue>(oval)) { |
| return oval; |
| } else if (isa<ConstantInt>(oval)) { |
| return oval; |
| } else if (auto CD = dyn_cast<ConstantDataArray>(oval)) { |
| SmallVector<Constant *, 1> Vals; |
| for (size_t i = 0, len = CD->getNumElements(); i < len; i++) { |
| Vals.push_back(GetOrCreateShadowConstant( |
| Logic, TLI, TA, CD->getElementAsConstant(i), mode, width, AtomicAdd)); |
| } |
| return ConstantArray::get(CD->getType(), Vals); |
| } else if (auto CD = dyn_cast<ConstantArray>(oval)) { |
| SmallVector<Constant *, 1> Vals; |
| for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { |
| Vals.push_back(GetOrCreateShadowConstant( |
| Logic, TLI, TA, CD->getOperand(i), mode, width, AtomicAdd)); |
| } |
| return ConstantArray::get(CD->getType(), Vals); |
| } else if (auto CD = dyn_cast<ConstantStruct>(oval)) { |
| SmallVector<Constant *, 1> Vals; |
| for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { |
| Vals.push_back(GetOrCreateShadowConstant( |
| Logic, TLI, TA, CD->getOperand(i), mode, width, AtomicAdd)); |
| } |
| return ConstantStruct::get(CD->getType(), Vals); |
| } else if (auto CD = dyn_cast<ConstantVector>(oval)) { |
| SmallVector<Constant *, 1> Vals; |
| for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { |
| Vals.push_back(GetOrCreateShadowConstant( |
| Logic, TLI, TA, CD->getOperand(i), mode, width, AtomicAdd)); |
| } |
| return ConstantVector::get(Vals); |
| } else if (auto F = dyn_cast<Function>(oval)) { |
| return GetOrCreateShadowFunction(Logic, TLI, TA, F, mode, width, AtomicAdd); |
| } else if (auto arg = dyn_cast<ConstantExpr>(oval)) { |
| auto C = GetOrCreateShadowConstant(Logic, TLI, TA, arg->getOperand(0), mode, |
| width, AtomicAdd); |
| if (arg->isCast() || arg->getOpcode() == Instruction::GetElementPtr || |
| arg->getOpcode() == Instruction::Add) { |
| SmallVector<Constant *, 8> NewOps; |
| for (unsigned i = 0, e = arg->getNumOperands(); i != e; ++i) |
| NewOps.push_back(i == 0 ? C : arg->getOperand(i)); |
| return arg->getWithOperands(NewOps); |
| } |
| } else if (auto arg = dyn_cast<GlobalVariable>(oval)) { |
| if (arg->getName() == "_ZTVN10__cxxabiv120__si_class_type_infoE" || |
| arg->getName() == "_ZTVN10__cxxabiv117__class_type_infoE" || |
| arg->getName() == "_ZTVN10__cxxabiv121__vmi_class_type_infoE") |
| return arg; |
| |
| if (hasMetadata(arg, "enzyme_shadow")) { |
| auto md = arg->getMetadata("enzyme_shadow"); |
| if (!isa<MDTuple>(md)) { |
| llvm::errs() << *arg << "\n"; |
| llvm::errs() << *md << "\n"; |
| assert(0 && "cannot compute with global variable that doesn't have " |
| "marked shadow global"); |
| report_fatal_error( |
| "cannot compute with global variable that doesn't " |
| "have marked shadow global (metadata incorrect type)"); |
| } |
| auto md2 = cast<MDTuple>(md); |
| assert(md2->getNumOperands() == 1); |
| auto gvemd = cast<ConstantAsMetadata>(md2->getOperand(0)); |
| return gvemd->getValue(); |
| } |
| |
| auto Arch = llvm::Triple(arg->getParent()->getTargetTriple()).getArch(); |
| int SharedAddrSpace = Arch == Triple::amdgcn |
| ? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local |
| : 3; |
| int AddrSpace = cast<PointerType>(arg->getType())->getAddressSpace(); |
| if ((Arch == Triple::nvptx || Arch == Triple::nvptx64 || |
| Arch == Triple::amdgcn) && |
| AddrSpace == SharedAddrSpace) { |
| assert(0 && "shared memory not handled in meta global"); |
| } |
| |
| // Create global variable locally if not externally visible |
| if (arg->isConstant() || arg->hasInternalLinkage() || |
| arg->hasPrivateLinkage() || |
| (arg->hasExternalLinkage() && arg->hasInitializer())) { |
| Type *type = arg->getType()->getPointerElementType(); |
| auto shadow = new GlobalVariable( |
| *arg->getParent(), type, arg->isConstant(), arg->getLinkage(), |
| Constant::getNullValue(type), arg->getName() + "_shadow", arg, |
| arg->getThreadLocalMode(), arg->getType()->getAddressSpace(), |
| arg->isExternallyInitialized()); |
| arg->setMetadata("enzyme_shadow", |
| MDTuple::get(shadow->getContext(), |
| {ConstantAsMetadata::get(shadow)})); |
| #if LLVM_VERSION_MAJOR >= 11 |
| shadow->setAlignment(arg->getAlign()); |
| #else |
| shadow->setAlignment(arg->getAlignment()); |
| #endif |
| shadow->setUnnamedAddr(arg->getUnnamedAddr()); |
| if (arg->hasInitializer()) |
| shadow->setInitializer(GetOrCreateShadowConstant( |
| Logic, TLI, TA, cast<Constant>(arg->getOperand(0)), mode, width, |
| AtomicAdd)); |
| return shadow; |
| } |
| } |
| llvm::errs() << " unknown constant to create shadow of: " << *oval << "\n"; |
| llvm_unreachable("unknown constant to create shadow of"); |
| } |
| |
| Constant *GradientUtils::GetOrCreateShadowFunction( |
| EnzymeLogic &Logic, TargetLibraryInfo &TLI, TypeAnalysis &TA, Function *fn, |
| DerivativeMode mode, unsigned width, bool AtomicAdd) { |
| //! Todo allow tape propagation |
| // Note that specifically this should _not_ be called with topLevel=true |
| // (since it may not be valid to always assume we can recompute the |
| // augmented primal) However, in the absence of a way to pass tape data |
| // from an indirect augmented (and also since we dont presently allow |
| // indirect augmented calls), topLevel MUST be true otherwise subcalls will |
| // not be able to lookup the augmenteddata/subdata (triggering an assertion |
| // failure, among much worse) |
| bool isRealloc = false; |
| if (fn->empty()) { |
| if (hasMetadata(fn, "enzyme_callwrapper")) { |
| auto md = fn->getMetadata("enzyme_callwrapper"); |
| if (!isa<MDTuple>(md)) { |
| llvm::errs() << *fn << "\n"; |
| llvm::errs() << *md << "\n"; |
| assert(0 && "callwrapper of incorrect type"); |
| report_fatal_error("callwrapper of incorrect type"); |
| } |
| auto md2 = cast<MDTuple>(md); |
| assert(md2->getNumOperands() == 1); |
| auto gvemd = cast<ConstantAsMetadata>(md2->getOperand(0)); |
| fn = cast<Function>(gvemd->getValue()); |
| } else { |
| auto oldfn = fn; |
| fn = Function::Create(oldfn->getFunctionType(), Function::InternalLinkage, |
| "callwrap_" + oldfn->getName(), oldfn->getParent()); |
| BasicBlock *entry = BasicBlock::Create(fn->getContext(), "entry", fn); |
| IRBuilder<> B(entry); |
| SmallVector<Value *, 4> args; |
| for (auto &a : fn->args()) |
| args.push_back(&a); |
| auto res = B.CreateCall(oldfn, args); |
| if (fn->getReturnType()->isVoidTy()) |
| B.CreateRetVoid(); |
| else |
| B.CreateRet(res); |
| oldfn->setMetadata( |
| "enzyme_callwrapper", |
| MDTuple::get(oldfn->getContext(), {ConstantAsMetadata::get(fn)})); |
| if (oldfn->getName() == "realloc") |
| isRealloc = true; |
| } |
| } |
| std::map<Argument *, bool> uncacheable_args; |
| FnTypeInfo type_args(fn); |
| if (isRealloc) { |
| llvm::errs() << "warning: assuming realloc only creates pointers\n"; |
| type_args.Return.insert({-1, -1}, BaseType::Pointer); |
| } |
| |
| // conservatively assume that we can only cache existing floating types |
| // (i.e. that all args are uncacheable) |
| std::vector<DIFFE_TYPE> types; |
| for (auto &a : fn->args()) { |
| uncacheable_args[&a] = !a.getType()->isFPOrFPVectorTy(); |
| TypeTree TT; |
| if (a.getType()->isFPOrFPVectorTy()) |
| TT.insert({-1}, ConcreteType(a.getType()->getScalarType())); |
| type_args.Arguments.insert(std::pair<Argument *, TypeTree>(&a, TT)); |
| type_args.KnownValues.insert( |
| std::pair<Argument *, std::set<int64_t>>(&a, {})); |
| DIFFE_TYPE typ; |
| if (a.getType()->isFPOrFPVectorTy()) { |
| typ = mode == DerivativeMode::ForwardMode ? DIFFE_TYPE::DUP_ARG |
| : DIFFE_TYPE::OUT_DIFF; |
| } else if (a.getType()->isIntegerTy() && |
| cast<IntegerType>(a.getType())->getBitWidth() < 16) { |
| typ = DIFFE_TYPE::CONSTANT; |
| } else if (a.getType()->isVoidTy() || a.getType()->isEmptyTy()) { |
| typ = DIFFE_TYPE::CONSTANT; |
| } else { |
| typ = DIFFE_TYPE::DUP_ARG; |
| } |
| types.push_back(typ); |
| } |
| |
| DIFFE_TYPE retType = fn->getReturnType()->isFPOrFPVectorTy() && |
| mode != DerivativeMode::ForwardMode |
| ? DIFFE_TYPE::OUT_DIFF |
| : DIFFE_TYPE::DUP_ARG; |
| |
| if (fn->getReturnType()->isVoidTy() || fn->getReturnType()->isEmptyTy() || |
| (fn->getReturnType()->isIntegerTy() && |
| cast<IntegerType>(fn->getReturnType())->getBitWidth() < 16)) |
| retType = DIFFE_TYPE::CONSTANT; |
| |
| if (mode != DerivativeMode::ForwardMode && retType == DIFFE_TYPE::DUP_ARG) { |
| if (auto ST = dyn_cast<StructType>(fn->getReturnType())) { |
| size_t numflt = 0; |
| |
| for (unsigned i = 0; i < ST->getNumElements(); ++i) { |
| auto midTy = ST->getElementType(i); |
| if (midTy->isFPOrFPVectorTy()) |
| numflt++; |
| } |
| if (numflt == ST->getNumElements()) |
| retType = DIFFE_TYPE::OUT_DIFF; |
| } |
| } |
| |
| switch (mode) { |
| case DerivativeMode::ForwardMode: { |
| Constant *newf = Logic.CreateForwardDiff( |
| fn, retType, types, TA, false, mode, /*freeMemory*/ true, width, |
| nullptr, type_args, uncacheable_args, /*augmented*/ nullptr); |
| |
| assert(newf); |
| |
| std::string prefix = "_enzyme_forward"; |
| |
| if (width > 1) { |
| prefix += std::to_string(width); |
| } |
| |
| std::string globalname = (prefix + "_" + fn->getName() + "'").str(); |
| auto GV = fn->getParent()->getNamedValue(globalname); |
| |
| if (GV == nullptr) { |
| GV = new GlobalVariable(*fn->getParent(), newf->getType(), true, |
| GlobalValue::LinkageTypes::InternalLinkage, newf, |
| globalname); |
| } |
| |
| return ConstantExpr::getPointerCast(GV, fn->getType()); |
| } |
| case DerivativeMode::ForwardModeSplit: { |
| auto &augdata = Logic.CreateAugmentedPrimal( |
| fn, retType, /*constant_args*/ types, TA, |
| /*returnUsed*/ !fn->getReturnType()->isEmptyTy() && |
| !fn->getReturnType()->isVoidTy(), |
| /*shadowReturnUsed*/ false, type_args, uncacheable_args, |
| /*forceAnonymousTape*/ true, width, AtomicAdd); |
| Constant *newf = Logic.CreateForwardDiff( |
| fn, retType, types, TA, false, mode, /*freeMemory*/ true, width, |
| nullptr, type_args, uncacheable_args, /*augmented*/ &augdata); |
| |
| assert(newf); |
| |
| std::string prefix = "_enzyme_forwardsplit"; |
| |
| if (width > 1) { |
| prefix += std::to_string(width); |
| } |
| |
| auto cdata = ConstantStruct::get( |
| StructType::get(newf->getContext(), |
| {augdata.fn->getType(), newf->getType()}), |
| {augdata.fn, newf}); |
| |
| std::string globalname = (prefix + "_" + fn->getName() + "'").str(); |
| auto GV = fn->getParent()->getNamedValue(globalname); |
| |
| if (GV == nullptr) { |
| GV = new GlobalVariable(*fn->getParent(), cdata->getType(), true, |
| GlobalValue::LinkageTypes::InternalLinkage, cdata, |
| globalname); |
| } |
| |
| return ConstantExpr::getPointerCast(GV, fn->getType()); |
| } |
| case DerivativeMode::ReverseModeCombined: |
| case DerivativeMode::ReverseModeGradient: |
| case DerivativeMode::ReverseModePrimal: { |
| // TODO re atomic add consider forcing it to be atomic always as fallback if |
| // used in a parallel context |
| bool returnUsed = |
| !fn->getReturnType()->isEmptyTy() && !fn->getReturnType()->isVoidTy(); |
| bool shadowReturnUsed = returnUsed && (retType == DIFFE_TYPE::DUP_ARG || |
| retType == DIFFE_TYPE::DUP_NONEED); |
| auto &augdata = Logic.CreateAugmentedPrimal( |
| fn, retType, /*constant_args*/ types, TA, returnUsed, shadowReturnUsed, |
| type_args, uncacheable_args, /*forceAnonymousTape*/ true, width, |
| AtomicAdd); |
| Constant *newf = Logic.CreatePrimalAndGradient( |
| (ReverseCacheKey){.todiff = fn, |
| .retType = retType, |
| .constant_args = types, |
| .uncacheable_args = uncacheable_args, |
| .returnUsed = false, |
| .shadowReturnUsed = false, |
| .mode = DerivativeMode::ReverseModeGradient, |
| .width = width, |
| .freeMemory = true, |
| .AtomicAdd = AtomicAdd, |
| .additionalType = |
| Type::getInt8PtrTy(fn->getContext()), |
| .typeInfo = type_args}, |
| TA, |
| /*map*/ &augdata); |
| assert(newf); |
| auto cdata = ConstantStruct::get( |
| StructType::get(newf->getContext(), |
| {augdata.fn->getType(), newf->getType()}), |
| {augdata.fn, newf}); |
| std::string globalname = ("_enzyme_reverse_" + fn->getName() + "'").str(); |
| auto GV = fn->getParent()->getNamedValue(globalname); |
| |
| if (GV == nullptr) { |
| GV = new GlobalVariable(*fn->getParent(), cdata->getType(), true, |
| GlobalValue::LinkageTypes::InternalLinkage, cdata, |
| globalname); |
| } |
| return ConstantExpr::getPointerCast(GV, fn->getType()); |
| } |
| } |
| } |
| |
| Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, |
| bool nullShadow) { |
| assert(oval); |
| if (auto inst = dyn_cast<Instruction>(oval)) { |
| assert(inst->getParent()->getParent() == oldFunc); |
| } |
| if (auto arg = dyn_cast<Argument>(oval)) { |
| assert(arg->getParent() == oldFunc); |
| } |
| |
| if (isa<ConstantPointerNull>(oval)) { |
| return applyChainRule(oval->getType(), BuilderM, [&]() { return oval; }); |
| } else if (isa<UndefValue>(oval)) { |
| if (nullShadow) |
| return Constant::getNullValue(getShadowType(oval->getType())); |
| return applyChainRule(oval->getType(), BuilderM, [&]() { return oval; }); |
| } else if (isa<ConstantInt>(oval)) { |
| if (nullShadow) |
| return Constant::getNullValue(getShadowType(oval->getType())); |
| return applyChainRule(oval->getType(), BuilderM, [&]() { return oval; }); |
| } else if (auto CD = dyn_cast<ConstantDataArray>(oval)) { |
| SmallVector<Constant *, 1> Vals; |
| for (size_t i = 0, len = CD->getNumElements(); i < len; i++) { |
| Value *val = invertPointerM(CD->getElementAsConstant(i), BuilderM); |
| Vals.push_back(cast<Constant>(val)); |
| } |
| auto rule = [&CD](ArrayRef<Constant *> Vals) { |
| return ConstantArray::get(CD->getType(), Vals); |
| }; |
| return applyChainRule(CD->getType(), Vals, BuilderM, rule); |
| } else if (auto CD = dyn_cast<ConstantArray>(oval)) { |
| SmallVector<Constant *, 1> Vals; |
| for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { |
| Value *val = invertPointerM(CD->getOperand(i), BuilderM); |
| Vals.push_back(cast<Constant>(val)); |
| } |
| |
| auto rule = [&CD](ArrayRef<Constant *> Vals) { |
| return ConstantArray::get(CD->getType(), Vals); |
| }; |
| |
| return applyChainRule(CD->getType(), Vals, BuilderM, rule); |
| } else if (auto CD = dyn_cast<ConstantStruct>(oval)) { |
| SmallVector<Constant *, 1> Vals; |
| for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { |
| Vals.push_back( |
| cast<Constant>(invertPointerM(CD->getOperand(i), BuilderM))); |
| } |
| |
| auto rule = [&CD](ArrayRef<Constant *> Vals) { |
| return ConstantStruct::get(CD->getType(), Vals); |
| }; |
| return applyChainRule(CD->getType(), Vals, BuilderM, rule); |
| } else if (auto CD = dyn_cast<ConstantVector>(oval)) { |
| SmallVector<Constant *, 1> Vals; |
| for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { |
| Vals.push_back( |
| cast<Constant>(invertPointerM(CD->getOperand(i), BuilderM))); |
| } |
| |
| auto rule = [](ArrayRef<Constant *> Vals) { |
| return ConstantVector::get(Vals); |
| }; |
| |
| return applyChainRule(CD->getType(), Vals, BuilderM, rule); |
| } else if (isa<ConstantData>(oval) && nullShadow) { |
| auto rule = [&oval]() { return Constant::getNullValue(oval->getType()); }; |
| |
| return applyChainRule(oval->getType(), BuilderM, rule); |
| } |
| |
| if (isConstantValue(oval) && !isa<InsertValueInst>(oval) && |
| !isa<ExtractValueInst>(oval)) { |
| // NOTE, this is legal and the correct resolution, however, our activity |
| // analysis honeypot no longer exists |
| |
| // Nulling the shadow for a constant is only necessary if any of the data |
| // could contain a float (e.g. should not be applied to pointers). |
| if (nullShadow) { |
| auto CT = TR.query(oval)[{-1}]; |
| if (!CT.isKnown() || CT.isFloat()) { |
| return Constant::getNullValue(getShadowType(oval->getType())); |
| } |
| } |
| |
| if (isa<ConstantExpr>(oval)) { |
| auto rule = [&oval]() { return oval; }; |
| return applyChainRule(oval->getType(), BuilderM, rule); |
| } |
| |
| Value *newval = getNewFromOriginal(oval); |
| |
| auto rule = [&]() { return newval; }; |
| |
| return applyChainRule(oval->getType(), BuilderM, rule); |
| } |
| |
| auto M = oldFunc->getParent(); |
| assert(oval); |
| |
| { |
| auto ifound = invertedPointers.find(oval); |
| if (ifound != invertedPointers.end()) { |
| return &*ifound->second; |
| } |
| } |
| |
| if (isa<Argument>(oval) && cast<Argument>(oval)->hasByValAttr()) { |
| IRBuilder<> bb(inversionAllocs); |
| |
| auto rule1 = [&]() { |
| AllocaInst *antialloca = bb.CreateAlloca( |
| oval->getType()->getPointerElementType(), |
| cast<PointerType>(oval->getType())->getPointerAddressSpace(), nullptr, |
| oval->getName() + "'ipa"); |
| |
| auto dst_arg = |
| bb.CreateBitCast(antialloca, Type::getInt8PtrTy(oval->getContext())); |
| auto val_arg = ConstantInt::get(Type::getInt8Ty(oval->getContext()), 0); |
| auto len_arg = |
| ConstantInt::get(Type::getInt64Ty(oval->getContext()), |
| M->getDataLayout().getTypeAllocSizeInBits( |
| oval->getType()->getPointerElementType()) / |
| 8); |
| auto volatile_arg = ConstantInt::getFalse(oval->getContext()); |
| |
| #if LLVM_VERSION_MAJOR == 6 |
| auto align_arg = ConstantInt::get(Type::getInt32Ty(oval->getContext()), |
| antialloca->getAlignment()); |
| Value *args[] = {dst_arg, val_arg, len_arg, align_arg, volatile_arg}; |
| #else |
| Value *args[] = {dst_arg, val_arg, len_arg, volatile_arg}; |
| #endif |
| Type *tys[] = {dst_arg->getType(), len_arg->getType()}; |
| cast<CallInst>(bb.CreateCall( |
| Intrinsic::getDeclaration(M, Intrinsic::memset, tys), args)); |
| |
| return antialloca; |
| }; |
| |
| Value *antialloca = applyChainRule(oval->getType(), bb, rule1); |
| |
| invertedPointers.insert(std::make_pair( |
| (const Value *)oval, InvertedPointerVH(this, antialloca))); |
| |
| return antialloca; |
| } else if (auto arg = dyn_cast<GlobalVariable>(oval)) { |
| if (!hasMetadata(arg, "enzyme_shadow")) { |
| |
| if ((mode == DerivativeMode::ReverseModeCombined || |
| mode == DerivativeMode::ForwardMode) && |
| arg->getType()->getPointerAddressSpace() == 0) { |
| auto CT = TR.query(arg)[{-1, -1}]; |
| // Can only localy replace a global variable if it is |
| // known not to contain a pointer, which may be initialized |
| // outside of this function to contain other memory which |
| // will not have a shadow within the current function. |
| if (CT.isKnown() && CT != BaseType::Pointer) { |
| bool seen = false; |
| MemoryLocation |
| #if LLVM_VERSION_MAJOR >= 12 |
| Loc = MemoryLocation(oval, LocationSize::beforeOrAfterPointer()); |
| #elif LLVM_VERSION_MAJOR >= 9 |
| Loc = MemoryLocation(oval, LocationSize::unknown()); |
| #else |
| Loc = MemoryLocation(oval, MemoryLocation::UnknownSize); |
| #endif |
| for (CallInst *CI : originalCalls) { |
| if (isa<IntrinsicInst>(CI)) |
| continue; |
| if (!isConstantInstruction(CI)) { |
| Function *F = getFunctionFromCall(CI); |
| if (F && (isMemFreeLibMFunction(F->getName()) || |
| F->getName() == "__fd_sincos_1")) { |
| continue; |
| } |
| if (llvm::isModOrRefSet(OrigAA.getModRefInfo(CI, Loc))) { |
| seen = true; |
| llvm::errs() << " cannot shadow-inline global " << *oval |
| << " due to " << *CI << "\n"; |
| goto endCheck; |
| } |
| } |
| } |
| endCheck:; |
| if (!seen) { |
| IRBuilder<> bb(inversionAllocs); |
| Type *allocaTy = arg->getValueType(); |
| |
| auto rule1 = [&]() { |
| AllocaInst *antialloca = bb.CreateAlloca( |
| allocaTy, arg->getType()->getPointerAddressSpace(), nullptr, |
| arg->getName() + "'ipa"); |
| if (arg->getAlignment()) { |
| #if LLVM_VERSION_MAJOR >= 10 |
| antialloca->setAlignment(Align(arg->getAlignment())); |
| #else |
| antialloca->setAlignment(arg->getAlignment()); |
| #endif |
| } |
| return antialloca; |
| }; |
| |
| Value *antialloca = applyChainRule(arg->getType(), bb, rule1); |
| |
| invertedPointers.insert(std::make_pair( |
| (const Value *)oval, InvertedPointerVH(this, antialloca))); |
| |
| auto rule2 = [&](Value *antialloca) { |
| auto dst_arg = bb.CreateBitCast( |
| antialloca, Type::getInt8PtrTy(arg->getContext())); |
| auto val_arg = |
| ConstantInt::get(Type::getInt8Ty(arg->getContext()), 0); |
| auto len_arg = |
| ConstantInt::get(Type::getInt64Ty(arg->getContext()), |
| M->getDataLayout().getTypeAllocSizeInBits( |
| arg->getValueType()) / |
| 8); |
| auto volatile_arg = ConstantInt::getFalse(oval->getContext()); |
| |
| #if LLVM_VERSION_MAJOR == 6 |
| auto align_arg = |
| ConstantInt::get(Type::getInt32Ty(oval->getContext()), |
| antialloca->getAlignment()); |
| Value *args[] = {dst_arg, val_arg, len_arg, align_arg, |
| volatile_arg}; |
| #else |
| Value *args[] = {dst_arg, val_arg, len_arg, volatile_arg}; |
| #endif |
| Type *tys[] = {dst_arg->getType(), len_arg->getType()}; |
| auto memset = cast<CallInst>(bb.CreateCall( |
| Intrinsic::getDeclaration(M, Intrinsic::memset, tys), args)); |
| #if LLVM_VERSION_MAJOR >= 10 |
| if (arg->getAlignment()) { |
| memset->addParamAttr( |
| 0, Attribute::getWithAlignment(arg->getContext(), |
| Align(arg->getAlignment()))); |
| } |
| #else |
| if (arg->getAlignment() != 0) { |
| memset->addParamAttr( |
| 0, Attribute::getWithAlignment(arg->getContext(), |
| arg->getAlignment())); |
| } |
| #endif |
| memset->addParamAttr(0, Attribute::NonNull); |
| assert((width > 1 && antialloca->getType() == |
| ArrayType::get(arg->getType(), width)) || |
| antialloca->getType() == arg->getType()); |
| return antialloca; |
| }; |
| |
| return applyChainRule(arg->getType(), bb, rule2, antialloca); |
| } |
| } |
| } |
| |
| auto Arch = |
| llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch(); |
| int SharedAddrSpace = |
| Arch == Triple::amdgcn |
| ? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local |
| : 3; |
| int AddrSpace = cast<PointerType>(arg->getType())->getAddressSpace(); |
| if ((Arch == Triple::nvptx || Arch == Triple::nvptx64 || |
| Arch == Triple::amdgcn) && |
| AddrSpace == SharedAddrSpace) { |
| llvm::errs() << "warning found shared memory\n"; |
| //#if LLVM_VERSION_MAJOR >= 11 |
| Type *type = arg->getType()->getPointerElementType(); |
| // TODO this needs initialization by entry |
| auto shadow = new GlobalVariable( |
| *arg->getParent(), type, arg->isConstant(), arg->getLinkage(), |
| UndefValue::get(type), arg->getName() + "_shadow", arg, |
| arg->getThreadLocalMode(), arg->getType()->getAddressSpace(), |
| arg->isExternallyInitialized()); |
| arg->setMetadata("enzyme_shadow", |
| MDTuple::get(shadow->getContext(), |
| {ConstantAsMetadata::get(shadow)})); |
| shadow->setMetadata("enzyme_internalshadowglobal", |
| MDTuple::get(shadow->getContext(), {})); |
| #if LLVM_VERSION_MAJOR >= 11 |
| shadow->setAlignment(arg->getAlign()); |
| #else |
| shadow->setAlignment(arg->getAlignment()); |
| #endif |
| shadow->setUnnamedAddr(arg->getUnnamedAddr()); |
| invertedPointers.insert(std::make_pair( |
| (const Value *)oval, InvertedPointerVH(this, shadow))); |
| return shadow; |
| } |
| |
| // Create global variable locally if not externally visible |
| // If a variable is constant, for forward mode it will also |
| // only be read, so invert initializing is fine. |
| // For reverse mode, any floats will be +='d into, but never |
| // read, and any pointers will be used as expected. The never |
| // read means even if two globals for floats, that's fine. |
| // As long as the pointers point to equivalent places (which |
| // they should from the same initialization), it is also ok. |
| if (arg->hasInternalLinkage() || arg->hasPrivateLinkage() || |
| (arg->hasExternalLinkage() && arg->hasInitializer()) || |
| arg->isConstant()) { |
| Type *elemTy = arg->getType()->getPointerElementType(); |
| IRBuilder<> B(inversionAllocs); |
| |
| auto rule = [&]() { |
| auto shadow = new GlobalVariable( |
| *arg->getParent(), elemTy, arg->isConstant(), arg->getLinkage(), |
| Constant::getNullValue(elemTy), arg->getName() + "_shadow", arg, |
| arg->getThreadLocalMode(), arg->getType()->getAddressSpace(), |
| arg->isExternallyInitialized()); |
| arg->setMetadata("enzyme_shadow", |
| MDTuple::get(shadow->getContext(), |
| {ConstantAsMetadata::get(shadow)})); |
| #if LLVM_VERSION_MAJOR >= 11 |
| shadow->setAlignment(arg->getAlign()); |
| #else |
| shadow->setAlignment(arg->getAlignment()); |
| #endif |
| shadow->setUnnamedAddr(arg->getUnnamedAddr()); |
| |
| return shadow; |
| }; |
| |
| Value *shadow = applyChainRule(oval->getType(), BuilderM, rule); |
| |
| if (arg->hasInitializer()) { |
| applyChainRule( |
| BuilderM, |
| [&](Value *shadow, Value *ip) { |
| cast<GlobalVariable>(shadow)->setInitializer( |
| cast<Constant>(ip)); |
| }, |
| shadow, |
| invertPointerM(arg->getInitializer(), B, /*nullShadow*/ true)); |
| } |
| |
| invertedPointers.insert(std::make_pair( |
| (const Value *)oval, InvertedPointerVH(this, shadow))); |
| return shadow; |
| } |
| |
| llvm::errs() << *oldFunc->getParent() << "\n"; |
| llvm::errs() << *oldFunc << "\n"; |
| llvm::errs() << *newFunc << "\n"; |
| llvm::errs() << *arg << "\n"; |
| assert(0 && "cannot compute with global variable that doesn't have " |
| "marked shadow global"); |
| report_fatal_error("cannot compute with global variable that doesn't " |
| "have marked shadow global"); |
| } |
| auto md = arg->getMetadata("enzyme_shadow"); |
| if (!isa<MDTuple>(md)) { |
| llvm::errs() << *arg << "\n"; |
| llvm::errs() << *md << "\n"; |
| assert(0 && "cannot compute with global variable that doesn't have " |
| "marked shadow global"); |
| report_fatal_error("cannot compute with global variable that doesn't " |
| "have marked shadow global (metadata incorrect type)"); |
| } |
| auto md2 = cast<MDTuple>(md); |
| assert(md2->getNumOperands() == 1); |
| auto gvemd = cast<ConstantAsMetadata>(md2->getOperand(0)); |
| auto cs = cast<Constant>(gvemd->getValue()); |
| |
| if (width > 1) { |
| SmallVector<Constant *, 2> Vals; |
| for (unsigned i = 0; i < width; ++i) { |
| |
| Constant *idxs[] = { |
| ConstantInt::get(Type::getInt32Ty(cs->getContext()), 0), |
| ConstantInt::get(Type::getInt32Ty(cs->getContext()), i)}; |
| Constant *elem = ConstantExpr::getInBoundsGetElementPtr( |
| cs->getType()->getPointerElementType(), cs, idxs); |
| Vals.push_back(elem); |
| } |
| |
| auto agg = ConstantArray::get( |
| cast<ArrayType>(getShadowType(arg->getType())), Vals); |
| |
| invertedPointers.insert( |
| std::make_pair((const Value *)oval, InvertedPointerVH(this, agg))); |
| return agg; |
| } else { |
| invertedPointers.insert( |
| std::make_pair((const Value *)oval, InvertedPointerVH(this, cs))); |
| return cs; |
| } |
| } else if (auto fn = dyn_cast<Function>(oval)) { |
| Constant *shadow = |
| GetOrCreateShadowFunction(Logic, TLI, TA, fn, mode, width, AtomicAdd); |
| if (width > 1) { |
| SmallVector<Constant *, 3> arr; |
| for (unsigned i = 0; i < width; ++i) { |
| arr.push_back(shadow); |
| } |
| ArrayType *arrTy = ArrayType::get(shadow->getType(), width); |
| shadow = ConstantArray::get(arrTy, arr); |
| } |
| return shadow; |
| } else if (auto arg = dyn_cast<CastInst>(oval)) { |
| IRBuilder<> bb(getNewFromOriginal(arg)); |
| Value *invertOp = invertPointerM(arg->getOperand(0), bb); |
| Type *shadowTy = arg->getDestTy(); |
| |
| auto rule = [&](Value *invertOp) { |
| return bb.CreateCast(arg->getOpcode(), invertOp, shadowTy, |
| arg->getName() + "'ipc"); |
| }; |
| |
| Value *shadow = applyChainRule(shadowTy, bb, rule, invertOp); |
| |
| invertedPointers.insert( |
| std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); |
| return shadow; |
| } else if (auto arg = dyn_cast<ConstantExpr>(oval)) { |
| IRBuilder<> bb(inversionAllocs); |
| auto ip = invertPointerM(arg->getOperand(0), bb); |
| |
| if (arg->isCast()) { |
| if (auto PT = dyn_cast<PointerType>(arg->getType())) { |
| if (isConstantValue(arg->getOperand(0)) && |
| PT->getPointerElementType()->isFunctionTy()) { |
| goto end; |
| } |
| } |
| if (isa<Constant>(ip)) { |
| auto rule = [&arg](Value *ip) { |
| return ConstantExpr::getCast(arg->getOpcode(), cast<Constant>(ip), |
| arg->getType()); |
| }; |
| |
| return applyChainRule(arg->getType(), bb, rule, ip); |
| |
| } else { |
| auto rule = [&](Value *ip) { |
| return bb.CreateCast((Instruction::CastOps)arg->getOpcode(), ip, |
| arg->getType(), arg->getName() + "'ipc"); |
| }; |
| |
| Value *shadow = applyChainRule(arg->getType(), bb, rule, ip); |
| |
| invertedPointers.insert(std::make_pair( |
| (const Value *)oval, InvertedPointerVH(this, shadow))); |
| |
| return shadow; |
| } |
| } else if (arg->getOpcode() == Instruction::GetElementPtr) { |
| if (auto C = dyn_cast<Constant>(ip)) { |
| auto rule = [&arg, &C]() { |
| SmallVector<Constant *, 8> NewOps; |
| for (unsigned i = 0, e = arg->getNumOperands(); i != e; ++i) |
| NewOps.push_back(i == 0 ? C : arg->getOperand(i)); |
| return cast<Value>(arg->getWithOperands(NewOps)); |
| }; |
| |
| return applyChainRule(arg->getType(), bb, rule); |
| } else { |
| SmallVector<Value *, 4> invertargs; |
| for (unsigned i = 0; i < arg->getNumOperands() - 1; ++i) { |
| Value *b = getNewFromOriginal(arg->getOperand(1 + i)); |
| invertargs.push_back(b); |
| } |
| |
| auto rule = [&bb, &arg, &invertargs](Value *ip) { |
| // TODO mark this the same inbounds as the original |
| #if LLVM_VERSION_MAJOR > 7 |
| return bb.CreateGEP(ip->getType()->getPointerElementType(), ip, |
| invertargs, arg->getName() + "'ipg"); |
| #else |
| return bb.CreateGEP(ip, invertargs, arg->getName() + "'ipg"); |
| #endif |
| }; |
| |
| Value *shadow = applyChainRule(arg->getType(), bb, rule, ip); |
| |
| invertedPointers.insert(std::make_pair( |
| (const Value *)oval, InvertedPointerVH(this, shadow))); |
| return shadow; |
| } |
| } else { |
| llvm::errs() << *arg << "\n"; |
| assert(0 && "unhandled"); |
| } |
| goto end; |
| } else if (auto arg = dyn_cast<ExtractValueInst>(oval)) { |
| IRBuilder<> bb(getNewFromOriginal(arg)); |
| auto ip = invertPointerM(arg->getOperand(0), bb, nullShadow); |
| |
| auto rule = [&bb, &arg](Value *ip) { |
| return bb.CreateExtractValue(ip, arg->getIndices(), |
| arg->getName() + "'ipev"); |
| }; |
| |
| Value *shadow = applyChainRule(arg->getType(), bb, rule, ip); |
| |
| invertedPointers.insert( |
| std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); |
| return shadow; |
| } else if (auto arg = dyn_cast<InsertValueInst>(oval)) { |
| IRBuilder<> bb(getNewFromOriginal(arg)); |
| auto ip0 = invertPointerM(arg->getOperand(0), bb, nullShadow); |
| auto ip1 = invertPointerM(arg->getOperand(1), bb, nullShadow); |
| |
| auto rule = [&bb, &arg](Value *ip0, Value *ip1) { |
| return bb.CreateInsertValue(ip0, ip1, arg->getIndices(), |
| arg->getName() + "'ipiv"); |
| }; |
| |
| Value *shadow = applyChainRule(arg->getType(), bb, rule, ip0, ip1); |
| |
| invertedPointers.insert( |
| std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); |
| return shadow; |
| } else if (auto arg = dyn_cast<ExtractElementInst>(oval)) { |
| IRBuilder<> bb(getNewFromOriginal(arg)); |
| auto ip = invertPointerM(arg->getVectorOperand(), bb); |
| |
| auto rule = [&](Value *ip) { |
| return bb.CreateExtractElement(ip, |
| getNewFromOriginal(arg->getIndexOperand()), |
| arg->getName() + "'ipee"); |
| ; |
| }; |
| |
| Value *shadow = applyChainRule(arg->getType(), bb, rule, ip); |
| |
| invertedPointers.insert( |
| std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); |
| return shadow; |
| } else if (auto arg = dyn_cast<InsertElementInst>(oval)) { |
| IRBuilder<> bb(getNewFromOriginal(arg)); |
| Value *op0 = arg->getOperand(0); |
| Value *op1 = arg->getOperand(1); |
| Value *op2 = arg->getOperand(2); |
| auto ip0 = invertPointerM(op0, bb, nullShadow); |
| auto ip1 = invertPointerM(op1, bb, nullShadow); |
| |
| auto rule = [&](Value *ip0, Value *ip1) { |
| return bb.CreateInsertElement(ip0, ip1, getNewFromOriginal(op2), |
| arg->getName() + "'ipie"); |
| }; |
| |
| Value *shadow = applyChainRule(arg->getType(), bb, rule, ip0, ip1); |
| |
| invertedPointers.insert( |
| std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); |
| return shadow; |
| } else if (auto arg = dyn_cast<ShuffleVectorInst>(oval)) { |
| IRBuilder<> bb(getNewFromOriginal(arg)); |
| Value *op0 = arg->getOperand(0); |
| Value *op1 = arg->getOperand(1); |
| auto ip0 = invertPointerM(op0, bb); |
| auto ip1 = invertPointerM(op1, bb); |
| |
| auto rule = [&bb, &arg](Value *ip0, Value *ip1) { |
| #if LLVM_VERSION_MAJOR >= 11 |
| return bb.CreateShuffleVector(ip0, ip1, arg->getShuffleMaskForBitcode(), |
| arg->getName() + "'ipsv"); |
| #else |
| return bb.CreateShuffleVector(ip0, ip1, arg->getOperand(2), |
| arg->getName() + "'ipsv"); |
| #endif |
| }; |
| |
| Value *shadow = applyChainRule(arg->getType(), bb, rule, ip0, ip1); |
| |
| invertedPointers.insert( |
| std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); |
| return shadow; |
| } else if (auto arg = dyn_cast<SelectInst>(oval)) { |
| IRBuilder<> bb(getNewFromOriginal(arg)); |
| bb.setFastMathFlags(getFast()); |
| Value *shadow = applyChainRule( |
| arg->getType(), bb, |
| [&](Value *tv, Value *fv) { |
| return bb.CreateSelect(getNewFromOriginal(arg->getCondition()), tv, |
| fv, arg->getName() + "'ipse"); |
| }, |
| invertPointerM(arg->getTrueValue(), bb, nullShadow), |
| invertPointerM(arg->getFalseValue(), bb, nullShadow)); |
| invertedPointers.insert( |
| std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); |
| return shadow; |
| } else if (auto arg = dyn_cast<LoadInst>(oval)) { |
| IRBuilder<> bb(getNewFromOriginal(arg)); |
| Value *op0 = arg->getOperand(0); |
| Value *ip = invertPointerM(op0, bb); |
| |
| auto rule = [&](Value *ip) { |
| #if LLVM_VERSION_MAJOR > 7 |
| auto li = bb.CreateLoad(arg->getType(), ip, arg->getName() + "'ipl"); |
| #else |
| auto li = bb.CreateLoad(ip, arg->getName() + "'ipl"); |
| #endif |
| llvm::SmallVector<unsigned int, 9> ToCopy2(MD_ToCopy); |
| ToCopy2.push_back(LLVMContext::MD_noalias); |
| li->copyMetadata(*arg, ToCopy2); |
| li->copyIRFlags(arg); |
| #if LLVM_VERSION_MAJOR >= 10 |
| li->setAlignment(arg->getAlign()); |
| #else |
| li->setAlignment(arg->getAlignment()); |
| #endif |
| li->setDebugLoc(getNewFromOriginal(arg->getDebugLoc())); |
| li->setVolatile(arg->isVolatile()); |
| li->setOrdering(arg->getOrdering()); |
| li->setSyncScopeID(arg->getSyncScopeID()); |
| return li; |
| }; |
| |
| Value *li = applyChainRule(arg->getType(), bb, rule, ip); |
| |
| invertedPointers.insert( |
| std::make_pair((const Value *)oval, InvertedPointerVH(this, li))); |
| return li; |
| |
| } else if (auto arg = dyn_cast<BinaryOperator>(oval)) { |
| if (arg->getOpcode() == Instruction::FAdd) |
| return getNewFromOriginal(arg); |
| |
| if (!arg->getType()->isIntOrIntVectorTy()) { |
| llvm::errs() << *oval << "\n"; |
| } |
| assert(arg->getType()->isIntOrIntVectorTy()); |
| IRBuilder<> bb(getNewFromOriginal(arg)); |
| Value *val0 = nullptr; |
| Value *val1 = nullptr; |
| |
| val0 = invertPointerM(arg->getOperand(0), bb); |
| val1 = invertPointerM(arg->getOperand(1), bb); |
| assert(val0->getType() == val1->getType()); |
| |
| auto rule = [&bb, &arg](Value *val0, Value *val1) { |
| auto li = bb.CreateBinOp(arg->getOpcode(), val0, val1, arg->getName()); |
| if (auto BI = dyn_cast<BinaryOperator>(li)) |
| BI->copyIRFlags(arg); |
| return li; |
| }; |
| |
| Value *li = applyChainRule(arg->getType(), bb, rule, val0, val1); |
| |
| invertedPointers.insert( |
| std::make_pair((const Value *)oval, InvertedPointerVH(this, li))); |
| return li; |
| } else if (auto arg = dyn_cast<GetElementPtrInst>(oval)) { |
| IRBuilder<> bb(getNewFromOriginal(arg)); |
| SmallVector<Value *, 4> invertargs; |
| for (unsigned i = 0; i < arg->getNumIndices(); ++i) { |
| Value *b = getNewFromOriginal(arg->getOperand(1 + i)); |
| invertargs.push_back(b); |
| } |
| Value *ip = invertPointerM(arg->getPointerOperand(), bb); |
| |
| auto rule = [&](Value *ip) { |
| #if LLVM_VERSION_MAJOR > 7 |
| auto shadow = bb.CreateGEP(arg->getSourceElementType(), ip, invertargs, |
| arg->getName() + "'ipg"); |
| #else |
| auto shadow = bb.CreateGEP(ip, invertargs, arg->getName() + "'ipg"); |
| #endif |
| |
| if (auto gep = dyn_cast<GetElementPtrInst>(shadow)) |
| gep->setIsInBounds(arg->isInBounds()); |
| |
| return shadow; |
| }; |
| |
| Value *shadow = applyChainRule(arg->getType(), bb, rule, ip); |
| |
| invertedPointers.insert( |
| std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); |
| return shadow; |
| } else if (auto inst = dyn_cast<AllocaInst>(oval)) { |
| IRBuilder<> bb(getNewFromOriginal(inst)); |
| Value *asize = getNewFromOriginal(inst->getArraySize()); |
| |
| auto rule1 = [&]() { |
| AllocaInst *antialloca = bb.CreateAlloca( |
| inst->getAllocatedType(), inst->getType()->getPointerAddressSpace(), |
| asize, inst->getName() + "'ipa"); |
| #if LLVM_VERSION_MAJOR >= 11 |
| antialloca->setAlignment(inst->getAlign()); |
| #elif LLVM_VERSION_MAJOR == 10 |
| if (inst->getAlignment()) { |
| antialloca->setAlignment(Align(inst->getAlignment())); |
| } |
| #else |
| if (inst->getAlignment()) { |
| antialloca->setAlignment(inst->getAlignment()); |
| } |
| #endif |
| return antialloca; |
| }; |
| |
| Value *antialloca = applyChainRule(oval->getType(), bb, rule1); |
| |
| invertedPointers.insert(std::make_pair( |
| (const Value *)oval, InvertedPointerVH(this, antialloca))); |
| |
| if (auto ci = dyn_cast<ConstantInt>(asize)) { |
| if (ci->isOne()) { |
| |
| auto rule = [&](Value *antialloca) { |
| StoreInst *st = bb.CreateStore( |
| Constant::getNullValue(inst->getAllocatedType()), antialloca); |
| #if LLVM_VERSION_MAJOR >= 11 |
| cast<StoreInst>(st)->setAlignment(inst->getAlign()); |
| #elif LLVM_VERSION_MAJOR == 10 |
| if (inst->getAlignment()) { |
| cast<StoreInst>(st)->setAlignment(Align(inst->getAlignment())); |
| } |
| #else |
| if (inst->getAlignment()) { |
| cast<StoreInst>(st)->setAlignment(inst->getAlignment()); |
| } |
| #endif |
| }; |
| |
| applyChainRule(bb, rule, antialloca); |
| |
| return antialloca; |
| } else { |
| // TODO handle alloca of size > 1 |
| } |
| } |
| |
| auto rule2 = [&](Value *antialloca) { |
| auto dst_arg = |
| bb.CreateBitCast(antialloca, Type::getInt8PtrTy(oval->getContext())); |
| auto val_arg = ConstantInt::get(Type::getInt8Ty(oval->getContext()), 0); |
| auto len_arg = bb.CreateMul( |
| bb.CreateZExtOrTrunc(asize, Type::getInt64Ty(oval->getContext())), |
| ConstantInt::get(Type::getInt64Ty(oval->getContext()), |
| M->getDataLayout().getTypeAllocSizeInBits( |
| inst->getAllocatedType()) / |
| 8), |
| "", true, true); |
| auto volatile_arg = ConstantInt::getFalse(oval->getContext()); |
| |
| #if LLVM_VERSION_MAJOR == 6 |
| auto align_arg = ConstantInt::get(Type::getInt32Ty(oval->getContext()), |
| antialloca->getAlignment()); |
| Value *args[] = {dst_arg, val_arg, len_arg, align_arg, volatile_arg}; |
| #else |
| Value *args[] = {dst_arg, val_arg, len_arg, volatile_arg}; |
| #endif |
| Type *tys[] = {dst_arg->getType(), len_arg->getType()}; |
| auto memset = cast<CallInst>(bb.CreateCall( |
| Intrinsic::getDeclaration(M, Intrinsic::memset, tys), args)); |
| #if LLVM_VERSION_MAJOR >= 11 |
| memset->addParamAttr( |
| 0, Attribute::getWithAlignment(inst->getContext(), inst->getAlign())); |
| #elif LLVM_VERSION_MAJOR == 10 |
| if (inst->getAlignment() != 0) { |
| memset->addParamAttr( |
| 0, Attribute::getWithAlignment(inst->getContext(), |
| Align(inst->getAlignment()))); |
| } |
| #else |
| if (inst->getAlignment() != 0) { |
| memset->addParamAttr(0, Attribute::getWithAlignment( |
| inst->getContext(), inst->getAlignment())); |
| } |
| #endif |
| memset->addParamAttr(0, Attribute::NonNull); |
| }; |
| |
| applyChainRule(bb, rule2, antialloca); |
| |
| return antialloca; |
| } else if (auto II = dyn_cast<IntrinsicInst>(oval)) { |
| IRBuilder<> bb(getNewFromOriginal(II)); |
| bb.setFastMathFlags(getFast()); |
| switch (II->getIntrinsicID()) { |
| default: |
| goto end; |
| case Intrinsic::nvvm_ldu_global_i: |
| case Intrinsic::nvvm_ldu_global_p: |
| case Intrinsic::nvvm_ldu_global_f: |
| case Intrinsic::nvvm_ldg_global_i: |
| case Intrinsic::nvvm_ldg_global_p: |
| case Intrinsic::nvvm_ldg_global_f: { |
| return applyChainRule( |
| II->getType(), bb, |
| [&](Value *ptr) { |
| Value *args[] = {ptr}; |
| auto li = bb.CreateCall(II->getCalledFunction(), args); |
| llvm::SmallVector<unsigned int, 9> ToCopy2(MD_ToCopy); |
| ToCopy2.push_back(LLVMContext::MD_noalias); |
| li->copyMetadata(*II, ToCopy2); |
| li->setDebugLoc(getNewFromOriginal(II->getDebugLoc())); |
| return li; |
| }, |
| invertPointerM(II->getArgOperand(0), bb)); |
| case Intrinsic::masked_load: |
| return applyChainRule( |
| II->getType(), bb, |
| [&](Value *ptr, Value *defaultV) { |
| Value *args[] = {ptr, getNewFromOriginal(II->getArgOperand(1)), |
| getNewFromOriginal(II->getArgOperand(2)), |
| defaultV}; |
| llvm::SmallVector<unsigned int, 9> ToCopy2(MD_ToCopy); |
| ToCopy2.push_back(LLVMContext::MD_noalias); |
| auto li = bb.CreateCall(II->getCalledFunction(), args); |
| li->copyMetadata(*II, ToCopy2); |
| li->setDebugLoc(getNewFromOriginal(II->getDebugLoc())); |
| return li; |
| }, |
| invertPointerM(II->getArgOperand(0), bb), |
| invertPointerM(II->getArgOperand(3), bb, nullShadow)); |
| } |
| } |
| } else if (auto phi = dyn_cast<PHINode>(oval)) { |
| |
| if (phi->getNumIncomingValues() == 0) { |
| dumpMap(invertedPointers); |
| assert(0 && "illegal iv of phi"); |
| } |
| std::map<Value *, std::set<BasicBlock *>> mapped; |
| for (unsigned int i = 0; i < phi->getNumIncomingValues(); ++i) { |
| mapped[phi->getIncomingValue(i)].insert(phi->getIncomingBlock(i)); |
| } |
| |
| if (false && mapped.size() == 1) { |
| return invertPointerM(phi->getIncomingValue(0), BuilderM); |
| } |
| #if 0 |
| else if (false && mapped.size() == 2) { |
| IRBuilder <> bb(phi); |
| auto which = bb.CreatePHI(Type::getInt1Ty(phi->getContext()), phi->getNumIncomingValues()); |
| //TODO this is not recursive |
| |
| int cnt = 0; |
| Value* vals[2]; |
| for(auto v : mapped) { |
| assert( cnt <= 1 ); |
| vals[cnt] = v.first; |
| for (auto b : v.second) { |
| which->addIncoming(ConstantInt::get(which->getType(), cnt), b); |
| } |
| ++cnt; |
| } |
| auto result = BuilderM.CreateSelect(which, invertPointerM(vals[1], BuilderM), invertPointerM(vals[0], BuilderM)); |
| return result; |
| } |
| #endif |
| |
| else { |
| auto NewV = getNewFromOriginal(phi); |
| IRBuilder<> bb(NewV); |
| bb.setFastMathFlags(getFast()); |
| // Note if the original phi node get's scev'd in NewF, it may |
| // no longer be a phi and we need a new place to insert this phi |
| // Note that if scev'd this can still be a phi with 0 incoming indicating |
| // an unnecessary value to be replaced |
| // TODO consider allowing the inverted pointer to become a scev |
| if (!isa<PHINode>(NewV) || |
| cast<PHINode>(NewV)->getNumIncomingValues() == 0) { |
| bb.SetInsertPoint(bb.GetInsertBlock(), bb.GetInsertBlock()->begin()); |
| } |
| |
| if (EnzymeVectorSplitPhi && width > 1) { |
| IRBuilder<> postPhi(NewV->getParent()->getFirstNonPHI()); |
| Type *shadowTy = getShadowType(phi->getType()); |
| PHINode *tmp = bb.CreatePHI(shadowTy, phi->getNumIncomingValues()); |
| |
| invertedPointers.insert( |
| std::make_pair((const Value *)oval, InvertedPointerVH(this, tmp))); |
| |
| Type *wrappedType = ArrayType::get(phi->getType(), width); |
| Value *res = UndefValue::get(wrappedType); |
| |
| for (unsigned int i = 0; i < getWidth(); ++i) { |
| PHINode *which = |
| bb.CreatePHI(phi->getType(), phi->getNumIncomingValues()); |
| which->setDebugLoc(getNewFromOriginal(phi->getDebugLoc())); |
| |
| for (unsigned int j = 0; j < phi->getNumIncomingValues(); ++j) { |
| IRBuilder<> pre( |
| cast<BasicBlock>(getNewFromOriginal(phi->getIncomingBlock(j))) |
| ->getTerminator()); |
| Value *val = |
| invertPointerM(phi->getIncomingValue(j), pre, nullShadow); |
| auto extracted_diff = extractMeta(pre, val, i); |
| which->addIncoming( |
| extracted_diff, |
| cast<BasicBlock>(getNewFromOriginal(phi->getIncomingBlock(j)))); |
| } |
| |
| res = postPhi.CreateInsertValue(res, which, {i}); |
| } |
| invertedPointers.erase((const Value *)oval); |
| replaceAWithB(tmp, res); |
| erase(tmp); |
| |
| invertedPointers.insert( |
| std::make_pair((const Value *)oval, InvertedPointerVH(this, res))); |
| |
| return res; |
| } else { |
| Type *shadowTy = getShadowType(phi->getType()); |
| PHINode *which = bb.CreatePHI(shadowTy, phi->getNumIncomingValues()); |
| which->setDebugLoc(getNewFromOriginal(phi->getDebugLoc())); |
| |
| invertedPointers.insert(std::make_pair((const Value *)oval, |
| InvertedPointerVH(this, which))); |
| |
| for (unsigned int i = 0; i < phi->getNumIncomingValues(); ++i) { |
| IRBuilder<> pre( |
| cast<BasicBlock>(getNewFromOriginal(phi->getIncomingBlock(i))) |
| ->getTerminator()); |
| Value *val = |
| invertPointerM(phi->getIncomingValue(i), pre, nullShadow); |
| which->addIncoming(val, cast<BasicBlock>(getNewFromOriginal( |
| phi->getIncomingBlock(i)))); |
| } |
| return which; |
| } |
| } |
| } |
| |
| end:; |
| assert(BuilderM.GetInsertBlock()); |
| assert(BuilderM.GetInsertBlock()->getParent()); |
| assert(oval); |
| |
| if (CustomErrorHandler) { |
| std::string str; |
| raw_string_ostream ss(str); |
| ss << "cannot find shadow for " << *oval; |
| CustomErrorHandler(str.c_str(), wrap(oval), ErrorType::NoShadow, this); |
| } |
| |
| llvm::errs() << *newFunc->getParent() << "\n"; |
| llvm::errs() << "fn:" << *newFunc << "\noval=" << *oval |
| << " icv=" << isConstantValue(oval) << "\n"; |
| for (auto z : invertedPointers) { |
| llvm::errs() << "available inversion for " << *z.first << " of " |
| << *z.second << "\n"; |
| } |
| assert(0 && "cannot find deal with ptr that isnt arg"); |
| report_fatal_error("cannot find deal with ptr that isnt arg"); |
| } |
| |
| Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, |
| const ValueToValueMapTy &incoming_available, |
| bool tryLegalRecomputeCheck, BasicBlock *scope) { |
| |
| assert(mode == DerivativeMode::ReverseModePrimal || |
| mode == DerivativeMode::ReverseModeGradient || |
| mode == DerivativeMode::ReverseModeCombined); |
| |
| assert(val->getName() != "<badref>"); |
| if (isa<Constant>(val)) { |
| return val; |
| } |
| if (isa<BasicBlock>(val)) { |
| return val; |
| } |
| if (isa<Function>(val)) { |
| return val; |
| } |
| if (isa<UndefValue>(val)) { |
| return val; |
| } |
| if (isa<Argument>(val)) { |
| return val; |
| } |
| if (isa<MetadataAsValue>(val)) { |
| return val; |
| } |
| if (isa<InlineAsm>(val)) { |
| return val; |
| } |
| |
| if (!isa<Instruction>(val)) { |
| llvm::errs() << *val << "\n"; |
| } |
| |
| auto inst = cast<Instruction>(val); |
| assert(inst->getName() != "<badref>"); |
| if (inversionAllocs && inst->getParent() == inversionAllocs) { |
| return val; |
| } |
| assert(inst->getParent()->getParent() == newFunc); |
| assert(BuilderM.GetInsertBlock()->getParent() == newFunc); |
| if (scope == nullptr) |
| scope = BuilderM.GetInsertBlock(); |
| assert(scope->getParent() == newFunc); |
| |
| bool reduceRegister = false; |
| |
| if (EnzymeRegisterReduce) { |
| if (auto II = dyn_cast<IntrinsicInst>(inst)) { |
| switch (II->getIntrinsicID()) { |
| case Intrinsic::nvvm_ldu_global_i: |
| case Intrinsic::nvvm_ldu_global_p: |
| case Intrinsic::nvvm_ldu_global_f: |
| case Intrinsic::nvvm_ldg_global_i: |
| case Intrinsic::nvvm_ldg_global_p: |
| case Intrinsic::nvvm_ldg_global_f: |
| reduceRegister = true; |
| break; |
| default: |
| break; |
| } |
| } |
| if (auto LI = dyn_cast<LoadInst>(inst)) { |
| auto Arch = |
| llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch(); |
| unsigned int SharedAddrSpace = |
| Arch == Triple::amdgcn |
| ? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local |
| : 3; |
| if (cast<PointerType>(LI->getPointerOperand()->getType()) |
| ->getAddressSpace() == SharedAddrSpace) { |
| reduceRegister |= tryLegalRecomputeCheck && |
| legalRecompute(LI, incoming_available, &BuilderM) && |
| shouldRecompute(LI, incoming_available, &BuilderM); |
| } |
| } |
| if (!inst->mayReadOrWriteMemory()) { |
| reduceRegister |= tryLegalRecomputeCheck && |
| legalRecompute(inst, incoming_available, &BuilderM) && |
| shouldRecompute(inst, incoming_available, &BuilderM); |
| } |
| if (this->isOriginalBlock(*BuilderM.GetInsertBlock())) |
| reduceRegister = false; |
| } |
| |
| if (!reduceRegister) { |
| if (isOriginalBlock(*BuilderM.GetInsertBlock())) { |
| if (BuilderM.GetInsertBlock()->size() && |
| BuilderM.GetInsertPoint() != BuilderM.GetInsertBlock()->end()) { |
| Instruction *use = &*BuilderM.GetInsertPoint(); |
| while (isa<PHINode>(use)) |
| use = use->getNextNode(); |
| if (DT.dominates(inst, use)) { |
| return inst; |
| } else { |
| llvm::errs() << *BuilderM.GetInsertBlock()->getParent() << "\n"; |
| llvm::errs() << "didn't dominate inst: " << *inst |
| << " point: " << *BuilderM.GetInsertPoint() |
| << "\nbb: " << *BuilderM.GetInsertBlock() << "\n"; |
| } |
| } else { |
| if (inst->getParent() == BuilderM.GetInsertBlock() || |
| DT.dominates(inst, BuilderM.GetInsertBlock())) { |
| // allowed from block domination |
| return inst; |
| } else { |
| llvm::errs() << *BuilderM.GetInsertBlock()->getParent() << "\n"; |
| llvm::errs() << "didn't dominate inst: " << *inst |
| << "\nbb: " << *BuilderM.GetInsertBlock() << "\n"; |
| } |
| } |
| // This is a reverse block |
| } else if (BuilderM.GetInsertBlock() != inversionAllocs) { |
| // Something in the entry (or anything that dominates all returns, doesn't |
| // need caching) |
| BasicBlock *orig = isOriginal(inst->getParent()); |
| if (!orig) { |
| llvm::errs() << "oldFunc: " << *oldFunc << "\n"; |
| llvm::errs() << "newFunc: " << *newFunc << "\n"; |
| llvm::errs() << "insertBlock: " << *BuilderM.GetInsertBlock() << "\n"; |
| llvm::errs() << "instP: " << *inst->getParent() << "\n"; |
| llvm::errs() << "inst: " << *inst << "\n"; |
| } |
| assert(orig); |
| |
| // TODO upgrade this to be all returns that this could enter from |
| bool legal = BlocksDominatingAllReturns.count(orig); |
| if (legal) { |
| |
| BasicBlock *forwardBlock = |
| isOriginal(originalForReverseBlock(*BuilderM.GetInsertBlock())); |
| assert(forwardBlock); |
| |
| // Don't allow this if we're not definitely using the last iteration of |
| // this value |
| // + either because the value isn't in a loop |
| // + or because the forward of the block usage location isn't in a |
| // loop (thus last iteration) |
| // + or because the loop nests share no ancestry |
| |
| bool loopLegal = true; |
| for (Loop *idx = OrigLI.getLoopFor(orig); idx != nullptr; |
| idx = idx->getParentLoop()) { |
| for (Loop *fdx = OrigLI.getLoopFor(forwardBlock); fdx != nullptr; |
| fdx = fdx->getParentLoop()) { |
| if (idx == fdx) { |
| loopLegal = false; |
| break; |
| } |
| } |
| } |
| |
| if (loopLegal) { |
| return inst; |
| } |
| } |
| } |
| } |
| |
| if (lookup_cache[BuilderM.GetInsertBlock()].find(val) != |
| lookup_cache[BuilderM.GetInsertBlock()].end()) { |
| auto result = lookup_cache[BuilderM.GetInsertBlock()][val]; |
| if (result == nullptr) { |
| lookup_cache[BuilderM.GetInsertBlock()].erase(val); |
| } else { |
| assert(result); |
| assert(result->getType()); |
| result = BuilderM.CreateBitCast(result, val->getType()); |
| assert(result->getType() == inst->getType()); |
| return result; |
| } |
| } |
| |
| ValueToValueMapTy available; |
| for (auto pair : incoming_available) { |
| if (pair.second) |
| assert(pair.first->getType() == pair.second->getType()); |
| available[pair.first] = pair.second; |
| } |
| |
| { |
| BasicBlock *forwardPass = BuilderM.GetInsertBlock(); |
| if (forwardPass != inversionAllocs && !isOriginalBlock(*forwardPass)) { |
| forwardPass = originalForReverseBlock(*forwardPass); |
| } |
| LoopContext lc; |
| bool inLoop = getContext(forwardPass, lc); |
| |
| if (inLoop) { |
| bool first = true; |
| for (LoopContext idx = lc;; getContext(idx.parent->getHeader(), idx)) { |
| if (available.count(idx.var) == 0) { |
| if (!isOriginalBlock(*BuilderM.GetInsertBlock())) { |
| #if LLVM_VERSION_MAJOR > 7 |
| available[idx.var] = |
| BuilderM.CreateLoad(idx.var->getType(), idx.antivaralloc); |
| #else |
| available[idx.var] = BuilderM.CreateLoad(idx.antivaralloc); |
| #endif |
| } else { |
| available[idx.var] = idx.var; |
| } |
| } |
| if (!first && idx.var == inst) |
| return available[idx.var]; |
| if (first) { |
| first = false; |
| } |
| if (idx.parent == nullptr) |
| break; |
| } |
| } |
| } |
| |
| if (available.count(inst)) { |
| assert(available[inst]->getType() == inst->getType()); |
| return available[inst]; |
| } |
| |
| // If requesting loop bound and not available from index per above |
| // we must be requesting the total size. Rather than generating |
| // a new lcssa variable, use the existing loop exact bound var |
| { |
| LoopContext lc; |
| bool loopVar = false; |
| if (getContext(inst->getParent(), lc) && lc.var == inst) { |
| loopVar = true; |
| } else if (auto phi = dyn_cast<PHINode>(inst)) { |
| Value *V = nullptr; |
| bool legal = true; |
| for (auto &val : phi->incoming_values()) { |
| if (isa<UndefValue>(val)) |
| continue; |
| if (V == nullptr) |
| V = val; |
| else if (V != val) { |
| legal = false; |
| break; |
| } |
| } |
| if (legal) { |
| if (auto I = dyn_cast_or_null<PHINode>(V)) { |
| if (getContext(I->getParent(), lc) && lc.var == I) { |
| loopVar = true; |
| } |
| } |
| } |
| } |
| if (loopVar) { |
| Value *lim = nullptr; |
| if (lc.dynamic) { |
| // Must be in a reverse pass fashion for a lookup to index bound to be |
| // legal |
| assert(/*ReverseLimit*/ reverseBlocks.size() > 0); |
| LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, |
| lc.preheader); |
| lim = |
| lookupValueFromCache(/*forwardPass*/ false, BuilderM, lctx, |
| getDynamicLoopLimit(LI.getLoopFor(lc.header)), |
| /*isi1*/ false, available); |
| } else { |
| lim = lookupM(lc.trueLimit, BuilderM); |
| } |
| lookup_cache[BuilderM.GetInsertBlock()][val] = lim; |
| return lim; |
| } |
| } |
| |
| Instruction *prelcssaInst = inst; |
| |
| assert(inst->getName() != "<badref>"); |
| val = fixLCSSA(inst, scope); |
| if (isa<UndefValue>(val) || inst->getName() == "a14") { |
| llvm::errs() << *oldFunc << "\n"; |
| llvm::errs() << *newFunc << "\n"; |
| llvm::errs() << *BuilderM.GetInsertBlock() << "\n"; |
| llvm::errs() << *scope << "\n"; |
| llvm::errs() << *val << " inst " << *inst << "\n"; |
| assert(0 && "undef value upon lcssa"); |
| } |
| inst = cast<Instruction>(val); |
| assert(prelcssaInst->getType() == inst->getType()); |
| assert(!this->isOriginalBlock(*BuilderM.GetInsertBlock())); |
| |
| // Update index and caching per lcssa |
| if (lookup_cache[BuilderM.GetInsertBlock()].find(val) != |
| lookup_cache[BuilderM.GetInsertBlock()].end()) { |
| auto result = lookup_cache[BuilderM.GetInsertBlock()][val]; |
| if (result == nullptr) { |
| lookup_cache[BuilderM.GetInsertBlock()].erase(val); |
| } else { |
| assert(result); |
| assert(result->getType()); |
| result = BuilderM.CreateBitCast(result, val->getType()); |
| assert(result->getType() == inst->getType()); |
| return result; |
| } |
| } |
| |
| // TODO consider call as part of |
| bool lrc = false, src = false; |
| if (tryLegalRecomputeCheck && |
| (lrc = legalRecompute(prelcssaInst, available, &BuilderM))) { |
| if ((src = shouldRecompute(prelcssaInst, available, &BuilderM))) { |
| auto op = unwrapM(prelcssaInst, BuilderM, available, |
| UnwrapMode::AttemptSingleUnwrap, scope); |
| if (op) { |
| assert(op); |
| assert(op->getType()); |
| if (op->getType() != inst->getType()) { |
| llvm::errs() << " op: " << *op << " inst: " << *inst << "\n"; |
| } |
| assert(op->getType() == inst->getType()); |
| if (!reduceRegister) |
| lookup_cache[BuilderM.GetInsertBlock()][val] = op; |
| return op; |
| } |
| } else { |
| if (isa<LoadInst>(prelcssaInst)) { |
| } |
| } |
| } |
| |
| if (auto li = dyn_cast<LoadInst>(inst)) |
| if (auto origInst = dyn_cast_or_null<LoadInst>(isOriginal(inst))) { |
| #if LLVM_VERSION_MAJOR >= 12 |
| auto liobj = getUnderlyingObject(li->getPointerOperand(), 100); |
| #else |
| auto liobj = GetUnderlyingObject( |
| li->getPointerOperand(), oldFunc->getParent()->getDataLayout(), 100); |
| #endif |
| |
| #if LLVM_VERSION_MAJOR >= 12 |
| auto orig_liobj = getUnderlyingObject(origInst->getPointerOperand(), 100); |
| #else |
| auto orig_liobj = |
| GetUnderlyingObject(origInst->getPointerOperand(), |
| oldFunc->getParent()->getDataLayout(), 100); |
| #endif |
| |
| if (scopeMap.find(inst) == scopeMap.end()) { |
| for (auto pair : scopeMap) { |
| if (auto li2 = dyn_cast<LoadInst>(const_cast<Value *>(pair.first))) { |
| |
| #if LLVM_VERSION_MAJOR >= 12 |
| auto li2obj = getUnderlyingObject(li2->getPointerOperand(), 100); |
| #else |
| auto li2obj = |
| GetUnderlyingObject(li2->getPointerOperand(), |
| oldFunc->getParent()->getDataLayout(), 100); |
| #endif |
| |
| if (liobj == li2obj && DT.dominates(li2, li)) { |
| auto orig2 = isOriginal(li2); |
| if (!orig2) |
| continue; |
| |
| bool failed = false; |
| |
| // llvm::errs() << "found potential candidate loads: oli:" |
| // << *origInst << " oli2: " << *orig2 << "\n"; |
| |
| auto scev1 = SE.getSCEV(li->getPointerOperand()); |
| auto scev2 = SE.getSCEV(li2->getPointerOperand()); |
| // llvm::errs() << " scev1: " << *scev1 << " scev2: " << *scev2 |
| // << "\n"; |
| |
| allInstructionsBetween( |
| OrigLI, orig2, origInst, [&](Instruction *I) -> bool { |
| if (I->mayWriteToMemory() && |
| writesToMemoryReadBy(OrigAA, TLI, |
| /*maybeReader*/ origInst, |
| /*maybeWriter*/ I)) { |
| failed = true; |
| // llvm::errs() << "FAILED: " << *I << "\n"; |
| return /*earlyBreak*/ true; |
| } |
| return /*earlyBreak*/ false; |
| }); |
| if (failed) |
| continue; |
| |
| if (auto ar1 = dyn_cast<SCEVAddRecExpr>(scev1)) { |
| if (auto ar2 = dyn_cast<SCEVAddRecExpr>(scev2)) { |
| if (ar1->getStart() != SE.getCouldNotCompute() && |
| ar1->getStart() == ar2->getStart() && |
| ar1->getStepRecurrence(SE) != SE.getCouldNotCompute() && |
| ar1->getStepRecurrence(SE) == |
| ar2->getStepRecurrence(SE)) { |
| |
| LoopContext l1; |
| getContext(ar1->getLoop()->getHeader(), l1); |
| LoopContext l2; |
| getContext(ar2->getLoop()->getHeader(), l2); |
| if (l1.dynamic || l2.dynamic) |
| continue; |
| |
| // TODO IF len(ar2) >= len(ar1) then we can replace li with |
| // li2 |
| if (SE.getSCEV(l1.trueLimit) != SE.getCouldNotCompute() && |
| SE.getSCEV(l1.trueLimit) == SE.getSCEV(l2.trueLimit)) { |
| // llvm::errs() |
| // << " step1: " << *ar1->getStepRecurrence(SE) |
| // << " step2: " << *ar2->getStepRecurrence(SE) << |
| // "\n"; |
| |
| inst = li2; |
| break; |
| } |
| } |
| } |
| } |
| } |
| } |
| } |
| |
| auto scev1 = OrigSE.getSCEV(origInst->getPointerOperand()); |
| |
| auto Arch = |
| llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch(); |
| unsigned int SharedAddrSpace = |
| Arch == Triple::amdgcn |
| ? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local |
| : 3; |
| if (EnzymeSharedForward && scev1 != OrigSE.getCouldNotCompute() && |
| cast<PointerType>(orig_liobj->getType())->getAddressSpace() == |
| SharedAddrSpace) { |
| Value *resultValue = nullptr; |
| ValueToValueMapTy newavail; |
| for (const auto &pair : available) { |
| assert(pair.first->getType() == pair.second->getType()); |
| newavail[pair.first] = pair.second; |
| } |
| allDomPredecessorsOf(origInst, OrigDT, [&](Instruction *pred) { |
| if (auto SI = dyn_cast<StoreInst>(pred)) { |
| // auto NewSI = cast<StoreInst>(getNewFromOriginal(SI)); |
| #if LLVM_VERSION_MAJOR >= 12 |
| auto si2obj = getUnderlyingObject(SI->getPointerOperand(), 100); |
| #else |
| auto si2obj = |
| GetUnderlyingObject(SI->getPointerOperand(), |
| oldFunc->getParent()->getDataLayout(), 100); |
| #endif |
| |
| if (si2obj != orig_liobj) |
| return false; |
| |
| bool lastStore = true; |
| bool interveningSync = false; |
| allInstructionsBetween( |
| OrigLI, SI, origInst, [&](Instruction *potentialAlias) { |
| if (!potentialAlias->mayWriteToMemory()) |
| return false; |
| if (!writesToMemoryReadBy(OrigAA, TLI, origInst, |
| potentialAlias)) |
| return false; |
| |
| if (auto II = dyn_cast<IntrinsicInst>(potentialAlias)) { |
| if (II->getIntrinsicID() == Intrinsic::nvvm_barrier0 || |
| II->getIntrinsicID() == Intrinsic::amdgcn_s_barrier) { |
| interveningSync = |
| DT.dominates(SI, II) && DT.dominates(II, origInst); |
| allUnsyncdPredecessorsOf( |
| II, |
| [&](Instruction *mid) { |
| if (!mid->mayWriteToMemory()) |
| return false; |
| |
| if (mid == SI) |
| return false; |
| |
| if (!writesToMemoryReadBy(OrigAA, TLI, origInst, |
| mid)) { |
| return false; |
| } |
| lastStore = false; |
| return true; |
| }, |
| [&]() { |
| // if gone past entry |
| if (mode != DerivativeMode::ReverseModeCombined) { |
| lastStore = false; |
| } |
| }); |
| if (!lastStore) |
| return true; |
| else |
| return false; |
| } |
| } |
| |
| lastStore = false; |
| return true; |
| }); |
| |
| if (!lastStore) |
| return false; |
| |
| auto scev2 = OrigSE.getSCEV(SI->getPointerOperand()); |
| bool legal = scev1 == scev2; |
| if (auto ar2 = dyn_cast<SCEVAddRecExpr>(scev2)) { |
| if (auto ar1 = dyn_cast<SCEVAddRecExpr>(scev1)) { |
| if (ar2->getStart() != OrigSE.getCouldNotCompute() && |
| ar1->getStart() == ar2->getStart() && |
| ar2->getStepRecurrence(OrigSE) != |
| OrigSE.getCouldNotCompute() && |
| ar1->getStepRecurrence(OrigSE) == |
| ar2->getStepRecurrence(OrigSE)) { |
| |
| LoopContext l1; |
| getContext(getNewFromOriginal(ar1->getLoop()->getHeader()), |
| l1); |
| LoopContext l2; |
| getContext(getNewFromOriginal(ar2->getLoop()->getHeader()), |
| l2); |
| if (!l1.dynamic && !l2.dynamic) { |
| // TODO IF len(ar2) >= len(ar1) then we can replace li |
| // with li2 |
| if (l1.trueLimit == l2.trueLimit) { |
| const Loop *L1 = ar1->getLoop(); |
| while (L1) { |
| if (L1 == ar2->getLoop()) |
| return false; |
| L1 = L1->getParentLoop(); |
| } |
| newavail[l2.var] = available[l1.var]; |
| legal = true; |
| } |
| } |
| } |
| } |
| } |
| if (!legal) { |
| Value *sval = SI->getPointerOperand(); |
| Value *lval = origInst->getPointerOperand(); |
| while (auto CI = dyn_cast<CastInst>(sval)) |
| sval = CI->getOperand(0); |
| while (auto CI = dyn_cast<CastInst>(lval)) |
| lval = CI->getOperand(0); |
| if (auto sgep = dyn_cast<GetElementPtrInst>(sval)) { |
| if (auto lgep = dyn_cast<GetElementPtrInst>(lval)) { |
| if (sgep->getPointerOperand() == |
| lgep->getPointerOperand()) { |
| SmallVector<Value *, 3> svals; |
| for (auto &v : sgep->indices()) { |
| Value *q = v; |
| while (auto CI = dyn_cast<CastInst>(q)) |
| q = CI->getOperand(0); |
| svals.push_back(q); |
| } |
| SmallVector<Value *, 3> lvals; |
| for (auto &v : lgep->indices()) { |
| Value *q = v; |
| while (auto CI = dyn_cast<CastInst>(q)) |
| q = CI->getOperand(0); |
| lvals.push_back(q); |
| } |
| ValueToValueMapTy ThreadLookup; |
| bool legal = true; |
| for (size_t i = 0; i < svals.size(); i++) { |
| auto ss = OrigSE.getSCEV(svals[i]); |
| auto ls = OrigSE.getSCEV(lvals[i]); |
| if (cast<IntegerType>(ss->getType())->getBitWidth() > |
| cast<IntegerType>(ls->getType())->getBitWidth()) { |
| ls = OrigSE.getZeroExtendExpr(ls, ss->getType()); |
| } |
| if (cast<IntegerType>(ss->getType())->getBitWidth() < |
| cast<IntegerType>(ls->getType())->getBitWidth()) { |
| ls = OrigSE.getTruncateExpr(ls, ss->getType()); |
| } |
| if (ls != ss) { |
| if (auto II = dyn_cast<IntrinsicInst>(svals[i])) { |
| switch (II->getIntrinsicID()) { |
| case Intrinsic::nvvm_read_ptx_sreg_tid_x: |
| case Intrinsic::nvvm_read_ptx_sreg_tid_y: |
| case Intrinsic::nvvm_read_ptx_sreg_tid_z: |
| case Intrinsic::amdgcn_workitem_id_x: |
| case Intrinsic::amdgcn_workitem_id_y: |
| case Intrinsic::amdgcn_workitem_id_z: |
| ThreadLookup[getNewFromOriginal(II)] = |
| BuilderM.CreateZExtOrTrunc( |
| lookupM(getNewFromOriginal(lvals[i]), |
| BuilderM, available), |
| II->getType()); |
| break; |
| default: |
| legal = false; |
| break; |
| } |
| } else { |
| legal = false; |
| break; |
| } |
| } |
| } |
| if (legal) { |
| for (auto pair : newavail) { |
| assert(pair.first->getType() == |
| pair.second->getType()); |
| ThreadLookup[pair.first] = pair.second; |
| } |
| Value *recomp = unwrapM( |
| getNewFromOriginal(SI->getValueOperand()), BuilderM, |
| ThreadLookup, UnwrapMode::AttemptFullUnwrap, scope, |
| /*permitCache*/ false); |
| if (recomp) { |
| resultValue = recomp; |
| return true; |
| ; |
| } |
| } |
| } |
| } |
| } |
| } |
| if (!legal) |
| return false; |
| return true; |
| } |
| return false; |
| }); |
| |
| if (resultValue) { |
| if (resultValue->getType() != val->getType()) |
| resultValue = BuilderM.CreateBitCast(resultValue, val->getType()); |
| return resultValue; |
| } |
| } |
| } |
| |
| auto loadSize = (li->getParent() |
| ->getParent() |
| ->getParent() |
| ->getDataLayout() |
| .getTypeAllocSizeInBits(li->getType()) + |
| 7) / |
| 8; |
| |
| // this is guarded because havent told cacheForReverse how to move |
| if (mode == DerivativeMode::ReverseModeCombined) |
| if (!li->isVolatile() && EnzymeLoopInvariantCache) { |
| if (auto AI = dyn_cast<AllocaInst>(liobj)) { |
| assert(isa<AllocaInst>(orig_liobj)); |
| if (auto AT = dyn_cast<ArrayType>(AI->getAllocatedType())) |
| if (auto GEP = |
| dyn_cast<GetElementPtrInst>(li->getPointerOperand())) { |
| if (GEP->getPointerOperand() == AI) { |
| LoopContext l1; |
| if (!getContext(li->getParent(), l1)) |
| goto noSpeedCache; |
| |
| BasicBlock *ctx = l1.preheader; |
| |
| auto origPH = cast_or_null<BasicBlock>(isOriginal(ctx)); |
| assert(origPH); |
| if (OrigPDT.dominates(origPH, origInst->getParent())) { |
| goto noSpeedCache; |
| } |
| |
| Instruction *origTerm = origPH->getTerminator(); |
| if (!origTerm) |
| llvm::errs() << *origTerm << "\n"; |
| assert(origTerm); |
| IRBuilder<> OB(origTerm); |
| LoadInst *tmpload = OB.CreateLoad(AT, orig_liobj, "'tmpload"); |
| |
| bool failed = false; |
| allInstructionsBetween( |
| OrigLI, &*origTerm, origInst, |
| [&](Instruction *I) -> bool { |
| if (I->mayWriteToMemory() && |
| writesToMemoryReadBy(OrigAA, TLI, |
| /*maybeReader*/ tmpload, |
| /*maybeWriter*/ I)) { |
| failed = true; |
| return /*earlyBreak*/ true; |
| } |
| return /*earlyBreak*/ false; |
| }); |
| if (failed) { |
| tmpload->eraseFromParent(); |
| goto noSpeedCache; |
| } |
| while (Loop *L = LI.getLoopFor(ctx)) { |
| BasicBlock *nctx = L->getLoopPreheader(); |
| assert(nctx); |
| bool failed = false; |
| auto origPH = cast_or_null<BasicBlock>(isOriginal(nctx)); |
| assert(origPH); |
| if (OrigPDT.dominates(origPH, origInst->getParent())) { |
| break; |
| } |
| Instruction *origTerm = origPH->getTerminator(); |
| allInstructionsBetween( |
| OrigLI, &*origTerm, origInst, |
| [&](Instruction *I) -> bool { |
| if (I->mayWriteToMemory() && |
| writesToMemoryReadBy(OrigAA, TLI, |
| /*maybeReader*/ tmpload, |
| /*maybeWriter*/ I)) { |
| failed = true; |
| return /*earlyBreak*/ true; |
| } |
| return /*earlyBreak*/ false; |
| }); |
| if (failed) |
| break; |
| ctx = nctx; |
| } |
| |
| tmpload->eraseFromParent(); |
| |
| IRBuilder<> v(ctx->getTerminator()); |
| bool isi1 = false; |
| |
| AllocaInst *cache = nullptr; |
| |
| LoopContext tmp; |
| bool forceSingleIter = false; |
| if (!getContext(ctx, tmp)) { |
| forceSingleIter = true; |
| } |
| LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, |
| ctx, forceSingleIter); |
| |
| if (auto found = findInMap(scopeMap, (Value *)liobj)) { |
| cache = found->first; |
| } else { |
| // if freeing reverseblocks must exist |
| assert(reverseBlocks.size()); |
| cache = createCacheForScope(lctx, AT, li->getName(), |
| /*shouldFree*/ true, |
| /*allocate*/ true); |
| assert(cache); |
| scopeMap.insert( |
| std::make_pair(AI, std::make_pair(cache, lctx))); |
| |
| v.setFastMathFlags(getFast()); |
| assert(isOriginalBlock(*v.GetInsertBlock())); |
| Value *outer = getCachePointer( |
| /*inForwardPass*/ true, v, lctx, cache, isi1, |
| /*storeinstorecache*/ true, |
| /*available*/ ValueToValueMapTy(), |
| /*extraSize*/ nullptr); |
| |
| auto ld = v.CreateLoad(AT, AI); |
| #if LLVM_VERSION_MAJOR >= 11 |
| ld->setAlignment(AI->getAlign()); |
| #elif LLVM_VERSION_MAJOR == 10 |
| if (AI->getAlignment()) { |
| ld->setAlignment(Align(AI->getAlignment())); |
| } |
| #else |
| if (AI->getAlignment()) { |
| ld->setAlignment(AI->getAlignment()); |
| } |
| #endif |
| scopeInstructions[cache].push_back(ld); |
| auto st = v.CreateStore(ld, outer); |
| auto bsize = newFunc->getParent() |
| ->getDataLayout() |
| .getTypeAllocSizeInBits(AT) / |
| 8; |
| if ((bsize & (bsize - 1)) == 0) { |
| #if LLVM_VERSION_MAJOR >= 10 |
| st->setAlignment(Align(bsize)); |
| #else |
| st->setAlignment(bsize); |
| #endif |
| } |
| scopeInstructions[cache].push_back(st); |
| for (auto post : PostCacheStore(st, v)) { |
| scopeInstructions[cache].push_back(post); |
| } |
| } |
| |
| assert(!isOriginalBlock(*BuilderM.GetInsertBlock())); |
| Value *outer = getCachePointer( |
| /*inForwardPass*/ false, BuilderM, lctx, cache, isi1, |
| /*storeinstorecache*/ true, available, |
| /*extraSize*/ nullptr); |
| SmallVector<Value *, 2> idxs; |
| for (auto &idx : GEP->indices()) { |
| idxs.push_back(lookupM(idx, BuilderM, available, |
| tryLegalRecomputeCheck)); |
| } |
| |
| #if LLVM_VERSION_MAJOR > 7 |
| auto cptr = BuilderM.CreateGEP( |
| outer->getType()->getPointerElementType(), outer, idxs); |
| #else |
| auto cptr = BuilderM.CreateGEP(outer, idxs); |
| #endif |
| cast<GetElementPtrInst>(cptr)->setIsInBounds(true); |
| |
| // Retrieve the actual result |
| auto result = loadFromCachePointer(BuilderM, cptr, cache); |
| |
| assert(result->getType() == inst->getType()); |
| lookup_cache[BuilderM.GetInsertBlock()][val] = result; |
| return result; |
| } |
| } |
| } |
| |
| auto scev1 = SE.getSCEV(li->getPointerOperand()); |
| // Store in memcpy opt |
| Value *lim = nullptr; |
| BasicBlock *ctx = nullptr; |
| Value *start = nullptr; |
| Value *offset = nullptr; |
| if (auto ar1 = dyn_cast<SCEVAddRecExpr>(scev1)) { |
| if (auto step = |
| dyn_cast<SCEVConstant>(ar1->getStepRecurrence(SE))) { |
| if (step->getAPInt() != loadSize) |
| goto noSpeedCache; |
| |
| LoopContext l1; |
| getContext(ar1->getLoop()->getHeader(), l1); |
| |
| if (l1.dynamic) |
| goto noSpeedCache; |
| |
| offset = available[l1.var]; |
| ctx = l1.preheader; |
| |
| IRBuilder<> v(ctx->getTerminator()); |
| |
| auto origPH = cast_or_null<BasicBlock>(isOriginal(ctx)); |
| assert(origPH); |
| if (OrigPDT.dominates(origPH, origInst->getParent())) { |
| goto noSpeedCache; |
| } |
| |
| lim = unwrapM(l1.trueLimit, v, |
| /*available*/ ValueToValueMapTy(), |
| UnwrapMode::AttemptFullUnwrapWithLookup); |
| if (!lim) { |
| goto noSpeedCache; |
| } |
| lim = v.CreateAdd(lim, ConstantInt::get(lim->getType(), 1), "", |
| true, true); |
| |
| SmallVector<Instruction *, 4> toErase; |
| { |
| #if LLVM_VERSION_MAJOR >= 12 |
| SCEVExpander Exp(SE, |
| ctx->getParent()->getParent()->getDataLayout(), |
| "enzyme"); |
| #else |
| fake::SCEVExpander Exp( |
| SE, ctx->getParent()->getParent()->getDataLayout(), |
| "enzyme"); |
| #endif |
| Exp.setInsertPoint(l1.header->getTerminator()); |
| Value *start0 = Exp.expandCodeFor( |
| ar1->getStart(), li->getPointerOperand()->getType()); |
| start = unwrapM(start0, v, |
| /*available*/ ValueToValueMapTy(), |
| UnwrapMode::AttemptFullUnwrapWithLookup); |
| std::set<Value *> todo = {start0}; |
| while (todo.size()) { |
| Value *now = *todo.begin(); |
| todo.erase(now); |
| if (Instruction *inst = dyn_cast<Instruction>(now)) { |
| if (inst != start && inst->getNumUses() == 0 && |
| Exp.isInsertedInstruction(inst)) { |
| for (auto &op : inst->operands()) { |
| todo.insert(op); |
| } |
| toErase.push_back(inst); |
| } |
| } |
| } |
| } |
| for (auto a : toErase) |
| erase(a); |
| |
| if (!start) |
| goto noSpeedCache; |
| |
| Instruction *origTerm = origPH->getTerminator(); |
| |
| bool failed = false; |
| allInstructionsBetween( |
| OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { |
| if (I->mayWriteToMemory() && |
| writesToMemoryReadBy(OrigAA, TLI, |
| /*maybeReader*/ origInst, |
| /*maybeWriter*/ I)) { |
| failed = true; |
| return /*earlyBreak*/ true; |
| } |
| return /*earlyBreak*/ false; |
| }); |
| if (failed) |
| goto noSpeedCache; |
| } |
| } |
| |
| if (ctx && lim && start && offset) { |
| Value *firstLim = lim; |
| Value *firstStart = start; |
| while (Loop *L = LI.getLoopFor(ctx)) { |
| BasicBlock *nctx = L->getLoopPreheader(); |
| assert(nctx); |
| bool failed = false; |
| auto origPH = cast_or_null<BasicBlock>(isOriginal(nctx)); |
| assert(origPH); |
| if (OrigPDT.dominates(origPH, origInst->getParent())) { |
| break; |
| } |
| Instruction *origTerm = origPH->getTerminator(); |
| allInstructionsBetween( |
| OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { |
| if (I->mayWriteToMemory() && |
| writesToMemoryReadBy(OrigAA, TLI, |
| /*maybeReader*/ origInst, |
| /*maybeWriter*/ I)) { |
| failed = true; |
| return /*earlyBreak*/ true; |
| } |
| return /*earlyBreak*/ false; |
| }); |
| if (failed) |
| break; |
| IRBuilder<> nv(nctx->getTerminator()); |
| Value *nlim = unwrapM(firstLim, nv, |
| /*available*/ ValueToValueMapTy(), |
| UnwrapMode::AttemptFullUnwrapWithLookup); |
| if (!nlim) |
| break; |
| Value *nstart = unwrapM(firstStart, nv, |
| /*available*/ ValueToValueMapTy(), |
| UnwrapMode::AttemptFullUnwrapWithLookup); |
| if (!nstart) |
| break; |
| lim = nlim; |
| start = nstart; |
| ctx = nctx; |
| } |
| IRBuilder<> v(ctx->getTerminator()); |
| bool isi1 = val->getType()->isIntegerTy() && |
| cast<IntegerType>(li->getType())->getBitWidth() == 1; |
| |
| AllocaInst *cache = nullptr; |
| |
| LoopContext tmp; |
| bool forceSingleIter = false; |
| if (!getContext(ctx, tmp)) { |
| forceSingleIter = true; |
| } else if (auto inst = dyn_cast<Instruction>(lim)) { |
| if (inst->getParent() == ctx || |
| !DT.dominates(inst->getParent(), ctx)) { |
| forceSingleIter = true; |
| } |
| } |
| LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, ctx, |
| forceSingleIter); |
| |
| if (auto found = findInMap(scopeMap, (Value *)inst)) { |
| cache = found->first; |
| } else { |
| // if freeing reverseblocks must exist |
| assert(reverseBlocks.size()); |
| cache = createCacheForScope(lctx, li->getType(), li->getName(), |
| /*shouldFree*/ true, |
| /*allocate*/ true, /*extraSize*/ lim); |
| assert(cache); |
| scopeMap.insert( |
| std::make_pair(inst, std::make_pair(cache, lctx))); |
| |
| v.setFastMathFlags(getFast()); |
| assert(isOriginalBlock(*v.GetInsertBlock())); |
| Value *outer = |
| getCachePointer(/*inForwardPass*/ true, v, lctx, cache, isi1, |
| /*storeinstorecache*/ true, |
| /*available*/ ValueToValueMapTy(), |
| /*extraSize*/ lim); |
| |
| auto dst_arg = v.CreateBitCast( |
| outer, |
| Type::getInt8PtrTy( |
| inst->getContext(), |
| cast<PointerType>(outer->getType())->getAddressSpace())); |
| scopeInstructions[cache].push_back(cast<Instruction>(dst_arg)); |
| auto src_arg = v.CreateBitCast( |
| start, |
| Type::getInt8PtrTy( |
| inst->getContext(), |
| cast<PointerType>(start->getType())->getAddressSpace())); |
| auto len_arg = |
| v.CreateMul(ConstantInt::get(lim->getType(), loadSize), lim, |
| "", true, true); |
| if (Instruction *I = dyn_cast<Instruction>(len_arg)) |
| scopeInstructions[cache].push_back(I); |
| auto volatile_arg = ConstantInt::getFalse(inst->getContext()); |
| |
| Value *nargs[] = {dst_arg, src_arg, len_arg, volatile_arg}; |
| |
| Type *tys[] = {dst_arg->getType(), src_arg->getType(), |
| len_arg->getType()}; |
| |
| auto memcpyF = Intrinsic::getDeclaration(newFunc->getParent(), |
| Intrinsic::memcpy, tys); |
| auto mem = cast<CallInst>(v.CreateCall(memcpyF, nargs)); |
| |
| mem->addParamAttr(0, Attribute::NonNull); |
| mem->addParamAttr(1, Attribute::NonNull); |
| |
| auto bsize = |
| newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits( |
| li->getType()) / |
| 8; |
| if ((bsize & (bsize - 1)) == 0) { |
| #if LLVM_VERSION_MAJOR >= 10 |
| mem->addParamAttr(0, Attribute::getWithAlignment( |
| memcpyF->getContext(), Align(bsize))); |
| #else |
| mem->addParamAttr(0, Attribute::getWithAlignment( |
| memcpyF->getContext(), bsize)); |
| #endif |
| } |
| |
| #if LLVM_VERSION_MAJOR >= 11 |
| mem->addParamAttr(1, Attribute::getWithAlignment( |
| memcpyF->getContext(), li->getAlign())); |
| #elif LLVM_VERSION_MAJOR >= 10 |
| if (li->getAlign()) |
| mem->addParamAttr( |
| 1, Attribute::getWithAlignment(memcpyF->getContext(), |
| li->getAlign().getValue())); |
| #else |
| if (li->getAlignment()) |
| mem->addParamAttr( |
| 1, Attribute::getWithAlignment(memcpyF->getContext(), |
| li->getAlignment())); |
| #endif |
| |
| scopeInstructions[cache].push_back(mem); |
| } |
| |
| assert(!isOriginalBlock(*BuilderM.GetInsertBlock())); |
| Value *result = lookupValueFromCache( |
| /*isForwardPass*/ false, BuilderM, lctx, cache, isi1, available, |
| /*extraSize*/ lim, offset); |
| assert(result->getType() == inst->getType()); |
| lookup_cache[BuilderM.GetInsertBlock()][val] = result; |
| |
| EmitWarning("Uncacheable", *inst, "Caching instruction ", *inst, |
| " legalRecompute: ", lrc, " shouldRecompute: ", src, |
| " tryLegalRecomputeCheck: ", tryLegalRecomputeCheck); |
| return result; |
| } |
| } |
| noSpeedCache:; |
| } |
| |
| if (scopeMap.find(inst) == scopeMap.end()) { |
| EmitWarning("Uncacheable", *inst, "Caching instruction ", *inst, |
| " legalRecompute: ", lrc, " shouldRecompute: ", src, |
| " tryLegalRecomputeCheck: ", tryLegalRecomputeCheck); |
| } |
| |
| BasicBlock *scope = inst->getParent(); |
| if (auto origInst = isOriginal(inst)) { |
| auto found = rematerializableAllocations.find(origInst); |
| if (found != rematerializableAllocations.end()) |
| if (found->second.LI && found->second.LI->contains(origInst)) { |
| bool cacheWholeAllocation = false; |
| if (knownRecomputeHeuristic.count(origInst)) { |
| if (!knownRecomputeHeuristic[origInst]) { |
| cacheWholeAllocation = true; |
| } |
| } |
| // If not caching whole allocation and rematerializing the allocation |
| // within the loop, force an entry-level scope so there is no need |
| // to cache. |
| if (!cacheWholeAllocation) |
| scope = &newFunc->getEntryBlock(); |
| } |
| } else { |
| for (auto pair : backwardsOnlyShadows) { |
| if (auto pinst = dyn_cast<Instruction>(pair.first)) |
| if (!pair.second.primalInitialize && pair.second.LI && |
| pair.second.LI->contains(pinst->getParent())) { |
| auto found = invertedPointers.find(pair.first); |
| if (found != invertedPointers.end() && found->second == inst) { |
| scope = &newFunc->getEntryBlock(); |
| |
| // Prevent the phi node from being stored into the cache by creating |
| // it before the ensureLookupCached. |
| if (scopeMap.find(inst) == scopeMap.end()) { |
| LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, |
| scope); |
| |
| AllocaInst *cache = createCacheForScope( |
| lctx, inst->getType(), inst->getName(), /*shouldFree*/ true); |
| assert(cache); |
| insert_or_assign(scopeMap, (Value *&)inst, |
| std::pair<AssertingVH<AllocaInst>, LimitContext>( |
| cache, lctx)); |
| } |
| break; |
| } |
| } |
| } |
| } |
| |
| ensureLookupCached(inst, /*shouldFree*/ true, scope, |
| inst->getMetadata(LLVMContext::MD_tbaa)); |
| bool isi1 = inst->getType()->isIntegerTy() && |
| cast<IntegerType>(inst->getType())->getBitWidth() == 1; |
| assert(!isOriginalBlock(*BuilderM.GetInsertBlock())); |
| auto found = findInMap(scopeMap, (Value *)inst); |
| Value *result = |
| lookupValueFromCache(/*isForwardPass*/ false, BuilderM, found->second, |
| found->first, isi1, available); |
| if (auto LI2 = dyn_cast<LoadInst>(result)) |
| if (auto LI1 = dyn_cast<LoadInst>(inst)) { |
| llvm::SmallVector<unsigned int, 9> ToCopy2(MD_ToCopy); |
| ToCopy2.push_back(LLVMContext::MD_noalias); |
| LI2->copyMetadata(*LI1, ToCopy2); |
| } |
| if (result->getType() != inst->getType()) { |
| llvm::errs() << "newFunc: " << *newFunc << "\n"; |
| llvm::errs() << "result: " << *result << "\n"; |
| llvm::errs() << "inst: " << *inst << "\n"; |
| llvm::errs() << "val: " << *val << "\n"; |
| } |
| assert(result->getType() == inst->getType()); |
| lookup_cache[BuilderM.GetInsertBlock()][val] = result; |
| assert(result); |
| if (result->getType() != val->getType()) { |
| result = BuilderM.CreateBitCast(result, val->getType()); |
| } |
| assert(result->getType() == val->getType()); |
| assert(result->getType()); |
| return result; |
| } |
| |
| //! Given a map of edges we could have taken to desired target, compute a value |
| //! that determines which target should be branched to |
| // This function attempts to determine an equivalent condition from earlier in |
| // the code and use that if possible, falling back to creating a phi node of |
| // which edge was taken if necessary This function can be used in two ways: |
| // * If replacePHIs is null (usual case), this function does the branch |
| // * If replacePHIs isn't null, do not perform the branch and instead replace |
| // the PHI's with the derived condition as to whether we should branch to a |
| // particular target |
| void GradientUtils::branchToCorrespondingTarget( |
| BasicBlock *ctx, IRBuilder<> &BuilderM, |
| const std::map<BasicBlock *, |
| std::vector<std::pair</*pred*/ BasicBlock *, |
| /*successor*/ BasicBlock *>>> |
| &targetToPreds, |
| const std::map<BasicBlock *, PHINode *> *replacePHIs) { |
| assert(targetToPreds.size() > 0); |
| if (replacePHIs) { |
| if (replacePHIs->size() == 0) |
| return; |
| |
| for (auto x : *replacePHIs) { |
| assert(targetToPreds.find(x.first) != targetToPreds.end()); |
| } |
| } |
| |
| if (targetToPreds.size() == 1) { |
| if (replacePHIs == nullptr) { |
| if (!(BuilderM.GetInsertBlock()->size() == 0 || |
| !isa<BranchInst>(BuilderM.GetInsertBlock()->back()))) { |
| llvm::errs() << *oldFunc << "\n"; |
| llvm::errs() << *newFunc << "\n"; |
| llvm::errs() << *BuilderM.GetInsertBlock() << "\n"; |
| } |
| assert(BuilderM.GetInsertBlock()->size() == 0 || |
| !isa<BranchInst>(BuilderM.GetInsertBlock()->back())); |
| BuilderM.CreateBr(targetToPreds.begin()->first); |
| } else { |
| for (auto pair : *replacePHIs) { |
| pair.second->replaceAllUsesWith( |
| ConstantInt::getTrue(pair.second->getContext())); |
| pair.second->eraseFromParent(); |
| } |
| } |
| return; |
| } |
| |
| // Map of function edges to list of targets this can branch to we have |
| std::map<std::pair</*pred*/ BasicBlock *, /*successor*/ BasicBlock *>, |
| std::set<BasicBlock *>> |
| done; |
| { |
| std::deque< |
| std::tuple<std::pair</*pred*/ BasicBlock *, /*successor*/ BasicBlock *>, |
| BasicBlock *>> |
| Q; // newblock, target |
| |
| for (auto pair : targetToPreds) { |
| for (auto pred_edge : pair.second) { |
| Q.push_back(std::make_pair(pred_edge, pair.first)); |
| } |
| } |
| |
| for (std::tuple< |
| std::pair</*pred*/ BasicBlock *, /*successor*/ BasicBlock *>, |
| BasicBlock *> |
| trace; |
| Q.size() > 0;) { |
| trace = Q.front(); |
| Q.pop_front(); |
| auto edge = std::get<0>(trace); |
| auto block = edge.first; |
| auto target = std::get<1>(trace); |
| |
| if (done[edge].count(target)) |
| continue; |
| done[edge].insert(target); |
| |
| // If this block dominates the context, don't go back up as any |
| // predecessors won't contain the conditions. |
| if (DT.dominates(block, ctx)) |
| continue; |
| |
| Loop *blockLoop = LI.getLoopFor(block); |
| |
| for (BasicBlock *Pred : predecessors(block)) { |
| // Don't go up the backedge as we can use the last value if desired via |
| // lcssa |
| if (blockLoop && blockLoop->getHeader() == block && |
| blockLoop == LI.getLoopFor(Pred)) |
| continue; |
| |
| Q.push_back( |
| std::tuple<std::pair<BasicBlock *, BasicBlock *>, BasicBlock *>( |
| std::make_pair(Pred, block), target)); |
| } |
| } |
| } |
| |
| IntegerType *T; |
| if (targetToPreds.size() == 2) |
| T = Type::getInt1Ty(BuilderM.getContext()); |
| else if (targetToPreds.size() < 256) |
| T = Type::getInt8Ty(BuilderM.getContext()); |
| else |
| T = Type::getInt32Ty(BuilderM.getContext()); |
| |
| Instruction *equivalentTerminator = nullptr; |
| |
| std::set<BasicBlock *> blocks; |
| |
| // llvm::errs() << "\n\n<DONE = " << ctx->getName() << ">\n"; |
| for (auto pair : done) { |
| const auto &edge = pair.first; |
| blocks.insert(edge.first); |
| // llvm::errs() << " edge (" << edge.first->getName() << ", " |
| // << edge.second->getName() << ") : ["; |
| // for (auto s : pair.second) |
| // llvm::errs() << s->getName() << ","; |
| // llvm::errs() << "]\n"; |
| } |
| // llvm::errs() << "</DONE>\n"; |
| |
| if (targetToPreds.size() == 3) { |
| // Try `block` as a potential first split point. |
| for (auto block : blocks) { |
| { |
| // The original split block must not have a parent with an edge |
| // to a block other than to itself, which can reach any targets. |
| if (!DT.dominates(block, ctx)) |
| continue; |
| |
| // For all successors and thus edges (block, succ): |
| // 1) Ensure that no successors have overlapping potential |
| // destinations (a list of destinations previously seen is in |
| // foundtargets). |
| // 2) The block branches to all 3 destinations (foundTargets==3) |
| std::set<BasicBlock *> foundtargets; |
| // 3) The unique target split off from the others is stored in |
| // uniqueTarget. |
| std::set<BasicBlock *> uniqueTargets; |
| for (BasicBlock *succ : successors(block)) { |
| auto edge = std::make_pair(block, succ); |
| for (BasicBlock *target : done[edge]) { |
| if (foundtargets.find(target) != foundtargets.end()) { |
| goto rnextpair; |
| } |
| foundtargets.insert(target); |
| if (done[edge].size() == 1) |
| uniqueTargets.insert(target); |
| } |
| } |
| if (foundtargets.size() != 3) |
| goto rnextpair; |
| if (uniqueTargets.size() != 1) |
| goto rnextpair; |
| |
| // Only handle cases where the split was due to a conditional |
| // branch. This branch, `bi`, splits off uniqueTargets[0] from |
| // the remainder of foundTargets. |
| auto bi1 = dyn_cast<BranchInst>(block->getTerminator()); |
| if (!bi1) |
| goto rnextpair; |
| |
| { |
| // Find a second block `subblock` which splits the two merged |
| // targets from each other. |
| BasicBlock *subblock = nullptr; |
| for (auto block2 : blocks) { |
| { |
| // The second split block must not have a parent with an edge |
| // to a block other than to itself, which can reach any of its two |
| // targets. |
| // TODO verify this |
| for (auto P : predecessors(block2)) { |
| for (auto S : successors(P)) { |
| if (S == block2) |
| continue; |
| auto edge = std::make_pair(P, S); |
| if (done.find(edge) != done.end()) { |
| for (auto target : done[edge]) { |
| if (foundtargets.find(target) != foundtargets.end() && |
| uniqueTargets.find(target) == uniqueTargets.end()) { |
| goto nextblock; |
| } |
| } |
| } |
| } |
| } |
| |
| // Again, a successful split must have unique targets. |
| std::set<BasicBlock *> seen2; |
| for (BasicBlock *succ : successors(block2)) { |
| auto edge = std::make_pair(block2, succ); |
| // Since there are only two targets, a successful split |
| // condition has only 1 target per successor of block2. |
| if (done[edge].size() != 1) { |
| goto nextblock; |
| } |
| for (BasicBlock *target : done[edge]) { |
| // block2 has non-unique targets. |
| if (seen2.find(target) != seen2.end()) { |
| goto nextblock; |
| } |
| seen2.insert(target); |
| // block2 has a target which is not part of the two needing |
| // to be split. The two needing to be split is equal to |
| // foundtargets-uniqueTargets. |
| if (foundtargets.find(target) == foundtargets.end()) { |
| goto nextblock; |
| } |
| if (uniqueTargets.find(target) != uniqueTargets.end()) { |
| goto nextblock; |
| } |
| } |
| } |
| // If we didn't find two valid successors, continue. |
| if (seen2.size() != 2) { |
| // llvm::errs() << " -- failed from not 2 seen\n"; |
| goto nextblock; |
| } |
| subblock = block2; |
| break; |
| } |
| nextblock:; |
| } |
| |
| // If no split block was found, try again. |
| if (subblock == nullptr) |
| goto rnextpair; |
| |
| // This branch, `bi2`, splits off the two blocks in |
| // (foundTargets-uniqueTargets) from each other. |
| auto bi2 = dyn_cast<BranchInst>(subblock->getTerminator()); |
| if (!bi2) |
| goto rnextpair; |
| |
| // Condition cond1 splits off uniqueTargets[0] from |
| // the remainder of foundTargets. |
| auto cond1 = lookupM(bi1->getCondition(), BuilderM); |
| |
| // Condition cond2 splits off the two blocks in |
| // (foundTargets-uniqueTargets) from each other. |
| auto cond2 = lookupM(bi2->getCondition(), BuilderM); |
| |
| if (replacePHIs == nullptr) { |
| BasicBlock *staging = |
| BasicBlock::Create(oldFunc->getContext(), "staging", newFunc); |
| auto stagingIfNeeded = [&](BasicBlock *B) { |
| auto edge = std::make_pair(block, B); |
| if (done[edge].size() == 1) { |
| return *done[edge].begin(); |
| } else { |
| assert(done[edge].size() == 2); |
| return staging; |
| } |
| }; |
| BuilderM.CreateCondBr(cond1, stagingIfNeeded(bi1->getSuccessor(0)), |
| stagingIfNeeded(bi1->getSuccessor(1))); |
| BuilderM.SetInsertPoint(staging); |
| BuilderM.CreateCondBr( |
| cond2, |
| *done[std::make_pair(subblock, bi2->getSuccessor(0))].begin(), |
| *done[std::make_pair(subblock, bi2->getSuccessor(1))].begin()); |
| } else { |
| Value *otherBranch = nullptr; |
| for (unsigned i = 0; i < 2; ++i) { |
| Value *val = cond1; |
| if (i == 1) |
| val = BuilderM.CreateNot(val, "anot1_"); |
| auto edge = std::make_pair(block, bi1->getSuccessor(i)); |
| if (done[edge].size() == 1) { |
| auto found = replacePHIs->find(*done[edge].begin()); |
| if (found == replacePHIs->end()) |
| continue; |
| if (&*BuilderM.GetInsertPoint() == found->second) { |
| if (found->second->getNextNode()) |
| BuilderM.SetInsertPoint(found->second->getNextNode()); |
| else |
| BuilderM.SetInsertPoint(found->second->getParent()); |
| } |
| found->second->replaceAllUsesWith(val); |
| found->second->eraseFromParent(); |
| } else { |
| otherBranch = val; |
| } |
| } |
| |
| for (unsigned i = 0; i < 2; ++i) { |
| auto edge = std::make_pair(subblock, bi2->getSuccessor(i)); |
| auto found = replacePHIs->find(*done[edge].begin()); |
| if (found == replacePHIs->end()) |
| continue; |
| |
| Value *val = cond2; |
| if (i == 1) |
| val = BuilderM.CreateNot(val, "bnot1_"); |
| val = BuilderM.CreateAnd(val, otherBranch, |
| "andVal" + std::to_string(i)); |
| if (&*BuilderM.GetInsertPoint() == found->second) { |
| if (found->second->getNextNode()) |
| BuilderM.SetInsertPoint(found->second->getNextNode()); |
| else |
| BuilderM.SetInsertPoint(found->second->getParent()); |
| } |
| found->second->replaceAllUsesWith(val); |
| found->second->eraseFromParent(); |
| } |
| } |
| |
| return; |
| } |
| } |
| rnextpair:; |
| } |
| } |
| |
| BasicBlock *forwardBlock = BuilderM.GetInsertBlock(); |
| |
| if (!isOriginalBlock(*forwardBlock)) { |
| forwardBlock = originalForReverseBlock(*forwardBlock); |
| } |
| |
| for (auto block : blocks) { |
| { |
| // The original split block must not have a parent with an edge |
| // to a block other than to itself, which can reach any targets. |
| if (!DT.dominates(block, ctx)) |
| for (auto P : predecessors(block)) { |
| for (auto S : successors(P)) { |
| if (S == block) |
| continue; |
| auto edge = std::make_pair(P, S); |
| if (done.find(edge) != done.end() && done[edge].size()) |
| goto nextpair; |
| } |
| } |
| |
| std::set<BasicBlock *> foundtargets; |
| for (BasicBlock *succ : successors(block)) { |
| auto edge = std::make_pair(block, succ); |
| if (done[edge].size() != 1) { |
| goto nextpair; |
| } |
| BasicBlock *target = *done[edge].begin(); |
| if (foundtargets.find(target) != foundtargets.end()) { |
| goto nextpair; |
| } |
| foundtargets.insert(target); |
| } |
| if (foundtargets.size() != targetToPreds.size()) { |
| goto nextpair; |
| } |
| |
| if (forwardBlock == block || DT.dominates(block, forwardBlock)) { |
| equivalentTerminator = block->getTerminator(); |
| goto fast; |
| } |
| } |
| nextpair:; |
| } |
| goto nofast; |
| |
| fast:; |
| assert(equivalentTerminator); |
| |
| if (auto branch = dyn_cast<BranchInst>(equivalentTerminator)) { |
| BasicBlock *block = equivalentTerminator->getParent(); |
| assert(branch->getCondition()); |
| |
| assert(branch->getCondition()->getType() == T); |
| |
| if (replacePHIs == nullptr) { |
| if (!(BuilderM.GetInsertBlock()->size() == 0 || |
| !isa<BranchInst>(BuilderM.GetInsertBlock()->back()))) { |
| llvm::errs() << "newFunc : " << *newFunc << "\n"; |
| llvm::errs() << "blk : " << *BuilderM.GetInsertBlock() << "\n"; |
| } |
| assert(BuilderM.GetInsertBlock()->size() == 0 || |
| !isa<BranchInst>(BuilderM.GetInsertBlock()->back())); |
| BuilderM.CreateCondBr( |
| lookupM(branch->getCondition(), BuilderM), |
| *done[std::make_pair(block, branch->getSuccessor(0))].begin(), |
| *done[std::make_pair(block, branch->getSuccessor(1))].begin()); |
| } else { |
| for (auto pair : *replacePHIs) { |
| Value *phi = lookupM(branch->getCondition(), BuilderM); |
| Value *val = nullptr; |
| if (pair.first == |
| *done[std::make_pair(block, branch->getSuccessor(0))].begin()) { |
| val = phi; |
| } else if (pair.first == |
| *done[std::make_pair(block, branch->getSuccessor(1))] |
| .begin()) { |
| val = BuilderM.CreateNot(phi); |
| } else { |
| llvm::errs() << *pair.first->getParent() << "\n"; |
| llvm::errs() << *pair.first << "\n"; |
| llvm::errs() << *branch << "\n"; |
| llvm_unreachable("unknown successor for replacephi"); |
| } |
| if (&*BuilderM.GetInsertPoint() == pair.second) { |
| if (pair.second->getNextNode()) |
| BuilderM.SetInsertPoint(pair.second->getNextNode()); |
| else |
| BuilderM.SetInsertPoint(pair.second->getParent()); |
| } |
| pair.second->replaceAllUsesWith(val); |
| pair.second->eraseFromParent(); |
| } |
| } |
| } else if (auto si = dyn_cast<SwitchInst>(equivalentTerminator)) { |
| BasicBlock *block = equivalentTerminator->getParent(); |
| |
| IRBuilder<> pbuilder(equivalentTerminator); |
| pbuilder.setFastMathFlags(getFast()); |
| |
| if (replacePHIs == nullptr) { |
| SwitchInst *swtch = BuilderM.CreateSwitch( |
| lookupM(si->getCondition(), BuilderM), |
| *done[std::make_pair(block, si->getDefaultDest())].begin()); |
| for (auto switchcase : si->cases()) { |
| swtch->addCase( |
| switchcase.getCaseValue(), |
| *done[std::make_pair(block, switchcase.getCaseSuccessor())] |
| .begin()); |
| } |
| } else { |
| for (auto pair : *replacePHIs) { |
| Value *cas = nullptr; |
| for (auto c : si->cases()) { |
| if (pair.first == |
| *done[std::make_pair(block, c.getCaseSuccessor())].begin()) { |
| cas = c.getCaseValue(); |
| break; |
| } |
| } |
| if (cas == nullptr) { |
| assert(pair.first == |
| *done[std::make_pair(block, si->getDefaultDest())].begin()); |
| } |
| Value *val = nullptr; |
| Value *phi = lookupM(si->getCondition(), BuilderM); |
| |
| if (cas) { |
| val = BuilderM.CreateICmpEQ(cas, phi); |
| } else { |
| // default case |
| val = ConstantInt::getFalse(pair.second->getContext()); |
| for (auto switchcase : si->cases()) { |
| val = BuilderM.CreateOr( |
| val, BuilderM.CreateICmpEQ(switchcase.getCaseValue(), phi)); |
| } |
| val = BuilderM.CreateNot(val); |
| } |
| if (&*BuilderM.GetInsertPoint() == pair.second) { |
| if (pair.second->getNextNode()) |
| BuilderM.SetInsertPoint(pair.second->getNextNode()); |
| else |
| BuilderM.SetInsertPoint(pair.second->getParent()); |
| } |
| pair.second->replaceAllUsesWith(val); |
| pair.second->eraseFromParent(); |
| } |
| } |
| } else { |
| llvm::errs() << "unknown equivalent terminator\n"; |
| llvm::errs() << *equivalentTerminator << "\n"; |
| llvm_unreachable("unknown equivalent terminator"); |
| } |
| return; |
| |
| nofast:; |
| |
| // if freeing reverseblocks must exist |
| assert(reverseBlocks.size()); |
| LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, ctx); |
| AllocaInst *cache = createCacheForScope(lctx, T, "", /*shouldFree*/ true); |
| SmallVector<BasicBlock *, 4> targets; |
| { |
| size_t idx = 0; |
| std::map<BasicBlock * /*storingblock*/, |
| std::map<ConstantInt * /*target*/, |
| std::vector<BasicBlock *> /*predecessors*/>> |
| storing; |
| for (const auto &pair : targetToPreds) { |
| for (auto pred : pair.second) { |
| storing[pred.first][ConstantInt::get(T, idx)].push_back(pred.second); |
| } |
| targets.push_back(pair.first); |
| ++idx; |
| } |
| assert(targets.size() > 0); |
| |
| for (const auto &pair : storing) { |
| IRBuilder<> pbuilder(pair.first); |
| |
| if (pair.first->getTerminator()) |
| pbuilder.SetInsertPoint(pair.first->getTerminator()); |
| |
| pbuilder.setFastMathFlags(getFast()); |
| |
| Value *tostore = ConstantInt::get(T, 0); |
| |
| if (pair.second.size() == 1) { |
| tostore = pair.second.begin()->first; |
| } else { |
| assert(0 && "multi exit edges not supported"); |
| exit(1); |
| // for(auto targpair : pair.second) { |
| // tostore = pbuilder.CreateOr(tostore, pred); |
| //} |
| } |
| storeInstructionInCache(lctx, pbuilder, tostore, cache); |
| } |
| } |
| |
| bool isi1 = T->isIntegerTy() && cast<IntegerType>(T)->getBitWidth() == 1; |
| Value *which = lookupValueFromCache( |
| /*forwardPass*/ isOriginalBlock(*BuilderM.GetInsertBlock()), BuilderM, |
| LimitContext(/*reversePass*/ reverseBlocks.size() > 0, ctx), cache, isi1, |
| /*available*/ ValueToValueMapTy()); |
| assert(which); |
| assert(which->getType() == T); |
| |
| if (replacePHIs == nullptr) { |
| if (targetToPreds.size() == 2) { |
| assert(BuilderM.GetInsertBlock()->size() == 0 || |
| !isa<BranchInst>(BuilderM.GetInsertBlock()->back())); |
| BuilderM.CreateCondBr(which, /*true*/ targets[1], /*false*/ targets[0]); |
| } else { |
| assert(targets.size() > 0); |
| auto swit = |
| BuilderM.CreateSwitch(which, targets.back(), targets.size() - 1); |
| for (unsigned i = 0; i < targets.size() - 1; ++i) { |
| swit->addCase(ConstantInt::get(T, i), targets[i]); |
| } |
| } |
| } else { |
| for (unsigned i = 0; i < targets.size(); ++i) { |
| auto found = replacePHIs->find(targets[i]); |
| if (found == replacePHIs->end()) |
| continue; |
| |
| Value *val = nullptr; |
| if (targets.size() == 2 && i == 0) { |
| val = BuilderM.CreateNot(which); |
| } else if (targets.size() == 2 && i == 1) { |
| val = which; |
| } else { |
| val = BuilderM.CreateICmpEQ(ConstantInt::get(T, i), which); |
| } |
| if (&*BuilderM.GetInsertPoint() == found->second) { |
| if (found->second->getNextNode()) |
| BuilderM.SetInsertPoint(found->second->getNextNode()); |
| else |
| BuilderM.SetInsertPoint(found->second->getParent()); |
| } |
| found->second->replaceAllUsesWith(val); |
| found->second->eraseFromParent(); |
| } |
| } |
| return; |
| } |
| |
| void GradientUtils::computeMinCache() { |
| if (EnzymeMinCutCache) { |
| SmallPtrSet<Value *, 4> Recomputes; |
| |
| std::map<UsageKey, bool> FullSeen; |
| std::map<UsageKey, bool> OneLevelSeen; |
| |
| ValueToValueMapTy Available; |
| |
| std::map<Loop *, std::set<Instruction *>> LoopAvail; |
| |
| for (BasicBlock &BB : *oldFunc) { |
| if (notForAnalysis.count(&BB)) |
| continue; |
| auto L = OrigLI.getLoopFor(&BB); |
| |
| auto invariant = [&](Value *V) { |
| if (isa<Constant>(V)) |
| return true; |
| if (isa<Argument>(V)) |
| return true; |
| if (auto I = dyn_cast<Instruction>(V)) { |
| if (!L->contains(OrigLI.getLoopFor(I->getParent()))) |
| return true; |
| } |
| return false; |
| }; |
| for (Instruction &I : BB) { |
| if (auto PN = dyn_cast<PHINode>(&I)) { |
| if (!OrigLI.isLoopHeader(&BB)) |
| continue; |
| if (PN->getType()->isIntegerTy()) { |
| bool legal = true; |
| SmallPtrSet<Instruction *, 4> Increment; |
| for (auto B : PN->blocks()) { |
| if (OrigLI.getLoopFor(B) == L) { |
| if (auto BO = dyn_cast<BinaryOperator>( |
| PN->getIncomingValueForBlock(B))) { |
| if (BO->getOpcode() == BinaryOperator::Add) { |
| if ((BO->getOperand(0) == PN && |
| invariant(BO->getOperand(1))) || |
| (BO->getOperand(1) == PN && |
| invariant(BO->getOperand(0)))) { |
| Increment.insert(BO); |
| } else { |
| legal = false; |
| } |
| } else if (BO->getOpcode() == BinaryOperator::Sub) { |
| if (BO->getOperand(0) == PN && |
| invariant(BO->getOperand(1))) { |
| Increment.insert(BO); |
| } else { |
| legal = false; |
| } |
| } else { |
| legal = false; |
| } |
| } else { |
| legal = false; |
| } |
| } |
| } |
| if (legal) { |
| LoopAvail[L].insert(PN); |
| for (auto I : Increment) |
| LoopAvail[L].insert(I); |
| } |
| } |
| } else if (auto CI = dyn_cast<CallInst>(&I)) { |
| StringRef funcName = getFuncNameFromCall(CI); |
| if (isAllocationFunction(funcName, TLI)) |
| Available[CI] = CI; |
| } |
| } |
| } |
| |
| SmallPtrSet<Instruction *, 3> NewLoopBoundReq; |
| { |
| std::deque<Instruction *> LoopBoundRequirements; |
| |
| for (auto &context : loopContexts) { |
| for (auto val : {context.second.maxLimit, context.second.trueLimit}) { |
| if (val) |
| if (auto inst = dyn_cast<Instruction>(&*val)) { |
| LoopBoundRequirements.push_back(inst); |
| } |
| } |
| } |
| SmallPtrSet<Instruction *, 3> Seen; |
| while (LoopBoundRequirements.size()) { |
| Instruction *val = LoopBoundRequirements.front(); |
| LoopBoundRequirements.pop_front(); |
| if (NewLoopBoundReq.count(val)) |
| continue; |
| if (Seen.count(val)) |
| continue; |
| Seen.insert(val); |
| if (auto orig = isOriginal(val)) { |
| NewLoopBoundReq.insert(orig); |
| } else { |
| for (auto &op : val->operands()) { |
| if (auto inst = dyn_cast<Instruction>(op)) { |
| LoopBoundRequirements.push_back(inst); |
| } |
| } |
| } |
| } |
| for (auto inst : NewLoopBoundReq) { |
| OneLevelSeen[UsageKey(inst, ValueType::Primal)] = true; |
| FullSeen[UsageKey(inst, ValueType::Primal)] = true; |
| } |
| } |
| |
| auto minCutMode = (mode == DerivativeMode::ReverseModePrimal) |
| ? DerivativeMode::ReverseModeGradient |
| : mode; |
| |
| for (BasicBlock &BB : *oldFunc) { |
| if (notForAnalysis.count(&BB)) |
| continue; |
| ValueToValueMapTy Available2; |
| for (auto a : Available) |
| Available2[a.first] = a.second; |
| for (Loop *L = OrigLI.getLoopFor(&BB); L != nullptr; |
| L = L->getParentLoop()) { |
| for (auto v : LoopAvail[L]) { |
| Available2[v] = v; |
| } |
| } |
| for (Instruction &I : BB) { |
| if (!legalRecompute(&I, Available2, nullptr)) { |
| if (is_value_needed_in_reverse<ValueType::Primal>( |
| this, &I, minCutMode, FullSeen, notForAnalysis)) { |
| bool oneneed = is_value_needed_in_reverse<ValueType::Primal, |
| /*OneLevel*/ true>( |
| this, &I, minCutMode, OneLevelSeen, notForAnalysis); |
| if (oneneed) { |
| knownRecomputeHeuristic[&I] = false; |
| } else |
| Recomputes.insert(&I); |
| } |
| } |
| } |
| } |
| |
| SmallPtrSet<Value *, 4> Intermediates; |
| SmallPtrSet<Value *, 4> Required; |
| |
| Intermediates.clear(); |
| Required.clear(); |
| std::deque<Value *> todo(Recomputes.begin(), Recomputes.end()); |
| |
| while (todo.size()) { |
| Value *V = todo.front(); |
| todo.pop_front(); |
| if (Intermediates.count(V)) |
| continue; |
| if (!is_value_needed_in_reverse<ValueType::Primal>( |
| this, V, minCutMode, FullSeen, notForAnalysis)) { |
| continue; |
| } |
| if (!Recomputes.count(V)) { |
| ValueToValueMapTy Available2; |
| for (auto a : Available) |
| Available2[a.first] = a.second; |
| for (Loop *L = OrigLI.getLoopFor(cast<Instruction>(V)->getParent()); |
| L != nullptr; L = L->getParentLoop()) { |
| for (auto v : LoopAvail[L]) { |
| Available2[v] = v; |
| } |
| } |
| if (!legalRecompute(V, Available2, nullptr)) { |
| // if not legal to recompute, we would've already explicitly marked |
| // this for caching if it was needed in reverse pass |
| continue; |
| } |
| } |
| Intermediates.insert(V); |
| if (is_value_needed_in_reverse<ValueType::Primal, /*OneLevel*/ true>( |
| this, V, minCutMode, OneLevelSeen, notForAnalysis)) { |
| Required.insert(V); |
| } else { |
| for (auto V2 : V->users()) { |
| if (auto Inst = dyn_cast<Instruction>(V2)) |
| for (auto pair : rematerializableAllocations) { |
| if (pair.second.stores.count(Inst)) { |
| todo.push_back(pair.first); |
| } |
| } |
| todo.push_back(V2); |
| } |
| } |
| } |
| |
| SmallPtrSet<Value *, 5> MinReq; |
| minCut(oldFunc->getParent()->getDataLayout(), OrigLI, Recomputes, |
| Intermediates, Required, MinReq, rematerializableAllocations); |
| SmallPtrSet<Value *, 5> NeedGraph; |
| for (Value *V : MinReq) |
| NeedGraph.insert(V); |
| for (Value *V : Required) |
| todo.push_back(V); |
| while (todo.size()) { |
| Value *V = todo.front(); |
| todo.pop_front(); |
| if (NeedGraph.count(V)) |
| continue; |
| NeedGraph.insert(V); |
| if (auto I = dyn_cast<Instruction>(V)) |
| for (auto &V2 : I->operands()) { |
| if (Intermediates.count(V2)) |
| todo.push_back(V2); |
| } |
| } |
| |
| for (auto V : Intermediates) { |
| knownRecomputeHeuristic[V] = !MinReq.count(V); |
| if (!NeedGraph.count(V)) { |
| unnecessaryIntermediates.insert(cast<Instruction>(V)); |
| } |
| } |
| } |
| } |
| |
| void InvertedPointerVH::deleted() { |
| llvm::errs() << *gutils->oldFunc << "\n"; |
| llvm::errs() << *gutils->newFunc << "\n"; |
| gutils->dumpPointers(); |
| llvm::errs() << **this << "\n"; |
| assert(0 && "erasing something in invertedPointers map"); |
| } |
| |
| void SubTransferHelper(GradientUtils *gutils, DerivativeMode mode, |
| Type *secretty, Intrinsic::ID intrinsic, |
| unsigned dstalign, unsigned srcalign, unsigned offset, |
| bool dstConstant, Value *shadow_dst, bool srcConstant, |
| Value *shadow_src, Value *length, Value *isVolatile, |
| llvm::CallInst *MTI, bool allowForward, |
| bool shadowsLookedUp, bool backwardsShadow) { |
| // TODO offset |
| if (secretty) { |
| // no change to forward pass if represents floats |
| if (mode == DerivativeMode::ReverseModeGradient || |
| mode == DerivativeMode::ReverseModeCombined || |
| mode == DerivativeMode::ForwardModeSplit) { |
| IRBuilder<> Builder2(MTI); |
| if (mode == DerivativeMode::ForwardModeSplit) |
| gutils->getForwardBuilder(Builder2); |
| else |
| gutils->getReverseBuilder(Builder2); |
| |
| // If the src is constant simply zero d_dst and don't propagate to d_src |
| // (which thus == src and may be illegal) |
| if (srcConstant) { |
| // Don't zero in forward mode. |
| if (mode != DerivativeMode::ForwardModeSplit) { |
| |
| Value *args[] = { |
| shadowsLookedUp ? shadow_dst |
| : gutils->lookupM(shadow_dst, Builder2), |
| ConstantInt::get(Type::getInt8Ty(MTI->getContext()), 0), |
| gutils->lookupM(length, Builder2), |
| #if LLVM_VERSION_MAJOR <= 6 |
| ConstantInt::get(Type::getInt32Ty(MTI->getContext()), |
| max(1U, dstalign)), |
| #endif |
| ConstantInt::getFalse(MTI->getContext()) |
| }; |
| |
| if (args[0]->getType()->isIntegerTy()) |
| args[0] = Builder2.CreateIntToPtr( |
| args[0], Type::getInt8PtrTy(MTI->getContext())); |
| |
| Type *tys[] = {args[0]->getType(), args[2]->getType()}; |
| auto memsetIntr = Intrinsic::getDeclaration( |
| MTI->getParent()->getParent()->getParent(), Intrinsic::memset, |
| tys); |
| auto cal = Builder2.CreateCall(memsetIntr, args); |
| cal->setCallingConv(memsetIntr->getCallingConv()); |
| if (dstalign != 0) { |
| #if LLVM_VERSION_MAJOR >= 10 |
| cal->addParamAttr(0, Attribute::getWithAlignment(MTI->getContext(), |
| Align(dstalign))); |
| #else |
| cal->addParamAttr( |
| 0, Attribute::getWithAlignment(MTI->getContext(), dstalign)); |
| #endif |
| } |
| } |
| |
| } else { |
| auto dsto = |
| (shadowsLookedUp || mode == DerivativeMode::ForwardModeSplit) |
| ? shadow_dst |
| : gutils->lookupM(shadow_dst, Builder2); |
| if (dsto->getType()->isIntegerTy()) |
| dsto = Builder2.CreateIntToPtr( |
| dsto, Type::getInt8PtrTy(dsto->getContext())); |
| unsigned dstaddr = |
| cast<PointerType>(dsto->getType())->getAddressSpace(); |
| if (offset != 0) { |
| #if LLVM_VERSION_MAJOR > 7 |
| dsto = Builder2.CreateConstInBoundsGEP1_64( |
| dsto->getType()->getPointerElementType(), dsto, offset); |
| #else |
| dsto = Builder2.CreateConstInBoundsGEP1_64(dsto, offset); |
| #endif |
| } |
| auto srco = |
| (shadowsLookedUp || mode == DerivativeMode::ForwardModeSplit) |
| ? shadow_src |
| : gutils->lookupM(shadow_src, Builder2); |
| if (mode != DerivativeMode::ForwardModeSplit) |
| dsto = Builder2.CreatePointerCast( |
| dsto, PointerType::get(secretty, dstaddr)); |
| if (srco->getType()->isIntegerTy()) |
| srco = Builder2.CreateIntToPtr( |
| srco, Type::getInt8PtrTy(srco->getContext())); |
| unsigned srcaddr = |
| cast<PointerType>(srco->getType())->getAddressSpace(); |
| if (offset != 0) { |
| #if LLVM_VERSION_MAJOR > 7 |
| srco = Builder2.CreateConstInBoundsGEP1_64( |
| srco->getType()->getPointerElementType(), srco, offset); |
| #else |
| srco = Builder2.CreateConstInBoundsGEP1_64(srco, offset); |
| #endif |
| } |
| if (mode != DerivativeMode::ForwardModeSplit) |
| srco = Builder2.CreatePointerCast( |
| srco, PointerType::get(secretty, srcaddr)); |
| |
| if (mode == DerivativeMode::ForwardModeSplit) { |
| #if LLVM_VERSION_MAJOR >= 11 |
| MaybeAlign dalign; |
| if (dstalign) |
| dalign = MaybeAlign(dstalign); |
| MaybeAlign salign; |
| if (srcalign) |
| salign = MaybeAlign(srcalign); |
| #else |
| auto dalign = dstalign; |
| auto salign = srcalign; |
| #endif |
| |
| if (intrinsic == Intrinsic::memmove) { |
| Builder2.CreateMemMove(dsto, dalign, srco, salign, length); |
| } else { |
| Builder2.CreateMemCpy(dsto, dalign, srco, salign, length); |
| } |
| } else { |
| Value *args[]{ |
| Builder2.CreatePointerCast(dsto, |
| PointerType::get(secretty, dstaddr)), |
| Builder2.CreatePointerCast(srco, |
| PointerType::get(secretty, srcaddr)), |
| Builder2.CreateUDiv( |
| gutils->lookupM(length, Builder2), |
| ConstantInt::get(length->getType(), |
| Builder2.GetInsertBlock() |
| ->getParent() |
| ->getParent() |
| ->getDataLayout() |
| .getTypeAllocSizeInBits(secretty) / |
| 8))}; |
| |
| auto dmemcpy = ((intrinsic == Intrinsic::memcpy) |
| ? getOrInsertDifferentialFloatMemcpy |
| : getOrInsertDifferentialFloatMemmove)( |
| *MTI->getParent()->getParent()->getParent(), secretty, dstalign, |
| srcalign, dstaddr, srcaddr); |
| Builder2.CreateCall(dmemcpy, args); |
| } |
| } |
| } |
| } else { |
| |
| // if represents pointer or integer type then only need to modify forward |
| // pass with the copy |
| if ((allowForward && (mode == DerivativeMode::ReverseModePrimal || |
| mode == DerivativeMode::ReverseModeCombined)) || |
| (backwardsShadow && (mode == DerivativeMode::ReverseModeGradient || |
| mode == DerivativeMode::ForwardModeSplit))) { |
| assert(!shadowsLookedUp); |
| |
| // It is questionable how the following case would even occur, but if |
| // the dst is constant, we shouldn't do anything extra |
| if (dstConstant) { |
| return; |
| } |
| |
| IRBuilder<> BuilderZ(gutils->getNewFromOriginal(MTI)); |
| |
| // If src is inactive, then we should copy from the regular pointer |
| // (i.e. suppose we are copying constant memory representing dimensions |
| // into a tensor) |
| // to ensure that the differential tensor is well formed for use |
| // OUTSIDE the derivative generation (as enzyme doesn't need this), we |
| // should also perform the copy onto the differential. Future |
| // Optimization (not implemented): If dst can never escape Enzyme code, |
| // we may omit this copy. |
| // no need to update pointers, even if dst is active |
| auto dsto = shadow_dst; |
| if (dsto->getType()->isIntegerTy()) |
| dsto = BuilderZ.CreateIntToPtr(dsto, |
| Type::getInt8PtrTy(MTI->getContext())); |
| if (offset != 0) { |
| #if LLVM_VERSION_MAJOR > 7 |
| dsto = BuilderZ.CreateConstInBoundsGEP1_64( |
| dsto->getType()->getPointerElementType(), dsto, offset); |
| #else |
| dsto = BuilderZ.CreateConstInBoundsGEP1_64(dsto, offset); |
| #endif |
| } |
| auto srco = shadow_src; |
| if (srco->getType()->isIntegerTy()) |
| srco = BuilderZ.CreateIntToPtr(srco, |
| Type::getInt8PtrTy(MTI->getContext())); |
| if (offset != 0) { |
| #if LLVM_VERSION_MAJOR > 7 |
| srco = BuilderZ.CreateConstInBoundsGEP1_64( |
| srco->getType()->getPointerElementType(), srco, offset); |
| #else |
| srco = BuilderZ.CreateConstInBoundsGEP1_64(srco, offset); |
| #endif |
| } |
| Value *args[] = { |
| dsto, |
| srco, |
| length, |
| #if LLVM_VERSION_MAJOR <= 6 |
| ConstantInt::get(Type::getInt32Ty(MTI->getContext()), |
| max(1U, min(srcalign, dstalign))), |
| #endif |
| isVolatile |
| }; |
| |
| //#if LLVM_VERSION_MAJOR >= 7 |
| Type *tys[] = {args[0]->getType(), args[1]->getType(), |
| args[2]->getType()}; |
| //#else |
| // Type *tys[] = {args[0]->getType(), args[1]->getType(), |
| // args[2]->getType(), args[3]->getType()}; #endif |
| |
| auto memtransIntr = Intrinsic::getDeclaration( |
| gutils->newFunc->getParent(), intrinsic, tys); |
| auto cal = BuilderZ.CreateCall(memtransIntr, args); |
| cal->setAttributes(MTI->getAttributes()); |
| cal->setCallingConv(memtransIntr->getCallingConv()); |
| cal->setTailCallKind(MTI->getTailCallKind()); |
| |
| if (dstalign != 0) { |
| #if LLVM_VERSION_MAJOR >= 10 |
| cal->addParamAttr( |
| 0, Attribute::getWithAlignment(MTI->getContext(), Align(dstalign))); |
| #else |
| cal->addParamAttr( |
| 0, Attribute::getWithAlignment(MTI->getContext(), dstalign)); |
| #endif |
| } |
| if (srcalign != 0) { |
| #if LLVM_VERSION_MAJOR >= 10 |
| cal->addParamAttr( |
| 1, Attribute::getWithAlignment(MTI->getContext(), Align(srcalign))); |
| #else |
| cal->addParamAttr( |
| 1, Attribute::getWithAlignment(MTI->getContext(), srcalign)); |
| #endif |
| } |
| } |
| } |
| } |
| |
| void GradientUtils::computeForwardingProperties(Instruction *V) { |
| if (!EnzymeRematerialize) |
| return; |
| |
| // For the piece of memory V allocated within this scope, it will be |
| // initialized in some way by the (augmented) forward pass. Loads and other |
| // load-like operations will either require the allocation V itself to be |
| // preserved for the reverse pass, or alternatively the tape for those |
| // operations. |
| // |
| // Instead, we ask here whether or not we can restore the memory state of V in |
| // the reverse pass by recreating all of the stores and store-like operations |
| // into the V prior to their load-like uses. |
| // |
| // Notably, we only need to preserve the ability to reload any values actually |
| // used in the reverse pass. |
| |
| std::map<UsageKey, bool> Seen; |
| bool primalNeededInReverse = is_value_needed_in_reverse<ValueType::Primal>( |
| this, V, DerivativeMode::ReverseModeGradient, Seen, notForAnalysis); |
| |
| SmallVector<LoadInst *, 1> loads; |
| SmallVector<LoadLikeCall, 1> loadLikeCalls; |
| SmallPtrSet<Instruction *, 1> stores; |
| SmallPtrSet<Instruction *, 1> storingOps; |
| SmallPtrSet<Instruction *, 1> frees; |
| SmallPtrSet<IntrinsicInst *, 1> LifetimeStarts; |
| bool promotable = true; |
| bool shadowpromotable = true; |
| |
| SmallVector<Instruction *, 1> shadowPointerLoads; |
| |
| std::set<std::pair<Instruction *, Value *>> seen; |
| SmallVector<std::pair<Instruction *, Value *>, 1> todo; |
| for (auto U : V->users()) |
| if (auto I = dyn_cast<Instruction>(U)) |
| todo.push_back(std::make_pair(I, V)); |
| while (todo.size()) { |
| auto tup = todo.back(); |
| Instruction *cur = tup.first; |
| Value *prev = tup.second; |
| todo.pop_back(); |
| if (seen.count(tup)) |
| continue; |
| seen.insert(tup); |
| if (isa<CastInst>(cur) || isa<GetElementPtrInst>(cur)) { |
| for (auto u : cur->users()) { |
| if (auto I = dyn_cast<Instruction>(u)) |
| todo.push_back(std::make_pair(I, (Value *)cur)); |
| } |
| } else if (auto load = dyn_cast<LoadInst>(cur)) { |
| |
| // If loaded value is an int or pointer, may need |
| // to preserve initialization within the primal. |
| auto TT = TR.query(load)[{-1}]; |
| if (!TT.isFloat()) { |
| shadowPointerLoads.push_back(cur); |
| } |
| loads.push_back(load); |
| } else if (auto store = dyn_cast<StoreInst>(cur)) { |
| // TODO only add store to shadow iff non float type |
| if (store->getValueOperand() == prev) { |
| EmitWarning("NotPromotable", *cur, " Could not promote allocation ", *V, |
| " due to capturing store ", *cur); |
| promotable = false; |
| shadowpromotable = false; |
| break; |
| } else { |
| stores.insert(store); |
| storingOps.insert(store); |
| } |
| } else if (auto II = dyn_cast<IntrinsicInst>(cur)) { |
| switch (II->getIntrinsicID()) { |
| case Intrinsic::lifetime_start: |
| LifetimeStarts.insert(II); |
| break; |
| case Intrinsic::dbg_declare: |
| case Intrinsic::dbg_value: |
| #if LLVM_VERSION_MAJOR > 6 |
| case Intrinsic::dbg_label: |
| #endif |
| case Intrinsic::dbg_addr: |
| case Intrinsic::lifetime_end: |
| break; |
| case Intrinsic::memset: { |
| bool first = true; |
| stores.insert(II); |
| storingOps.insert(II); |
| break; |
| } |
| // TODO memtransfer(cpy/move) |
| case Intrinsic::memcpy: |
| case Intrinsic::memmove: |
| default: |
| promotable = false; |
| shadowpromotable = false; |
| EmitWarning("NotPromotable", *cur, " Could not promote allocation ", *V, |
| " due to unknown intrinsic ", *cur); |
| break; |
| } |
| } else if (auto CI = dyn_cast<CallInst>(cur)) { |
| StringRef funcName = getFuncNameFromCall(CI); |
| if (isDeallocationFunction(funcName, TLI)) { |
| frees.insert(CI); |
| continue; |
| } |
| if (funcName == "julia.write_barrier") { |
| stores.insert(CI); |
| continue; |
| } |
| |
| size_t idx = 0; |
| bool seenLoadLikeCall = false; |
| #if LLVM_VERSION_MAJOR >= 14 |
| for (auto &arg : CI->args()) |
| #else |
| for (auto &arg : CI->arg_operands()) |
| #endif |
| { |
| if (arg != prev) { |
| idx++; |
| continue; |
| } |
| auto F = getFunctionFromCall(CI); |
| auto TT = TR.query(prev)[{-1, -1}]; |
| |
| #if LLVM_VERSION_MAJOR >= 8 |
| bool NoCapture = CI->doesNotCapture(idx); |
| #else |
| bool NoCapture = |
| CI->dataOperandHasImpliedAttr(idx + 1, Attribute::NoCapture) || |
| (F && F->hasParamAttribute(idx, Attribute::NoCapture)); |
| #endif |
| |
| #if LLVM_VERSION_MAJOR >= 8 |
| bool ReadOnly = CI->onlyReadsMemory(idx); |
| #else |
| bool ReadOnly = |
| CI->dataOperandHasImpliedAttr(idx + 1, Attribute::ReadOnly) || |
| CI->dataOperandHasImpliedAttr(idx + 1, Attribute::ReadNone) || |
| (F && (F->hasParamAttribute(idx, Attribute::ReadOnly) || |
| F->hasParamAttribute(idx, Attribute::ReadNone))); |
| #endif |
| |
| #if LLVM_VERSION_MAJOR >= 14 |
| bool WriteOnly = CI->onlyWritesMemory(idx); |
| #else |
| bool WriteOnly = |
| CI->dataOperandHasImpliedAttr(idx + 1, Attribute::WriteOnly) || |
| CI->dataOperandHasImpliedAttr(idx + 1, Attribute::ReadNone) || |
| (F && (F->hasParamAttribute(idx, Attribute::WriteOnly) || |
| F->hasParamAttribute(idx, Attribute::ReadNone))); |
| #endif |
| |
| // If the pointer is captured, conservatively assume it is used in |
| // nontrivial ways that make both the primal and shadow not promotable. |
| if (!NoCapture) { |
| shadowpromotable = false; |
| promotable = false; |
| EmitWarning("NotPromotable", *cur, " Could not promote allocation ", |
| *V, " due to unknown capturing call ", *cur); |
| idx++; |
| continue; |
| } |
| |
| // From here on out we can assume the pointer is not captured, and only |
| // written to or read from. |
| |
| // If we may read from the memory, consider this a load-like call |
| // that must have all writes done in preparation for any reverse-pass |
| // users. |
| if (!WriteOnly) { |
| if (!seenLoadLikeCall) { |
| loadLikeCalls.push_back(LoadLikeCall(CI, prev)); |
| seenLoadLikeCall = true; |
| } |
| } |
| |
| // If we may write to memory, we cannot promote if any values |
| // need the allocation or any descendants for the reverse pass. |
| if (!ReadOnly) { |
| if (primalNeededInReverse) { |
| promotable = false; |
| EmitWarning("NotPromotable", *cur, " Could not promote allocation ", |
| *V, " due to unknown writing call ", *cur); |
| } |
| storingOps.insert(cur); |
| } |
| |
| // Consider shadow memory now. |
| // |
| // If the memory is all floats, there's no issue, since besides zero |
| // initialization nothing should occur for them in the forward pass |
| if (TT.isFloat()) { |
| } else if (WriteOnly) { |
| // Don't need in the case of int/pointer stores, (should be done by |
| // fwd pass), and as isFloat above described does not prevent the |
| // shadow |
| } else { |
| shadowPointerLoads.push_back(cur); |
| } |
| |
| idx++; |
| } |
| |
| } else { |
| promotable = false; |
| shadowpromotable = false; |
| EmitWarning("NotPromotable", *cur, " Could not promote allocation ", *V, |
| " due to unknown instruction ", *cur); |
| } |
| } |
| |
| // Find the outermost loop of all stores, and the allocation/lifetime |
| Loop *outer = OrigLI.getLoopFor(V->getParent()); |
| if (LifetimeStarts.size() == 1) { |
| outer = OrigLI.getLoopFor((*LifetimeStarts.begin())->getParent()); |
| } |
| |
| for (auto S : stores) { |
| outer = getAncestor(outer, OrigLI.getLoopFor(S->getParent())); |
| } |
| |
| // May now read pointers for storing into other pointers. Therefore we |
| // need to pre initialize the shadow. |
| bool primalInitializationOfShadow = shadowPointerLoads.size() > 0; |
| |
| if (shadowpromotable && !isConstantValue(V)) { |
| for (auto LI : shadowPointerLoads) { |
| // Is there a store which could occur after the load. |
| // In other words |
| SmallVector<Instruction *, 2> results; |
| mayExecuteAfter(results, LI, storingOps, outer); |
| for (auto res : results) { |
| if (overwritesToMemoryReadBy(OrigAA, TLI, SE, OrigLI, OrigDT, LI, res, |
| outer)) { |
| EmitWarning("NotPromotable", *LI, |
| " Could not promote shadow allocation ", *V, |
| " due to pointer load ", *LI, |
| " which does not postdominates store ", *res); |
| shadowpromotable = false; |
| goto exitL; |
| } |
| } |
| } |
| exitL:; |
| if (shadowpromotable) { |
| backwardsOnlyShadows[V] = ShadowRematerializer( |
| stores, frees, primalInitializationOfShadow, outer); |
| } |
| } |
| |
| if (!promotable) |
| return; |
| |
| SmallPtrSet<LoadInst *, 1> rematerializable; |
| |
| // We currently require a rematerializable allocation to have |
| // all of its loads be able to be performed again. Thus if |
| // there is an overwriting store after a load in context, |
| // it may no longer be rematerializable. |
| for (auto LI : loads) { |
| // Is there a store which could occur after the load. |
| // In other words |
| SmallVector<Instruction *, 2> results; |
| mayExecuteAfter(results, LI, storingOps, outer); |
| for (auto res : results) { |
| if (overwritesToMemoryReadBy(OrigAA, TLI, SE, OrigLI, OrigDT, LI, res, |
| outer)) { |
| EmitWarning("NotPromotable", *LI, " Could not promote allocation ", *V, |
| " due to load ", *LI, |
| " which does not postdominates store ", *res); |
| return; |
| } |
| } |
| rematerializable.insert(LI); |
| } |
| for (auto LI : loadLikeCalls) { |
| // Is there a store which could occur after the load. |
| // In other words |
| SmallVector<Instruction *, 2> results; |
| mayExecuteAfter(results, LI.loadCall, storingOps, outer); |
| for (auto res : results) { |
| if (overwritesToMemoryReadBy(OrigAA, TLI, SE, OrigLI, OrigDT, LI.loadCall, |
| res, outer)) { |
| EmitWarning("NotPromotable", *LI.loadCall, |
| " Could not promote allocation ", *V, |
| " due to load-like call ", *LI.loadCall, |
| " which does not postdominates store ", *res); |
| return; |
| } |
| } |
| } |
| rematerializableAllocations[V] = |
| Rematerializer(loads, loadLikeCalls, stores, frees, outer); |
| } |
| |
| void GradientUtils::computeGuaranteedFrees() { |
| SmallPtrSet<CallInst *, 2> allocsToPromote; |
| for (auto &BB : *oldFunc) { |
| if (notForAnalysis.count(&BB)) |
| continue; |
| for (auto &I : BB) { |
| if (auto AI = dyn_cast<AllocaInst>(&I)) |
| computeForwardingProperties(AI); |
| |
| auto CI = dyn_cast<CallInst>(&I); |
| if (!CI) |
| continue; |
| |
| StringRef funcName = getFuncNameFromCall(CI); |
| |
| if (isDeallocationFunction(funcName, TLI)) { |
| |
| llvm::Value *val = CI->getArgOperand(0); |
| while (auto cast = dyn_cast<CastInst>(val)) |
| val = cast->getOperand(0); |
| |
| if (auto dc = dyn_cast<CallInst>(val)) { |
| StringRef sfuncName = getFuncNameFromCall(dc); |
| if (isAllocationFunction(sfuncName, TLI)) { |
| |
| bool hasPDFree = false; |
| if (dc->getParent() == CI->getParent() || |
| OrigPDT.dominates(CI->getParent(), dc->getParent())) { |
| hasPDFree = true; |
| } |
| |
| if (hasPDFree) { |
| allocationsWithGuaranteedFree[dc].insert(CI); |
| } |
| } |
| } |
| } |
| if (isAllocationFunction(funcName, TLI)) { |
| allocsToPromote.insert(CI); |
| if (hasMetadata(CI, "enzyme_fromstack")) { |
| allocationsWithGuaranteedFree[CI].insert(CI); |
| } |
| if (funcName == "jl_alloc_array_1d" || |
| funcName == "jl_alloc_array_2d" || |
| funcName == "jl_alloc_array_3d" || funcName == "jl_array_copy" || |
| funcName == "ijl_alloc_array_1d" || |
| funcName == "ijl_alloc_array_2d" || |
| funcName == "ijl_alloc_array_3d" || funcName == "ijl_array_copy" || |
| funcName == "julia.gc_alloc_obj" || |
| funcName == "jl_gc_alloc_typed" || |
| funcName == "ijl_gc_alloc_typed") { |
| } |
| } |
| } |
| } |
| for (CallInst *V : allocsToPromote) { |
| // TODO compute if an only load/store (non capture) |
| // allocaion by traversing its users. If so, mark |
| // all of its load/stores, as now the loads can |
| // potentially be rematerialized without a cache |
| // of the allocation, but the operands of all stores. |
| // This info needs to be provided to minCutCache |
| // the derivative of store needs to redo the store, |
| // isValueNeededInReverse needs to know to preserve the |
| // store operands in this case, etc |
| computeForwardingProperties(V); |
| } |
| } |
| |
| /// Perform the corresponding deallocation of tofree, given it was allocated by |
| /// allocationfn |
| // For updating below one should read MemoryBuiltins.cpp, TargetLibraryInfo.cpp |
| llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder, |
| llvm::Value *tofree, |
| const llvm::StringRef allocationfn, |
| const llvm::DebugLoc &debuglocation, |
| const llvm::TargetLibraryInfo &TLI, |
| llvm::CallInst *orig, |
| GradientUtils *gutils) { |
| using namespace llvm; |
| assert(isAllocationFunction(allocationfn, TLI)); |
| |
| if (allocationfn == "__rust_alloc" || allocationfn == "__rust_alloc_zeroed") { |
| llvm_unreachable("todo - hook in rust allocation fns"); |
| } |
| if (allocationfn == "julia.gc_alloc_obj" || |
| allocationfn == "jl_gc_alloc_typed" || |
| allocationfn == "ijl_gc_alloc_typed") |
| return nullptr; |
| |
| if (allocationfn == "enzyme_allocator") { |
| auto inds = getDeallocationIndicesFromCall(orig); |
| SmallVector<Value *, 2> vals; |
| for (auto ind : inds) { |
| if (ind == -1) |
| vals.push_back(tofree); |
| else |
| vals.push_back(gutils->lookupM( |
| gutils->getNewFromOriginal(orig->getArgOperand(ind)), builder)); |
| } |
| auto tocall = getDeallocatorFnFromCall(orig); |
| auto freecall = builder.CreateCall(tocall, vals); |
| freecall->setDebugLoc(debuglocation); |
| return freecall; |
| } |
| |
| if (allocationfn == "swift_allocObject") { |
| Type *VoidTy = Type::getVoidTy(tofree->getContext()); |
| Type *IntPtrTy = Type::getInt8PtrTy(tofree->getContext()); |
| |
| auto FT = FunctionType::get(VoidTy, ArrayRef<Type *>(IntPtrTy), false); |
| #if LLVM_VERSION_MAJOR >= 9 |
| Value *freevalue = builder.GetInsertBlock() |
| ->getParent() |
| ->getParent() |
| ->getOrInsertFunction("swift_release", FT) |
| .getCallee(); |
| #else |
| Value *freevalue = |
| builder.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction( |
| "swift_release", FT); |
| #endif |
| CallInst *freecall = cast<CallInst>( |
| #if LLVM_VERSION_MAJOR >= 8 |
| CallInst::Create( |
| FT, freevalue, |
| ArrayRef<Value *>(builder.CreatePointerCast(tofree, IntPtrTy)), |
| #else |
| CallInst::Create( |
| freevalue, |
| ArrayRef<Value *>(builder.CreatePointerCast(tofree, IntPtrTy)), |
| #endif |
| "", builder.GetInsertBlock())); |
| freecall->setDebugLoc(debuglocation); |
| freecall->setTailCall(); |
| if (isa<CallInst>(tofree) && |
| #if LLVM_VERSION_MAJOR >= 14 |
| cast<CallInst>(tofree)->getAttributes().hasAttributeAtIndex( |
| AttributeList::ReturnIndex, Attribute::NonNull) |
| #else |
| cast<CallInst>(tofree)->getAttributes().hasAttribute( |
| AttributeList::ReturnIndex, Attribute::NonNull) |
| #endif |
| ) { |
| #if LLVM_VERSION_MAJOR >= 14 |
| freecall->addAttributeAtIndex(AttributeList::FirstArgIndex, |
| Attribute::NonNull); |
| #else |
| freecall->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull); |
| #endif |
| } |
| if (Function *F = dyn_cast<Function>(freevalue)) |
| freecall->setCallingConv(F->getCallingConv()); |
| if (freecall->getParent() == nullptr) |
| builder.Insert(freecall); |
| return freecall; |
| } |
| |
| if (shadowErasers.find(allocationfn.str()) != shadowErasers.end()) { |
| return shadowErasers[allocationfn.str()](builder, tofree); |
| } |
| |
| if (tofree->getType()->isIntegerTy()) |
| tofree = builder.CreateIntToPtr(tofree, |
| Type::getInt8PtrTy(tofree->getContext())); |
| |
| llvm::LibFunc libfunc; |
| if (allocationfn == "calloc" || allocationfn == "malloc") { |
| libfunc = LibFunc_malloc; |
| } else { |
| bool res = TLI.getLibFunc(allocationfn, libfunc); |
| assert(res && "ought find known allocation fn"); |
| } |
| |
| llvm::LibFunc freefunc; |
| |
| switch (libfunc) { |
| case LibFunc_malloc: // malloc(unsigned int); |
| case LibFunc_valloc: // valloc(unsigned int); |
| freefunc = LibFunc_free; |
| break; |
| |
| case LibFunc_Znwj: // new(unsigned int); |
| case LibFunc_ZnwjRKSt9nothrow_t: // new(unsigned int, nothrow); |
| #if LLVM_VERSION_MAJOR > 6 |
| case LibFunc_ZnwjSt11align_val_t: // new(unsigned int, align_val_t) |
| case LibFunc_ZnwjSt11align_val_tRKSt9nothrow_t: // new(unsigned int, |
| // align_val_t, nothrow) |
| #endif |
| |
| case LibFunc_Znwm: // new(unsigned long); |
| case LibFunc_ZnwmRKSt9nothrow_t: // new(unsigned long, nothrow); |
| #if LLVM_VERSION_MAJOR > 6 |
| case LibFunc_ZnwmSt11align_val_t: // new(unsigned long, align_val_t) |
| case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t: // new(unsigned long, |
| // align_val_t, nothrow) |
| #endif |
| freefunc = LibFunc_ZdlPv; |
| break; |
| |
| case LibFunc_Znaj: // new[](unsigned int); |
| case LibFunc_ZnajRKSt9nothrow_t: // new[](unsigned int, nothrow); |
| #if LLVM_VERSION_MAJOR > 6 |
| case LibFunc_ZnajSt11align_val_t: // new[](unsigned int, align_val_t) |
| case LibFunc_ZnajSt11align_val_tRKSt9nothrow_t: // new[](unsigned int, |
| // align_val_t, nothrow) |
| #endif |
| |
| case LibFunc_Znam: // new[](unsigned long); |
| case LibFunc_ZnamRKSt9nothrow_t: // new[](unsigned long, nothrow); |
| #if LLVM_VERSION_MAJOR > 6 |
| case LibFunc_ZnamSt11align_val_t: // new[](unsigned long, align_val_t) |
| case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t: // new[](unsigned long, |
| // align_val_t, nothrow) |
| #endif |
| freefunc = LibFunc_ZdaPv; |
| break; |
| |
| case LibFunc_msvc_new_int: // new(unsigned int); |
| case LibFunc_msvc_new_int_nothrow: // new(unsigned int, nothrow); |
| case LibFunc_msvc_new_longlong: // new(unsigned long long); |
| case LibFunc_msvc_new_longlong_nothrow: // new(unsigned long long, nothrow); |
| case LibFunc_msvc_new_array_int: // new[](unsigned int); |
| case LibFunc_msvc_new_array_int_nothrow: // new[](unsigned int, nothrow); |
| case LibFunc_msvc_new_array_longlong: // new[](unsigned long long); |
| case LibFunc_msvc_new_array_longlong_nothrow: // new[](unsigned long long, |
| // nothrow); |
| llvm_unreachable("msvc deletion not handled"); |
| |
| default: |
| llvm_unreachable("unknown allocation function"); |
| } |
| llvm::StringRef freename = TLI.getName(freefunc); |
| if (freefunc == LibFunc_free) { |
| freename = "free"; |
| assert(freename == "free"); |
| if (freename != "free") |
| llvm_unreachable("illegal free"); |
| } |
| |
| Type *VoidTy = Type::getVoidTy(tofree->getContext()); |
| Type *IntPtrTy = Type::getInt8PtrTy(tofree->getContext()); |
| |
| auto FT = FunctionType::get(VoidTy, {IntPtrTy}, false); |
| #if LLVM_VERSION_MAJOR >= 9 |
| Value *freevalue = builder.GetInsertBlock() |
| ->getParent() |
| ->getParent() |
| ->getOrInsertFunction(freename, FT) |
| .getCallee(); |
| #else |
| Value *freevalue = |
| builder.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction( |
| freename, FT); |
| #endif |
| CallInst *freecall = cast<CallInst>( |
| #if LLVM_VERSION_MAJOR >= 8 |
| CallInst::Create(FT, freevalue, |
| {builder.CreatePointerCast(tofree, IntPtrTy)}, |
| #else |
| CallInst::Create(freevalue, {builder.CreatePointerCast(tofree, IntPtrTy)}, |
| #endif |
| "", builder.GetInsertBlock())); |
| freecall->setTailCall(); |
| freecall->setDebugLoc(debuglocation); |
| if (isa<CallInst>(tofree) && |
| #if LLVM_VERSION_MAJOR >= 14 |
| cast<CallInst>(tofree)->getAttributes().hasAttributeAtIndex( |
| AttributeList::ReturnIndex, Attribute::NonNull) |
| #else |
| cast<CallInst>(tofree)->getAttributes().hasAttribute( |
| AttributeList::ReturnIndex, Attribute::NonNull) |
| #endif |
| ) { |
| #if LLVM_VERSION_MAJOR >= 14 |
| freecall->addAttributeAtIndex(AttributeList::FirstArgIndex, |
| Attribute::NonNull); |
| #else |
| freecall->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull); |
| #endif |
| } |
| if (Function *F = dyn_cast<Function>(freevalue)) |
| freecall->setCallingConv(F->getCallingConv()); |
| if (freecall->getParent() == nullptr) |
| builder.Insert(freecall); |
| return freecall; |
| } |