tmp
diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index df3dc13..1b5eb9b 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp
@@ -338,7 +338,7 @@ } } -#define getOpFullest(Builder, vtmp, frominst, check) \ +#define getOpFullest(Builder, vtmp, frominst, lookupInst, check) \ ({ \ Value *v = vtmp; \ BasicBlock *origParent = frominst; \ @@ -362,12 +362,38 @@ if (!DT.dominates(opinst, &*Builder.GetInsertPoint())) \ noLookup = true; \ } \ + origParent = lookupInst; \ + llvm::errs() << " v: " << *v << "\n"; \ + llvm::errs() << " lookupInst: " << *lookupInst << "\n"; \ + if (BasicBlock *forwardBlock = origParent) \ + if (auto opinst = dyn_cast<Instruction>(v)) { \ + if (!isOriginalBlock(*forwardBlock)) { \ + forwardBlock = originalForReverseBlock(*forwardBlock); \ + } \ + llvm::errs() << " fwd: " << *forwardBlock << "\n"; \ + if (isPotentialLastLoopValue(opinst, forwardBlock, LI)) { \ + v = fixLCSSA(opinst, forwardBlock); \ + origParent = nullptr; \ + llvm::errs() << " last: " << *v << "\n";\ + } \ + } \ if (!noLookup) \ ___res = lookupM(v, Builder, available, v != val, origParent); \ } \ if (___res) \ assert(___res->getType() == v->getType() && "uw"); \ } else { \ + origParent = lookupInst; \ + if (BasicBlock *forwardBlock = origParent) \ + if (auto opinst = dyn_cast<Instruction>(v)) { \ + if (!isOriginalBlock(*forwardBlock)) { \ + forwardBlock = originalForReverseBlock(*forwardBlock); \ + } \ + if (isPotentialLastLoopValue(opinst, forwardBlock, LI)) { \ + v = fixLCSSA(opinst, forwardBlock); \ + origParent = nullptr; \ + } \ + } \ assert(unwrapMode == UnwrapMode::AttemptSingleUnwrap); \ auto found = available.find(v); \ assert(found == available.end() || found->second); \ @@ -381,12 +407,18 @@ } \ ___res; \ }) -#define getOpFull(Builder, vtmp, frominst) \ - getOpFullest(Builder, vtmp, frominst, true) +#define getOpFull(Builder, vtmp, frominst) \ + ({\ + BasicBlock *parent = scope; \ + if (parent == nullptr) \ + if (auto originst = dyn_cast<Instruction>(val)) \ + parent = originst->getParent(); \ + getOpFullest(Builder, vtmp, frominst, parent, true);\ + }) #define getOpUnchecked(vtmp) \ ({ \ BasicBlock *parent = scope; \ - getOpFullest(BuilderM, vtmp, parent, false); \ + getOpFullest(BuilderM, vtmp, parent, parent, false); \ }) #define getOp(vtmp) \ ({ \ @@ -394,7 +426,7 @@ if (parent == nullptr) \ if (auto originst = dyn_cast<Instruction>(val)) \ parent = originst->getParent(); \ - getOpFullest(BuilderM, vtmp, parent, true); \ + getOpFullest(BuilderM, vtmp, parent, parent, true); \ }) if (isa<Argument>(val) || isa<Constant>(val)) { @@ -1462,19 +1494,42 @@ if (!DT.dominates(inst, &*B.GetInsertPoint())) noLookup = true; } - if (!noLookup) - ___res = - lookupM(inst, B, prevAvailable, inst != val, nextScope); + if (!noLookup) { + BasicBlock *nS2 = nextScope; + Value *v = inst; + if (BasicBlock *forwardBlock = nextScope) + if (auto opinst = dyn_cast<Instruction>(v)) { + if (!isOriginalBlock(*forwardBlock)) { + forwardBlock = originalForReverseBlock(*forwardBlock); + } + if (isPotentialLastLoopValue(opinst, forwardBlock, + LI)) { + v = fixLCSSA(opinst, forwardBlock); + nS2 = nullptr; + } + } + ___res = lookupM(v, B, prevAvailable, v != val, nS2); + } } if (___res) assert(___res->getType() == inst->getType() && "uw"); } else { - ___res = - lookupM(inst, B, prevAvailable, inst != val, nextScope); - if (___res && ___res->getType() != inst->getType()) { + BasicBlock *nS2 = nextScope; + Value *v = inst; + if (BasicBlock *forwardBlock = nextScope) + if (auto opinst = dyn_cast<Instruction>(v)) { + if (!isOriginalBlock(*forwardBlock)) { + forwardBlock = originalForReverseBlock(*forwardBlock); + } + if (isPotentialLastLoopValue(opinst, forwardBlock, LI)) { + v = fixLCSSA(opinst, forwardBlock); + nS2 = nullptr; + } + } + ___res = lookupM(v, B, prevAvailable, v != val, nS2); + if (___res && ___res->getType() != v->getType()) { llvm::errs() << *newFunc << "\n"; - llvm::errs() - << " inst = " << *inst << " res = " << *___res << "\n"; + llvm::errs() << " v = " << *v << " res = " << *___res << "\n"; } if (___res) assert(___res->getType() == inst->getType() && "lu"); @@ -1763,8 +1818,19 @@ return nullptr; } } + BasicBlock *nS2 = scope; + if (BasicBlock *forwardBlock = scope) + if (auto opinst = dyn_cast<Instruction>(nval)) { + if (!isOriginalBlock(*forwardBlock)) { + forwardBlock = originalForReverseBlock(*forwardBlock); + } + if (isPotentialLastLoopValue(opinst, forwardBlock, LI)) { + nval = fixLCSSA(opinst, forwardBlock); + nS2 = nullptr; + } + } auto toreturn = lookupM(nval, BuilderM, available, - /*tryLegalRecomputeCheck*/ false, scope); + /*tryLegalRecomputeCheck*/ false, nS2); assert(val->getType() == toreturn->getType()); return toreturn; } @@ -5210,12 +5276,13 @@ assert(inst->getName() != "<badref>"); val = fixLCSSA(inst, scope); - if (isa<UndefValue>(val)) { + if (isa<UndefValue>(val) || inst->getName() == "a14") { llvm::errs() << *oldFunc << "\n"; llvm::errs() << *newFunc << "\n"; llvm::errs() << *BuilderM.GetInsertBlock() << "\n"; llvm::errs() << *scope << "\n"; llvm::errs() << *val << " inst " << *inst << "\n"; + assert(0 && "undef value upon lcssa"); } inst = cast<Instruction>(val); assert(prelcssaInst->getType() == inst->getType());