Add MPI_Type_commit and PMPI_Type_create_subarray
diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp
index a976a9a..0bd6755 100644
--- a/enzyme/Enzyme/ActivityAnalysis.cpp
+++ b/enzyme/Enzyme/ActivityAnalysis.cpp
@@ -254,6 +254,8 @@
"PMPI_Comm_size",
"MPI_Comm_rank",
"PMPI_Comm_rank",
+ "MPI_Type_commit",
+ "PMPI_Type_create_subarray",
"MPI_Get_processor_name",
"MPI_Finalize",
"MPI_Test",
diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp
index 6e6e018..b0fcd77 100644
--- a/enzyme/Enzyme/Enzyme.cpp
+++ b/enzyme/Enzyme/Enzyme.cpp
@@ -486,6 +486,34 @@
F.addParamAttr(1, Attribute::NoCapture);
}
}
+ if (F.getName() == "MPI_Type_commit") {
+ F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
+ F.addFnAttr(Attribute::NoUnwind);
+ F.addFnAttr(Attribute::NoRecurse);
+#if LLVM_VERSION_MAJOR >= 9
+ F.addFnAttr(Attribute::WillReturn);
+ F.addFnAttr(Attribute::NoFree);
+ F.addFnAttr(Attribute::NoSync);
+#endif
+ F.addParamAttr(0, Attribute::NoCapture);
+ }
+ if (F.getName() == "PMPI_Type_create_subarray") {
+ F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
+ F.addFnAttr(Attribute::NoUnwind);
+ F.addFnAttr(Attribute::NoRecurse);
+#if LLVM_VERSION_MAJOR >= 9
+ F.addFnAttr(Attribute::WillReturn);
+ F.addFnAttr(Attribute::NoFree);
+ F.addFnAttr(Attribute::NoSync);
+#endif
+ F.addParamAttr(0, Attribute::NoCapture);
+ F.addParamAttr(1, Attribute::NoCapture);
+ F.addParamAttr(2, Attribute::NoCapture);
+ F.addParamAttr(3, Attribute::NoCapture);
+ F.addParamAttr(4, Attribute::NoCapture);
+ F.addParamAttr(5, Attribute::NoCapture);
+ F.addParamAttr(6, Attribute::NoCapture);
+ }
if (F.getName() == "MPI_Wait" || F.getName() == "PMPI_Wait") {
F.addFnAttr(Attribute::NoUnwind);
F.addFnAttr(Attribute::NoRecurse);
diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp
index 94eea72..b717381 100644
--- a/enzyme/Enzyme/EnzymeLogic.cpp
+++ b/enzyme/Enzyme/EnzymeLogic.cpp
@@ -4878,7 +4878,9 @@
"_ZNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEE9underflowEv",
"_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE9_M_assignERKS4_",
"_ZNSaIcED1Ev",
- "_ZNSaIcEC1Ev"};
+ "_ZNSaIcEC1Ev",
+ "MPI_Type_commit",
+ "PMPI_Type_create_subarray"};
if (F->getName().startswith("_ZNSolsE") || NoFrees.count(F->getName().str()))
return F;
diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
index ba2be99..53d0e6e 100644
--- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
+++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
@@ -3762,6 +3762,22 @@
updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
return;
}
+ if (funcName == "MPI_Type_commit") {
+ TypeTree ptrint;
+ ptrint.insert({-1}, BaseType::Pointer);
+ ptrint.insert({-1, 0}, BaseType::Integer);
+ updateAnalysis(call.getOperand(1), ptrint, &call);
+ updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
+ return;
+ }
+ if (funcName == "PMPI_Type_create_subarray") {
+ TypeTree ptrint;
+ ptrint.insert({-1}, BaseType::Pointer);
+ ptrint.insert({-1, 0}, BaseType::Integer);
+ updateAnalysis(call.getOperand(1), ptrint, &call);
+ updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
+ return;
+ }
if (funcName == "MPI_Barrier" || funcName == "MPI_Finalize") {
updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
return;