Only throw fwd mode runtime activity blas error if something is runti⦠(#2318)
* Only throw fwd mode runtime activity blas error if something is runtime inactive
* fmt
diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp
index be6e34a..46cc44c 100644
--- a/enzyme/Enzyme/Utils.cpp
+++ b/enzyme/Enzyme/Utils.cpp
@@ -3659,11 +3659,12 @@
llvm::Value *EmitNoDerivativeError(const std::string &message,
llvm::Instruction &inst,
GradientUtils *gutils,
- llvm::IRBuilder<> &Builder2) {
+ llvm::IRBuilder<> &Builder2,
+ llvm::Value *condition) {
if (CustomErrorHandler) {
return unwrap(CustomErrorHandler(message.c_str(), wrap(&inst),
- ErrorType::NoDerivative, gutils, nullptr,
- wrap(&Builder2)));
+ ErrorType::NoDerivative, gutils,
+ wrap(condition), wrap(&Builder2)));
} else if (EnzymeRuntimeError) {
auto &M = *inst.getParent()->getParent()->getParent();
FunctionType *FT = FunctionType::get(Type::getInt32Ty(M.getContext()),
diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h
index 998b665..8538f55 100644
--- a/enzyme/Enzyme/Utils.h
+++ b/enzyme/Enzyme/Utils.h
@@ -209,7 +209,8 @@
struct RequestContext;
llvm::Value *EmitNoDerivativeError(const std::string &message,
llvm::Instruction &inst,
- GradientUtils *gutils, llvm::IRBuilder<> &B);
+ GradientUtils *gutils, llvm::IRBuilder<> &B,
+ llvm::Value *condition = nullptr);
bool EmitNoDerivativeError(const std::string &message, llvm::Value *todiff,
RequestContext &ctx);
diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
index ad2cc72..ef145b3 100644
--- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
+++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
@@ -370,7 +370,8 @@
<< " // returns true, or if runtimeActivity is on and the\n"
<< " // shadow points to the primal arg.\n";
- os << " if(gutils->runtimeActivity && cacheMode) {\n";
+ os << " if(gutils->runtimeActivity) {\n";
+ os << " Value *anyRuntimeActivity = nullptr;\n";
for (size_t i = 0; i < actArgs.size(); i++) {
auto name = nameVec[actArgs[i]];
@@ -385,6 +386,11 @@
<< " rt_inactive_" << name << " = BuilderZ.CreateICmpEQ(shadow_"
<< name << ", arg_" << name << ", \"rt.tmp.inactive.\" \"" << name
<< "\");\n"
+ << " if (Mode == DerivativeMode::ForwardMode || Mode == "
+ "DerivativeMode::ForwardModeSplit) anyRuntimeActivity = "
+ "anyRuntimeActivity ? BuilderZ.CreateOr(anyRuntimeActivity, "
+ "rt_inactive_"
+ << name << ") : rt_inactive_" << name << ";\n"
<< " }\n";
}
// Blas functions return one float XOR modify one output arg.
@@ -412,10 +418,27 @@
os << "active_" << name << ") {\n"
<< " rt_inactive_" << name << " = BuilderZ.CreateOr(rt_inactive_"
<< name << ", rt_inactive_out, \"rt.inactive.\" \"" << name << "\");\n"
+ << " if (Mode == DerivativeMode::ForwardMode || Mode == "
+ "DerivativeMode::ForwardModeSplit) anyRuntimeActivity = "
+ "anyRuntimeActivity ? BuilderZ.CreateOr(anyRuntimeActivity, "
+ "rt_inactive_"
+ << name << ") : rt_inactive_" << name << ";\n"
<< " }\n";
}
}
+ os << " if ((Mode == DerivativeMode::ForwardMode || Mode == "
+ "DerivativeMode::ForwardModeSplit) && anyRuntimeActivity) {\n"
+ << " std::string s;\n"
+ << " llvm::raw_string_ostream ss(s);\n"
+ << " ss << \"" << pattern.getName() << "\" << \"\\n\";\n"
+ << " ss << \"Runtime Activity not yet implemented for Forward-Mode "
+ "BLAS calls\" << "
+ "\"\\n\";\n"
+ << " EmitNoDerivativeError(ss.str(), call, gutils, BuilderZ, "
+ "anyRuntimeActivity);\n"
+ << " }\n";
+
os << " }\n";
bool hasFP = false;
@@ -2048,18 +2071,6 @@
<< " \n"
<< " auto callval = call.getCalledOperand(); \n\n";
- os << " if (gutils->runtimeActivity) {\n"
- << " std::string s;\n"
- << " llvm::raw_string_ostream ss(s);\n"
- << " ss << \"" << pattern.getName() << "\" << \"\\n\";\n"
- << " ss << call << \"\\n\";\n"
- << " ss << \"Runtime Activity not yet implemented for Forward-Mode "
- "BLAS calls\" << "
- "\"\\n\";\n"
- << " EmitNoDerivativeError(ss.str(), call, gutils, BuilderZ);\n"
- << " return false;\n"
- << " }\n";
-
// just make this const one available now to have less variable name repition
os << "Value * const_one = to_blas_callconv(Builder2, "
"ConstantInt::get(intType, 1), "