MLIR: Add verifyPostPasses option (#2284)
* MLIR: Add verifyPostPasses option
* fmt
* Fmt
diff --git a/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp
index dab82d1..a9a824c 100644
--- a/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp
+++ b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp
@@ -73,7 +73,8 @@
fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, mode,
freeMemory, width,
/* addedType */ nullptr, type_args, volatile_args,
- /* augmented */ nullptr, gutils->omp, gutils->postpasses);
+ /* augmented */ nullptr, gutils->omp, gutils->postpasses,
+ gutils->verifyPostPasses);
SmallVector<Value> fwdArguments;
@@ -173,7 +174,8 @@
auto revFn = gutils->Logic.CreateReverseDiff(
fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, returnShadow,
mode, freeMemory, width, /*addedType*/ nullptr, type_args,
- volatile_args, /*augmented*/ nullptr, gutils->omp, gutils->postpasses);
+ volatile_args, /*augmented*/ nullptr, gutils->omp, gutils->postpasses,
+ gutils->verifyPostPasses);
SmallVector<Value> revArguments;
diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp
index 60e6ca6..ac6b04c 100644
--- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp
+++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp
@@ -81,7 +81,7 @@
std::vector<bool> returnPrimals, DerivativeMode mode, bool freeMemory,
size_t width, mlir::Type addedType, MFnTypeInfo type_args,
std::vector<bool> volatile_args, void *augmented, bool omp,
- llvm::StringRef postpasses) {
+ llvm::StringRef postpasses, bool verifyPostPasses) {
if (fn.getFunctionBody().empty()) {
llvm::errs() << fn << "\n";
llvm_unreachable("Differentiating empty function");
@@ -109,7 +109,7 @@
auto gutils = MDiffeGradientUtils::CreateFromClone(
*this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP,
RetActivity, ArgActivity, addedType,
- /*omp*/ false, postpasses);
+ /*omp*/ false, postpasses, verifyPostPasses);
ForwardCachedFunctions[tup] = gutils->newFunc;
insert_or_assign2<MForwardCacheKey, FunctionOpInterface>(
@@ -201,6 +201,7 @@
if (postpasses != "") {
mlir::PassManager pm(nf->getContext());
+ pm.enableVerifier(verifyPostPasses);
std::string error_message;
// llvm::raw_string_ostream error_stream(error_message);
mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm);
diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h
index d7480ca..814a598 100644
--- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h
+++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h
@@ -209,16 +209,16 @@
std::vector<bool> returnPrimals, DerivativeMode mode,
bool freeMemory, size_t width, mlir::Type addedType,
MFnTypeInfo type_args, std::vector<bool> volatile_args,
- void *augmented, bool omp, llvm::StringRef postpasses);
+ void *augmented, bool omp, llvm::StringRef postpasses,
+ bool verifyPostPasses);
- FunctionOpInterface
- CreateReverseDiff(FunctionOpInterface fn, std::vector<DIFFE_TYPE> retType,
- std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA,
- std::vector<bool> returnPrimals,
- std::vector<bool> returnShadows, DerivativeMode mode,
- bool freeMemory, size_t width, mlir::Type addedType,
- MFnTypeInfo type_args, std::vector<bool> volatile_args,
- void *augmented, bool omp, llvm::StringRef postpasses);
+ FunctionOpInterface CreateReverseDiff(
+ FunctionOpInterface fn, std::vector<DIFFE_TYPE> retType,
+ std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA,
+ std::vector<bool> returnPrimals, std::vector<bool> returnShadows,
+ DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType,
+ MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented,
+ bool omp, llvm::StringRef postpasses, bool verifyPostPasses);
void
initializeShadowValues(SmallVector<mlir::Block *> &dominatorToposortBlocks,
diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
index 18e6390..8574136 100644
--- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
+++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
@@ -186,7 +186,7 @@
std::vector<bool> returnPrimals, std::vector<bool> returnShadows,
DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType,
MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented,
- bool omp, llvm::StringRef postpasses) {
+ bool omp, llvm::StringRef postpasses, bool verifyPostPasses) {
if (fn.getFunctionBody().empty()) {
llvm::errs() << fn << "\n";
@@ -217,7 +217,7 @@
MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone(
*this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP,
- retType, constants, addedType, omp, postpasses);
+ retType, constants, addedType, omp, postpasses, verifyPostPasses);
ReverseCachedFunctions[tup] = gutils->newFunc;
@@ -259,6 +259,7 @@
if (postpasses != "") {
mlir::PassManager pm(nf->getContext());
+ pm.enableVerifier(verifyPostPasses);
std::string error_message;
// llvm::raw_string_ostream error_stream(error_message);
mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm);
diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
index 72a0b4c..a8efa96 100644
--- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
+++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
@@ -37,15 +37,17 @@
ArrayRef<DIFFE_TYPE> ReturnActivity, ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
- DerivativeMode mode, unsigned width, bool omp, llvm::StringRef postpasses)
+ DerivativeMode mode, unsigned width, bool omp, llvm::StringRef postpasses,
+ bool verifyPostPasses)
: newFunc(newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_),
invertedPointers(invertedPointers_), originalToNewFn(originalToNewFn_),
originalToNewFnOps(originalToNewFnOps_), blocksNotForAnalysis(),
activityAnalyzer(std::make_unique<enzyme::ActivityAnalyzer>(
blocksNotForAnalysis, constantvalues_, activevals_, ReturnActivity)),
- TA(TA_), TR(TR_), omp(omp), postpasses(postpasses),
- returnPrimals(returnPrimals), returnShadows(returnShadows), width(width),
- ArgDiffeTypes(ArgDiffeTypes_), RetDiffeTypes(ReturnActivity) {}
+ TA(TA_), TR(TR_), omp(omp), verifyPostPasses(verifyPostPasses),
+ postpasses(postpasses), returnPrimals(returnPrimals),
+ returnShadows(returnShadows), width(width), ArgDiffeTypes(ArgDiffeTypes_),
+ RetDiffeTypes(ReturnActivity) {}
mlir::Value mlir::enzyme::MGradientUtils::getNewFromOriginal(
const mlir::Value originst) const {
diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h
index 3910ba7..853a22a 100644
--- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h
+++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h
@@ -36,6 +36,7 @@
MTypeAnalysis &TA;
MTypeResults TR;
bool omp;
+ bool verifyPostPasses;
llvm::StringRef postpasses;
const llvm::ArrayRef<bool> returnPrimals;
const llvm::ArrayRef<bool> returnShadows;
@@ -61,7 +62,7 @@
IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
DerivativeMode mode, unsigned width, bool omp,
- llvm::StringRef postpasses);
+ llvm::StringRef postpasses, bool verifyPostPasses);
void erase(Operation *op) { op->erase(); }
void replaceOrigOpWith(Operation *op, ValueRange vals) {
for (auto &&[res, rep] : llvm::zip(op->getResults(), vals)) {
@@ -127,21 +128,24 @@
ArrayRef<DIFFE_TYPE> ArgActivity, IRMapping &origToNew_,
std::map<Operation *, Operation *> &origToNewOps_,
DerivativeMode mode, unsigned width, bool omp,
- llvm::StringRef postpasses)
+ llvm::StringRef postpasses, bool verifyPostPasses)
: MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_,
returnPrimals, returnShadows, constantvalues_,
activevals_, RetActivity, ArgActivity, origToNew_,
- origToNewOps_, mode, width, omp, postpasses),
+ origToNewOps_, mode, width, omp, postpasses,
+ verifyPostPasses),
initializationBlock(&*(newFunc.getFunctionBody().begin())) {}
// Technically diffe constructor
- static MDiffeGradientUtils *CreateFromClone(
- MEnzymeLogic &Logic, DerivativeMode mode, unsigned width,
- FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
- const llvm::ArrayRef<bool> returnPrimals,
- const llvm::ArrayRef<bool> returnShadows,
- ArrayRef<DIFFE_TYPE> RetActivity, ArrayRef<DIFFE_TYPE> ArgActivity,
- mlir::Type additionalArg, bool omp, llvm::StringRef postpasses) {
+ static MDiffeGradientUtils *
+ CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode, unsigned width,
+ FunctionOpInterface todiff, MTypeAnalysis &TA,
+ MFnTypeInfo &oldTypeInfo,
+ const llvm::ArrayRef<bool> returnPrimals,
+ const llvm::ArrayRef<bool> returnShadows,
+ ArrayRef<DIFFE_TYPE> RetActivity,
+ ArrayRef<DIFFE_TYPE> ArgActivity, mlir::Type additionalArg,
+ bool omp, llvm::StringRef postpasses, bool verifyPostPasses) {
std::string prefix;
switch (mode) {
@@ -178,7 +182,7 @@
Logic, newFunc, todiff, TA, TR, invertedPointers, returnPrimals,
returnShadows, constant_values, nonconstant_values, RetActivity,
ArgActivity, originalToNew, originalToNewOps, mode, width, omp,
- postpasses);
+ postpasses, verifyPostPasses);
}
};
diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp
index 6dc7aca..e770198 100644
--- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp
+++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp
@@ -37,12 +37,13 @@
ArrayRef<DIFFE_TYPE> ReturnActivity, ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
- DerivativeMode mode_, unsigned width, bool omp, StringRef postpasses)
+ DerivativeMode mode_, unsigned width, bool omp, StringRef postpasses,
+ bool verifyPostPasses)
: MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {},
invertedPointers_, returnPrimals, returnShadows,
constantvalues_, activevals_, ReturnActivity,
ArgDiffeTypes_, originalToNewFn_, originalToNewFnOps_,
- mode_, width, omp, postpasses) {}
+ mode_, width, omp, postpasses, verifyPostPasses) {}
Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() {
Type indexType = getIndexType();
@@ -138,7 +139,8 @@
FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
const ArrayRef<bool> returnPrimals, const ArrayRef<bool> returnShadows,
ArrayRef<DIFFE_TYPE> retType, ArrayRef<DIFFE_TYPE> constant_args,
- mlir::Type additionalArg, bool omp, llvm::StringRef postpasses) {
+ mlir::Type additionalArg, bool omp, llvm::StringRef postpasses,
+ bool verifyPostPasses) {
std::string prefix;
switch (mode_) {
@@ -175,5 +177,5 @@
Logic, newFunc, todiff, TA, invertedPointers, returnPrimals,
returnShadows, constant_values, nonconstant_values, retType,
constant_args, originalToNew, originalToNewOps, mode_, width, omp,
- postpasses);
+ postpasses, verifyPostPasses);
}
diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h
index a201301..88ad615 100644
--- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h
+++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h
@@ -37,7 +37,7 @@
IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
DerivativeMode mode_, unsigned width, bool omp,
- llvm::StringRef postpasses);
+ llvm::StringRef postpasses, bool verifyPostPasses);
IRMapping mapReverseModeBlocks;
@@ -71,7 +71,7 @@
const ArrayRef<bool> returnPrimals, const ArrayRef<bool> returnShadows,
llvm::ArrayRef<DIFFE_TYPE> retType,
llvm::ArrayRef<DIFFE_TYPE> constant_args, mlir::Type additionalArg,
- bool omp, llvm::StringRef postpasses);
+ bool omp, llvm::StringRef postpasses, bool verifyPostPasses);
};
} // namespace enzyme
diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
index 5a109c5..6ad1c76 100644
--- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
+++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
@@ -164,7 +164,7 @@
FunctionOpInterface newFunc = Logic.CreateForwardDiff(
fn, retType, constants, TA, returnPrimals, mode, freeMemory, width,
/*addedType*/ nullptr, type_args, volatile_args,
- /*augmented*/ nullptr, omp, postpasses);
+ /*augmented*/ nullptr, omp, postpasses, verifyPostPasses);
if (!newFunc)
return failure();
@@ -286,11 +286,11 @@
volatile_args.push_back(!(mode == DerivativeMode::ReverseModeCombined));
}
- FunctionOpInterface newFunc =
- Logic.CreateReverseDiff(fn, retType, arg_activities, TA, returnPrimals,
- returnShadows, mode, freeMemory, width,
- /*addedType*/ nullptr, type_args, volatile_args,
- /*augmented*/ nullptr, omp, postpasses);
+ FunctionOpInterface newFunc = Logic.CreateReverseDiff(
+ fn, retType, arg_activities, TA, returnPrimals, returnShadows, mode,
+ freeMemory, width,
+ /*addedType*/ nullptr, type_args, volatile_args,
+ /*augmented*/ nullptr, omp, postpasses, verifyPostPasses);
if (!newFunc)
return failure();
diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp
index 1d5c9c4..2f93651 100644
--- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp
+++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp
@@ -76,6 +76,7 @@
auto fn = cast<FunctionOpInterface>(symbolOp);
bool omp = false;
std::string postpasses = "";
+ bool verifyPostPasses = true;
std::vector<DIFFE_TYPE> ArgActivity =
parseActivityString(argTys.getValue());
@@ -119,17 +120,17 @@
FunctionOpInterface newFunc;
if (mode == DerivativeMode::ForwardMode) {
- newFunc = Logic.CreateForwardDiff(fn, RetActivity, ArgActivity, TA,
- returnPrimal, mode, freeMemory, width,
- /*addedType*/ nullptr, type_args,
- volatile_args,
- /*augmented*/ nullptr, omp, postpasses);
+ newFunc = Logic.CreateForwardDiff(
+ fn, RetActivity, ArgActivity, TA, returnPrimal, mode, freeMemory,
+ width,
+ /*addedType*/ nullptr, type_args, volatile_args,
+ /*augmented*/ nullptr, omp, postpasses, verifyPostPasses);
} else {
newFunc = Logic.CreateReverseDiff(
fn, RetActivity, ArgActivity, TA, returnPrimal, returnShadow, mode,
freeMemory, width,
/*addedType*/ nullptr, type_args, volatile_args,
- /*augmented*/ nullptr, omp, postpasses);
+ /*augmented*/ nullptr, omp, postpasses, verifyPostPasses);
}
if (!newFunc) {
signalPassFailure();
diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td
index ebe0013..8ef0dae 100644
--- a/enzyme/Enzyme/MLIR/Passes/Passes.td
+++ b/enzyme/Enzyme/MLIR/Passes/Passes.td
@@ -28,6 +28,13 @@
/*default=*/"",
/*description=*/"Optimization passes to apply to generated derivative functions"
>,
+ Option<
+ /*C++ variable name=*/"verifyPostPasses",
+ /*CLI argument=*/"verifyPostPasses",
+ /*type=*/"bool",
+ /*default=*/"true",
+ /*description=*/"Whether or not to run verifier for the postpasses pass manager."
+ >,
];
let constructor = "mlir::enzyme::createDifferentiatePass()";
}