enzyme_leaf flag
diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 6280b31..407ae40 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp
@@ -91,10 +91,6 @@ llvm::cl::opt<bool> EnzymeOMPOpt("enzyme-omp-opt", cl::init(false), cl::Hidden, cl::desc("Whether to enable openmp opt")); -llvm::cl::opt<bool> EnzymeVectorizeAtLeafNodes( - "enzyme-vectorize-at-leaf-nodes", cl::init(false), cl::Hidden, - cl::desc("Run enzyme with an optimized memory layout for vector mode")); - #if LLVM_VERSION_MAJOR >= 14 #define addAttribute addAttributeAtIndex #endif @@ -988,6 +984,7 @@ IRBuilder<> Builder(CI); unsigned truei = 0; unsigned width = 1; + VectorModeMemoryLayout MemoryLayout = VectorModeMemoryLayout::VectorizeAtRootNode; std::map<unsigned, Value *> batchOffset; bool returnUsed = !fn->getReturnType()->isVoidTy() && !fn->getReturnType()->isEmptyTy(); @@ -1027,7 +1024,7 @@ case DerivativeMode::ForwardModeSplit: case DerivativeMode::ForwardMode: { Value *sretPt = CI->getArgOperand(0); - if (width > 1 && !EnzymeVectorizeAtLeafNodes) { + if (width > 1 && MemoryLayout != VectorModeMemoryLayout::VectorizeAtLeafNodes) { PointerType *pty = cast<PointerType>(sretPt->getType()); if (auto sty = dyn_cast<StructType>(pty->getPointerElementType())) { Value *acc = UndefValue::get( @@ -1203,6 +1200,9 @@ } else if (*metaString == "enzyme_width") { ++i; continue; + } else if (*metaString == "enzyme_leaf") { + MemoryLayout = VectorModeMemoryLayout::VectorizeAtLeafNodes; + continue; } else { EmitFailure("IllegalDiffeType", CI->getDebugLoc(), CI, "illegal enzyme metadata classification ", *CI, @@ -1351,7 +1351,7 @@ Value *res = nullptr; bool batch = batchOffset.count(i - 1) != 0; - unsigned actual_width = EnzymeVectorizeAtLeafNodes ? 1 : width; + unsigned actual_width = MemoryLayout == VectorModeMemoryLayout::VectorizeAtLeafNodes ? 1 : width; for (unsigned v = 0; v < actual_width; ++v) { #if LLVM_VERSION_MAJOR >= 14 @@ -1404,7 +1404,7 @@ auto expectedType = GradientUtils::getShadowType( *CI->getModule(), PTy, width, VectorModeMemoryLayout::VectorizeAtLeafNodes); - if (EnzymeVectorizeAtLeafNodes && + if (MemoryLayout == VectorModeMemoryLayout::VectorizeAtLeafNodes && expectedType != element->getType()) { element = castToDiffeFunctionArgType(Builder, CI, FT, expectedType, i, mode, element, truei); @@ -1412,7 +1412,7 @@ if (!element) { return false; } - } else if (!EnzymeVectorizeAtLeafNodes && PTy != element->getType()) { + } else if (MemoryLayout != VectorModeMemoryLayout::VectorizeAtLeafNodes && PTy != element->getType()) { element = castToDiffeFunctionArgType(Builder, CI, FT, PTy, i, mode, element, truei); if (!element) { @@ -1420,7 +1420,7 @@ } } - if (width > 1 && !EnzymeVectorizeAtLeafNodes) { + if (width > 1 && MemoryLayout != VectorModeMemoryLayout::VectorizeAtLeafNodes) { res = res ? Builder.CreateInsertValue(res, element, {v}) : Builder.CreateInsertValue(UndefValue::get(ArrayType::get( @@ -1489,15 +1489,11 @@ Function *newFunc = nullptr; Type *tapeType = nullptr; const AugmentedReturn *aug; - VectorModeMemoryLayout memoryLayout = - EnzymeVectorizeAtLeafNodes - ? VectorModeMemoryLayout::VectorizeAtLeafNodes - : VectorModeMemoryLayout::VectorizeAtRootNode; switch (mode) { case DerivativeMode::ForwardMode: newFunc = Logic.CreateForwardDiff( fn, retType, constants, TA, - /*should return*/ false, mode, memoryLayout, freeMemory, width, + /*should return*/ false, mode, MemoryLayout, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, /*augmented*/ nullptr); break; @@ -1541,7 +1537,7 @@ } newFunc = Logic.CreateForwardDiff( fn, retType, constants, TA, - /*should return*/ false, mode, memoryLayout, freeMemory, width, + /*should return*/ false, mode, MemoryLayout, freeMemory, width, /*addedType*/ tapeType, type_args, volatile_args, aug); break; } @@ -1757,7 +1753,7 @@ // convention. if (width > 1 && !diffret->getType()->isEmptyTy() && !diffret->getType()->isVoidTy() && - memoryLayout == VectorModeMemoryLayout::VectorizeAtRootNode && + MemoryLayout == VectorModeMemoryLayout::VectorizeAtRootNode && (mode == DerivativeMode::ForwardMode || mode == DerivativeMode::ForwardModeSplit)) {
diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 3195391..c15b488 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp
@@ -4056,9 +4056,21 @@ SmallPtrSet<Value *, 4> nonconstant_values; std::string prefix = "fakeaugmented"; - if (width > 1) + if (width > 1) { prefix += std::to_string(width); - prefix += "_"; + prefix += "_"; + switch (memoryLayout) { + case VectorModeMemoryLayout::VectorizeAtRootNode: + prefix += ".leaf"; + break; + case VectorModeMemoryLayout::VectorizeAtLeafNodes: + prefix += ".root"; + break; + } + prefix += "_"; + } + + prefix += todiff->getName().str(); auto newFunc = Logic.PPC.CloneFunctionWithReturns( @@ -4218,9 +4230,28 @@ case DerivativeMode::ReverseModePrimal: llvm_unreachable("invalid DerivativeMode: ReverseModePrimal\n"); } + + switch (memoryLayout) { + case VectorModeMemoryLayout::VectorizeAtRootNode: + prefix += ".leaf"; + break; + case VectorModeMemoryLayout::VectorizeAtLeafNodes: + prefix += ".root"; + break; + } - if (width > 1) + if (width > 1) { prefix += std::to_string(width); + switch (memoryLayout) { + case VectorModeMemoryLayout::VectorizeAtRootNode: + prefix += ".leaf"; + break; + case VectorModeMemoryLayout::VectorizeAtLeafNodes: + prefix += ".root"; + break; + } + } + auto newFunc = Logic.PPC.CloneFunctionWithReturns( mode, memoryLayout, width, oldFunc, invertedPointers, constant_args, @@ -4502,6 +4533,17 @@ if (width > 1) { prefix += std::to_string(width); + prefix += "_"; + + switch (memoryLayout) { + case VectorModeMemoryLayout::VectorizeAtRootNode: + prefix += ".leaf"; + break; + case VectorModeMemoryLayout::VectorizeAtLeafNodes: + prefix += ".root"; + break; + } + prefix += "_"; } std::string globalname = (prefix + "_" + fn->getName() + "'").str(); @@ -4533,6 +4575,17 @@ if (width > 1) { prefix += std::to_string(width); + prefix += "_"; + + switch (memoryLayout) { + case VectorModeMemoryLayout::VectorizeAtRootNode: + prefix += ".leaf"; + break; + case VectorModeMemoryLayout::VectorizeAtLeafNodes: + prefix += ".root"; + break; + } + prefix += "_"; } auto cdata = ConstantStruct::get(
diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index b5595b8..ef1d5ea 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp
@@ -767,6 +767,15 @@ DebugLoc DebugLoc = call->getDebugLoc(); std::string name = "__enzyme_checked_free_" + std::to_string(width); + + switch (memoryLayout) { + case VectorModeMemoryLayout::VectorizeAtRootNode: + name += ".leaf"; + break; + case VectorModeMemoryLayout::VectorizeAtLeafNodes: + name += ".root"; + break; + } SmallVector<Type *, 3> types; types.push_back(Ty);