blob: 669e6d0c5aa9d943848e06c73a832d493efc598b [file] [log] [blame] [edit]
//===- Enzyme.cpp - Automatic Differentiation Transformation Pass -------===//
//
// Enzyme Project
//
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// If using this code in an academic setting, please cite the following:
// @incollection{enzymeNeurips,
// title = {Instead of Rewriting Foreign Code for Machine Learning,
// Automatically Synthesize Fast Gradients},
// author = {Moses, William S. and Churavy, Valentin},
// booktitle = {Advances in Neural Information Processing Systems 33},
// year = {2020},
// note = {To appear in},
// }
//
//===----------------------------------------------------------------------===//
//
// This file contains Enzyme, a transformation pass that takes replaces calls
// to function calls to *__enzyme_autodiff* with a call to the derivative of
// the function passed as the first argument.
//
//===----------------------------------------------------------------------===//
#include <llvm/Config/llvm-config.h>
#if LLVM_VERSION_MAJOR >= 16
#define private public
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
#undef private
#else
#include "SCEV/ScalarEvolution.h"
#include "SCEV/ScalarEvolutionExpander.h"
#endif
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/MapVector.h"
#if LLVM_VERSION_MAJOR <= 16
#include "llvm/ADT/Optional.h"
#else
#include <optional>
#endif
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/InlineAdvisor.h"
#include "llvm/Analysis/InlineCost.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/AbstractCallSite.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "ActivityAnalysis.h"
#include "DiffeGradientUtils.h"
#include "EnzymeLogic.h"
#include "GradientUtils.h"
#include "TraceInterface.h"
#include "TraceUtils.h"
#include "Utils.h"
#include "InstructionBatcher.h"
#include "llvm/Transforms/Utils.h"
#if LLVM_VERSION_MAJOR >= 13
#include "llvm/Transforms/IPO/Attributor.h"
#include "llvm/Transforms/IPO/OpenMPOpt.h"
#include "llvm/Transforms/Utils/Mem2Reg.h"
#endif
#include "BlasAttributor.inc"
#include "CApi.h"
using namespace llvm;
#ifdef DEBUG_TYPE
#undef DEBUG_TYPE
#endif
#define DEBUG_TYPE "lower-enzyme-intrinsic"
llvm::cl::opt<bool>
EnzymePostOpt("enzyme-postopt", cl::init(false), cl::Hidden,
cl::desc("Run enzymepostprocessing optimizations"));
llvm::cl::opt<bool> EnzymeAttributor("enzyme-attributor", cl::init(false),
cl::Hidden,
cl::desc("Run attributor post Enzyme"));
llvm::cl::opt<bool> EnzymeOMPOpt("enzyme-omp-opt", cl::init(false), cl::Hidden,
cl::desc("Whether to enable openmp opt"));
#if LLVM_VERSION_MAJOR >= 14
#define addAttribute addAttributeAtIndex
#define getAttribute getAttributeAtIndex
#endif
void attributeKnownFunctions(llvm::Function &F) {
if (F.getName().contains("__enzyme_float") ||
F.getName().contains("__enzyme_double") ||
F.getName().contains("__enzyme_integer") ||
F.getName().contains("__enzyme_pointer") ||
F.getName().contains("__enzyme_todense") ||
F.getName().contains("__enzyme_iter") ||
F.getName().contains("__enzyme_virtualreverse")) {
#if LLVM_VERSION_MAJOR >= 16
F.setOnlyReadsMemory();
F.setOnlyWritesMemory();
#else
F.addFnAttr(Attribute::ReadNone);
#endif
if (!F.getName().contains("__enzyme_todense"))
for (auto &arg : F.args()) {
if (arg.getType()->isPointerTy()) {
arg.addAttr(Attribute::ReadNone);
arg.addAttr(Attribute::NoCapture);
}
}
}
if (F.getName() == "memcmp") {
#if LLVM_VERSION_MAJOR >= 16
F.setOnlyAccessesArgMemory();
F.setOnlyReadsMemory();
#else
F.addFnAttr(Attribute::ArgMemOnly);
F.addFnAttr(Attribute::ReadOnly);
#endif
F.addFnAttr(Attribute::NoUnwind);
F.addFnAttr(Attribute::NoRecurse);
F.addFnAttr(Attribute::WillReturn);
F.addFnAttr(Attribute::NoFree);
F.addFnAttr(Attribute::NoSync);
for (int i = 0; i < 2; i++)
if (F.getFunctionType()->getParamType(i)->isPointerTy()) {
F.addParamAttr(i, Attribute::NoCapture);
F.addParamAttr(i, Attribute::WriteOnly);
}
}
attributeTablegen(F);
if (F.getName() ==
"_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE9_M_createERmm") {
F.addFnAttr(Attribute::NoFree);
}
if (F.getName() == "MPI_Irecv" || F.getName() == "PMPI_Irecv") {
#if LLVM_VERSION_MAJOR >= 16
F.setOnlyAccessesInaccessibleMemOrArgMem();
#else
F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
#endif
F.addFnAttr(Attribute::NoUnwind);
F.addFnAttr(Attribute::NoRecurse);
F.addFnAttr(Attribute::WillReturn);
F.addFnAttr(Attribute::NoFree);
F.addFnAttr(Attribute::NoSync);
F.addParamAttr(0, Attribute::WriteOnly);
if (F.getFunctionType()->getParamType(2)->isPointerTy()) {
F.addParamAttr(2, Attribute::NoCapture);
F.addParamAttr(2, Attribute::WriteOnly);
}
F.addParamAttr(6, Attribute::WriteOnly);
}
if (F.getName() == "MPI_Isend" || F.getName() == "PMPI_Isend") {
#if LLVM_VERSION_MAJOR >= 16
F.setOnlyAccessesInaccessibleMemOrArgMem();
#else
F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
#endif
F.addFnAttr(Attribute::NoUnwind);
F.addFnAttr(Attribute::NoRecurse);
F.addFnAttr(Attribute::WillReturn);
F.addFnAttr(Attribute::NoFree);
F.addFnAttr(Attribute::NoSync);
F.addParamAttr(0, Attribute::ReadOnly);
if (F.getFunctionType()->getParamType(2)->isPointerTy()) {
F.addParamAttr(2, Attribute::NoCapture);
F.addParamAttr(2, Attribute::ReadOnly);
}
F.addParamAttr(6, Attribute::WriteOnly);
}
if (F.getName() == "MPI_Comm_rank" || F.getName() == "PMPI_Comm_rank" ||
F.getName() == "MPI_Comm_size" || F.getName() == "PMPI_Comm_size") {
#if LLVM_VERSION_MAJOR >= 16
F.setOnlyAccessesInaccessibleMemOrArgMem();
#else
F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
#endif
F.addFnAttr(Attribute::NoUnwind);
F.addFnAttr(Attribute::NoRecurse);
F.addFnAttr(Attribute::WillReturn);
F.addFnAttr(Attribute::NoFree);
F.addFnAttr(Attribute::NoSync);
if (F.getFunctionType()->getParamType(0)->isPointerTy()) {
F.addParamAttr(0, Attribute::NoCapture);
F.addParamAttr(0, Attribute::ReadOnly);
}
if (F.getFunctionType()->getParamType(1)->isPointerTy()) {
F.addParamAttr(1, Attribute::WriteOnly);
F.addParamAttr(1, Attribute::NoCapture);
}
}
if (F.getName() == "MPI_Wait" || F.getName() == "PMPI_Wait") {
F.addFnAttr(Attribute::NoUnwind);
F.addFnAttr(Attribute::NoRecurse);
F.addFnAttr(Attribute::WillReturn);
F.addFnAttr(Attribute::NoFree);
F.addFnAttr(Attribute::NoSync);
F.addParamAttr(0, Attribute::NoCapture);
F.addParamAttr(1, Attribute::WriteOnly);
F.addParamAttr(1, Attribute::NoCapture);
}
if (F.getName() == "MPI_Waitall" || F.getName() == "PMPI_Waitall") {
F.addFnAttr(Attribute::NoUnwind);
F.addFnAttr(Attribute::NoRecurse);
F.addFnAttr(Attribute::WillReturn);
F.addFnAttr(Attribute::NoFree);
F.addFnAttr(Attribute::NoSync);
F.addParamAttr(1, Attribute::NoCapture);
F.addParamAttr(2, Attribute::WriteOnly);
F.addParamAttr(2, Attribute::NoCapture);
}
if (F.getName() == "omp_get_max_threads" ||
F.getName() == "omp_get_thread_num") {
#if LLVM_VERSION_MAJOR >= 16
F.setOnlyAccessesInaccessibleMemory();
F.setOnlyReadsMemory();
#else
F.addFnAttr(Attribute::InaccessibleMemOnly);
F.addFnAttr(Attribute::ReadOnly);
#endif
}
if (F.getName() == "frexp" || F.getName() == "frexpf" ||
F.getName() == "frexpl") {
#if LLVM_VERSION_MAJOR >= 16
F.setOnlyAccessesArgMemory();
#else
F.addFnAttr(Attribute::ArgMemOnly);
#endif
F.addParamAttr(1, Attribute::WriteOnly);
}
if (F.getName() == "__fd_sincos_1" || F.getName() == "__fd_cos_1" ||
F.getName() == "__mth_i_ipowi") {
#if LLVM_VERSION_MAJOR >= 16
F.setOnlyReadsMemory();
F.setOnlyWritesMemory();
#else
F.addFnAttr(Attribute::ReadNone);
#endif
}
}
namespace {
static Value *
castToDiffeFunctionArgType(IRBuilder<> &Builder, llvm::CallInst *CI,
llvm::FunctionType *FT, llvm::Type *destType,
unsigned int i, DerivativeMode mode,
llvm::Value *value, unsigned int truei) {
auto res = value;
if (auto ptr = dyn_cast<PointerType>(res->getType())) {
if (auto PT = dyn_cast<PointerType>(destType)) {
if (ptr->getAddressSpace() != PT->getAddressSpace()) {
#if LLVM_VERSION_MAJOR < 18
#if LLVM_VERSION_MAJOR >= 15
if (CI->getContext().supportsTypedPointers()) {
#endif
res = Builder.CreateAddrSpaceCast(
res, PointerType::get(ptr->getPointerElementType(),
PT->getAddressSpace()));
#if LLVM_VERSION_MAJOR >= 15
} else {
res = Builder.CreateAddrSpaceCast(res, PT);
}
#endif
#else
res = Builder.CreateAddrSpaceCast(res, PT);
#endif
assert(value);
assert(destType);
assert(FT);
llvm::errs() << "Warning cast(2) __enzyme_autodiff argument " << i
<< " " << *res << "|" << *res->getType() << " to argument "
<< truei << " " << *destType << "\n"
<< "orig: " << *FT << "\n";
return res;
}
}
}
if (!res->getType()->canLosslesslyBitCastTo(destType)) {
assert(value);
assert(value->getType());
assert(destType);
assert(FT);
auto loc = CI->getDebugLoc();
if (auto arg = dyn_cast<Instruction>(res)) {
loc = arg->getDebugLoc();
}
EmitFailure("IllegalArgCast", loc, CI,
"Cannot cast __enzyme_autodiff shadow argument ", i, ", found ",
*res, ", type ", *res->getType(), " - to arg ", truei, " ",
*destType);
return nullptr;
}
return Builder.CreateBitCast(value, destType);
}
#if LLVM_VERSION_MAJOR > 16
static std::optional<StringRef> getMetadataName(llvm::Value *res);
#else
static Optional<StringRef> getMetadataName(llvm::Value *res);
#endif
// if all phi arms are (recursively) based on the same metaString, use that
#if LLVM_VERSION_MAJOR > 16
static std::optional<StringRef> recursePhiReads(PHINode *val)
#else
static Optional<StringRef> recursePhiReads(PHINode *val)
#endif
{
#if LLVM_VERSION_MAJOR > 16
std::optional<StringRef> finalMetadata;
#else
Optional<StringRef> finalMetadata;
#endif
SmallVector<PHINode *, 1> todo = {val};
SmallSet<PHINode *, 1> done;
while (todo.size()) {
auto phiInst = todo.back();
todo.pop_back();
if (done.count(phiInst))
continue;
done.insert(phiInst);
for (unsigned j = 0; j < phiInst->getNumIncomingValues(); ++j) {
auto newVal = phiInst->getIncomingValue(j);
if (auto phi = dyn_cast<PHINode>(newVal)) {
todo.push_back(phi);
} else {
auto metaString = getMetadataName(newVal);
if (metaString) {
if (!finalMetadata) {
finalMetadata = metaString;
} else if (finalMetadata != metaString) {
return {};
}
}
}
}
}
return finalMetadata;
}
#if LLVM_VERSION_MAJOR > 16
std::optional<StringRef> getMetadataName(llvm::Value *res)
#else
Optional<StringRef> getMetadataName(llvm::Value *res)
#endif
{
if (auto av = dyn_cast<MetadataAsValue>(res)) {
return cast<MDString>(av->getMetadata())->getString();
} else if ((isa<LoadInst>(res) || isa<CastInst>(res)) &&
isa<GlobalVariable>(cast<Instruction>(res)->getOperand(0))) {
GlobalVariable *gv =
cast<GlobalVariable>(cast<Instruction>(res)->getOperand(0));
return gv->getName();
} else if (isa<LoadInst>(res) &&
isa<ConstantExpr>(cast<LoadInst>(res)->getOperand(0)) &&
cast<ConstantExpr>(cast<LoadInst>(res)->getOperand(0))->isCast() &&
isa<GlobalVariable>(
cast<ConstantExpr>(cast<LoadInst>(res)->getOperand(0))
->getOperand(0))) {
auto gv = cast<GlobalVariable>(
cast<ConstantExpr>(cast<LoadInst>(res)->getOperand(0))->getOperand(0));
return gv->getName();
} else if (auto gv = dyn_cast<GlobalVariable>(res)) {
return gv->getName();
} else if (isa<ConstantExpr>(res) && cast<ConstantExpr>(res)->isCast() &&
isa<GlobalVariable>(cast<ConstantExpr>(res)->getOperand(0))) {
auto gv = cast<GlobalVariable>(cast<ConstantExpr>(res)->getOperand(0));
return gv->getName();
} else if (isa<CastInst>(res) && cast<CastInst>(res) &&
isa<AllocaInst>(cast<CastInst>(res)->getOperand(0))) {
auto gv = cast<AllocaInst>(cast<CastInst>(res)->getOperand(0));
return gv->getName();
} else if (auto gv = dyn_cast<AllocaInst>(res)) {
return gv->getName();
} else {
if (isa<PHINode>(res)) {
return recursePhiReads(cast<PHINode>(res));
}
return {};
}
}
static Value *adaptReturnedVector(Value *ret, Value *diffret,
IRBuilder<> &Builder, unsigned width) {
Type *returnType = ret->getType();
if (StructType *sty = dyn_cast<StructType>(returnType)) {
Value *agg = ConstantAggregateZero::get(sty);
for (unsigned int i = 0; i < width; i++) {
Value *elem = Builder.CreateExtractValue(diffret, {i});
if (auto vty = dyn_cast<FixedVectorType>(elem->getType())) {
for (unsigned j = 0; j < vty->getNumElements(); ++j) {
Value *vecelem = Builder.CreateExtractElement(elem, j);
agg = Builder.CreateInsertValue(agg, vecelem, {i * j});
}
} else {
agg = Builder.CreateInsertValue(agg, elem, {i});
}
}
diffret = agg;
}
return diffret;
}
static bool ReplaceOriginalCall(IRBuilder<> &Builder, Value *ret,
Type *retElemType, Value *diffret,
Instruction *CI, DerivativeMode mode) {
Type *retType = ret->getType();
Type *diffretType = diffret->getType();
auto &DL = CI->getModule()->getDataLayout();
if (diffretType->isEmptyTy() || diffretType->isVoidTy() ||
retType->isEmptyTy() || retType->isVoidTy()) {
CI->replaceAllUsesWith(UndefValue::get(CI->getType()));
CI->eraseFromParent();
return true;
}
if (retType == diffretType) {
CI->replaceAllUsesWith(diffret);
CI->eraseFromParent();
return true;
}
if (auto sretType = dyn_cast<StructType>(retType),
diffsretType = dyn_cast<StructType>(diffretType);
sretType && diffsretType && sretType->isLayoutIdentical(diffsretType)) {
Value *newStruct = UndefValue::get(sretType);
for (unsigned int i = 0; i < sretType->getStructNumElements(); i++) {
Value *elem = Builder.CreateExtractValue(diffret, {i});
newStruct = Builder.CreateInsertValue(newStruct, elem, {i});
}
CI->replaceAllUsesWith(newStruct);
CI->eraseFromParent();
return true;
}
if (isa<PointerType>(retType)) {
retType = retElemType;
if (auto sretType = dyn_cast<StructType>(retType),
diffsretType = dyn_cast<StructType>(diffretType);
sretType && diffsretType && sretType->isLayoutIdentical(diffsretType)) {
for (unsigned int i = 0; i < sretType->getStructNumElements(); i++) {
Value *sgep = Builder.CreateStructGEP(retType, ret, i);
Builder.CreateStore(Builder.CreateExtractValue(diffret, {i}), sgep);
}
CI->eraseFromParent();
return true;
}
if (DL.getTypeSizeInBits(retType) >= DL.getTypeSizeInBits(diffretType)) {
Builder.CreateStore(
diffret,
Builder.CreatePointerCast(ret, PointerType::getUnqual(diffretType)));
CI->eraseFromParent();
return true;
}
}
if (mode == DerivativeMode::ReverseModePrimal &&
DL.getTypeSizeInBits(retType) >= DL.getTypeSizeInBits(diffretType)) {
IRBuilder<> EB(CI->getFunction()->getEntryBlock().getFirstNonPHI());
auto AL = EB.CreateAlloca(retType);
Builder.CreateStore(diffret, Builder.CreatePointerCast(
AL, PointerType::getUnqual(diffretType)));
Value *cload = Builder.CreateLoad(retType, AL);
CI->replaceAllUsesWith(cload);
CI->eraseFromParent();
return true;
}
if (mode != DerivativeMode::ReverseModePrimal &&
diffret->getType()->isAggregateType()) {
auto diffreti = Builder.CreateExtractValue(diffret, {0});
if (diffreti->getType() == retType) {
CI->replaceAllUsesWith(diffreti);
CI->eraseFromParent();
return true;
} else if (diffretType == retType) {
CI->replaceAllUsesWith(diffret);
CI->eraseFromParent();
return true;
}
}
EmitFailure("IllegalReturnCast", CI->getDebugLoc(), CI,
"Cannot cast return type of gradient ", *diffretType, *diffret,
", to desired type ", *retType);
return false;
}
class EnzymeBase {
public:
EnzymeLogic Logic;
EnzymeBase(bool PostOpt)
: Logic(EnzymePostOpt.getNumOccurrences() ? EnzymePostOpt : PostOpt) {
// initializeLowerAutodiffIntrinsicPass(*PassRegistry::getPassRegistry());
}
Function *parseFunctionParameter(CallInst *CI, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI) {
Value *fn = CI->getArgOperand(0);
// determine function to differentiate
if (CI->hasStructRetAttr()) {
fn = CI->getArgOperand(1);
}
Value *ofn = fn;
fn = GetFunctionFromValue(fn, AA, TLI);
if (!fn || !isa<Function>(fn)) {
assert(ofn);
EmitFailure("NoFunctionToDifferentiate", CI->getDebugLoc(), CI,
"failed to find fn to differentiate", *CI, " - found - ",
*ofn);
return nullptr;
}
if (cast<Function>(fn)->empty()) {
EmitFailure("EmptyFunctionToDifferentiate", CI->getDebugLoc(), CI,
"failed to find fn to differentiate", *CI, " - found - ",
*fn);
return nullptr;
}
return cast<Function>(fn);
}
#if LLVM_VERSION_MAJOR > 16
static std::optional<unsigned> parseWidthParameter(CallInst *CI)
#else
static Optional<unsigned> parseWidthParameter(CallInst *CI)
#endif
{
unsigned width = 1;
#if LLVM_VERSION_MAJOR >= 14
for (auto [i, found] = std::tuple{0u, false}; i < CI->arg_size(); ++i)
#else
for (auto [i, found] = std::tuple{0u, false}; i < CI->getNumArgOperands();
++i)
#endif
{
Value *arg = CI->getArgOperand(i);
if (auto MDName = getMetadataName(arg)) {
if (*MDName == "enzyme_width") {
if (found) {
EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI,
"vector width declared more than once",
*CI->getArgOperand(i), " in", *CI);
return {};
}
#if LLVM_VERSION_MAJOR >= 14
if (i + 1 >= CI->arg_size())
#else
if (i + 1 >= CI->getNumArgOperands())
#endif
{
EmitFailure("MissingVectorWidth", CI->getDebugLoc(), CI,
"constant integer followong enzyme_width is missing",
*CI->getArgOperand(i), " in", *CI);
return {};
}
Value *width_arg = CI->getArgOperand(i + 1);
if (auto cint = dyn_cast<ConstantInt>(width_arg)) {
width = cint->getZExtValue();
found = true;
} else {
EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI,
"enzyme_width must be a constant integer",
*CI->getArgOperand(i), " in", *CI);
return {};
}
if (!found) {
EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI,
"illegal enzyme vector argument width ",
*CI->getArgOperand(i), " in", *CI);
return {};
}
}
}
}
return width;
}
struct Options {
Value *differet;
Value *tape;
Value *dynamic_interface;
Value *trace;
Value *observations;
Value *likelihood;
Value *diffeLikelihood;
unsigned width;
int allocatedTapeSize;
bool freeMemory;
bool returnUsed;
bool tapeIsPointer;
bool differentialReturn;
bool diffeTrace;
DIFFE_TYPE retType;
bool primalReturn;
StringSet<> ActiveRandomVariables;
};
#if LLVM_VERSION_MAJOR > 16
static std::optional<Options>
handleArguments(IRBuilder<> &Builder, CallInst *CI, Function *fn,
DerivativeMode mode, bool sizeOnly,
std::vector<DIFFE_TYPE> &constants,
SmallVectorImpl<Value *> &args, std::map<int, Type *> &byVal)
#else
static Optional<Options>
handleArguments(IRBuilder<> &Builder, CallInst *CI, Function *fn,
DerivativeMode mode, bool sizeOnly,
std::vector<DIFFE_TYPE> &constants,
SmallVectorImpl<Value *> &args, std::map<int, Type *> &byVal)
#endif
{
FunctionType *FT = fn->getFunctionType();
Value *differet = nullptr;
Value *tape = nullptr;
Value *dynamic_interface = nullptr;
Value *trace = nullptr;
Value *observations = nullptr;
Value *likelihood = nullptr;
Value *diffeLikelihood = nullptr;
unsigned width = 1;
int allocatedTapeSize = -1;
bool freeMemory = true;
bool tapeIsPointer = false;
bool diffeTrace = false;
unsigned truei = 0;
unsigned byRefSize = 0;
bool primalReturn = false;
StringSet<> ActiveRandomVariables;
DIFFE_TYPE retType = whatType(fn->getReturnType(), mode);
bool returnUsed =
!fn->getReturnType()->isVoidTy() && !fn->getReturnType()->isEmptyTy();
bool sret = CI->hasStructRetAttr() ||
fn->hasParamAttribute(0, Attribute::StructRet);
#if LLVM_VERSION_MAJOR >= 14
for (unsigned i = 1 + sret; i < CI->arg_size(); ++i)
#else
for (unsigned i = 1 + sret; i < CI->getNumArgOperands(); ++i)
#endif
{
Value *res = CI->getArgOperand(i);
auto metaString = getMetadataName(res);
// handle metadata
if (metaString && metaString->startswith("enzyme_")) {
if (*metaString == "enzyme_const_return") {
retType = DIFFE_TYPE::CONSTANT;
continue;
} else if (*metaString == "enzyme_active_return") {
retType = DIFFE_TYPE::OUT_DIFF;
continue;
} else if (*metaString == "enzyme_dup_return") {
retType = DIFFE_TYPE::DUP_ARG;
continue;
}
}
}
bool differentialReturn = (mode == DerivativeMode::ReverseModeCombined ||
mode == DerivativeMode::ReverseModeGradient) &&
(retType == DIFFE_TYPE::OUT_DIFF);
// find and handle enzyme_width
if (auto parsedWidth = parseWidthParameter(CI)) {
width = *parsedWidth;
} else {
return {};
}
// handle different argument order for struct return.
if (fn->hasParamAttribute(0, Attribute::StructRet)) {
truei = 1;
const DataLayout &DL = CI->getParent()->getModule()->getDataLayout();
Type *Ty = nullptr;
#if LLVM_VERSION_MAJOR >= 12
Ty = fn->getParamAttribute(0, Attribute::StructRet).getValueAsType();
#else
Type *fnsrety = cast<PointerType>(FT->getParamType(0));
Ty = fnsrety->getPointerElementType();
#endif
Type *CTy = nullptr;
#if LLVM_VERSION_MAJOR >= 12
CTy = CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet)
.getValueAsType();
#else
CTy = cast<PointerType>(CI->getArgOperand(0)->getType())
->getPointerElementType();
#endif
AllocaInst *primal = new AllocaInst(Ty, DL.getAllocaAddrSpace(), nullptr,
DL.getPrefTypeAlign(Ty));
primal->insertBefore(CI);
Value *shadow;
switch (mode) {
case DerivativeMode::ForwardModeSplit:
case DerivativeMode::ForwardMode: {
Value *sretPt = CI->getArgOperand(0);
if (width > 1) {
PointerType *pty = cast<PointerType>(sretPt->getType());
if (auto sty = dyn_cast<StructType>(CTy)) {
Value *acc = UndefValue::get(
ArrayType::get(PointerType::get(sty->getElementType(0),
pty->getAddressSpace()),
width));
for (size_t i = 0; i < width; ++i) {
Value *elem = Builder.CreateStructGEP(sty, sretPt, i);
acc = Builder.CreateInsertValue(acc, elem, i);
}
shadow = acc;
} else {
EmitFailure(
"IllegalReturnType", CI->getDebugLoc(), CI,
"Return type of __enzyme_autodiff has to be a struct with",
width, "elements of the same type.");
return {};
}
} else {
shadow = sretPt;
}
break;
}
case DerivativeMode::ReverseModePrimal:
case DerivativeMode::ReverseModeCombined:
case DerivativeMode::ReverseModeGradient: {
shadow = CI->getArgOperand(1);
sret = true;
break;
}
}
args.push_back(primal);
args.push_back(shadow);
constants.push_back(DIFFE_TYPE::DUP_ARG);
}
#if LLVM_VERSION_MAJOR >= 14
for (unsigned i = 1 + sret; i < CI->arg_size(); ++i)
#else
for (unsigned i = 1 + sret; i < CI->getNumArgOperands(); ++i)
#endif
{
Value *res = CI->getArgOperand(i);
auto metaString = getMetadataName(res);
#if LLVM_VERSION_MAJOR > 16
std::optional<Value *> batchOffset;
std::optional<DIFFE_TYPE> opt_ty;
#else
Optional<Value *> batchOffset;
Optional<DIFFE_TYPE> opt_ty;
#endif
// handle metadata
if (metaString && metaString->startswith("enzyme_")) {
if (*metaString == "enzyme_byref") {
++i;
if (!isa<ConstantInt>(CI->getArgOperand(i))) {
EmitFailure("IllegalAllocatedSize", CI->getDebugLoc(), CI,
"illegal enzyme byref size ", *CI->getArgOperand(i),
"in", *CI);
return {};
}
byRefSize = cast<ConstantInt>(CI->getArgOperand(i))->getZExtValue();
assert(byRefSize > 0);
continue;
}
if (*metaString == "enzyme_dup") {
opt_ty = DIFFE_TYPE::DUP_ARG;
} else if (*metaString == "enzyme_dupv") {
opt_ty = DIFFE_TYPE::DUP_ARG;
++i;
Value *offset_arg = CI->getArgOperand(i);
if (offset_arg->getType()->isIntegerTy()) {
batchOffset = offset_arg;
} else {
EmitFailure("IllegalVectorOffset", CI->getDebugLoc(), CI,
"enzyme_batch must be followd by an integer "
"offset.",
*CI->getArgOperand(i), " in", *CI);
return {};
}
} else if (*metaString == "enzyme_dupnoneed") {
opt_ty = DIFFE_TYPE::DUP_NONEED;
} else if (*metaString == "enzyme_dupnoneedv") {
opt_ty = DIFFE_TYPE::DUP_NONEED;
++i;
Value *offset_arg = CI->getArgOperand(i);
if (offset_arg->getType()->isIntegerTy()) {
batchOffset = offset_arg;
} else {
EmitFailure("IllegalVectorOffset", CI->getDebugLoc(), CI,
"enzyme_batch must be followd by an integer "
"offset.",
*CI->getArgOperand(i), " in", *CI);
return {};
}
} else if (*metaString == "enzyme_out") {
opt_ty = DIFFE_TYPE::OUT_DIFF;
} else if (*metaString == "enzyme_const") {
opt_ty = DIFFE_TYPE::CONSTANT;
} else if (*metaString == "enzyme_noret") {
returnUsed = false;
continue;
} else if (*metaString == "enzyme_allocated") {
assert(!sizeOnly);
++i;
if (!isa<ConstantInt>(CI->getArgOperand(i))) {
EmitFailure("IllegalAllocatedSize", CI->getDebugLoc(), CI,
"illegal enzyme allocated size ", *CI->getArgOperand(i),
"in", *CI);
return {};
}
allocatedTapeSize =
cast<ConstantInt>(CI->getArgOperand(i))->getZExtValue();
continue;
} else if (*metaString == "enzyme_tape") {
assert(!sizeOnly);
++i;
tape = CI->getArgOperand(i);
tapeIsPointer = true;
continue;
} else if (*metaString == "enzyme_nofree") {
assert(!sizeOnly);
freeMemory = false;
continue;
} else if (*metaString == "enzyme_primal_return") {
primalReturn = true;
continue;
} else if (*metaString == "enzyme_const_return") {
continue;
} else if (*metaString == "enzyme_active_return") {
continue;
} else if (*metaString == "enzyme_dup_return") {
continue;
} else if (*metaString == "enzyme_width") {
++i;
continue;
} else if (*metaString == "enzyme_interface") {
++i;
dynamic_interface = CI->getArgOperand(i);
continue;
} else if (*metaString == "enzyme_trace") {
trace = CI->getArgOperand(++i);
opt_ty = DIFFE_TYPE::CONSTANT;
continue;
} else if (*metaString == "enzyme_duptrace") {
trace = CI->getArgOperand(++i);
diffeTrace = true;
opt_ty = DIFFE_TYPE::CONSTANT;
continue;
} else if (*metaString == "enzyme_likelihood") {
likelihood = CI->getArgOperand(++i);
opt_ty = DIFFE_TYPE::CONSTANT;
continue;
} else if (*metaString == "enzyme_duplikelihood") {
likelihood = CI->getArgOperand(++i);
diffeLikelihood = CI->getArgOperand(++i);
opt_ty = DIFFE_TYPE::DUP_ARG;
continue;
} else if (*metaString == "enzyme_observations") {
observations = CI->getArgOperand(++i);
opt_ty = DIFFE_TYPE::CONSTANT;
continue;
} else if (*metaString == "enzyme_active_rand_var") {
Value *string = CI->getArgOperand(++i);
StringRef const_string;
if (getConstantStringInfo(string, const_string)) {
ActiveRandomVariables.insert(const_string);
} else {
EmitFailure(
"IllegalStringType", CI->getDebugLoc(), CI,
"active variable address must be a compile-time constant", *CI,
*metaString);
}
continue;
} else {
EmitFailure("IllegalDiffeType", CI->getDebugLoc(), CI,
"illegal enzyme metadata classification ", *CI,
*metaString);
return {};
}
if (sizeOnly) {
assert(opt_ty);
constants.push_back(*opt_ty);
truei++;
continue;
}
++i;
res = CI->getArgOperand(i);
}
if (byRefSize) {
Type *subTy = nullptr;
if (truei < FT->getNumParams()) {
subTy = FT->getParamType(i);
} else if ((mode == DerivativeMode::ReverseModeGradient ||
mode == DerivativeMode::ForwardModeSplit)) {
if (differentialReturn && differet == nullptr) {
subTy = FT->getReturnType();
}
}
if (!subTy) {
EmitFailure("IllegalByVal", CI->getDebugLoc(), CI,
"illegal enzyme byval arg", truei, " ", *res);
return {};
}
auto &DL = fn->getParent()->getDataLayout();
auto BitSize = DL.getTypeSizeInBits(subTy);
if (BitSize / 8 != byRefSize) {
EmitFailure("IllegalByRefSize", CI->getDebugLoc(), CI,
"illegal enzyme pointer type size ", *res, " expected ",
byRefSize, " (bytes) actual size ", BitSize,
" (bits) in ", *CI);
}
res = Builder.CreateBitCast(
res,
PointerType::get(
subTy, cast<PointerType>(res->getType())->getAddressSpace()));
res = Builder.CreateLoad(subTy, res);
byRefSize = 0;
}
if (truei >= FT->getNumParams()) {
if (!isa<MetadataAsValue>(res) &&
(mode == DerivativeMode::ReverseModeGradient ||
mode == DerivativeMode::ForwardModeSplit)) {
if (differentialReturn && differet == nullptr) {
differet = res;
if (CI->paramHasAttr(i, Attribute::ByVal)) {
Type *T = nullptr;
#if LLVM_VERSION_MAJOR > 12
T = CI->getParamAttr(i, Attribute::ByVal).getValueAsType();
#else
T = differet->getType()->getPointerElementType();
#endif
differet = Builder.CreateLoad(T, differet);
}
if (differet->getType() != fn->getReturnType())
if (auto ST0 = dyn_cast<StructType>(differet->getType()))
if (auto ST1 = dyn_cast<StructType>(fn->getReturnType()))
if (ST0->isLayoutIdentical(ST1)) {
IRBuilder<> B(&Builder.GetInsertBlock()
->getParent()
->getEntryBlock()
.front());
auto AI = B.CreateAlloca(ST1);
Builder.CreateStore(differet,
Builder.CreatePointerCast(
AI, PointerType::getUnqual(ST0)));
differet = Builder.CreateLoad(ST1, AI);
}
if (differet->getType() != fn->getReturnType()) {
EmitFailure("BadDiffRet", CI->getDebugLoc(), CI,
"Bad DiffRet type ", *differet, " expected ",
*fn->getReturnType());
return {};
}
continue;
} else if (tape == nullptr) {
tape = res;
if (CI->paramHasAttr(i, Attribute::ByVal)) {
Type *T = nullptr;
#if LLVM_VERSION_MAJOR > 12
T = CI->getParamAttr(i, Attribute::ByVal).getValueAsType();
#else
T = tape->getType()->getPointerElementType();
#endif
tape = Builder.CreateLoad(T, tape);
}
continue;
}
}
EmitFailure("TooManyArgs", CI->getDebugLoc(), CI,
"Had too many arguments to __enzyme_autodiff", *CI,
" - extra arg - ", *res);
return {};
}
assert(truei < FT->getNumParams());
auto PTy = FT->getParamType(truei);
DIFFE_TYPE ty = opt_ty ? *opt_ty : whatType(PTy, mode);
constants.push_back(ty);
assert(truei < FT->getNumParams());
// cast primal
if (PTy != res->getType()) {
if (auto ptr = dyn_cast<PointerType>(res->getType())) {
if (auto PT = dyn_cast<PointerType>(PTy)) {
if (ptr->getAddressSpace() != PT->getAddressSpace()) {
#if LLVM_VERSION_MAJOR < 18
#if LLVM_VERSION_MAJOR >= 15
if (CI->getContext().supportsTypedPointers()) {
#endif
res = Builder.CreateAddrSpaceCast(
res, PointerType::get(ptr->getPointerElementType(),
PT->getAddressSpace()));
#if LLVM_VERSION_MAJOR >= 15
} else {
res = Builder.CreateAddrSpaceCast(res, PT);
}
#endif
#else
res = Builder.CreateAddrSpaceCast(res, PT);
#endif
assert(res);
assert(PTy);
assert(FT);
llvm::errs() << "Warning cast(1) __enzyme_autodiff argument " << i
<< " " << *res << "|" << *res->getType()
<< " to argument " << truei << " " << *PTy << "\n"
<< "orig: " << *FT << "\n";
}
}
}
if (res->getType()->canLosslesslyBitCastTo(PTy)) {
res = Builder.CreateBitCast(res, PTy);
}
if (res->getType() != PTy && res->getType()->isIntegerTy() &&
PTy->isIntegerTy(1)) {
res = Builder.CreateTrunc(res, PTy);
}
if (res->getType() != PTy) {
auto loc = CI->getDebugLoc();
if (auto arg = dyn_cast<Instruction>(res)) {
loc = arg->getDebugLoc();
}
EmitFailure("IllegalArgCast", loc, CI,
"Cannot cast __enzyme_autodiff primal argument ", i,
", found ", *res, ", type ", *res->getType(),
" - to arg ", truei, " ", *PTy);
return {};
}
}
if (CI->isByValArgument(i)) {
byVal[args.size()] = CI->getParamByValType(i);
}
args.push_back(res);
if (ty == DIFFE_TYPE::DUP_ARG || ty == DIFFE_TYPE::DUP_NONEED) {
++i;
Value *res = nullptr;
#if LLVM_VERSION_MAJOR >= 16
bool batch = batchOffset.has_value();
#else
bool batch = batchOffset.hasValue();
#endif
for (unsigned v = 0; v < width; ++v) {
#if LLVM_VERSION_MAJOR >= 14
if (i >= CI->arg_size())
#else
if (i >= CI->getNumArgOperands())
#endif
{
EmitFailure("MissingArgShadow", CI->getDebugLoc(), CI,
"__enzyme_autodiff missing argument shadow at index ",
i, ", need shadow of type ", *PTy,
" to shadow primal argument ", *args.back(),
" at call ", *CI);
return {};
}
// cast diffe
Value *element = CI->getArgOperand(i);
if (batch) {
if (auto elementPtrTy = dyn_cast<PointerType>(element->getType())) {
element = Builder.CreateBitCast(
element, PointerType::get(Type::getInt8Ty(CI->getContext()),
elementPtrTy->getAddressSpace()));
element = Builder.CreateGEP(
Type::getInt8Ty(CI->getContext()), element,
Builder.CreateMul(
*batchOffset,
ConstantInt::get((*batchOffset)->getType(), v)));
element = Builder.CreateBitCast(element, elementPtrTy);
} else {
EmitFailure(
"NonPointerBatch", CI->getDebugLoc(), CI,
"Batched argument at index ", i,
" must be of pointer type, found: ", *element->getType());
return {};
}
}
if (PTy != element->getType()) {
element = castToDiffeFunctionArgType(Builder, CI, FT, PTy, i, mode,
element, truei);
if (!element) {
return {};
}
}
if (width > 1) {
res =
res ? Builder.CreateInsertValue(res, element, {v})
: Builder.CreateInsertValue(UndefValue::get(ArrayType::get(
element->getType(), width)),
element, {v});
if (v < width - 1 && !batch) {
++i;
}
} else {
res = element;
}
}
args.push_back(res);
}
++truei;
}
if (truei < FT->getNumParams()) {
auto numParams = FT->getNumParams();
EmitFailure(
"EnzymeInsufficientArgs", CI->getDebugLoc(), CI,
"Insufficient number of args passed to derivative call required ",
numParams, " primal args, found ", truei);
return {};
}
return Options({differet, tape, dynamic_interface, trace, observations,
likelihood, diffeLikelihood, width, allocatedTapeSize,
freeMemory, returnUsed, tapeIsPointer, differentialReturn,
diffeTrace, retType, primalReturn, ActiveRandomVariables});
}
static FnTypeInfo
populate_overwritten_args(TypeAnalysis &TA, llvm::Function *fn,
DerivativeMode mode,
std::vector<bool> &overwritten_args) {
FnTypeInfo type_args(fn);
for (auto &a : type_args.Function->args()) {
overwritten_args.push_back(
!(mode == DerivativeMode::ReverseModeCombined));
TypeTree dt;
if (a.getType()->isFPOrFPVectorTy()) {
dt = ConcreteType(a.getType()->getScalarType());
} else if (a.getType()->isPointerTy()) {
#if LLVM_VERSION_MAJOR < 18
#if LLVM_VERSION_MAJOR >= 15
if (a.getContext().supportsTypedPointers()) {
#endif
auto et = a.getType()->getPointerElementType();
if (et->isFPOrFPVectorTy()) {
dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr);
} else if (et->isPointerTy()) {
dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr);
}
#if LLVM_VERSION_MAJOR >= 15
}
#endif
#endif
dt.insert({}, BaseType::Pointer);
} else if (a.getType()->isIntOrIntVectorTy()) {
dt = ConcreteType(BaseType::Integer);
}
type_args.Arguments.insert(
std::pair<Argument *, TypeTree>(&a, dt.Only(-1, nullptr)));
// TODO note that here we do NOT propagate constants in type info (and
// should consider whether we should)
type_args.KnownValues.insert(
std::pair<Argument *, std::set<int64_t>>(&a, {}));
}
TypeTree dt;
if (fn->getReturnType()->isFPOrFPVectorTy()) {
dt = ConcreteType(fn->getReturnType()->getScalarType());
}
type_args.Return = dt.Only(-1, nullptr);
type_args = TA.analyzeFunction(type_args).getAnalyzedTypeInfo();
return type_args;
}
bool HandleBatch(CallInst *CI, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI) {
unsigned width = 1;
unsigned truei = 0;
std::map<unsigned, Value *> batchOffset;
SmallVector<Value *, 4> args;
SmallVector<BATCH_TYPE, 4> arg_types;
IRBuilder<> Builder(CI);
Function *F = parseFunctionParameter(CI, AA, TLI);
if (!F)
return false;
assert(F);
FunctionType *FT = F->getFunctionType();
// find and handle enzyme_width
if (auto parsedWidth = parseWidthParameter(CI)) {
width = *parsedWidth;
} else {
return false;
}
// handle different argument order for struct return.
bool sret =
CI->hasStructRetAttr() || F->hasParamAttribute(0, Attribute::StructRet);
if (F->hasParamAttribute(0, Attribute::StructRet)) {
truei = 1;
Value *sretPt = CI->getArgOperand(0);
args.push_back(sretPt);
arg_types.push_back(BATCH_TYPE::VECTOR);
}
#if LLVM_VERSION_MAJOR >= 14
for (unsigned i = 1 + sret; i < CI->arg_size(); ++i)
#else
for (unsigned i = 1 + sret; i < CI->getNumArgOperands(); ++i)
#endif
{
Value *res = CI->getArgOperand(i);
if (truei >= FT->getNumParams()) {
EmitFailure("TooManyArgs", CI->getDebugLoc(), CI,
"Had too many arguments to __enzyme_batch", *CI,
" - extra arg - ", *res);
return false;
}
assert(truei < FT->getNumParams());
auto PTy = FT->getParamType(truei);
BATCH_TYPE ty = width == 1 ? BATCH_TYPE::SCALAR : BATCH_TYPE::VECTOR;
auto metaString = getMetadataName(res);
// handle metadata
if (metaString && metaString->startswith("enzyme_")) {
if (*metaString == "enzyme_scalar") {
ty = BATCH_TYPE::SCALAR;
} else if (*metaString == "enzyme_vector") {
ty = BATCH_TYPE::VECTOR;
} else if (*metaString == "enzyme_buffer") {
ty = BATCH_TYPE::VECTOR;
++i;
Value *offset_arg = CI->getArgOperand(i);
if (offset_arg->getType()->isIntegerTy()) {
batchOffset[i + 1] = offset_arg;
} else {
EmitFailure("IllegalVectorOffset", CI->getDebugLoc(), CI,
"enzyme_batch must be followd by an integer "
"offset.",
*CI->getArgOperand(i), " in", *CI);
return false;
}
continue;
} else if (*metaString == "enzyme_width") {
++i;
continue;
} else {
EmitFailure("IllegalDiffeType", CI->getDebugLoc(), CI,
"illegal enzyme metadata classification ", *CI,
*metaString);
return false;
}
++i;
res = CI->getArgOperand(i);
}
arg_types.push_back(ty);
// wrap vector
if (ty == BATCH_TYPE::VECTOR) {
Value *res = nullptr;
bool batch = batchOffset.count(i - 1) != 0;
for (unsigned v = 0; v < width; ++v) {
#if LLVM_VERSION_MAJOR >= 14
if (i >= CI->arg_size())
#else
if (i >= CI->getNumArgOperands())
#endif
{
EmitFailure("MissingVectorArg", CI->getDebugLoc(), CI,
"__enzyme_batch missing vector argument at index ", i,
", need argument of type ", *PTy, " at call ", *CI);
return false;
}
// vectorize pointer
Value *element = CI->getArgOperand(i);
if (batch) {
if (auto elementPtrTy = dyn_cast<PointerType>(element->getType())) {
element = Builder.CreateBitCast(
element, PointerType::get(Type::getInt8Ty(CI->getContext()),
elementPtrTy->getAddressSpace()));
element = Builder.CreateGEP(
Type::getInt8Ty(CI->getContext()), element,
Builder.CreateMul(
batchOffset[i - 1],
ConstantInt::get(batchOffset[i - 1]->getType(), v)));
element = Builder.CreateBitCast(element, elementPtrTy);
} else {
return false;
}
}
if (width > 1) {
res =
res ? Builder.CreateInsertValue(res, element, {v})
: Builder.CreateInsertValue(UndefValue::get(ArrayType::get(
element->getType(), width)),
element, {v});
if (v < width - 1 && !batch) {
++i;
}
} else {
res = element;
}
}
args.push_back(res);
} else if (ty == BATCH_TYPE::SCALAR) {
args.push_back(res);
}
truei++;
}
BATCH_TYPE ret_type = (F->getReturnType()->isVoidTy() || width == 1)
? BATCH_TYPE::SCALAR
: BATCH_TYPE::VECTOR;
auto newFunc = Logic.CreateBatch(F, width, arg_types, ret_type);
if (!newFunc)
return false;
Value *batch =
Builder.CreateCall(newFunc->getFunctionType(), newFunc, args);
batch = adaptReturnedVector(CI, batch, Builder, width);
Value *ret = CI;
Type *retElemType = nullptr;
if (CI->hasStructRetAttr()) {
ret = CI->getArgOperand(0);
#if LLVM_VERSION_MAJOR >= 12
retElemType =
CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet)
.getValueAsType();
#else
retElemType = ret->getType()->getPointerElementType();
#endif
}
ReplaceOriginalCall(Builder, ret, retElemType, batch, CI,
DerivativeMode::ForwardMode);
return true;
}
bool HandleAutoDiff(Instruction *CI, CallingConv::ID CallingConv, Value *ret,
Type *retElemType, SmallVectorImpl<Value *> &args,
const std::map<int, Type *> &byVal,
const std::vector<DIFFE_TYPE> &constants, Function *fn,
DerivativeMode mode, Options &options, bool sizeOnly) {
auto &differet = options.differet;
auto &tape = options.tape;
auto &width = options.width;
auto &allocatedTapeSize = options.allocatedTapeSize;
auto &freeMemory = options.freeMemory;
auto &returnUsed = options.returnUsed;
auto &tapeIsPointer = options.tapeIsPointer;
auto &differentialReturn = options.differentialReturn;
auto &retType = options.retType;
auto primalReturn = options.primalReturn;
auto Arch = Triple(CI->getModule()->getTargetTriple()).getArch();
bool AtomicAdd = Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
Arch == Triple::amdgcn;
TypeAnalysis TA(Logic.PPC.FAM);
std::vector<bool> overwritten_args;
FnTypeInfo type_args =
populate_overwritten_args(TA, fn, mode, overwritten_args);
IRBuilder Builder(CI);
// differentiate fn
Function *newFunc = nullptr;
Type *tapeType = nullptr;
const AugmentedReturn *aug;
switch (mode) {
case DerivativeMode::ForwardMode:
newFunc = Logic.CreateForwardDiff(
fn, retType, constants, TA,
/*should return*/ primalReturn, mode, freeMemory, width,
/*addedType*/ nullptr, type_args, overwritten_args,
/*augmented*/ nullptr);
break;
case DerivativeMode::ForwardModeSplit: {
bool forceAnonymousTape = !sizeOnly && allocatedTapeSize == -1;
aug = &Logic.CreateAugmentedPrimal(
fn, retType, constants, TA,
/*returnUsed*/ false, /*shadowReturnUsed*/ false, type_args,
overwritten_args, forceAnonymousTape, width, /*atomicAdd*/ AtomicAdd);
auto &DL = fn->getParent()->getDataLayout();
if (!forceAnonymousTape) {
assert(!aug->tapeType);
if (aug->returns.find(AugmentedStruct::Tape) != aug->returns.end()) {
auto tapeIdx = aug->returns.find(AugmentedStruct::Tape)->second;
tapeType = (tapeIdx == -1)
? aug->fn->getReturnType()
: cast<StructType>(aug->fn->getReturnType())
->getElementType(tapeIdx);
} else {
if (sizeOnly) {
CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0, false));
CI->eraseFromParent();
return true;
}
}
if (sizeOnly) {
auto size = DL.getTypeSizeInBits(tapeType) / 8;
CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), size, false));
CI->eraseFromParent();
return true;
}
if (tapeType &&
DL.getTypeSizeInBits(tapeType) > 8 * (size_t)allocatedTapeSize) {
auto bytes = DL.getTypeSizeInBits(tapeType) / 8;
EmitFailure("Insufficient tape allocation size", CI->getDebugLoc(),
CI, "need ", bytes, " bytes have ", allocatedTapeSize,
" bytes");
}
} else {
tapeType = PointerType::getInt8PtrTy(fn->getContext());
}
newFunc = Logic.CreateForwardDiff(
fn, retType, constants, TA,
/*should return*/ primalReturn, mode, freeMemory, width,
/*addedType*/ tapeType, type_args, overwritten_args, aug);
break;
}
case DerivativeMode::ReverseModeCombined:
assert(freeMemory);
newFunc = Logic.CreatePrimalAndGradient(
(ReverseCacheKey){.todiff = fn,
.retType = retType,
.constant_args = constants,
.overwritten_args = overwritten_args,
.returnUsed = primalReturn,
.shadowReturnUsed = false,
.mode = mode,
.width = width,
.freeMemory = freeMemory,
.AtomicAdd = AtomicAdd,
.additionalType = nullptr,
.forceAnonymousTape = false,
.typeInfo = type_args},
TA, /*augmented*/ nullptr);
break;
case DerivativeMode::ReverseModePrimal:
case DerivativeMode::ReverseModeGradient: {
if (primalReturn) {
EmitFailure(
"SplitPrimalRet", CI->getDebugLoc(), CI,
"Option enzyme_primal_return not available in reverse split mode");
}
bool forceAnonymousTape = !sizeOnly && allocatedTapeSize == -1;
bool shadowReturnUsed = returnUsed && (retType == DIFFE_TYPE::DUP_ARG ||
retType == DIFFE_TYPE::DUP_NONEED);
aug = &Logic.CreateAugmentedPrimal(
fn, retType, constants, TA, returnUsed, shadowReturnUsed, type_args,
overwritten_args, forceAnonymousTape, width,
/*atomicAdd*/ AtomicAdd);
auto &DL = fn->getParent()->getDataLayout();
if (!forceAnonymousTape) {
assert(!aug->tapeType);
if (aug->returns.find(AugmentedStruct::Tape) != aug->returns.end()) {
auto tapeIdx = aug->returns.find(AugmentedStruct::Tape)->second;
tapeType = (tapeIdx == -1)
? aug->fn->getReturnType()
: cast<StructType>(aug->fn->getReturnType())
->getElementType(tapeIdx);
} else {
if (sizeOnly) {
CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0, false));
CI->eraseFromParent();
return true;
}
}
if (sizeOnly) {
auto size = DL.getTypeSizeInBits(tapeType) / 8;
CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), size, false));
CI->eraseFromParent();
return true;
}
if (tapeType &&
DL.getTypeSizeInBits(tapeType) > 8 * (size_t)allocatedTapeSize) {
auto bytes = DL.getTypeSizeInBits(tapeType) / 8;
EmitFailure("Insufficient tape allocation size", CI->getDebugLoc(),
CI, "need ", bytes, " bytes have ", allocatedTapeSize,
" bytes");
}
} else {
tapeType = PointerType::getInt8PtrTy(fn->getContext());
}
if (mode == DerivativeMode::ReverseModePrimal)
newFunc = aug->fn;
else
newFunc = Logic.CreatePrimalAndGradient(
(ReverseCacheKey){.todiff = fn,
.retType = retType,
.constant_args = constants,
.overwritten_args = overwritten_args,
.returnUsed = false,
.shadowReturnUsed = false,
.mode = mode,
.width = width,
.freeMemory = freeMemory,
.AtomicAdd = AtomicAdd,
.additionalType = tapeType,
.forceAnonymousTape = forceAnonymousTape,
.typeInfo = type_args},
TA, aug);
}
}
if (!newFunc) {
StringRef n = fn->getName();
EmitFailure("FailedToDifferentiate", fn->getSubprogram(),
&*fn->getEntryBlock().begin(),
"Could not generate derivative function of ", n);
return false;
}
if (differentialReturn) {
if (differet)
args.push_back(differet);
else if (fn->getReturnType()->isFPOrFPVectorTy()) {
Constant *seed = ConstantFP::get(fn->getReturnType(), 1.0);
if (width == 1) {
args.push_back(seed);
} else {
ArrayType *arrayType = ArrayType::get(fn->getReturnType(), width);
args.push_back(ConstantArray::get(
arrayType, SmallVector<Constant *, 3>(width, seed)));
}
} else if (auto ST = dyn_cast<StructType>(fn->getReturnType())) {
SmallVector<Constant *, 2> csts;
for (auto e : ST->elements()) {
csts.push_back(ConstantFP::get(e, 1.0));
}
args.push_back(ConstantStruct::get(ST, csts));
}
}
if ((mode == DerivativeMode::ReverseModeGradient ||
mode == DerivativeMode::ForwardModeSplit) &&
tape && tapeType) {
auto &DL = fn->getParent()->getDataLayout();
if (tapeIsPointer) {
tape = Builder.CreateBitCast(
tape, PointerType::get(
tapeType,
cast<PointerType>(tape->getType())->getAddressSpace()));
tape = Builder.CreateLoad(tapeType, tape);
} else if (tapeType != tape->getType() &&
DL.getTypeSizeInBits(tapeType) <=
DL.getTypeSizeInBits(tape->getType())) {
IRBuilder<> EB(&CI->getParent()->getParent()->getEntryBlock().front());
auto AL = EB.CreateAlloca(tape->getType());
Builder.CreateStore(tape, AL);
tape = Builder.CreateLoad(
tapeType,
Builder.CreatePointerCast(AL, PointerType::getUnqual(tapeType)));
}
assert(tape->getType() == tapeType);
args.push_back(tape);
}
if (EnzymePrint) {
llvm::errs() << "postfn:\n" << *newFunc << "\n";
}
Builder.setFastMathFlags(getFast());
// call newFunc with the provided arguments.
if (args.size() != newFunc->getFunctionType()->getNumParams()) {
llvm::errs() << *CI << "\n";
llvm::errs() << *newFunc << "\n";
for (auto arg : args) {
llvm::errs() << " + " << *arg << "\n";
}
auto modestr = to_string(mode);
EmitFailure(
"TooFewArguments", CI->getDebugLoc(), CI,
"Too few arguments passed to __enzyme_autodiff mode=", modestr);
return false;
}
assert(args.size() == newFunc->getFunctionType()->getNumParams());
CallInst *diffretc = cast<CallInst>(Builder.CreateCall(newFunc, args));
diffretc->setCallingConv(CallingConv);
diffretc->setDebugLoc(CI->getDebugLoc());
for (auto &&[attr, ty] : byVal) {
diffretc->addParamAttr(
attr, Attribute::getWithByValType(diffretc->getContext(), ty));
}
Value *diffret = diffretc;
if (mode == DerivativeMode::ReverseModePrimal && tape) {
if (aug->returns.find(AugmentedStruct::Tape) != aug->returns.end()) {
auto tapeIdx = aug->returns.find(AugmentedStruct::Tape)->second;
tapeType = (tapeIdx == -1) ? aug->fn->getReturnType()
: cast<StructType>(aug->fn->getReturnType())
->getElementType(tapeIdx);
unsigned idxs[] = {(unsigned)tapeIdx};
Value *tapeRes = (tapeIdx == -1)
? diffret
: Builder.CreateExtractValue(diffret, idxs);
Builder.CreateStore(
tapeRes,
Builder.CreateBitCast(
tape,
PointerType::get(
tapeRes->getType(),
cast<PointerType>(tape->getType())->getAddressSpace())));
if (tapeIdx != -1) {
auto ST = cast<StructType>(diffret->getType());
SmallVector<Type *, 2> tys(ST->elements().begin(),
ST->elements().end());
tys.erase(tys.begin());
auto ST0 = StructType::get(ST->getContext(), tys);
Value *out = UndefValue::get(ST0);
for (unsigned i = 0; i < tys.size(); i++) {
out = Builder.CreateInsertValue(
out, Builder.CreateExtractValue(diffret, {i + 1}), {i});
}
diffret = out;
} else {
auto ST0 = StructType::get(tape->getContext(), {});
diffret = UndefValue::get(ST0);
}
}
}
// Adapt the returned vector type to the struct type expected by our calling
// convention.
if (width > 1 && !diffret->getType()->isEmptyTy() &&
!diffret->getType()->isVoidTy() &&
(mode == DerivativeMode::ForwardMode ||
mode == DerivativeMode::ForwardModeSplit)) {
diffret = adaptReturnedVector(ret, diffret, Builder, width);
}
ReplaceOriginalCall(Builder, ret, retElemType, diffret, CI, mode);
if (Logic.PostOpt) {
auto Params = llvm::getInlineParams();
llvm::SetVector<CallInst *> Q;
Q.insert(diffretc);
while (Q.size()) {
auto cur = *Q.begin();
Function *outerFunc = cur->getParent()->getParent();
llvm::OptimizationRemarkEmitter ORE(outerFunc);
Q.erase(Q.begin());
if (auto F = cur->getCalledFunction()) {
if (!F->empty()) {
// Garbage collect AC's created
SmallVector<AssumptionCache *, 2> ACAlloc;
auto getAC = [&](Function &F) -> llvm::AssumptionCache & {
auto AC = new AssumptionCache(F);
ACAlloc.push_back(AC);
return *AC;
};
auto GetTLI =
[&](llvm::Function &F) -> const llvm::TargetLibraryInfo & {
return Logic.PPC.FAM.getResult<TargetLibraryAnalysis>(F);
};
auto GetInlineCost = [&](CallBase &CB) {
TargetTransformInfo TTI(F->getParent()->getDataLayout());
auto cst = llvm::getInlineCost(CB, Params, TTI, getAC, GetTLI);
return cst;
};
if (llvm::shouldInline(*cur, GetInlineCost, ORE)) {
InlineFunctionInfo IFI;
InlineResult IR = InlineFunction(*cur, IFI);
if (IR.isSuccess()) {
LowerSparsification(outerFunc, /*replaceAll*/ false);
for (auto U : outerFunc->users()) {
if (auto CI = dyn_cast<CallInst>(U)) {
if (CI->getCalledFunction() == outerFunc) {
Q.insert(CI);
}
}
}
}
}
for (auto AC : ACAlloc) {
delete AC;
}
}
}
}
}
return true;
}
/// Return whether successful
bool HandleAutoDiffArguments(CallInst *CI, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,DerivativeMode mode,
bool sizeOnly) {
// determine function to differentiate
Function *fn = parseFunctionParameter(CI, AA, TLI);
if (!fn)
return false;
IRBuilder<> Builder(CI);
if (EnzymePrint)
llvm::errs() << "prefn:\n" << *fn << "\n";
std::map<int, Type *> byVal;
std::vector<DIFFE_TYPE> constants;
SmallVector<Value *, 2> args;
auto options = handleArguments(Builder, CI, fn, mode, sizeOnly, constants,
args, byVal);
if (!options) {
return false;
}
Value *ret = CI;
Type *retElemType = nullptr;
if (CI->hasStructRetAttr()) {
ret = CI->getArgOperand(0);
#if LLVM_VERSION_MAJOR >= 12
retElemType =
CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet)
.getValueAsType();
#else
retElemType = ret->getType()->getPointerElementType();
#endif
}
#if LLVM_VERSION_MAJOR >= 16
return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args,
byVal, constants, fn, mode, options.value(),
sizeOnly);
#else
return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args,
byVal, constants, fn, mode, options.getValue(),
sizeOnly);
#endif
}
bool HandleProbProg(CallInst *CI, llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
ProbProgMode mode) {
IRBuilder<> Builder(CI);
Function *F = parseFunctionParameter(CI, AA, TLI);
if (!F)
return false;
assert(F);
std::vector<DIFFE_TYPE> constants;
std::map<int, Type *> byVal;
SmallVector<Value *, 4> args;
auto diffeMode = DerivativeMode::ReverseModeCombined;
auto opt = handleArguments(Builder, CI, F, diffeMode, false, constants,
args, byVal);
SmallVector<Value *, 6> dargs = SmallVector(args);
#if LLVM_VERSION_MAJOR >= 16
if (!opt.has_value())
return false;
#else
if (!opt.hasValue())
return false;
#endif
auto dynamic_interface = opt->dynamic_interface;
auto trace = opt->trace;
auto dtrace = opt->diffeTrace;
auto observations = opt->observations;
auto likelihood = opt->likelihood;
auto dlikelihood = opt->diffeLikelihood;
// Interface
bool has_dynamic_interface = dynamic_interface != nullptr;
bool needs_interface =
mode == ProbProgMode::Trace || mode == ProbProgMode::Condition;
TraceInterface *interface = nullptr;
if (has_dynamic_interface) {
interface =
new DynamicTraceInterface(dynamic_interface, CI->getFunction());
} else if (needs_interface) {
interface = new StaticTraceInterface(F->getParent());
}
// Find sample function
SmallPtrSet<Function *, 4> sampleFunctions;
SmallPtrSet<Function *, 4> observeFunctions;
for (auto &func : F->getParent()->functions()) {
if (func.getName().contains("__enzyme_sample")) {
assert(func.getFunctionType()->getNumParams() >= 3);
sampleFunctions.insert(&func);
} else if (func.getName().contains("__enzyme_observe")) {
assert(func.getFunctionType()->getNumParams() >= 3);
observeFunctions.insert(&func);
}
}
assert(!sampleFunctions.empty() || !observeFunctions.empty());
bool autodiff = dtrace || dlikelihood;
IRBuilder<> AllocaBuilder(CI->getParent()->getFirstNonPHI());
if (!likelihood) {
likelihood = AllocaBuilder.CreateAlloca(AllocaBuilder.getDoubleTy(),
nullptr, "likelihood");
Builder.CreateStore(ConstantFP::getNullValue(Builder.getDoubleTy()),
likelihood);
}
args.push_back(likelihood);
if (autodiff && !dlikelihood) {
dlikelihood = AllocaBuilder.CreateAlloca(AllocaBuilder.getDoubleTy(),
nullptr, "dlikelihood");
Builder.CreateStore(ConstantFP::get(Builder.getDoubleTy(), 1.0),
dlikelihood);
}
if (autodiff) {
dargs.push_back(likelihood);
dargs.push_back(dlikelihood);
constants.push_back(DIFFE_TYPE::DUP_ARG);
} else {
constants.push_back(DIFFE_TYPE::CONSTANT);
}
if (mode == ProbProgMode::Condition) {
args.push_back(observations);
dargs.push_back(observations);
constants.push_back(DIFFE_TYPE::CONSTANT);
}
if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) {
args.push_back(trace);
dargs.push_back(trace);
constants.push_back(DIFFE_TYPE::CONSTANT);
}
auto newFunc = Logic.CreateTrace(F, sampleFunctions, observeFunctions,
opt->ActiveRandomVariables, mode, autodiff,
interface);
if (!autodiff) {
auto call = CallInst::Create(newFunc->getFunctionType(), newFunc, args);
ReplaceInstWithInst(CI, call);
return true;
}
Value *ret = CI;
Type *retElemType = nullptr;
if (CI->hasStructRetAttr()) {
ret = CI->getArgOperand(0);
#if LLVM_VERSION_MAJOR >= 12
retElemType =
CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet)
.getValueAsType();
#else
retElemType = ret->getType()->getPointerElementType();
#endif
}
#if LLVM_VERSION_MAJOR >= 16
bool status = HandleAutoDiff(
CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants,
newFunc, DerivativeMode::ReverseModeCombined, opt.value(), false);
#else
bool status = HandleAutoDiff(
CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants,
newFunc, DerivativeMode::ReverseModeCombined, opt.getValue(), false);
#endif
delete interface;
return status;
}
bool lowerEnzymeCalls(Function &F, std::set<Function *> &done) {
if (done.count(&F))
return false;
done.insert(&F);
if (F.empty())
return false;
bool Changed = false;
for (BasicBlock &BB : F)
if (InvokeInst *II = dyn_cast<InvokeInst>(BB.getTerminator())) {
Function *Fn = II->getCalledFunction();
if (auto castinst = dyn_cast<ConstantExpr>(II->getCalledOperand())) {
if (castinst->isCast())
if (auto fn = dyn_cast<Function>(castinst->getOperand(0)))
Fn = fn;
}
if (!Fn)
continue;
if (!(Fn->getName().contains("__enzyme_float") ||
Fn->getName().contains("__enzyme_double") ||
Fn->getName().contains("__enzyme_integer") ||
Fn->getName().contains("__enzyme_pointer") ||
Fn->getName().contains("__enzyme_virtualreverse") ||
Fn->getName().contains("__enzyme_call_inactive") ||
Fn->getName().contains("__enzyme_autodiff") ||
Fn->getName().contains("__enzyme_fwddiff") ||
Fn->getName().contains("__enzyme_fwdsplit") ||
Fn->getName().contains("__enzyme_augmentfwd") ||
Fn->getName().contains("__enzyme_augmentsize") ||
Fn->getName().contains("__enzyme_reverse") ||
Fn->getName().contains("__enzyme_batch") ||
Fn->getName().contains("__enzyme_trace") ||
Fn->getName().contains("__enzyme_condition")))
continue;
SmallVector<Value *, 16> CallArgs(II->arg_begin(), II->arg_end());
SmallVector<OperandBundleDef, 1> OpBundles;
II->getOperandBundlesAsDefs(OpBundles);
// Insert a normal call instruction...
CallInst *NewCall =
CallInst::Create(II->getFunctionType(), II->getCalledOperand(),
CallArgs, OpBundles, "", II);
NewCall->takeName(II);
NewCall->setCallingConv(II->getCallingConv());
NewCall->setAttributes(II->getAttributes());
NewCall->setDebugLoc(II->getDebugLoc());
II->replaceAllUsesWith(NewCall);
// Insert an unconditional branch to the normal destination.
BranchInst::Create(II->getNormalDest(), II);
// Remove any PHI node entries from the exception destination.
II->getUnwindDest()->removePredecessor(&BB);
II->eraseFromParent();
Changed = true;
}
MapVector<CallInst *, DerivativeMode> toLower;
MapVector<CallInst *, DerivativeMode> toVirtual;
MapVector<CallInst *, DerivativeMode> toSize;
SmallVector<CallInst *, 4> toBatch;
MapVector<CallInst *, ProbProgMode> toProbProg;
SetVector<CallInst *> InactiveCalls;
SetVector<CallInst *> IterCalls;
auto &TLI = Logic.PPC.FAM.getResult<TargetLibraryAnalysis>(F);
auto &AA = Logic.PPC.getAAResultsFromFunction(&F);
retry:;
for (BasicBlock &BB : F) {
for (Instruction &I : BB) {
CallInst *CI = dyn_cast<CallInst>(&I);
if (!CI)
continue;
Function *Fn = CI->getCalledFunction();
if (auto castinst = dyn_cast<ConstantExpr>(CI->getCalledOperand())) {
if (castinst->isCast())
if (auto fn = dyn_cast<Function>(castinst->getOperand(0)))
Fn = fn;
}
if (!Fn)
continue;
#if LLVM_VERSION_MAJOR >= 14
size_t num_args = CI->arg_size();
#else
size_t num_args = CI->getNumArgOperands();
#endif
if (Fn->getName().contains("__enzyme_todense")) {
#if LLVM_VERSION_MAJOR >= 16
CI->setOnlyReadsMemory();
CI->setOnlyWritesMemory();
#else
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
#endif
}
if (Fn->getName().contains("__enzyme_float")) {
#if LLVM_VERSION_MAJOR >= 16
CI->setOnlyReadsMemory();
CI->setOnlyWritesMemory();
#else
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
#endif
for (size_t i = 0; i < num_args; ++i) {
if (CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::ReadNone);
CI->addParamAttr(i, Attribute::NoCapture);
}
}
}
if (Fn->getName().contains("__enzyme_integer")) {
#if LLVM_VERSION_MAJOR >= 16
CI->setOnlyReadsMemory();
CI->setOnlyWritesMemory();
#else
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
#endif
for (size_t i = 0; i < num_args; ++i) {
if (CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::ReadNone);
CI->addParamAttr(i, Attribute::NoCapture);
}
}
}
if (Fn->getName().contains("__enzyme_double")) {
#if LLVM_VERSION_MAJOR >= 16
CI->setOnlyReadsMemory();
CI->setOnlyWritesMemory();
#else
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
#endif
for (size_t i = 0; i < num_args; ++i) {
if (CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::ReadNone);
CI->addParamAttr(i, Attribute::NoCapture);
}
}
}
if (Fn->getName().contains("__enzyme_pointer")) {
#if LLVM_VERSION_MAJOR >= 16
CI->setOnlyReadsMemory();
CI->setOnlyWritesMemory();
#else
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
#endif
for (size_t i = 0; i < num_args; ++i) {
if (CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::ReadNone);
CI->addParamAttr(i, Attribute::NoCapture);
}
}
}
if (Fn->getName().contains("__enzyme_virtualreverse")) {
#if LLVM_VERSION_MAJOR >= 16
CI->setOnlyReadsMemory();
CI->setOnlyWritesMemory();
#else
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
#endif
}
if (Fn->getName().contains("__enzyme_iter")) {
#if LLVM_VERSION_MAJOR >= 16
CI->setOnlyReadsMemory();
CI->setOnlyWritesMemory();
#else
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
#endif
}
if (Fn->getName().contains("__enzyme_call_inactive")) {
InactiveCalls.insert(CI);
}
if (Fn->getName() == "omp_get_max_threads" ||
Fn->getName() == "omp_get_thread_num") {
#if LLVM_VERSION_MAJOR >= 16
Fn->setOnlyAccessesInaccessibleMemory();
CI->setOnlyAccessesInaccessibleMemory();
Fn->setOnlyReadsMemory();
CI->setOnlyReadsMemory();
#else
Fn->addFnAttr(Attribute::InaccessibleMemOnly);
CI->addAttribute(AttributeList::FunctionIndex,
Attribute::InaccessibleMemOnly);
Fn->addFnAttr(Attribute::ReadOnly);
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly);
#endif
}
if ((Fn->getName() == "cblas_ddot" || Fn->getName() == "cblas_sdot") &&
Fn->isDeclaration()) {
#if LLVM_VERSION_MAJOR >= 16
Fn->setOnlyAccessesArgMemory();
Fn->setOnlyReadsMemory();
CI->setOnlyReadsMemory();
#else
Fn->addFnAttr(Attribute::ArgMemOnly);
Fn->addFnAttr(Attribute::ReadOnly);
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly);
#endif
CI->addParamAttr(1, Attribute::ReadOnly);
CI->addParamAttr(1, Attribute::NoCapture);
CI->addParamAttr(3, Attribute::ReadOnly);
CI->addParamAttr(3, Attribute::NoCapture);
}
if (Fn->getName() == "frexp" || Fn->getName() == "frexpf" ||
Fn->getName() == "frexpl") {
#if LLVM_VERSION_MAJOR >= 16
CI->setOnlyAccessesArgMemory();
#else
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ArgMemOnly);
#endif
CI->addParamAttr(1, Attribute::WriteOnly);
}
if (Fn->getName() == "__fd_sincos_1" || Fn->getName() == "__fd_cos_1" ||
Fn->getName() == "__mth_i_ipowi") {
#if LLVM_VERSION_MAJOR >= 16
CI->setOnlyReadsMemory();
CI->setOnlyWritesMemory();
#else
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
#endif
}
if (Fn->getName().contains("strcmp")) {
Fn->addParamAttr(0, Attribute::ReadOnly);
Fn->addParamAttr(1, Attribute::ReadOnly);
#if LLVM_VERSION_MAJOR >= 16
Fn->setOnlyReadsMemory();
CI->setOnlyReadsMemory();
#else
Fn->addFnAttr(Attribute::ReadOnly);
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly);
#endif
}
if (Fn->getName() == "f90io_fmtw_end" ||
Fn->getName() == "f90io_unf_end") {
#if LLVM_VERSION_MAJOR >= 16
Fn->setOnlyAccessesInaccessibleMemory();
CI->setOnlyAccessesInaccessibleMemory();
#else
Fn->addFnAttr(Attribute::InaccessibleMemOnly);
CI->addAttribute(AttributeList::FunctionIndex,
Attribute::InaccessibleMemOnly);
#endif
}
if (Fn->getName() == "f90io_open2003a") {
#if LLVM_VERSION_MAJOR >= 16
Fn->setOnlyAccessesInaccessibleMemOrArgMem();
CI->setOnlyAccessesInaccessibleMemOrArgMem();
#else
Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
CI->addAttribute(AttributeList::FunctionIndex,
Attribute::InaccessibleMemOrArgMemOnly);
#endif
for (size_t i : {0, 1, 2, 3, 4, 5, 6, 7, /*8, */ 9, 10, 11, 12, 13}) {
if (i < num_args &&
CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::ReadOnly);
}
}
// todo more
for (size_t i : {0, 1}) {
if (i < num_args &&
CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::NoCapture);
}
}
}
if (Fn->getName() == "f90io_fmtw_inita") {
#if LLVM_VERSION_MAJOR >= 16
Fn->setOnlyAccessesInaccessibleMemOrArgMem();
CI->setOnlyAccessesInaccessibleMemOrArgMem();
#else
Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
CI->addAttribute(AttributeList::FunctionIndex,
Attribute::InaccessibleMemOrArgMemOnly);
#endif
// todo more
for (size_t i : {0, 2}) {
if (i < num_args &&
CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::ReadOnly);
}
}
// todo more
for (size_t i : {0, 2}) {
if (i < num_args &&
CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::NoCapture);
}
}
}
if (Fn->getName() == "f90io_unf_init") {
#if LLVM_VERSION_MAJOR >= 16
Fn->setOnlyAccessesInaccessibleMemOrArgMem();
CI->setOnlyAccessesInaccessibleMemOrArgMem();
#else
Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
CI->addAttribute(AttributeList::FunctionIndex,
Attribute::InaccessibleMemOrArgMemOnly);
#endif
// todo more
for (size_t i : {0, 1, 2, 3}) {
if (i < num_args &&
CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::ReadOnly);
}
}
// todo more
for (size_t i : {0, 1, 2, 3}) {
if (i < num_args &&
CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::NoCapture);
}
}
}
if (Fn->getName() == "f90io_src_info03a") {
#if LLVM_VERSION_MAJOR >= 16
Fn->setOnlyAccessesInaccessibleMemOrArgMem();
CI->setOnlyAccessesInaccessibleMemOrArgMem();
#else
Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
CI->addAttribute(AttributeList::FunctionIndex,
Attribute::InaccessibleMemOrArgMemOnly);
#endif
// todo more
for (size_t i : {0, 1}) {
if (i < num_args &&
CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::ReadOnly);
}
}
// todo more
for (size_t i : {0}) {
if (i < num_args &&
CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::NoCapture);
}
}
}
if (Fn->getName() == "f90io_sc_d_fmt_write" ||
Fn->getName() == "f90io_sc_i_fmt_write" ||
Fn->getName() == "ftnio_fmt_write64" ||
Fn->getName() == "f90io_fmt_write64_aa" ||
Fn->getName() == "f90io_fmt_writea" ||
Fn->getName() == "f90io_unf_writea" ||
Fn->getName() == "f90_pausea") {
#if LLVM_VERSION_MAJOR >= 16
Fn->setOnlyAccessesInaccessibleMemOrArgMem();
CI->setOnlyAccessesInaccessibleMemOrArgMem();
#else
Fn->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
CI->addAttribute(AttributeList::FunctionIndex,
Attribute::InaccessibleMemOrArgMemOnly);
#endif
for (size_t i = 0; i < num_args; ++i) {
if (CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::ReadOnly);
CI->addParamAttr(i, Attribute::NoCapture);
}
}
}
bool enableEnzyme = false;
bool virtualCall = false;
bool sizeOnly = false;
bool batch = false;
bool probProg = false;
DerivativeMode derivativeMode;
ProbProgMode probProgMode;
if (Fn->getName().contains("__enzyme_autodiff")) {
enableEnzyme = true;
derivativeMode = DerivativeMode::ReverseModeCombined;
} else if (Fn->getName().contains("__enzyme_fwddiff")) {
enableEnzyme = true;
derivativeMode = DerivativeMode::ForwardMode;
} else if (Fn->getName().contains("__enzyme_fwdsplit")) {
enableEnzyme = true;
derivativeMode = DerivativeMode::ForwardModeSplit;
} else if (Fn->getName().contains("__enzyme_augmentfwd")) {
enableEnzyme = true;
derivativeMode = DerivativeMode::ReverseModePrimal;
} else if (Fn->getName().contains("__enzyme_augmentsize")) {
enableEnzyme = true;
sizeOnly = true;
derivativeMode = DerivativeMode::ReverseModePrimal;
} else if (Fn->getName().contains("__enzyme_reverse")) {
enableEnzyme = true;
derivativeMode = DerivativeMode::ReverseModeGradient;
} else if (Fn->getName().contains("__enzyme_virtualreverse")) {
enableEnzyme = true;
virtualCall = true;
derivativeMode = DerivativeMode::ReverseModeCombined;
} else if (Fn->getName().contains("__enzyme_batch")) {
enableEnzyme = true;
batch = true;
} else if (Fn->getName().contains("__enzyme_likelihood")) {
enableEnzyme = true;
probProgMode = ProbProgMode::Likelihood;
probProg = true;
} else if (Fn->getName().contains("__enzyme_trace")) {
enableEnzyme = true;
probProgMode = ProbProgMode::Trace;
probProg = true;
} else if (Fn->getName().contains("__enzyme_condition")) {
enableEnzyme = true;
probProgMode = ProbProgMode::Condition;
probProg = true;
}
if (enableEnzyme) {
Value *fn = CI->getArgOperand(0);
GetFunctionFromValue(fn, AA, TLI, &fn);
if (auto si = dyn_cast<SelectInst>(fn)) {
BasicBlock *post = BB.splitBasicBlock(CI);
BasicBlock *sel1 = BasicBlock::Create(BB.getContext(), "sel1", &F);
BasicBlock *sel2 = BasicBlock::Create(BB.getContext(), "sel2", &F);
BB.getTerminator()->eraseFromParent();
IRBuilder<> PB(&BB);
PB.CreateCondBr(si->getCondition(), sel1, sel2);
IRBuilder<> S1(sel1);
auto B1 = S1.CreateBr(post);
CallInst *cloned = cast<CallInst>(CI->clone());
cloned->insertBefore(B1);
cloned->setOperand(0, si->getTrueValue());
IRBuilder<> S2(sel2);
auto B2 = S2.CreateBr(post);
CI->moveBefore(B2);
CI->setOperand(0, si->getFalseValue());
if (CI->getNumUses() != 0) {
IRBuilder<> P(post->getFirstNonPHI());
auto merge = P.CreatePHI(CI->getType(), 2);
merge->addIncoming(cloned, sel1);
merge->addIncoming(CI, sel2);
CI->replaceAllUsesWith(merge);
}
goto retry;
}
if (virtualCall)
toVirtual[CI] = derivativeMode;
else if (sizeOnly)
toSize[CI] = derivativeMode;
else if (batch)
toBatch.push_back(CI);
else if (probProg) {
toProbProg[CI] = probProgMode;
} else
toLower[CI] = derivativeMode;
if (auto dc = dyn_cast<Function>(fn)) {
// Force postopt on any inner functions in the nested
// AD case.
bool tmp = Logic.PostOpt;
Logic.PostOpt = true;
Changed |= lowerEnzymeCalls(*dc, done);
Logic.PostOpt = tmp;
}
}
}
}
for (auto CI : InactiveCalls) {
IRBuilder<> B(CI);
Value *fn = CI->getArgOperand(0);
SmallVector<Value *, 4> Args;
SmallVector<Type *, 4> ArgTypes;
#if LLVM_VERSION_MAJOR >= 14
for (size_t i = 1; i < CI->arg_size(); ++i)
#else
for (size_t i = 1; i < CI->getNumArgOperands(); ++i)
#endif
{
Args.push_back(CI->getArgOperand(i));
ArgTypes.push_back(CI->getArgOperand(i)->getType());
}
auto FT = FunctionType::get(CI->getType(), ArgTypes, /*varargs*/ false);
if (fn->getType() != FT) {
fn = B.CreatePointerCast(fn, PointerType::getUnqual(FT));
}
auto Rep = B.CreateCall(FT, fn, Args);
Rep->addAttribute(AttributeList::FunctionIndex,
Attribute::get(Rep->getContext(), "enzyme_inactive"));
CI->replaceAllUsesWith(Rep);
CI->eraseFromParent();
Changed = true;
}
// Perform all the size replacements first to create constants
for (auto pair : toSize) {
bool successful = HandleAutoDiffArguments(pair.first, AA, TLI, pair.second,
/*sizeOnly*/ true);
Changed = true;
if (!successful)
break;
}
for (auto pair : toLower) {
bool successful = HandleAutoDiffArguments(pair.first, AA, TLI, pair.second,
/*sizeOnly*/ false);
Changed = true;
if (!successful)
break;
}
for (auto pair : toVirtual) {
auto CI = pair.first;
Constant *fn = dyn_cast<Constant>(CI->getArgOperand(0));
if (!fn) {
EmitFailure("IllegalVirtual", CI->getDebugLoc(), CI,
"Cannot create virtual version of non-constant value ", *CI,
*CI->getArgOperand(0));
return false;
}
TypeAnalysis TA(Logic.PPC.FAM);
auto Arch =
llvm::Triple(
CI->getParent()->getParent()->getParent()->getTargetTriple())
.getArch();
bool AtomicAdd = Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
Arch == Triple::amdgcn;
auto val = GradientUtils::GetOrCreateShadowConstant(
Logic, Logic.PPC.FAM.getResult<TargetLibraryAnalysis>(F), TA, fn,
pair.second, /*width*/ 1, AtomicAdd);
CI->replaceAllUsesWith(ConstantExpr::getPointerCast(val, CI->getType()));
CI->eraseFromParent();
Changed = true;
}
for (auto call : toBatch) {
HandleBatch(call, AA, TLI);
}
for (auto &&[call, mode] : toProbProg) {
HandleProbProg(call, AA, TLI, mode);
}
if (Changed && EnzymeAttributor) {
// TODO consider enabling when attributor does not delete
// dead internal functions, which invalidates Enzyme's cache
// code left here to re-enable upon Attributor patch
#if LLVM_VERSION_MAJOR >= 13 && !defined(FLANG) && !defined(ROCM)
AnalysisGetter AG(Logic.PPC.FAM);
SetVector<Function *> Functions;
for (Function &F2 : *F.getParent()) {
Functions.insert(&F2);
}
CallGraphUpdater CGUpdater;
BumpPtrAllocator Allocator;
InformationCache InfoCache(*F.getParent(), AG, Allocator,
/* CGSCC */ nullptr);
DenseSet<const char *> Allowed = {
&AAHeapToStack::ID,
&AANoCapture::ID,
&AAMemoryBehavior::ID,
&AAMemoryLocation::ID,
&AANoUnwind::ID,
&AANoSync::ID,
&AANoRecurse::ID,
&AAWillReturn::ID,
&AANoReturn::ID,
&AANonNull::ID,
&AANoAlias::ID,
&AADereferenceable::ID,
&AAAlign::ID,
#if LLVM_VERSION_MAJOR < 18
&AAReturnedValues::ID,
#endif
&AANoFree::ID,
&AANoUndef::ID,
//&AAValueSimplify::ID,
//&AAReachability::ID,
//&AAValueConstantRange::ID,
//&AAUndefinedBehavior::ID,
//&AAPotentialValues::ID,
};
#if LLVM_VERSION_MAJOR >= 15
AttributorConfig aconfig(CGUpdater);
aconfig.Allowed = &Allowed;
aconfig.DeleteFns = false;
Attributor A(Functions, InfoCache, aconfig);
#else
Attributor A(Functions, InfoCache, CGUpdater, &Allowed,
/*DeleteFns*/ false);
#endif
for (Function *F : Functions) {
// Populate the Attributor with abstract attribute opportunities in
// the function and the information cache with IR information.
A.identifyDefaultAbstractAttributes(*F);
}
A.run();
#endif
}
return Changed;
}
bool run(Module &M) {
Logic.clear();
bool changed = false;
for (Function &F : M) {
attributeKnownFunctions(F);
if (F.empty())
continue;
SmallVector<Instruction *, 4> toErase;
for (BasicBlock &BB : F) {
for (Instruction &I : BB) {
if (auto CI = dyn_cast<CallInst>(&I)) {
Function *F = CI->getCalledFunction();
if (auto castinst =
dyn_cast<ConstantExpr>(CI->getCalledOperand())) {
if (castinst->isCast())
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
F = fn;
}
}
if (F && F->getName() == "f90_mzero8") {
toErase.push_back(CI);
IRBuilder<> B(CI);
SmallVector<Value *, 4> args;
args.push_back(CI->getArgOperand(0));
args.push_back(
ConstantInt::get(Type::getInt8Ty(M.getContext()), 0));
args.push_back(B.CreateMul(
CI->getArgOperand(1),
ConstantInt::get(CI->getArgOperand(1)->getType(), 8)));
args.push_back(ConstantInt::getFalse(M.getContext()));
Type *tys[] = {args[0]->getType(), args[2]->getType()};
auto memsetIntr =
Intrinsic::getDeclaration(&M, Intrinsic::memset, tys);
B.CreateCall(memsetIntr, args);
}
}
}
}
for (Instruction *I : toErase) {
I->eraseFromParent();
}
}
#if LLVM_VERSION_MAJOR >= 13
if (Logic.PostOpt && EnzymeOMPOpt) {
OpenMPOptPass().run(M, Logic.PPC.MAM);
/// Attributor is run second time for promoted args to get attributes.
AttributorPass().run(M, Logic.PPC.MAM);
for (auto &F : M)
if (!F.empty())
PromotePass().run(F, Logic.PPC.FAM);
changed = true;
}
#endif
std::set<Function *> done;
for (Function &F : M) {
if (F.empty())
continue;
changed |= lowerEnzymeCalls(F, done);
}
SmallVector<CallInst *, 4> toErase;
for (Function &F : M) {
if (F.empty())
continue;
for (BasicBlock &BB : F) {
for (Instruction &I : BB) {
if (auto CI = dyn_cast<CallInst>(&I)) {
Function *F = CI->getCalledFunction();
if (auto castinst =
dyn_cast<ConstantExpr>(CI->getCalledOperand())) {
if (castinst->isCast())
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
F = fn;
}
}
if (F) {
if (F->getName().contains("__enzyme_float") ||
F->getName().contains("__enzyme_double") ||
F->getName().contains("__enzyme_integer") ||
F->getName().contains("__enzyme_pointer")) {
toErase.push_back(CI);
}
if (F->getName() == "__enzyme_iter") {
CI->replaceAllUsesWith(CI->getArgOperand(0));
toErase.push_back(CI);
}
}
}
}
}
}
for (auto I : toErase) {
I->eraseFromParent();
changed = true;
}
SmallPtrSet<CallInst *, 16> sample_calls;
SmallPtrSet<CallInst *, 16> observe_calls;
for (auto &&func : M) {
for (auto &&BB : func) {
for (auto &&Inst : BB) {
if (auto CI = dyn_cast<CallInst>(&Inst)) {
Function *fun = CI->getCalledFunction();
if (!fun)
continue;
if (fun->getName().contains("__enzyme_sample")) {
if (CI->getNumOperands() < 3) {
EmitFailure(
"IllegalNumberOfArguments", CI->getDebugLoc(), CI,
"Not enough arguments passed to call to __enzyme_sample");
}
Function *samplefn = GetFunctionFromValue(CI->getOperand(0), AA, TLI);
unsigned expected =
samplefn->getFunctionType()->getNumParams() + 3;
#if LLVM_VERSION_MAJOR >= 14
unsigned actual = CI->arg_size();
#else
unsigned actual = CI->getNumArgOperands();
#endif
if (actual - 3 != samplefn->getFunctionType()->getNumParams()) {
EmitFailure("IllegalNumberOfArguments", CI->getDebugLoc(), CI,
"Illegal number of arguments passed to call to "
"__enzyme_sample.",
" Expected: ", expected, " got: ", actual);
}
Function *pdf = GetFunctionFromValue(CI->getArgOperand(1), AA, TLI);
for (unsigned i = 0;
i < samplefn->getFunctionType()->getNumParams(); ++i) {
Value *ci_arg = CI->getArgOperand(i + 3);
Value *sample_arg = samplefn->arg_begin() + i;
Value *pdf_arg = pdf->arg_begin() + i;
if (ci_arg->getType() != sample_arg->getType()) {
EmitFailure(
"IllegalSampleType", CI->getDebugLoc(), CI,
"Type of: ", *ci_arg, " (", *ci_arg->getType(), ")",
" does not match the argument type of the sample "
"function: ",
*samplefn, " at: ", i, " (", *sample_arg->getType(), ")");
}
if (ci_arg->getType() != pdf_arg->getType()) {
EmitFailure("IllegalSampleType", CI->getDebugLoc(), CI,
"Type of: ", *ci_arg, " (", *ci_arg->getType(),
")",
" does not match the argument type of the "
"density function: ",
*pdf, " at: ", i, " (", *pdf_arg->getType(), ")");
}
}
if ((pdf->arg_end() - 1)->getType() !=
samplefn->getReturnType()) {
EmitFailure(
"IllegalSampleType", CI->getDebugLoc(), CI,
"Return type of ", *samplefn, " (",
*samplefn->getReturnType(), ")",
" does not match the last argument type of the density "
"function: ",
*pdf, " (", *(pdf->arg_end() - 1)->getType(), ")");
}
sample_calls.insert(CI);
} else if (fun->getName().contains("__enzyme_observe")) {
if (CI->getNumOperands() < 3) {
EmitFailure(
"IllegalNumberOfArguments", CI->getDebugLoc(), CI,
"Not enough arguments passed to call to __enzyme_sample");
}
Value *observed = CI->getOperand(0);
Function *pdf = GetFunctionFromValue(CI->getArgOperand(1), AA, TLI);
unsigned expected = pdf->getFunctionType()->getNumParams() - 1;
#if LLVM_VERSION_MAJOR >= 14
unsigned actual = CI->arg_size();
#else
unsigned actual = CI->getNumArgOperands();
#endif
if (actual - 3 != expected) {
EmitFailure("IllegalNumberOfArguments", CI->getDebugLoc(), CI,
"Illegal number of arguments passed to call to "
"__enzyme_observe.",
" Expected: ", expected, " got: ", actual);
}
for (unsigned i = 0;
i < pdf->getFunctionType()->getNumParams() - 1; ++i) {
Value *ci_arg = CI->getArgOperand(i + 3);
Value *pdf_arg = pdf->arg_begin() + i;
if (ci_arg->getType() != pdf_arg->getType()) {
EmitFailure("IllegalSampleType", CI->getDebugLoc(), CI,
"Type of: ", *ci_arg, " (", *ci_arg->getType(),
")",
" does not match the argument type of the "
"density function: ",
*pdf, " at: ", i, " (", *pdf_arg->getType(), ")");
}
}
if ((pdf->arg_end() - 1)->getType() != observed->getType()) {
EmitFailure(
"IllegalSampleType", CI->getDebugLoc(), CI,
"Return type of ", *observed, " (", *observed->getType(),
")",
" does not match the last argument type of the density "
"function: ",
*pdf, " (", *(pdf->arg_end() - 1)->getType(), ")");
}
observe_calls.insert(CI);
}
}
}
}
}
// Replace calls to __enzyme_sample with the actual sample calls after
// running prob prog
for (auto call : sample_calls) {
Function *samplefn = GetFunctionFromValue(call->getArgOperand(0));
SmallVector<Value *, 2> args;
for (auto it = call->arg_begin() + 3; it != call->arg_end(); it++) {
args.push_back(*it);
}
CallInst *choice =
CallInst::Create(samplefn->getFunctionType(), samplefn, args);
ReplaceInstWithInst(call, choice);
}
for (auto call : observe_calls) {
Value *observed = call->getArgOperand(0);
if (!call->getType()->isVoidTy())
call->replaceAllUsesWith(observed);
call->eraseFromParent();
}
for (const auto &pair : Logic.PPC.cache)
pair.second->eraseFromParent();
Logic.clear();
if (changed && Logic.PostOpt) {
PassBuilder PB;
LoopAnalysisManager LAM;
FunctionAnalysisManager FAM;
CGSCCAnalysisManager CGAM;
ModuleAnalysisManager MAM;
PB.registerModuleAnalyses(MAM);
PB.registerFunctionAnalyses(FAM);
PB.registerLoopAnalyses(LAM);
PB.registerCGSCCAnalyses(CGAM);
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
#if LLVM_VERSION_MAJOR >= 14
auto PM = PB.buildModuleSimplificationPipeline(OptimizationLevel::O2,
ThinOrFullLTOPhase::None);
#elif LLVM_VERSION_MAJOR >= 12
auto PM = PB.buildModuleSimplificationPipeline(
PassBuilder::OptimizationLevel::O2, ThinOrFullLTOPhase::None);
#else
auto PM = PB.buildModuleSimplificationPipeline(
PassBuilder::OptimizationLevel::O2, PassBuilder::ThinLTOPhase::None);
#endif
PM.run(M, MAM);
#if LLVM_VERSION_MAJOR >= 13
if (EnzymeOMPOpt) {
OpenMPOptPass().run(M, MAM);
/// Attributor is run second time for promoted args to get attributes.
AttributorPass().run(M, MAM);
for (auto &F : M)
if (!F.empty())
PromotePass().run(F, FAM);
}
#endif
}
for (auto &F : M) {
if (!F.empty())
changed |= LowerSparsification(&F);
}
return changed;
}
};
class EnzymeOldPM : public EnzymeBase, public ModulePass {
public:
static char ID;
EnzymeOldPM(bool PostOpt = false) : EnzymeBase(PostOpt), ModulePass(ID) {}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<TargetLibraryInfoWrapperPass>();
// AU.addRequiredID(LCSSAID);
// LoopInfo is required to ensure that all loops have preheaders
// AU.addRequired<LoopInfoWrapperPass>();
// AU.addRequiredID(llvm::LoopSimplifyID);//<LoopSimplifyWrapperPass>();
}
bool runOnModule(Module &M) override { return run(M); }
};
} // namespace
char EnzymeOldPM::ID = 0;
static RegisterPass<EnzymeOldPM> X("enzyme", "Enzyme Pass");
ModulePass *createEnzymePass(bool PostOpt) { return new EnzymeOldPM(PostOpt); }
#include <llvm-c/Core.h>
#include <llvm-c/Types.h>
#include "llvm/IR/LegacyPassManager.h"
extern "C" void AddEnzymePass(LLVMPassManagerRef PM) {
unwrap(PM)->add(createEnzymePass(/*PostOpt*/ false));
}
#include "llvm/Passes/PassPlugin.h"
class EnzymeNewPM final : public EnzymeBase,
public AnalysisInfoMixin<EnzymeNewPM> {
friend struct llvm::AnalysisInfoMixin<EnzymeNewPM>;
private:
static llvm::AnalysisKey Key;
public:
using Result = llvm::PreservedAnalyses;
EnzymeNewPM(bool PostOpt = false) : EnzymeBase(PostOpt) {}
Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) {
return EnzymeBase::run(M) ? PreservedAnalyses::none()
: PreservedAnalyses::all();
}
static bool isRequired() { return true; }
};
#undef DEBUG_TYPE
AnalysisKey EnzymeNewPM::Key;
#include "ActivityAnalysisPrinter.h"
#include "PreserveNVVM.h"
#include "TypeAnalysis/TypeAnalysisPrinter.h"
#include "llvm/Passes/PassBuilder.h"
#if LLVM_VERSION_MAJOR >= 15
#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
#include "llvm/Transforms/IPO/CalledValuePropagation.h"
#include "llvm/Transforms/IPO/ConstantMerge.h"
#include "llvm/Transforms/IPO/CrossDSOCFI.h"
#include "llvm/Transforms/IPO/DeadArgumentElimination.h"
#include "llvm/Transforms/IPO/FunctionAttrs.h"
#include "llvm/Transforms/IPO/GlobalDCE.h"
#endif
#include "llvm/Transforms/IPO/GlobalOpt.h"
#if LLVM_VERSION_MAJOR >= 15
#include "llvm/Transforms/IPO/GlobalSplit.h"
#include "llvm/Transforms/IPO/InferFunctionAttrs.h"
#include "llvm/Transforms/IPO/SCCP.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Scalar/CallSiteSplitting.h"
#endif
#include "llvm/Transforms/Scalar/EarlyCSE.h"
#include "llvm/Transforms/Scalar/Float2Int.h"
#include "llvm/Transforms/Scalar/GVN.h"
#include "llvm/Transforms/Scalar/LoopDeletion.h"
#include "llvm/Transforms/Scalar/LoopRotation.h"
#include "llvm/Transforms/Scalar/LoopUnrollPass.h"
#include "llvm/Transforms/Scalar/SROA.h"
#if LLVM_VERSION_MAJOR >= 12
// #include "llvm/Transforms/IPO/MemProfContextDisambiguation.h"
#include "llvm/Transforms/IPO/ArgumentPromotion.h"
#include "llvm/Transforms/Scalar/ConstraintElimination.h"
#include "llvm/Transforms/Scalar/DeadStoreElimination.h"
#include "llvm/Transforms/Scalar/JumpThreading.h"
#include "llvm/Transforms/Scalar/MemCpyOptimizer.h"
#include "llvm/Transforms/Scalar/NewGVN.h"
#include "llvm/Transforms/Scalar/TailRecursionElimination.h"
#if LLVM_VERSION_MAJOR >= 17
#include "llvm/Transforms/Utils/MoveAutoInit.h"
#endif
#include "llvm/Transforms/Scalar/IndVarSimplify.h"
#include "llvm/Transforms/Scalar/LICM.h"
#include "llvm/Transforms/Scalar/LoopFlatten.h"
#include "llvm/Transforms/Scalar/MergedLoadStoreMotion.h"
#if LLVM_VERSION_MAJOR < 14
static InlineParams
getInlineParamsFromOptLevel(llvm::PassBuilder::OptimizationLevel Level)
#else
static InlineParams getInlineParamsFromOptLevel(OptimizationLevel Level)
#endif
{
return getInlineParams(Level.getSpeedupLevel(), Level.getSizeLevel());
}
#if LLVM_VERSION_MAJOR >= 12
#include "llvm/Transforms/Scalar/LowerConstantIntrinsics.h"
#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
namespace llvm {
extern cl::opt<unsigned> SetLicmMssaNoAccForPromotionCap;
extern cl::opt<unsigned> SetLicmMssaOptCap;
#define EnableLoopFlatten false
#define EagerlyInvalidateAnalyses false
#define RunNewGVN false
#define EnableConstraintElimination true
#define UseInlineAdvisor InliningAdvisorMode::Default
#define EnableMemProfContextDisambiguation false
// extern cl::opt<bool> EnableMatrix;
#define EnableMatrix false
#define EnableModuleInliner false
#if LLVM_VERSION_MAJOR <= 14
// extern cl::opt<bool> EnableFunctionSpecialization;
#define EnableFunctionSpecialization false
// extern cl::opt<bool> RunPartialInlining;
#define RunPartialInlining false
#endif
} // namespace llvm
#if LLVM_VERSION_MAJOR <= 14
#include "llvm/Transforms/IPO/CalledValuePropagation.h"
#include "llvm/Transforms/IPO/DeadArgumentElimination.h"
#include "llvm/Transforms/IPO/SCCP.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Scalar/SimplifyCFG.h"
#if LLVM_VERSION_MAJOR >= 12
#include "llvm/Transforms/Coroutines/CoroCleanup.h"
#endif
#include "llvm/Transforms/IPO/FunctionAttrs.h"
#include "llvm/Transforms/IPO/GlobalDCE.h"
#include "llvm/Transforms/IPO/PartialInlining.h"
#if LLVM_VERSION_MAJOR <= 12
#include "llvm/Transforms/Utils/Mem2Reg.h"
#endif
#endif
#endif
#endif
void augmentPassBuilder(llvm::PassBuilder &PB) {
#if LLVM_VERSION_MAJOR < 14
using OptimizationLevel = llvm::PassBuilder::OptimizationLevel;
#endif
auto PB0 = new llvm::PassBuilder(PB);
#if LLVM_VERSION_MAJOR >= 12
auto prePass = [PB0](ModulePassManager &MPM, OptimizationLevel Level)
#else
auto prePass = [PB0](ModulePassManager &MPM)
#endif
{
#if LLVM_VERSION_MAJOR < 12
llvm_unreachable("New Pass manager pipeline unsupported at version <= 11");
#else
#if LLVM_VERSION_MAJOR < 15
////// End of Module simplification
// Specialize functions with IPSCCP.
#if LLVM_VERSION_MAJOR >= 13
if (EnableFunctionSpecialization && Level == OptimizationLevel::O3)
MPM.addPass(FunctionSpecializationPass());
#endif
// Interprocedural constant propagation now that basic cleanup has
// occurred and prior to optimizing globals.
// FIXME: This position in the pipeline hasn't been carefully
// considered in years, it should be re-analyzed.
MPM.addPass(IPSCCPPass());
// Attach metadata to indirect call sites indicating the set of
// functions they may target at run-time. This should follow
// IPSCCP.
MPM.addPass(CalledValuePropagationPass());
// Optimize globals to try and fold them into constants.
MPM.addPass(GlobalOptPass());
// Promote any localized globals to SSA registers.
// FIXME: Should this instead by a run of SROA?
// FIXME: We should probably run instcombine and simplifycfg
// afterward to delete control flows that are dead once globals
// have been folded to constants.
MPM.addPass(createModuleToFunctionPassAdaptor(PromotePass()));
// Remove any dead arguments exposed by cleanups and constant
// folding globals.
MPM.addPass(DeadArgumentEliminationPass());
// Create a small function pass pipeline to cleanup after all the
// global optimizations.
FunctionPassManager GlobalCleanupPM;
GlobalCleanupPM.addPass(InstCombinePass());
#if LLVM_VERSION_MAJOR >= 14
GlobalCleanupPM.addPass(
SimplifyCFGPass(SimplifyCFGOptions().convertSwitchRangeToICmp(true)));
#else
GlobalCleanupPM.addPass(SimplifyCFGPass(SimplifyCFGOptions()));
#endif
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(GlobalCleanupPM)));
ThinOrFullLTOPhase Phase = ThinOrFullLTOPhase::None;
#if LLVM_VERSION >= 13
bool EnableModuleInliner = false;
if (EnableModuleInliner)
MPM.addPass(PB0->buildModuleInlinerPipeline(Level, Phase));
else
#endif
MPM.addPass(PB0->buildInlinerPipeline(Level, Phase));
FunctionPassManager CoroCleanupPM;
CoroCleanupPM.addPass(CoroCleanupPass());
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(CoroCleanupPM)));
////// Finished Module simplification, starting ModuleOptimization
//
// Optimize globals now that the module is fully simplified.
MPM.addPass(GlobalOptPass());
MPM.addPass(GlobalDCEPass());
// Run partial inlining pass to partially inline functions that
// have large bodies.
if (RunPartialInlining)
MPM.addPass(PartialInlinerPass());
// Do RPO function attribute inference across the module to
// forward-propagate attributes where applicable.
// FIXME: Is this really an optimization rather than a
// canonicalization?
MPM.addPass(ReversePostOrderFunctionAttrsPass());
#endif
FunctionPassManager OptimizePM;
OptimizePM.addPass(Float2IntPass());
OptimizePM.addPass(LowerConstantIntrinsicsPass());
if (EnableMatrix) {
OptimizePM.addPass(LowerMatrixIntrinsicsPass());
OptimizePM.addPass(EarlyCSEPass());
}
LoopPassManager LPM;
bool LTOPreLink = false;
// First rotate loops that may have been un-rotated by prior passes.
// Disable header duplication at -Oz.
LPM.addPass(LoopRotatePass(Level != OptimizationLevel::Oz, LTOPreLink));
// Some loops may have become dead by now. Try to delete them.
// FIXME: see discussion in https://reviews.llvm.org/D112851,
// this may need to be revisited once we run GVN before
// loop deletion in the simplification pipeline.
LPM.addPass(LoopDeletionPass());
LPM.addPass(llvm::LoopFullUnrollPass());
OptimizePM.addPass(createFunctionToLoopPassAdaptor(std::move(LPM)));
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizePM)));
#endif
};
#if LLVM_VERSION_MAJOR >= 12
auto loadPass = [prePass](ModulePassManager &MPM, OptimizationLevel Level)
#else
auto loadPass = [prePass](ModulePassManager &MPM)
#endif
{
MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true));
#if LLVM_VERSION_MAJOR >= 12
if (Level != OptimizationLevel::O0)
prePass(MPM, Level);
#else
prePass(MPM);
#endif
FunctionPassManager OptimizerPM;
FunctionPassManager OptimizerPM2;
#if LLVM_VERSION_MAJOR >= 16
OptimizerPM.addPass(llvm::GVNPass());
OptimizerPM.addPass(llvm::SROAPass(llvm::SROAOptions::PreserveCFG));
#elif LLVM_VERSION_MAJOR >= 14
OptimizerPM.addPass(llvm::GVNPass());
OptimizerPM.addPass(llvm::SROAPass());
#else
OptimizerPM.addPass(llvm::GVN());
OptimizerPM.addPass(llvm::SROA());
#endif
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizerPM)));
MPM.addPass(EnzymeNewPM(/*PostOpt=*/true));
MPM.addPass(PreserveNVVMNewPM(/*Begin*/ false));
#if LLVM_VERSION_MAJOR >= 16
OptimizerPM2.addPass(llvm::GVNPass());
OptimizerPM2.addPass(llvm::SROAPass(llvm::SROAOptions::PreserveCFG));
#elif LLVM_VERSION_MAJOR >= 14
OptimizerPM2.addPass(llvm::GVNPass());
OptimizerPM2.addPass(llvm::SROAPass());
#else
OptimizerPM2.addPass(llvm::GVN());
OptimizerPM2.addPass(llvm::SROA());
#endif
LoopPassManager LPM1;
LPM1.addPass(LoopDeletionPass());
OptimizerPM2.addPass(createFunctionToLoopPassAdaptor(std::move(LPM1)));
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizerPM2)));
MPM.addPass(GlobalOptPass());
};
// TODO need for perf reasons to move Enzyme pass to the pre vectorization.
#if LLVM_VERSION_MAJOR >= 15
PB.registerOptimizerEarlyEPCallback(loadPass);
#elif LLVM_VERSION_MAJOR >= 12
PB.registerPipelineEarlySimplificationEPCallback(loadPass);
#else
PB.registerPipelineStartEPCallback(loadPass);
#endif
#if LLVM_VERSION_MAJOR >= 12
auto loadNVVM = [](ModulePassManager &MPM, OptimizationLevel)
#else
auto loadNVVM = [](ModulePassManager &MPM)
#endif
{ MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true)); };
// We should register at vectorizer start for consistency, however,
// that requires a functionpass, and we have a modulepass.
// PB.registerVectorizerStartEPCallback(loadPass);
PB.registerPipelineStartEPCallback(loadNVVM);
#if LLVM_VERSION_MAJOR >= 15
PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadNVVM);
auto preLTOPass = [](ModulePassManager &MPM, OptimizationLevel Level) {
// Create a function that performs CFI checks for cross-DSO calls with
// targets in the current module.
MPM.addPass(CrossDSOCFIPass());
if (Level == OptimizationLevel::O0) {
return;
}
// Try to run OpenMP optimizations, quick no-op if no OpenMP metadata
// present.
#if LLVM_VERSION_MAJOR >= 16
MPM.addPass(OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink));
#else
MPM.addPass(OpenMPOptPass());
#endif
// Remove unused virtual tables to improve the quality of code
// generated by whole-program devirtualization and bitset lowering.
MPM.addPass(GlobalDCEPass());
// Do basic inference of function attributes from known properties of
// system libraries and other oracles.
MPM.addPass(InferFunctionAttrsPass());
if (Level.getSpeedupLevel() > 1) {
MPM.addPass(createModuleToFunctionPassAdaptor(CallSiteSplittingPass(),
EagerlyInvalidateAnalyses));
// Indirect call promotion. This should promote all the targets that
// are left by the earlier promotion pass that promotes intra-module
// targets. This two-step promotion is to save the compile time. For
// LTO, it should produce the same result as if we only do promotion
// here.
// MPM.addPass(PGOIndirectCallPromotion(
// true /* InLTO */, PGOOpt && PGOOpt->Action ==
// PGOOptions::SampleUse));
// Propagate constants at call sites into the functions they call.
// This opens opportunities for globalopt (and inlining) by
// substituting function pointers passed as arguments to direct uses
// of functions.
#if LLVM_VERSION_MAJOR >= 16
MPM.addPass(IPSCCPPass(IPSCCPOptions(/*AllowFuncSpec=*/
Level != OptimizationLevel::Os &&
Level != OptimizationLevel::Oz)));
#else
MPM.addPass(IPSCCPPass());
#endif
// Attach metadata to indirect call sites indicating the set of
// functions they may target at run-time. This should follow IPSCCP.
MPM.addPass(CalledValuePropagationPass());
}
// Now deduce any function attributes based in the current code.
MPM.addPass(
createModuleToPostOrderCGSCCPassAdaptor(PostOrderFunctionAttrsPass()));
// Do RPO function attribute inference across the module to
// forward-propagate attributes where applicable.
// FIXME: Is this really an optimization rather than a
// canonicalization?
MPM.addPass(ReversePostOrderFunctionAttrsPass());
// Use in-range annotations on GEP indices to split globals where
// beneficial.
MPM.addPass(GlobalSplitPass());
// Run whole program optimization of virtual call when the list of
// callees is fixed. MPM.addPass(WholeProgramDevirtPass(ExportSummary,
// nullptr));
// Stop here at -O1.
if (Level == OptimizationLevel::O1) {
return;
}
// Optimize globals to try and fold them into constants.
MPM.addPass(GlobalOptPass());
// Promote any localized globals to SSA registers.
MPM.addPass(createModuleToFunctionPassAdaptor(PromotePass()));
// Linking modules together can lead to duplicate global constant,
// only keep one copy of each constant.
MPM.addPass(ConstantMergePass());
// Remove unused arguments from functions.
MPM.addPass(DeadArgumentEliminationPass());
// Reduce the code after globalopt and ipsccp. Both can open up
// significant simplification opportunities, and both can propagate
// functions through function pointers. When this happens, we often
// have to resolve varargs calls, etc, so let instcombine do this.
FunctionPassManager PeepholeFPM;
PeepholeFPM.addPass(InstCombinePass());
if (Level.getSpeedupLevel() > 1)
PeepholeFPM.addPass(AggressiveInstCombinePass());
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(PeepholeFPM),
EagerlyInvalidateAnalyses));
// Note: historically, the PruneEH pass was run first to deduce
// nounwind and generally clean up exception handling overhead. It
// isn't clear this is valuable as the inliner doesn't currently care
// whether it is inlining an invoke or a call. Run the inliner now.
if (EnableModuleInliner) {
MPM.addPass(ModuleInlinerPass(getInlineParamsFromOptLevel(Level),
UseInlineAdvisor,
ThinOrFullLTOPhase::FullLTOPostLink));
} else {
MPM.addPass(ModuleInlinerWrapperPass(
getInlineParamsFromOptLevel(Level),
/* MandatoryFirst */ true,
InlineContext{ThinOrFullLTOPhase::FullLTOPostLink,
InlinePass::CGSCCInliner}));
}
// Perform context disambiguation after inlining, since that would
// reduce the amount of additional cloning required to distinguish the
// allocation contexts. if (EnableMemProfContextDisambiguation)
// MPM.addPass(MemProfContextDisambiguation());
// Optimize globals again after we ran the inliner.
MPM.addPass(GlobalOptPass());
// Run the OpenMPOpt pass again after global optimizations.
#if LLVM_VERSION_MAJOR >= 16
MPM.addPass(OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink));
#else
MPM.addPass(OpenMPOptPass());
#endif
// Garbage collect dead functions.
MPM.addPass(GlobalDCEPass());
// If we didn't decide to inline a function, check to see if we can
// transform it to pass arguments by value instead of by reference.
MPM.addPass(
createModuleToPostOrderCGSCCPassAdaptor(ArgumentPromotionPass()));
FunctionPassManager FPM;
// The IPO Passes may leave cruft around. Clean up after them.
FPM.addPass(InstCombinePass());
if (EnableConstraintElimination)
FPM.addPass(ConstraintEliminationPass());
FPM.addPass(JumpThreadingPass());
// Do a post inline PGO instrumentation and use pass. This is a context
// sensitive PGO pass.
#if 0
if (PGOOpt) {
if (PGOOpt->CSAction == PGOOptions::CSIRInstr)
addPGOInstrPasses(MPM, Level, /* RunProfileGen */ true,
/* IsCS */ true, PGOOpt->CSProfileGenFile,
PGOOpt->ProfileRemappingFile,
ThinOrFullLTOPhase::FullLTOPostLink, PGOOpt->FS);
else if (PGOOpt->CSAction == PGOOptions::CSIRUse)
addPGOInstrPasses(MPM, Level, /* RunProfileGen */ false,
/* IsCS */ true, PGOOpt->ProfileFile,
PGOOpt->ProfileRemappingFile,
ThinOrFullLTOPhase::FullLTOPostLink, PGOOpt->FS);
}
#endif
// Break up allocas
#if LLVM_VERSION_MAJOR >= 16
FPM.addPass(SROAPass(SROAOptions::ModifyCFG));
#else
FPM.addPass(SROAPass());
#endif
// LTO provides additional opportunities for tailcall elimination due
// to link-time inlining, and visibility of nocapture attribute.
FPM.addPass(TailCallElimPass());
// Run a few AA driver optimizations here and now to cleanup the code.
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM),
EagerlyInvalidateAnalyses));
MPM.addPass(
createModuleToPostOrderCGSCCPassAdaptor(PostOrderFunctionAttrsPass()));
// Require the GlobalsAA analysis for the module so we can query it
// within MainFPM.
MPM.addPass(RequireAnalysisPass<GlobalsAA, Module>());
};
auto loadLTO = [preLTOPass, loadPass](ModulePassManager &MPM,
OptimizationLevel Level) {
preLTOPass(MPM, Level);
MPM.addPass(
createModuleToPostOrderCGSCCPassAdaptor(PostOrderFunctionAttrsPass()));
// Require the GlobalsAA analysis for the module so we can query it
// within MainFPM.
MPM.addPass(RequireAnalysisPass<GlobalsAA, Module>());
// Invalidate AAManager so it can be recreated and pick up the newly
// available GlobalsAA.
MPM.addPass(
createModuleToFunctionPassAdaptor(InvalidateAnalysisPass<AAManager>()));
FunctionPassManager MainFPM;
MainFPM.addPass(createFunctionToLoopPassAdaptor(
LICMPass(SetLicmMssaOptCap, SetLicmMssaNoAccForPromotionCap,
/*AllowSpeculation=*/true),
/*USeMemorySSA=*/true, /*UseBlockFrequencyInfo=*/false));
if (RunNewGVN)
MainFPM.addPass(NewGVNPass());
else
MainFPM.addPass(GVNPass());
// Remove dead memcpy()'s.
MainFPM.addPass(MemCpyOptPass());
// Nuke dead stores.
MainFPM.addPass(DSEPass());
#if LLVM_VERSION_MAJOR >= 17
MainFPM.addPass(MoveAutoInitPass());
#endif
MainFPM.addPass(MergedLoadStoreMotionPass());
LoopPassManager LPM;
if (EnableLoopFlatten && Level.getSpeedupLevel() > 1)
LPM.addPass(LoopFlattenPass());
LPM.addPass(IndVarSimplifyPass());
LPM.addPass(LoopDeletionPass());
// FIXME: Add loop interchange.
loadPass(MPM, Level);
};
PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadLTO);
#endif
}
extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
llvmGetPassPluginInfo() {
return {LLVM_PLUGIN_API_VERSION, "EnzymeNewPM", "v0.1",
[](llvm::PassBuilder &PB) {
#ifdef ENZYME_RUNPASS
augmentPassBuilder(PB);
#endif
PB.registerPipelineParsingCallback(
[](llvm::StringRef Name, llvm::ModulePassManager &MPM,
llvm::ArrayRef<llvm::PassBuilder::PipelineElement>) {
if (Name == "enzyme") {
MPM.addPass(EnzymeNewPM());
return true;
}
if (Name == "preserve-nvvm") {
MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true));
return true;
}
if (Name == "print-type-analysis") {
MPM.addPass(TypeAnalysisPrinterNewPM());
return true;
}
return false;
});
PB.registerPipelineParsingCallback(
[](llvm::StringRef Name, llvm::FunctionPassManager &FPM,
llvm::ArrayRef<llvm::PassBuilder::PipelineElement>) {
if (Name == "print-activity-analysis") {
FPM.addPass(ActivityAnalysisPrinterNewPM());
return true;
}
return false;
});
}};
}