Add MPI_Type_commit and PMPI_Type_create_subarray
diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index a976a9a..0bd6755 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp
@@ -254,6 +254,8 @@ "PMPI_Comm_size", "MPI_Comm_rank", "PMPI_Comm_rank", + "MPI_Type_commit", + "PMPI_Type_create_subarray", "MPI_Get_processor_name", "MPI_Finalize", "MPI_Test",
diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 6e6e018..b0fcd77 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp
@@ -486,6 +486,34 @@ F.addParamAttr(1, Attribute::NoCapture); } } + if (F.getName() == "MPI_Type_commit") { + F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly); + F.addFnAttr(Attribute::NoUnwind); + F.addFnAttr(Attribute::NoRecurse); +#if LLVM_VERSION_MAJOR >= 9 + F.addFnAttr(Attribute::WillReturn); + F.addFnAttr(Attribute::NoFree); + F.addFnAttr(Attribute::NoSync); +#endif + F.addParamAttr(0, Attribute::NoCapture); + } + if (F.getName() == "PMPI_Type_create_subarray") { + F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly); + F.addFnAttr(Attribute::NoUnwind); + F.addFnAttr(Attribute::NoRecurse); +#if LLVM_VERSION_MAJOR >= 9 + F.addFnAttr(Attribute::WillReturn); + F.addFnAttr(Attribute::NoFree); + F.addFnAttr(Attribute::NoSync); +#endif + F.addParamAttr(0, Attribute::NoCapture); + F.addParamAttr(1, Attribute::NoCapture); + F.addParamAttr(2, Attribute::NoCapture); + F.addParamAttr(3, Attribute::NoCapture); + F.addParamAttr(4, Attribute::NoCapture); + F.addParamAttr(5, Attribute::NoCapture); + F.addParamAttr(6, Attribute::NoCapture); + } if (F.getName() == "MPI_Wait" || F.getName() == "PMPI_Wait") { F.addFnAttr(Attribute::NoUnwind); F.addFnAttr(Attribute::NoRecurse);
diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 94eea72..b717381 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp
@@ -4878,7 +4878,9 @@ "_ZNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEE9underflowEv", "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE9_M_assignERKS4_", "_ZNSaIcED1Ev", - "_ZNSaIcEC1Ev"}; + "_ZNSaIcEC1Ev", + "MPI_Type_commit", + "PMPI_Type_create_subarray"}; if (F->getName().startswith("_ZNSolsE") || NoFrees.count(F->getName().str())) return F;
diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index ba2be99..53d0e6e 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
@@ -3762,6 +3762,22 @@ updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); return; } + if (funcName == "MPI_Type_commit") { + TypeTree ptrint; + ptrint.insert({-1}, BaseType::Pointer); + ptrint.insert({-1, 0}, BaseType::Integer); + updateAnalysis(call.getOperand(1), ptrint, &call); + updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); + return; + } + if (funcName == "PMPI_Type_create_subarray") { + TypeTree ptrint; + ptrint.insert({-1}, BaseType::Pointer); + ptrint.insert({-1, 0}, BaseType::Integer); + updateAnalysis(call.getOperand(1), ptrint, &call); + updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); + return; + } if (funcName == "MPI_Barrier" || funcName == "MPI_Finalize") { updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); return;