blob: 70a220a9ea33e6d4e8778533237e84fbd7f103fb [file] [log] [blame] [edit]
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include <map>
#include <set>
#include <string>
using namespace llvm;
#define DATA EnzymeBlasBC
#include "blas_headers.h"
#undef DATA
static inline bool endsWith(llvm::StringRef string, llvm::StringRef suffix) {
#if LLVM_VERSION_MAJOR >= 18
return string.ends_with(suffix);
#else
return string.endswith(suffix);
#endif // LLVM_VERSION_MAJOR
}
static inline llvm::StringRef getFuncName(llvm::Function *called) {
if (called->hasFnAttribute("enzyme_math"))
return called->getFnAttribute("enzyme_math").getValueAsString();
else if (called->hasFnAttribute("enzyme_allocator"))
return "enzyme_allocator";
else
return called->getName();
}
bool provideDefinitions(Module &M, std::set<std::string> ignoreFunctions,
std::vector<std::string> &replaced) {
std::vector<StringRef> todo;
bool seen32 = false;
bool seen64 = false;
std::vector<std::pair<StringRef, llvm::Function *>> name_rewrites;
for (auto &F : M) {
if (!F.empty())
continue;
auto name = getFuncName(&F);
if (ignoreFunctions.count(name.str()))
continue;
int index = 0;
for (auto postfix : {"", "_", "_64_"}) {
std::string str;
if (strlen(postfix) == 0) {
str = name.str();
} else if (endsWith(name, postfix)) {
auto blasName = name.substr(0, name.size() - strlen(postfix)).str();
str = "cblas_" + blasName;
}
auto found = EnzymeBlasBC.find(str);
if (found != EnzymeBlasBC.end()) {
replaced.push_back(name.str());
todo.push_back(found->second);
if (name != F.getName()) {
name_rewrites.emplace_back(name, &F);
}
if (index == 1)
seen32 = true;
if (index == 2)
seen64 = true;
break;
}
index++;
}
}
for (auto &&[realname, F] : name_rewrites) {
auto decl = M.getOrInsertFunction(realname, F->getFunctionType());
auto entry = BasicBlock::Create(F->getContext(), "entry", F);
IRBuilder<> B(entry);
SmallVector<Value *, 1> vals;
for (auto &arg : F->args())
vals.push_back(&arg);
auto rt = B.CreateCall(decl, vals);
if (rt->getType()->isVoidTy())
B.CreateRetVoid();
else
B.CreateRet(rt);
}
// Push fortran wrapper libs before all the other blas
// to ensure the fortran injections have their code
// replaced
if (seen32)
todo.insert(todo.begin(), __data_fblas32);
if (seen64)
todo.insert(todo.begin(), __data_fblas64);
bool changed = false;
for (auto mod : todo) {
SMDiagnostic Err;
MemoryBufferRef buf(mod, StringRef("bcloader"));
#if LLVM_VERSION_MAJOR >= 16
auto BC = llvm::parseIR(buf, Err, M.getContext(),
llvm::ParserCallbacks([&](StringRef, StringRef) {
return std::optional<std::string>(
M.getDataLayout().getStringRepresentation());
}));
#else
auto BC = llvm::parseIR(buf, Err, M.getContext(), [&](StringRef) {
return Optional<std::string>(M.getDataLayout().getStringRepresentation());
});
#endif
if (!BC) {
Err.print("bcloader", llvm::errs());
continue;
}
assert(BC);
SmallVector<std::string, 1> toReplace;
for (auto &F : *BC) {
if (F.empty())
continue;
auto name = getFuncName(&F);
if (ignoreFunctions.count(name.str())) {
F.dropAllReferences();
#if LLVM_VERSION_MAJOR >= 16
F.erase(F.begin(), F.end());
#else
F.getBasicBlockList().erase(F.begin(), F.end());
#endif
continue;
}
toReplace.push_back(name.str());
}
BC->setTargetTriple("");
Linker L(M);
L.linkInModule(std::move(BC));
for (auto name : toReplace) {
if (auto F = M.getFunction(name)) {
F->setLinkage(Function::LinkageTypes::InternalLinkage);
F->addFnAttr(Attribute::AlwaysInline);
}
}
changed = true;
}
return changed;
}
extern "C" {
uint8_t EnzymeBitcodeReplacement(LLVMModuleRef M, char **FncsNamesToIgnore,
size_t numFncNames, const char ***foundP,
size_t *foundLen) {
std::set<std::string> ignoreFunctions = {};
for (size_t i = 0; i < numFncNames; i++) {
ignoreFunctions.insert(std::string(FncsNamesToIgnore[i]));
}
std::vector<std::string> replaced;
auto res = provideDefinitions(*unwrap(M), ignoreFunctions, replaced);
const char **found = nullptr;
if (replaced.size()) {
found = (const char **)malloc(replaced.size() * sizeof(const char **));
for (size_t i = 0; i < replaced.size(); i++) {
char *data = (char *)malloc(replaced[i].size() + 1);
memcpy(data, replaced[i].data(), replaced[i].size());
data[replaced[i].size()] = 0;
found[i] = data;
}
}
*foundP = found;
*foundLen = replaced.size();
return res;
}
}
namespace {
class BCLoader final : public ModulePass {
public:
static char ID;
BCLoader() : ModulePass(ID) {}
bool runOnModule(Module &M) override {
std::vector<std::string> replaced;
return provideDefinitions(M, {}, replaced);
}
};
} // namespace
char BCLoader::ID = 0;
static RegisterPass<BCLoader> X("bcloader",
"Link bitcode files for known functions");
ModulePass *createBCLoaderPass() { return new BCLoader(); }