wip
diff --git a/enzyme/Enzyme/Clang/EnzymeClang.cpp b/enzyme/Enzyme/Clang/EnzymeClang.cpp index f48ecc2..02c64e1 100644 --- a/enzyme/Enzyme/Clang/EnzymeClang.cpp +++ b/enzyme/Enzyme/Clang/EnzymeClang.cpp
@@ -76,7 +76,9 @@ public: EnzymePlugin(clang::CompilerInstance &CI) : CI(CI) { +#if LLVM_VERSION_MAJOR >= 13 CI.getCodeGenOpts().ClearASTBeforeBackend = false; +#endif FrontendOptions &Opts = CI.getFrontendOpts(); CodeGenOptions &CGOpts = CI.getCodeGenOpts(); auto PluginName = "ClangEnzyme-" + std::to_string(LLVM_VERSION_MAJOR); @@ -131,8 +133,10 @@ V->setInit(expr); S.MarkVariableReferenced(loc, V); S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V)); +#if LLVM_VERSION_MAJOR >= 13 CI.getCodeGenOpts().ClearASTBeforeBackend = false; assert(CI.getCodeGenOpts().ClearASTBeforeBackend == false); +#endif } bool HandleTopLevelDecl(clang::DeclGroupRef dg) override { using namespace clang; @@ -228,6 +232,7 @@ return AttributeNotApplied; } +#if LLVM_VERSION_MAJOR >= 13 SmallVector<Expr *, 16> ArgsBuf; ArgsBuf.push_back(Arg0); @@ -235,6 +240,94 @@ D->addAttr(AnnotateAttr::Create(S.Context, "enzyme_function_like", ArgsBuf.data(), ArgsBuf.size(), Attr.getRange())); +#else + auto &AST = S.getASTContext(); + auto FD = cast<FunctionDecl>(D); + DeclContext *declCtx = FD->getDeclContext(); + auto loc = FD->getLocation(); + RecordDecl *RD; + if (S.getLangOpts().CPlusPlus) + RD = CXXRecordDecl::Create(AST, clang::TagTypeKind::TTK_Struct, declCtx, + loc, loc, nullptr); // rId); + else + RD = RecordDecl::Create(AST, clang::TagTypeKind::TTK_Struct, declCtx, loc, + loc, nullptr); // rId); + RD->setAnonymousStructOrUnion(true); + RD->setImplicit(); + RD->startDefinition(); + auto Tinfo = nullptr; + auto Tinfo0 = nullptr; + auto FT = AST.getPointerType(FD->getType()); + auto CharTy = AST.getIntTypeForBitwidth(8, false); + auto FD0 = FieldDecl::Create(AST, RD, loc, loc, /*Ud*/ nullptr, FT, Tinfo0, + /*expr*/ nullptr, /*mutable*/ true, + /*inclassinit*/ ICIS_NoInit); + FD0->setAccess(AS_public); + RD->addDecl(FD0); + auto FD1 = FieldDecl::Create( + AST, RD, loc, loc, /*Ud*/ nullptr, AST.getPointerType(CharTy), Tinfo0, + /*expr*/ nullptr, /*mutable*/ true, /*inclassinit*/ ICIS_NoInit); + FD1->setAccess(AS_public); + RD->addDecl(FD1); + RD->completeDefinition(); + assert(RD->getDefinition()); + auto &Id = AST.Idents.get("__enzyme_function_like_autoreg_" + + FD->getNameAsString()); + auto T = AST.getRecordType(RD); + auto V = VarDecl::Create(AST, declCtx, loc, loc, &Id, T, Tinfo, SC_None); + V->setStorageClass(SC_PrivateExtern); + V->addAttr(clang::UsedAttr::CreateImplicit(AST)); + TemplateArgumentListInfo *TemplateArgs = nullptr; + auto DR = DeclRefExpr::Create(AST, NestedNameSpecifierLoc(), loc, FD, false, + loc, FD->getType(), ExprValueKind::VK_LValue, + FD, TemplateArgs); +#if LLVM_VERSION_MAJOR >= 13 + auto rval = ExprValueKind::VK_PRValue; +#else + auto rval = ExprValueKind::VK_RValue; +#endif +#if LLVM_VERSION_MAJOR >= 15 + auto stringkind = clang::StringLiteral::StringKind::Ordinary; +#else + auto stringkind = clang::StringLiteral::StringKind::Ascii; +#endif + StringRef cstr = Literal->getString(); + Expr *exprs[2] = { +#if LLVM_VERSION_MAJOR >= 12 + ImplicitCastExpr::Create(AST, FT, CastKind::CK_FunctionToPointerDecay, DR, + nullptr, rval, FPOptionsOverride()), + ImplicitCastExpr::Create( + AST, AST.getPointerType(CharTy), CastKind::CK_ArrayToPointerDecay, + StringLiteral::Create( + AST, cstr, stringkind, + /*Pascal*/ false, + AST.getStringLiteralArrayType(CharTy, cstr.size()), loc), + nullptr, rval, FPOptionsOverride()) +#else + ImplicitCastExpr::Create(AST, FT, CastKind::CK_FunctionToPointerDecay, DR, + nullptr, rval), + ImplicitCastExpr::Create( + AST, AST.getPointerType(CharTy), CastKind::CK_ArrayToPointerDecay, + StringLiteral::Create( + AST, cstr, stringkind, + /*Pascal*/ false, + AST.getStringLiteralArrayType(CharTy, cstr.size()), loc), + nullptr, rval) +#endif + }; + auto IL = new (AST) InitListExpr(AST, loc, exprs, loc); + V->setInit(IL); + IL->setType(T); + if (IL->isValueDependent()) { + unsigned ID = S.getDiagnostics().getCustomDiagID( + DiagnosticsEngine::Error, "use of attribute 'enzyme_function_like' " + "in a templated context not yet supported"); + S.Diag(Attr.getLoc(), ID); + return AttributeNotApplied; + } + S.MarkVariableReferenced(loc, V); + S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V)); +#endif return AttributeApplied; } }; @@ -279,9 +372,72 @@ return AttributeNotApplied; } +#if LLVM_VERSION_MAJOR >= 13 // Attach an annotate attribute to the Decl. D->addAttr(AnnotateAttr::Create(S.Context, "enzyme_inactive", nullptr, 0, Attr.getRange())); +#else + auto &AST = S.getASTContext(); + DeclContext *declCtx = D->getDeclContext(); + auto loc = D->getLocation(); + RecordDecl *RD; + if (S.getLangOpts().CPlusPlus) + RD = CXXRecordDecl::Create(AST, clang::TagTypeKind::TTK_Struct, declCtx, + loc, loc, nullptr); // rId); + else + RD = RecordDecl::Create(AST, clang::TagTypeKind::TTK_Struct, declCtx, loc, + loc, nullptr); // rId); + RD->setAnonymousStructOrUnion(true); + RD->setImplicit(); + RD->startDefinition(); + auto T = isa<FunctionDecl>(D) ? cast<FunctionDecl>(D)->getType() + : cast<VarDecl>(D)->getType(); + auto Name = isa<FunctionDecl>(D) ? cast<FunctionDecl>(D)->getNameAsString() + : cast<VarDecl>(D)->getNameAsString(); + auto FT = AST.getPointerType(T); + auto subname = isa<FunctionDecl>(D) ? "inactivefn" : "inactive_global"; + auto &Id = AST.Idents.get( + (StringRef("__enzyme_") + subname + "_autoreg_" + Name).str()); + auto V = VarDecl::Create(AST, declCtx, loc, loc, &Id, FT, nullptr, SC_None); + V->setStorageClass(SC_PrivateExtern); + V->addAttr(clang::UsedAttr::CreateImplicit(AST)); + TemplateArgumentListInfo *TemplateArgs = nullptr; + auto DR = DeclRefExpr::Create( + AST, NestedNameSpecifierLoc(), loc, cast<ValueDecl>(D), false, loc, T, + ExprValueKind::VK_LValue, cast<NamedDecl>(D), TemplateArgs); +#if LLVM_VERSION_MAJOR >= 13 + auto rval = ExprValueKind::VK_PRValue; +#else + auto rval = ExprValueKind::VK_RValue; +#endif + Expr *expr = nullptr; + if (isa<FunctionDecl>(D)) { +#if LLVM_VERSION_MAJOR >= 12 + expr = + ImplicitCastExpr::Create(AST, FT, CastKind::CK_FunctionToPointerDecay, + DR, nullptr, rval, FPOptionsOverride()); +#else + expr = ImplicitCastExpr::Create( + AST, FT, CastKind::CK_FunctionToPointerDecay, DR, nullptr, rval); +#endif + } else { + expr = + UnaryOperator::Create(AST, DR, UnaryOperatorKind::UO_AddrOf, FT, rval, + clang::ExprObjectKind ::OK_Ordinary, loc, + /*canoverflow*/ false, FPOptionsOverride()); + } + + if (expr->isValueDependent()) { + unsigned ID = S.getDiagnostics().getCustomDiagID( + DiagnosticsEngine::Error, "use of attribute 'enzyme_inactive' " + "in a templated context not yet supported"); + S.Diag(Attr.getLoc(), ID); + return AttributeNotApplied; + } + V->setInit(expr); + S.MarkVariableReferenced(loc, V); + S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V)); +#endif return AttributeApplied; } };
diff --git a/enzyme/Enzyme/Clang/EnzymePassLoader.cpp b/enzyme/Enzyme/Clang/EnzymePassLoader.cpp index b77608f..b7d0c07 100644 --- a/enzyme/Enzyme/Clang/EnzymePassLoader.cpp +++ b/enzyme/Enzyme/Clang/EnzymePassLoader.cpp
@@ -55,7 +55,7 @@ static void loadNVVMPass(const PassManagerBuilder &Builder, legacy::PassManagerBase &PM) { - PM.add(createPreserveNVVMPass(/*Begin=*/true)); + PM.add(createPreserveNVVMFunctionPass(/*Begin=*/true)); } // These constructors add our pass to a list of global extensions.
diff --git a/enzyme/Enzyme/PreserveNVVM.cpp b/enzyme/Enzyme/PreserveNVVM.cpp index fc4ab39..4e575a6 100644 --- a/enzyme/Enzyme/PreserveNVVM.cpp +++ b/enzyme/Enzyme/PreserveNVVM.cpp
@@ -50,6 +50,7 @@ #include "clang/Frontend/CompilerInstance.h" #include "clang/Lex/Preprocessor.h" #include "clang/CodeGen/ModuleBuilder.h" +#include "clang/Sema/Sema.h" #endif #include <map> @@ -310,7 +311,7 @@ return false; } -bool preserveNVVM(bool Begin, Module &M) { +bool preserveNVVM(bool Begin, Module &M, llvm::Function *SF) { bool changed = false; StringMap<std::pair<std::string, std::string>> Implements; for (std::string T : {"", "f"}) { @@ -343,6 +344,7 @@ } } for (auto &F : M) { + if (SF && &F != SF) continue; auto found = Implements.find(F.getName()); if (found != Implements.end()) { changed = true; @@ -435,24 +437,30 @@ Act->setCompilerInstance(&CI); auto cons = Act->CreateASTConsumer(CI, /*InFile*/"<tmpfile>"); +#if LLVM_VERSION_MAJOR >= 13 auto Gen = Act->getCodeGenerator(); assert(Gen); - cons->Initialize(CI.getASTContext()); +#endif auto oldPP = CI.getPreprocessorPtr(); CI.createPreprocessor(clang::TU_Complete); + cons->Initialize(CI.getASTContext()); + CI.setSema(new clang::Sema(CI.getPreprocessor(), CI.getASTContext(), *cons, + clang::TU_Complete, nullptr)); // Note that right here we reparse/recompile all of llvm things clang::ParseAST(CI.getPreprocessor(), cons.get(), CI.getASTContext()); CI.setPreprocessor(oldPP); +#if LLVM_VERSION_MAJOR >= 13 assert(&Gen->CGM()); assert(Gen->GetModule()); // auto &Ty = Gen->CGM().getTypes(); - +#endif for (auto &f : M) { f.addFnAttr("clang_compiler_instance", std::to_string((size_t)(void*)&CI)); if (f.empty()) continue; +#if LLVM_VERSION_MAJOR >= 13 if (auto FD = Gen->GetDeclForMangledName(f.getName())) { f.addFnAttr("clang_decl", std::to_string((size_t)(void*)FD)); f.addFnAttr("clang_codegen", std::to_string((size_t)(void*)Act)); @@ -470,9 +478,10 @@ } } } - // auto CGI = arrangeGlobalDeclaration(FD); } +#endif } +#if LLVM_VERSION_MAJOR >= 13 for (auto &g : M.globals()) { if (auto FD = Gen->GetDeclForMangledName(g.getName())) { for (auto attr : FD->getAttrs()) { @@ -484,6 +493,7 @@ } } } +#endif toErase.push_back(&g); changed = true; } @@ -531,10 +541,12 @@ break; } if (auto F = cast<Function>(V)) { + if (!SF || F == SF) { F->addAttribute(AttributeList::FunctionIndex, Attribute::get(g.getContext(), "enzyme_inactive")); toErase.push_back(&g); changed = true; + } } else { llvm::errs() << "Param of __enzyme_inactivefn must be a " "constant function" @@ -581,11 +593,13 @@ llvm_unreachable("enzyme_function_like"); } if (auto F = cast<Function>(V)) { + if (!SF || F == SF) { F->addAttribute( AttributeList::FunctionIndex, Attribute::get(g.getContext(), "enzyme_math", nameVal)); toErase.push_back(&g); changed = true; + } } else { llvm::errs() << "Param of __enzyme_function_like must be a " "constant function" @@ -661,6 +675,10 @@ F->addAttribute(AttributeList::FunctionIndex, Attribute::get(g.getContext(), "enzyme_allocator", std::to_string(index))); + F->addAttribute(AttributeList::FunctionIndex, + Attribute::get(g.getContext(), + "enzyme_deallocator", + deallocIndStr)); } else { llvm::errs() << "Param of __enzyme_allocation_like must be a " "function" @@ -668,17 +686,13 @@ << *V << "\n"; llvm_unreachable("__enzyme_allocation_like"); } - cast<Function>(V)->addAttribute(AttributeList::FunctionIndex, - Attribute::get(g.getContext(), - "enzyme_deallocator", - deallocIndStr)); if (auto F = dyn_cast<Function>(deallocfn)) { cast<Function>(V)->setMetadata( "enzyme_deallocator_fn", llvm::MDTuple::get(F->getContext(), {llvm::ValueAsMetadata::get(F)})); - changed |= preserveLinkage(Begin, *F); + preserveLinkage(Begin, *F); } else { llvm::errs() << "Free fn of __enzyme_allocation_like must be a " "function" @@ -747,6 +761,7 @@ #endif if (!Begin) { for (auto &F : M) { + if (SF && &F != SF) continue; if (F.hasFnAttribute("prev_fixup")) { changed = true; F.removeFnAttr("prev_fixup"); @@ -783,7 +798,7 @@ PreserveNVVM(bool Begin = true) : ModulePass(ID), Begin(Begin) {} void getAnalysisUsage(AnalysisUsage &AU) const override {} - bool runOnModule(Module &M) override { return preserveNVVM(Begin, M); } + bool runOnModule(Module &M) override { return preserveNVVM(Begin, M, nullptr); } }; } // namespace @@ -796,6 +811,27 @@ return new PreserveNVVM(Begin); } +namespace { + +class PreserveNVVMFunc final : public FunctionPass { +public: + static char ID; + bool Begin; + PreserveNVVMFunc(bool Begin = true) : FunctionPass(ID), Begin(Begin) {} + + void getAnalysisUsage(AnalysisUsage &AU) const override {} + bool runOnFunction(Function &F) override { return preserveNVVM(Begin, *F.getParent(), &F); } +}; + +} // namespace + +char PreserveNVVMFunc::ID = 0; + +static RegisterPass<PreserveNVVMFunc> X2("preserve-nvvm-function", "Preserve NVVM Pass"); +FunctionPass *createPreserveNVVMFunctionPass(bool Begin) { + return new PreserveNVVMFunc(Begin); +} + #include <llvm-c/Core.h> #include <llvm-c/Types.h> @@ -807,7 +843,7 @@ PreserveNVVMNewPM::Result PreserveNVVMNewPM::run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) { - bool changed = preserveNVVM(Begin, M); + bool changed = preserveNVVM(Begin, M, nullptr); return changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); } llvm::AnalysisKey PreserveNVVMNewPM::Key;
diff --git a/enzyme/Enzyme/PreserveNVVM.h b/enzyme/Enzyme/PreserveNVVM.h index 5db6d0c..0dc2a0d 100644 --- a/enzyme/Enzyme/PreserveNVVM.h +++ b/enzyme/Enzyme/PreserveNVVM.h
@@ -32,6 +32,7 @@ } llvm::ModulePass *createPreserveNVVMPass(bool Begin); +llvm::FunctionPass *createPreserveNVVMFunctionPass(bool Begin); class PreserveNVVMNewPM final : public llvm::AnalysisInfoMixin<PreserveNVVMNewPM> {
diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 721a200..eaaa284 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
@@ -207,7 +207,9 @@ auto err2 = F->getFnAttribute("clang_codegen").getValueAsString().getAsInteger(10, cgint); assert(!err2); llvm::errs() << " cgint" << cgint << "\n"; +#if LLVM_VERSION_MAJOR >= 13 auto Gen = ((clang::CodeGenAction *)(void*)cgint)->getCodeGenerator(); +#endif FD->dump(); } #endif