blob: 7ac416cdb941f9f573e9d1ce52de22e9de29fb47 [file] [log] [blame] [edit]
//===- FunctionUtils.h - Declaration of function utilities ---------------===//
//
// Enzyme Project
//
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// If using this code in an academic setting, please cite the following:
// @incollection{enzymeNeurips,
// title = {Instead of Rewriting Foreign Code for Machine Learning,
// Automatically Synthesize Fast Gradients},
// author = {Moses, William S. and Churavy, Valentin},
// booktitle = {Advances in Neural Information Processing Systems 33},
// year = {2020},
// note = {To appear in},
// }
//
//===----------------------------------------------------------------------===//
//
// This file declares utilities on LLVM Functions that are used as part of the
// AD process.
//
//===----------------------------------------------------------------------===//
#ifndef ENZYME_FUNCTION_UTILS_H
#define ENZYME_FUNCTION_UTILS_H
#include <deque>
#include <set>
#include "SCEV/ScalarEvolution.h"
#include "SCEV/ScalarEvolutionExpander.h"
#include "Utils.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
//;
class PreProcessCache {
public:
PreProcessCache();
PreProcessCache(PreProcessCache &) = delete;
// Using the default move constructor will botch the FAM/MAM proxy passes
// since now the new location of FAM/MAM will not be used. Therefore, use a
// custom move constructor and default initialize these, and move the
// cache/origin maps.
PreProcessCache(PreProcessCache &&prev) : PreProcessCache() {
cache = std::move(prev.cache);
CloneOrigin = std::move(prev.CloneOrigin);
};
llvm::FunctionAnalysisManager FAM;
llvm::ModuleAnalysisManager MAM;
std::map<std::pair<llvm::Function *, DerivativeMode>, llvm::Function *> cache;
std::map<llvm::Function *, llvm::Function *> CloneOrigin;
llvm::Function *preprocessForClone(llvm::Function *F, DerivativeMode mode);
llvm::AAResults &getAAResultsFromFunction(llvm::Function *NewF);
llvm::Function *CloneFunctionWithReturns(
DerivativeMode mode, VectorModeMemoryLayout memoryLayout, unsigned width,
llvm::Function *&F, llvm::ValueToValueMapTy &ptrInputs,
llvm::ArrayRef<DIFFE_TYPE> constant_args,
llvm::SmallPtrSetImpl<llvm::Value *> &constants,
llvm::SmallPtrSetImpl<llvm::Value *> &nonconstant,
llvm::SmallPtrSetImpl<llvm::Value *> &returnvals, ReturnType returnValue,
DIFFE_TYPE returnType, llvm::Twine name, llvm::ValueToValueMapTy *VMapO,
bool diffeReturnArg, llvm::Type *additionalArg = nullptr);
void ReplaceReallocs(llvm::Function *NewF, bool mem2reg = false);
void LowerAllocAddr(llvm::Function *NewF);
void AlwaysInline(llvm::Function *NewF);
void optimizeIntermediate(llvm::Function *F);
void clear();
};
class GradientUtils;
static inline void
getExitBlocks(const llvm::Loop *L,
llvm::SmallPtrSetImpl<llvm::BasicBlock *> &ExitBlocks) {
llvm::SmallVector<llvm::BasicBlock *, 8> PotentialExitBlocks;
L->getExitBlocks(PotentialExitBlocks);
for (auto a : PotentialExitBlocks) {
llvm::SmallVector<llvm::BasicBlock *, 4> tocheck;
llvm::SmallPtrSet<llvm::BasicBlock *, 4> checked;
tocheck.push_back(a);
bool isExit = false;
while (tocheck.size()) {
auto foo = tocheck.back();
tocheck.pop_back();
if (checked.count(foo)) {
isExit = true;
goto exitblockcheck;
}
checked.insert(foo);
if (auto bi = llvm::dyn_cast<llvm::BranchInst>(foo->getTerminator())) {
for (auto nb : bi->successors()) {
if (L->contains(nb))
continue;
tocheck.push_back(nb);
}
} else if (llvm::isa<llvm::UnreachableInst>(foo->getTerminator())) {
continue;
} else {
isExit = true;
goto exitblockcheck;
}
}
exitblockcheck:
if (isExit) {
ExitBlocks.insert(a);
}
}
}
static inline llvm::SmallVector<llvm::BasicBlock *, 3>
getLatches(const llvm::Loop *L,
const llvm::SmallPtrSetImpl<llvm::BasicBlock *> &ExitBlocks) {
llvm::BasicBlock *Preheader = L->getLoopPreheader();
if (!Preheader) {
llvm::errs() << *L->getHeader()->getParent() << "\n";
llvm::errs() << *L->getHeader() << "\n";
llvm::errs() << *L << "\n";
}
assert(Preheader && "requires preheader");
// Find latch, defined as a (perhaps unique) block in loop that branches to
// exit block
llvm::SmallVector<llvm::BasicBlock *, 3> Latches;
for (llvm::BasicBlock *ExitBlock : ExitBlocks) {
for (llvm::BasicBlock *pred : llvm::predecessors(ExitBlock)) {
if (L->contains(pred)) {
if (std::find(Latches.begin(), Latches.end(), pred) != Latches.end())
continue;
Latches.push_back(pred);
}
}
}
return Latches;
}
// TODO note this doesn't go through [loop, unreachable], and we could get more
// performance by doing this can consider doing some domtree magic potentially
static inline llvm::SmallPtrSet<llvm::BasicBlock *, 4>
getGuaranteedUnreachable(llvm::Function *F) {
llvm::SmallPtrSet<llvm::BasicBlock *, 4> knownUnreachables;
std::deque<llvm::BasicBlock *> todo;
for (auto &BB : *F) {
todo.push_back(&BB);
}
while (!todo.empty()) {
llvm::BasicBlock *next = todo.front();
todo.pop_front();
if (knownUnreachables.find(next) != knownUnreachables.end())
continue;
if (llvm::isa<llvm::ReturnInst>(next->getTerminator()))
continue;
if (llvm::isa<llvm::UnreachableInst>(next->getTerminator())) {
knownUnreachables.insert(next);
for (llvm::BasicBlock *Pred : predecessors(next)) {
todo.push_back(Pred);
}
continue;
}
// Assume resumes don't happen
// TODO consider EH
if (llvm::isa<llvm::ResumeInst>(next->getTerminator())) {
knownUnreachables.insert(next);
for (llvm::BasicBlock *Pred : predecessors(next)) {
todo.push_back(Pred);
}
continue;
}
bool unreachable = true;
for (llvm::BasicBlock *Succ : llvm::successors(next)) {
if (knownUnreachables.find(Succ) == knownUnreachables.end()) {
unreachable = false;
break;
}
}
if (!unreachable)
continue;
knownUnreachables.insert(next);
for (llvm::BasicBlock *Pred : llvm::predecessors(next)) {
todo.push_back(Pred);
}
continue;
}
return knownUnreachables;
}
enum class UseReq {
Need,
Recur,
Cached,
};
static inline void calculateUnusedValues(
const llvm::Function &oldFunc,
llvm::SmallPtrSetImpl<const llvm::Value *> &unnecessaryValues,
llvm::SmallPtrSetImpl<const llvm::Instruction *> &unnecessaryInstructions,
bool returnValue, std::function<bool(const llvm::Value *)> valneeded,
std::function<UseReq(const llvm::Instruction *)> instneeded) {
std::deque<const llvm::Instruction *> todo;
for (const llvm::BasicBlock &BB : oldFunc) {
if (auto ri = llvm::dyn_cast<llvm::ReturnInst>(BB.getTerminator())) {
if (!returnValue) {
unnecessaryInstructions.insert(ri);
}
}
for (auto &inst : BB) {
if (&inst == BB.getTerminator())
continue;
todo.push_back(&inst);
}
}
while (!todo.empty()) {
auto inst = todo.front();
todo.pop_front();
if (unnecessaryInstructions.count(inst)) {
assert(unnecessaryValues.count(inst));
continue;
}
if (unnecessaryValues.count(inst))
continue;
if (valneeded(inst))
continue;
bool necessaryUse = false;
llvm::SmallPtrSet<const llvm::Instruction *, 4> seen;
std::deque<const llvm::Instruction *> users;
for (auto user_dtx : inst->users()) {
if (auto cst = llvm::dyn_cast<llvm::Instruction>(user_dtx)) {
users.push_back(cst);
}
}
while (users.size()) {
auto val = users.front();
users.pop_front();
if (seen.count(val))
continue;
seen.insert(val);
if (unnecessaryInstructions.count(val))
continue;
switch (instneeded(val)) {
case UseReq::Need:
necessaryUse = true;
break;
case UseReq::Recur:
for (auto user_dtx : val->users()) {
if (auto cst = llvm::dyn_cast<llvm::Instruction>(user_dtx)) {
users.push_back(cst);
}
}
break;
case UseReq::Cached:
break;
}
if (necessaryUse)
break;
}
if (necessaryUse)
continue;
unnecessaryValues.insert(inst);
if (instneeded(inst) == UseReq::Need)
continue;
unnecessaryInstructions.insert(inst);
for (auto &operand : inst->operands()) {
if (auto usedinst = llvm::dyn_cast<llvm::Instruction>(operand.get())) {
todo.push_back(usedinst);
}
}
}
if (false && oldFunc.getName().endswith("subfn")) {
llvm::errs() << "Prepping values for: " << oldFunc.getName()
<< " returnValue: " << returnValue << "\n";
for (auto v : unnecessaryInstructions) {
llvm::errs() << "+ unnecessaryInstructions: " << *v << "\n";
}
for (auto v : unnecessaryValues) {
llvm::errs() << "+ unnecessaryValues: " << *v << "\n";
}
llvm::errs() << "</end>\n";
}
}
static inline void calculateUnusedStores(
const llvm::Function &oldFunc,
llvm::SmallPtrSetImpl<const llvm::Instruction *> &unnecessaryStores,
std::function<bool(const llvm::Instruction *)> needStore) {
std::deque<const llvm::Instruction *> todo;
for (const llvm::BasicBlock &BB : oldFunc) {
for (auto &inst : BB) {
if (&inst == BB.getTerminator())
continue;
todo.push_back(&inst);
}
}
while (!todo.empty()) {
auto inst = todo.front();
todo.pop_front();
if (unnecessaryStores.count(inst)) {
continue;
}
if (needStore(inst))
continue;
unnecessaryStores.insert(inst);
}
}
void RecursivelyReplaceAddressSpace(llvm::Value *AI, llvm::Value *rep,
bool legal);
void ReplaceFunctionImplementation(llvm::Module &M);
/// Is the use of value val as an argument of call CI potentially captured
bool couldFunctionArgumentCapture(llvm::CallInst *CI, llvm::Value *val);
llvm::FunctionType *getFunctionTypeForClone(
llvm::Module &M, llvm::FunctionType *FTy, DerivativeMode mode,
VectorModeMemoryLayout memoryLayout, unsigned width,
llvm::Type *additionalArg, llvm::ArrayRef<DIFFE_TYPE> constant_args,
bool diffeReturnArg, ReturnType returnValue, DIFFE_TYPE returnType);
#endif