blob: f1f388ccf5644c5ec0dadace5e5702a4a003d89a [file] [log] [blame] [edit]
//===- DiffeGradientUtils.h - Helper class and utilities for AD ---------===//
//
// Enzyme Project
//
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// If using this code in an academic setting, please cite the following:
// @incollection{enzymeNeurips,
// title = {Instead of Rewriting Foreign Code for Machine Learning,
// Automatically Synthesize Fast Gradients},
// author = {Moses, William S. and Churavy, Valentin},
// booktitle = {Advances in Neural Information Processing Systems 33},
// year = {2020},
// note = {To appear in},
// }
//
//===----------------------------------------------------------------------===//
//
// This file declares two helper classes GradientUtils and subclass
// DiffeGradientUtils. These classes contain utilities for managing the cache,
// recomputing statements, and in the case of DiffeGradientUtils, managing
// adjoint values and shadow pointers.
//
//===----------------------------------------------------------------------===//
#ifndef ENZYME_DIFFEGRADIENTUTILS_H_
#define ENZYME_DIFFEGRADIENTUTILS_H_
#include "GradientUtils.h"
#include <llvm/Config/llvm-config.h>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/ValueTracking.h"
#include "ActivityAnalysis.h"
#include "EnzymeLogic.h"
#include "Utils.h"
#include "llvm-c/Core.h"
#if LLVM_VERSION_MAJOR <= 16
#include "llvm/ADT/Triple.h"
#endif
class DiffeGradientUtils final : public GradientUtils {
DiffeGradientUtils(
EnzymeLogic &Logic, llvm::Function *newFunc_, llvm::Function *oldFunc_,
llvm::TargetLibraryInfo &TLI, TypeAnalysis &TA, TypeResults TR,
llvm::ValueToValueMapTy &invertedPointers_,
const llvm::SmallPtrSetImpl<llvm::Value *> &constantvalues_,
const llvm::SmallPtrSetImpl<llvm::Value *> &returnvals_,
DIFFE_TYPE ActiveReturn, bool shadowReturnUsed,
llvm::ArrayRef<DIFFE_TYPE> constant_values,
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> &origToNew_,
DerivativeMode mode, bool runtimeActivity, bool strongZero,
unsigned width, bool omp);
public:
/// Whether to free memory in reverse pass or split forward.
bool FreeMemory;
llvm::ValueMap<const llvm::Value *, llvm::TrackingVH<llvm::AllocaInst>>
differentials;
static DiffeGradientUtils *
CreateFromClone(EnzymeLogic &Logic, DerivativeMode mode, bool runtimeActivity,
bool strongZero, unsigned width, llvm::Function *todiff,
llvm::TargetLibraryInfo &TLI, TypeAnalysis &TA,
FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType,
bool shadowReturnArg, bool diffeReturnArg,
llvm::ArrayRef<DIFFE_TYPE> constant_args, bool returnTape,
bool returnPrimal, llvm::Type *additionalArg, bool omp);
llvm::AllocaInst *getDifferential(llvm::Value *val);
public:
llvm::Value *diffe(llvm::Value *val, llvm::IRBuilder<> &BuilderM);
/// Returns created select instructions, if any
llvm::SmallVector<llvm::SelectInst *, 4>
addToDiffe(llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &BuilderM,
llvm::Type *addingType, llvm::ArrayRef<llvm::Value *> idxs = {},
llvm::Value *mask = nullptr);
/// Returns created select instructions, if any
llvm::SmallVector<llvm::SelectInst *, 4>
addToDiffe(llvm::Value *val, llvm::Value *dif, llvm::IRBuilder<> &BuilderM,
llvm::Type *addingType, unsigned start, unsigned size,
llvm::ArrayRef<llvm::Value *> idxs = {},
llvm::Value *mask = nullptr, size_t ignoreFirstSlicesToDiff = 0);
void setDiffe(llvm::Value *val, llvm::Value *toset,
llvm::IRBuilder<> &BuilderM);
llvm::CallInst *
freeCache(llvm::BasicBlock *forwardPreheader, const SubLimitType &sublimits,
int i, llvm::AllocaInst *alloc, llvm::ConstantInt *byteSizeOfType,
llvm::Value *storeInto, llvm::MDNode *InvariantMD) override;
/// align is the alignment that should be specified for load/store to pointer
void addToInvertedPtrDiffe(llvm::Instruction *orig, llvm::Value *origVal,
llvm::Type *addingType, unsigned start,
unsigned size, llvm::Value *origptr,
llvm::Value *dif, llvm::IRBuilder<> &BuilderM,
llvm::MaybeAlign align = llvm::MaybeAlign(),
llvm::Value *mask = nullptr);
void addToInvertedPtrDiffe(llvm::Instruction *orig, llvm::Value *origVal,
TypeTree vd, unsigned size, llvm::Value *origptr,
llvm::Value *prediff, llvm::IRBuilder<> &Builder2,
llvm::MaybeAlign align = llvm::MaybeAlign(),
llvm::Value *premask = nullptr);
};
#endif