draft
diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index fb7f56e..38435c1 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h
@@ -1109,9 +1109,6 @@ Type *ptrType = orig_ptr->getType(); Type *shadowTy = orig_val->getType(); - if (MemoryLayout == VectorModeMemoryLayout::VectorizeAtLeafNodes) - shadowTy = gutils->getShadowType(orig_val); - auto &DL = gutils->newFunc->getParent()->getDataLayout(); if (unnecessaryStores.count(&I)) { @@ -1440,8 +1437,6 @@ assert(FT); Type *shadowTy = op0->getType(); - if (MemoryLayout == VectorModeMemoryLayout::VectorizeAtLeafNodes) - shadowTy = gutils->getShadowType(op0); auto rule = [&](Value *dif) { if (I.getOpcode() == CastInst::CastOps::FPTrunc || @@ -1957,9 +1952,6 @@ if (!gutils->isConstantValue(orig_agg)) { Type *shadowTy = orig_inserted->getType(); - if (MemoryLayout == VectorModeMemoryLayout::VectorizeAtLeafNodes) - shadowTy = gutils->getShadowType(orig_inserted); - auto rule = [&](Value *prediff) { return Builder2.CreateInsertValue( prediff, Constant::getNullValue(shadowTy), IVI.getIndices()); @@ -2000,7 +1992,7 @@ /// Unwraps a vector derivative from its internal representation and applies a /// function f to each element. Return values of f are collected and wrapped. - template <ResultType resTy = ResultType::WRAPPED, typename Func, + template <ResultType resTy = ResultType::UNWRAPPED, typename Func, typename... Args> Value *applyChainRule(Type *diffType, IRBuilder<> &Builder, Func rule, Args... args) { @@ -2010,7 +2002,7 @@ /// Unwraps a vector derivative from its internal representation and applies a /// function f to each element. - template <ResultType resTy = ResultType::WRAPPED, typename Func, + template <ResultType resTy = ResultType::UNWRAPPED, typename Func, typename... Args> void applyChainRule(IRBuilder<> &Builder, Func rule, Args... args) { ((GradientUtils *)gutils)->applyChainRule<resTy>(Builder, rule, args...); @@ -2344,8 +2336,6 @@ Builder2); Type *BOTy = BO.getType(); - if (MemoryLayout == VectorModeMemoryLayout::VectorizeAtLeafNodes) - BOTy = gutils->getShadowType(BO); auto rule = [&Builder2, &BOTy](Value *idiff, Type *FT) { auto neg = @@ -2825,8 +2815,6 @@ assert(dif[1 - i]); Type *BOTy = BO.getType(); - if (MemoryLayout == VectorModeMemoryLayout::VectorizeAtLeafNodes) - BOTy = gutils->getShadowType(BO); auto rule = [&Builder2, &BOTy](Value *difi, Type *FT) { auto neg = @@ -2912,7 +2900,7 @@ return V; }; - auto diffe = applyChainRule<ResultType::UNWRAPPED>( + auto diffe = applyChainRule( BO.getType(), Builder2, rule, Gradient(dif[1 - i]), Primal(CV), Primal(eFT)); setDiffe(&BO, diffe, Builder2); @@ -3123,7 +3111,7 @@ auto dbg = gutils->getNewFromOriginal(MS.getDebugLoc()); - applyChainRule<ResultType::UNWRAPPED>( + applyChainRule( BuilderZ, [&](Value *op0, Value *op1, Value *op2, Value *op3) { SmallVector<Value *, 4> args = {op0, op1, op2}; @@ -3387,7 +3375,7 @@ cal->setDebugLoc(gutils->getNewFromOriginal(MS.getDebugLoc())); }; - applyChainRule<ResultType::UNWRAPPED>( + applyChainRule( BuilderZ, rule, Gradient(shadow_dst), Primal(op1), Primal(length)); } if (secretty && (Mode == DerivativeMode::ReverseModeGradient || @@ -3427,7 +3415,7 @@ cal->setDebugLoc(gutils->getNewFromOriginal(MS.getDebugLoc())); }; - applyChainRule<ResultType::UNWRAPPED>( + applyChainRule( Builder2, rule, Gradient(gutils->lookupM(shadow_dst, Builder2)), Primal(op1l), Primal(op3l), Primal(length)); } @@ -3493,11 +3481,6 @@ auto ddst = gutils->invertPointerM(orig_dst, Builder2); auto dsrc = gutils->invertPointerM(orig_src, Builder2); - if (MemoryLayout == VectorModeMemoryLayout::VectorizeAtLeafNodes) - new_size = Builder2.CreateMul( - new_size, ConstantInt::get(new_size->getType(), gutils->getWidth()), - new_size->getName() + ".vecsize"); - auto rule = [&](Value *ddst, Value *dsrc) { if (ddst->getType()->isIntegerTy()) ddst = Builder2.CreateIntToPtr( @@ -3992,7 +3975,7 @@ Builder2.CreateInsertElement(und, vdiff, (uint64_t)0), und, mask); }; - auto vec = applyChainRule<ResultType::UNWRAPPED>( + auto vec = applyChainRule( orig_ops[1]->getType(), Builder2, rule, Gradient(vdiff), Primal(und), Primal(mask)); addToDiffe(orig_ops[1], vec, Builder2, orig_ops[0]->getType()); @@ -4503,7 +4486,7 @@ return cal; }; - Value *dif = applyChainRule<ResultType::UNWRAPPED>( + Value *dif = applyChainRule( I.getType(), Builder2, rule, Gradient(accdif), Gradient(vecdif)); setDiffe(&I, dif, Builder2); return; @@ -4541,7 +4524,7 @@ cmp, Constant::getNullValue(dif0->getType()), dif0); }; - Value *dif0 = applyChainRule<ResultType::UNWRAPPED>( + Value *dif0 = applyChainRule( I.getType(), Builder2, rule, Gradient(op)); setDiffe(&I, dif0, Builder2); return; @@ -4892,9 +4875,6 @@ Value *op = diffe(orig_ops[0], Builder2); Type *shadowTy = CI.getType(); - if (MemoryLayout == VectorModeMemoryLayout::VectorizeAtLeafNodes) - shadowTy = gutils->getShadowType(CI); - auto rule = [&Builder2, &DL, &shadowTy](Value *op, Value *res, Value *mul) { Value *out = Builder2.CreateFMul(mul, op); @@ -5930,7 +5910,7 @@ Value *ip = gutils->invertPointerM(call.getOperand(1), BuilderZ); - Value *dval = applyChainRule<ResultType::UNWRAPPED>( + Value *dval = applyChainRule( call.getType(), BuilderZ, rule, Gradient(ip), Primal(op0), Primal(op1), Primal(op2), Primal(norm)); setDiffe(&call, dval, BuilderZ); @@ -6381,7 +6361,7 @@ return dres; }; - Value *dres = applyChainRule<ResultType::UNWRAPPED>( + Value *dres = applyChainRule( call.getType(), Builder2, rule, Gradient(dx), Gradient(dy)); setDiffe(&call, dres, Builder2); } @@ -6465,7 +6445,7 @@ } }; - applyChainRule<ResultType::UNWRAPPED>(Builder2, rule, Gradient(dx), + applyChainRule(Builder2, rule, Gradient(dx), Gradient(dy), Gradient(dif)); setDiffe(&call, Constant::getNullValue(gutils->getShadowType(call)), @@ -10319,9 +10299,6 @@ Type *shadowTy = orig->getType(); - if (MemoryLayout == VectorModeMemoryLayout::VectorizeAtLeafNodes) - shadowTy = gutils->getShadowType(orig); - auto rule = [&Builder2, &shadowTy](Value *vdiff, Value *dsin, Value *dcos) { Value *res = UndefValue::get(shadowTy); @@ -10566,7 +10543,7 @@ return Builder2.CreateCall(called, {vdiff, exp}); }; - Value *darg = applyChainRule<ResultType::UNWRAPPED>( + Value *darg = applyChainRule( call.getType(), Builder2, rule, Gradient(vdiff), Primal(exp)); setDiffe(orig, darg, Builder2); return; @@ -10583,7 +10560,7 @@ return Builder2.CreateCall(called, {vdiff, exponent}); }; Value *vdiff = diffe(orig, Builder2); - Value *darg = applyChainRule<ResultType::UNWRAPPED>( + Value *darg = applyChainRule( orig->getArgOperand(0)->getType(), Builder2, rule, Gradient(vdiff), Primal(exponent)); setDiffe(orig, Constant::getNullValue(gutils->getShadowType(orig)), @@ -10719,7 +10696,7 @@ auto rule = [&]() { return shadowHandlers[funcName.str()](bb, orig, args, gutils); }; - anti = applyChainRule<ResultType::UNWRAPPED>(call.getType(), bb, + anti = applyChainRule(call.getType(), bb, rule); if (anti->getType() != placeholder->getType()) { llvm::errs() << "orig: " << *orig << "\n"; @@ -10812,7 +10789,7 @@ return anti; }; - anti = applyChainRule<ResultType::UNWRAPPED>(orig->getType(), bb, + anti = applyChainRule(orig->getType(), bb, rule); gutils->invertedPointers.erase(found); @@ -11287,7 +11264,7 @@ return BuilderZ.CreateCall(called, {v}); }; - Value *val = applyChainRule<ResultType::UNWRAPPED>( + Value *val = applyChainRule( call.getType(), BuilderZ, rule, Gradient(ptrshadow)); gutils->replaceAWithB(placeholder, val); @@ -11449,10 +11426,6 @@ args.push_back(newfree); auto rule = [&](Value *tofree) { - if (MemoryLayout == VectorModeMemoryLayout::VectorizeAtLeafNodes) - tofree = Builder2.CreatePointerCast( - tofree, PointerType::getInt8PtrTy(tofree->getContext())); - args.push_back(tofree); };
diff --git a/enzyme/Enzyme/Annotations.h b/enzyme/Enzyme/Annotations.h index da01eb6..6484bcf 100644 --- a/enzyme/Enzyme/Annotations.h +++ b/enzyme/Enzyme/Annotations.h
@@ -53,37 +53,8 @@ Value *getValue(IRBuilder<> &Builder, GradientUtils *gutils) { VectorModeMemoryLayout memoryLayout = gutils->memoryLayout; unsigned width = gutils->getWidth(); - - if (width == 1) - return value; - - switch (memoryLayout) { - case VectorModeMemoryLayout::VectorizeAtRootNode: - return value; - case VectorModeMemoryLayout::VectorizeAtLeafNodes: - if (auto vty = dyn_cast<VectorType>(value->getType())) { -#if LLVM_VERSION_MAJOR >= 12 - unsigned vector_width = vty->getElementCount().getKnownMinValue(); -#else - unsigned vector_width = vty->getNumElements(); -#endif - return Builder.CreateShuffleVector( - value, UndefValue::get(value->getType()), - GradientUtils::CreateVectorSplatMask(vector_width, width), - value->getName() + ".vecsplat"); - } else if (auto sty = dyn_cast<StructType>(value->getType())) { - auto vsty = GradientUtils::getShadowType( - *Builder.GetInsertBlock()->getModule(), sty, width, memoryLayout); - Value *vecstruct = UndefValue::get(vsty); - for (unsigned i = 0; i < sty->getNumElements(); ++i) { - auto elem = Builder.CreateExtractValue(value, {i}); - auto splat = Builder.CreateVectorSplat(width, elem); - vecstruct = Builder.CreateInsertValue(vecstruct, splat, {i}); - } - return vecstruct; - } - return Builder.CreateVectorSplat(width, value); - } + + return value; } Value *getValue(IRBuilder<> &Builder, GradientUtils *gutils, unsigned i) { @@ -102,16 +73,7 @@ VectorModeMemoryLayout memoryLayout = gutils->memoryLayout; unsigned width = gutils->getWidth(); - if (width == 1) - return type; - - switch (memoryLayout) { - case VectorModeMemoryLayout::VectorizeAtRootNode: - return type; - case VectorModeMemoryLayout::VectorizeAtLeafNodes: - return GradientUtils::getShadowType( - *Builder.GetInsertBlock()->getModule(), type, width, memoryLayout); - } + return type; } Type *getValue(IRBuilder<> &Builder, GradientUtils *gutils, unsigned i) { @@ -130,19 +92,7 @@ VectorModeMemoryLayout memoryLayout = gutils->memoryLayout; unsigned width = gutils->getWidth(); - if (width == 1) - return type; - - switch (memoryLayout) { - case VectorModeMemoryLayout::VectorizeAtRootNode: - return type; - case VectorModeMemoryLayout::VectorizeAtLeafNodes: { - Type *ty = GradientUtils::getShadowType( - *Builder.GetInsertBlock()->getModule(), type->getElementType(), width, - memoryLayout); - return ArrayType::get(ty, type->getNumElements()); - } - } + return type; } ArrayType *getValue(IRBuilder<> &Builder, GradientUtils *gutils, unsigned i) { @@ -160,17 +110,7 @@ Constant *getValue(IRBuilder<> &Builder, GradientUtils *gutils) { VectorModeMemoryLayout memoryLayout = gutils->memoryLayout; unsigned width = gutils->getWidth(); - - if (width == 1) - return c; - - switch (memoryLayout) { - case VectorModeMemoryLayout::VectorizeAtRootNode: - return c; - case VectorModeMemoryLayout::VectorizeAtLeafNodes: - std::vector<Constant *> cs(width, c); - return ConstantVector::get(cs); - } + return c; } Constant *getValue(IRBuilder<> &Builder, GradientUtils *gutils, unsigned i) { @@ -234,42 +174,10 @@ assert(cast<ArrayType>(value->getType())->getNumElements() == width); return GradientUtils::extractMeta(Builder, value, i); case VectorModeMemoryLayout::VectorizeAtLeafNodes: - if (auto vty = dyn_cast<VectorType>(value->getType())) { -#if LLVM_VERSION_MAJOR >= 12 - unsigned vector_width = vty->getElementCount().getKnownMinValue(); -#else - unsigned vector_width = vty->getNumElements(); -#endif - if (vector_width / width > 1) { - return Builder.CreateShuffleVector( - value, UndefValue::get(value->getType()), - GradientUtils::CreateExtractSubvectorMask(vector_width, width, i), - value->getName() + ".subvector." + Twine(i)); - } else { - return Builder.CreateExtractElement(value, i); - } - } else if (auto pty = dyn_cast<PointerType>(value->getType())) { + if (auto pty = dyn_cast<PointerType>(value->getType())) { #if LLVM_VERSION_MAJOR >= 15 return value; #else - if (auto vty = dyn_cast<VectorType>(pty->getElementType())) { -#if LLVM_VERSION_MAJOR >= 12 - unsigned vector_width = vty->getElementCount().getKnownMinValue(); - Type *res_type = - FixedVectorType::get(vty->getElementType(), vector_width / width); -#else - unsigned vector_width = vty->getNumElements(); - Type *res_type = - VectorType::get(vty->getElementType(), vector_width / width); -#endif - if (vector_width / width > 1) { - Type *gep_type = PointerType::get(res_type, pty->getAddressSpace()); - Value *idx[2] = {Builder.getInt32(0), - Builder.getInt32(i * vector_width / width)}; - auto gep = Builder.CreateInBoundsGEP(vty, value, idx); - return Builder.CreatePointerCast(gep, gep_type); - } - } Value *idx[2] = {Builder.getInt32(0), Builder.getInt32(i)}; return Builder.CreateInBoundsGEP(pty->getElementType(), value, idx); #endif @@ -284,15 +192,6 @@ VectorModeMemoryLayout memoryLayout = gutils->memoryLayout; unsigned width = gutils->getWidth(); - if (width == 1 || !value) - return value; - - if (!value) - return nullptr; - - if (value && memoryLayout == VectorModeMemoryLayout::VectorizeAtRootNode) - assert(cast<ArrayType>(value->getType())->getNumElements() == width); - return value; } }; @@ -311,63 +210,16 @@ if (width == 1 || !value) return value; - if (!value) - return nullptr; - switch (memoryLayout) { case VectorModeMemoryLayout::VectorizeAtRootNode: return value; case VectorModeMemoryLayout::VectorizeAtLeafNodes: - if (auto vty = dyn_cast<VectorType>(value->getType())) { -#if LLVM_VERSION_MAJOR >= 12 - unsigned vector_width = vty->getElementCount().getKnownMinValue(); -#else - unsigned vector_width = vty->getNumElements(); -#endif - if (vector_width / width > 1) { -#if LLVM_VERSION_MAJOR >= 11 - auto Mask = - GradientUtils::CreateExtractSubvectorMask(vector_width, width, i); -#else - auto MaskArray = - GradientUtils::CreateExtractSubvectorMask(vector_width, width, i); - SmallVector<Constant *, 8> ConstantArray; - for (auto elem : MaskArray) - ConstantArray.push_back(Builder.getInt32(elem)); - auto Mask = ConstantVector::get(ConstantArray); -#endif - Constant *res = ConstantExpr::getShuffleVector( - value, UndefValue::get(value->getType()), Mask); - res->setName(value->getName() + ".subvector." + Twine(i)); - return res; - } else { - return ConstantExpr::getExtractElement(value, Builder.getInt64(i)); - } - } else if (auto pty = dyn_cast<PointerType>(value->getType())) { + if (auto pty = dyn_cast<PointerType>(value->getType())) { #if LLVM_VERSION_MAJOR >= 15 return value; #else - if (auto vty = dyn_cast<VectorType>(pty->getElementType())) { -#if LLVM_VERSION_MAJOR >= 12 - unsigned vector_width = vty->getElementCount().getKnownMinValue(); - Type *res_type = - FixedVectorType::get(vty->getElementType(), vector_width / width); -#else - unsigned vector_width = vty->getNumElements(); - Type *res_type = - VectorType::get(vty->getElementType(), vector_width / width); -#endif - if (vector_width / width > 1) { - Type *gep_type = PointerType::get(res_type, pty->getAddressSpace()); - Constant *idx[2] = {Builder.getInt32(0), - Builder.getInt32(i * vector_width / width)}; - auto gep = ConstantExpr::getInBoundsGetElementPtr(vty, value, idx); - return ConstantExpr::getPointerCast(gep, gep_type); - } - } Constant *idx[2] = {Builder.getInt32(0), Builder.getInt32(i)}; - return ConstantExpr::getInBoundsGetElementPtr(pty->getElementType(), - value, idx); + return ConstantExpr::getInBoundsGetElementPtr(pty->getElementType(), value, idx); #endif } return value; @@ -380,15 +232,6 @@ VectorModeMemoryLayout memoryLayout = gutils->memoryLayout; unsigned width = gutils->getWidth(); - if (width == 1 || !value) - return value; - - if (!value) - return nullptr; - - if (value && memoryLayout == VectorModeMemoryLayout::VectorizeAtRootNode) - assert(cast<ArrayType>(value->getType())->getNumElements() == width); - return value; } };
diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 1151c4c..35c8973 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp
@@ -222,8 +222,7 @@ if (!mask) { applyChainRule(BuilderM, rule, Gradient(ptr), Gradient(newval)); } else { - applyChainRule<ResultType::UNWRAPPED>(BuilderM, rule_mask, Gradient(ptr), - Gradient(newval), Primal(mask)); + applyChainRule(BuilderM, rule_mask, Gradient(ptr), Gradient(newval), Primal(mask)); } } @@ -621,8 +620,7 @@ Value *sargs[] = {res, ptr, alignv, mask}; BuilderM.CreateCall(SF, sargs); }; - applyChainRule<ResultType::UNWRAPPED>(BuilderM, rule, Gradient(ptr), - Gradient(dif)); + applyChainRule(BuilderM, rule, Gradient(ptr), Gradient(dif)); } } @@ -1420,8 +1418,7 @@ return toreturn; }; - Value *toreturn = applyChainRule<ResultType::UNWRAPPED>( - dli->getType(), BuilderM, rule, Gradient(pidx)); + Value *toreturn = applyChainRule(dli->getType(), BuilderM, rule, Gradient(pidx)); // TODO adding to cache only legal if no alias of any future writes if (permitCache) @@ -3388,8 +3385,7 @@ return anti; }; - anti = applyChainRule<ResultType::UNWRAPPED>( - orig->getType(), NB, rule); + anti = applyChainRule(orig->getType(), NB, rule); if (auto MD = hasMetadata(orig, "enzyme_fromstack")) { auto rule1 = [&](Value *anti) { @@ -3412,9 +3408,7 @@ }; Value *replacement = - applyChainRule<ResultType::UNWRAPPED>( - Type::getInt8Ty(orig->getContext()), NB, rule1, - Gradient(anti)); + applyChainRule(Type::getInt8Ty(orig->getContext()), NB, rule1, Gradient(anti)); replaceAWithB(cast<Instruction>(anti), replacement); erase(cast<Instruction>(anti)); @@ -3426,8 +3420,7 @@ cast<CallInst>(orig)); }; - applyChainRule<ResultType::UNWRAPPED>( - NB, rule2, Primal(orig), Gradient(anti)); + applyChainRule(NB, rule2, Primal(orig), Gradient(anti)); } } else { llvm_unreachable("Unknown shadow rematerialization value"); @@ -4843,10 +4836,6 @@ M->getDataLayout().getTypeAllocSizeInBits(arg->getValueType()) / 8); - if (memoryLayout == VectorModeMemoryLayout::VectorizeAtLeafNodes) - len_arg = bb.CreateMul( - len_arg, ConstantInt::get(len_arg->getType(), width)); - auto rule2 = [&bb, &arg, &M, &oval, &len_arg, &val_arg, &ty](Value *antialloca) { auto dst_arg = bb.CreateBitCast(antialloca, ty); @@ -5034,9 +5023,6 @@ Value *invertOp = invertPointerM(arg->getOperand(0), bb); Type *shadowTy = arg->getDestTy(); - if (memoryLayout == VectorModeMemoryLayout::VectorizeAtLeafNodes) - shadowTy = getShadowType(arg); - auto rule = [&bb, &arg, &shadowTy](Value *invertOp) { return bb.CreateCast(arg->getOpcode(), invertOp, shadowTy, arg->getName() + "'ipc"); @@ -5157,8 +5143,7 @@ ; }; - Value *shadow = applyChainRule<ResultType::UNWRAPPED>( - arg->getType(), bb, rule, Gradient(ip), Primal(index)); + Value *shadow = applyChainRule(arg->getType(), bb, rule, Gradient(ip), Primal(index)); invertedPointers.insert( std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); @@ -5198,8 +5183,7 @@ return bb.CreateShuffleVector(ip0, ip1, mask, arg->getName() + "'ipsv"); }; - Value *shadow = applyChainRule<ResultType::UNWRAPPED>( - arg->getType(), bb, rule, Gradient(ip0), Gradient(ip1)); + Value *shadow = applyChainRule(arg->getType(), bb, rule, Gradient(ip0), Gradient(ip1)); invertedPointers.insert( std::make_pair((const Value *)oval, InvertedPointerVH(this, shadow))); @@ -5454,9 +5438,7 @@ return li; }; - return applyChainRule<ResultType::UNWRAPPED>( - II->getType(), bb, rule, Gradient(ptr), Primal(align), Primal(mask), - Gradient(defaultV)); + return applyChainRule(II->getType(), bb, rule, Gradient(ptr), Primal(align), Primal(mask), Gradient(defaultV)); } } } else if (auto phi = dyn_cast<PHINode>(oval)) {
diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index 0ddfd55..871e221 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h
@@ -1528,21 +1528,11 @@ width, under_construction, M), pty->getAddressSpace()); } else if (auto vty = dyn_cast<VectorType>(ty)) { -#if LLVM_VERSION_MAJOR >= 12 - return VectorType::get(vty->getElementType(), - vty->getElementCount() * width); -#else - return VectorType::get(vty->getElementType(), - vty->getNumElements() * width); -#endif + return ArrayType::get(vty, width); #endif } else { if (TT.Inner0().isPossibleFloat()) { -#if LLVM_VERSION_MAJOR >= 12 - return FixedVectorType::get(ty, width); -#else - return VectorType::get(ty, width); -#endif + return ArrayType::get(ty, width); } return ty; } @@ -1743,7 +1733,7 @@ /// Unwraps a vector derivative from its internal representation and applies a /// function f to each element. Return values of f are collected and wrapped. - template <ResultType resTy = ResultType::WRAPPED, typename Func, + template <ResultType resTy = ResultType::UNWRAPPED, typename Func, typename... Args> Value *applyChainRule(Type *diffType, IRBuilder<> &Builder, Func rule, Args... args) { @@ -1759,42 +1749,7 @@ return res; } else if (width > 1 && resTy == ResultType::UNWRAPPED && memoryLayout == VectorModeMemoryLayout::VectorizeAtLeafNodes) { - if (diffType->isVectorTy()) { - Value *res = nullptr; - for (unsigned int i = 0; i < width; ++i) { - Value *diff = - std::apply(rule, std::move(eval_tuple(Builder, i, args...))); - if (res) { - VectorType *rvty = cast<VectorType>(res->getType()); - VectorType *dvty = cast<VectorType>(diff->getType()); -#if LLVM_VERSION_MAJOR >= 12 - unsigned rvty_count = rvty->getElementCount().getKnownMinValue(); - unsigned dvty_count = dvty->getElementCount().getKnownMinValue(); -#else - unsigned rvty_count = rvty->getNumElements(); - unsigned dvty_count = dvty->getNumElements(); -#endif - if (dvty_count < rvty_count) { - auto PadMask = CreateVectorConcatenationMask(rvty_count, 0); - diff = Builder.CreateShuffleVector( - diff, UndefValue::get(diff->getType()), PadMask, - diff->getName() + ".vecpad"); - } - auto ConcatMask = - CreateVectorConcatenationMask(rvty_count, dvty_count); - res = Builder.CreateShuffleVector(res, diff, ConcatMask, - diff->getName() + ".vecconcat"); - } else { - res = diff; - } - } - return res; - } else { -#if LLVM_VERSION_MAJOR >= 12 - Type *wrappedType = FixedVectorType::get(diffType, width); -#else - Type *wrappedType = VectorType::get(diffType, width); -#endif + Type *wrappedType = ArrayType::get(diffType, width); Value *res = UndefValue::get(wrappedType); for (unsigned int i = 0; i < width; ++i) { auto diff = @@ -1803,13 +1758,12 @@ } return res; } - } return std::apply(rule, std::move(eval_tuple(Builder, nullptr, args...))); } /// Unwraps a vector derivative from its internal representation and applies a /// function f to each element. Return values of f are collected and wrapped. - template <ResultType resTy = ResultType::WRAPPED, typename Func, + template <ResultType resTy = ResultType::UNWRAPPED, typename Func, typename... Args> void applyChainRule(IRBuilder<> &Builder, Func rule, Args... args) {