blob: 24ec50183f4f551aa94da247afc273ae8034e1be [file] [log] [blame] [edit]
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include <map>
#include <set>
#include <string>
using namespace llvm;
#define DATA EnzymeBlasBC
#include "blas_headers.h"
#undef DATA
bool provideDefinitions(Module &M) {
std::vector<StringRef> todo;
bool seen32 = false;
bool seen64 = false;
bool seenGemm = false;
for (auto &F : M) {
if (!F.empty())
continue;
int index = 0;
for (auto postfix : {"", "_", "_64_"}) {
std::string str;
if (strlen(postfix) == 0)
str = F.getName().str();
else if (F.getName().endswith(postfix)) {
str = "cblas_" +
F.getName().substr(0, F.getName().size() - strlen(postfix)).str();
}
auto found = EnzymeBlasBC.find(str);
if (found != EnzymeBlasBC.end()) {
todo.push_back(found->second);
if (index == 1)
seen32 = true;
if (index == 2)
seen64 = true;
if (StringRef(str).endswith("gemm"))
seenGemm = true;
break;
}
index++;
}
}
// Push fortran wrapper libs before all the other blas
// to ensure the fortran injections have their code
// replaced
if (seen32)
todo.insert(todo.begin(), __data_fblas32);
if (seen64)
todo.insert(todo.begin(), __data_fblas64);
if (seenGemm) {
todo.push_back(__data_xerbla);
}
bool changed = false;
for (size_t i = 0; i < todo.size(); i++) {
SMDiagnostic Err;
MemoryBufferRef buf(todo[i], StringRef("bcloader"));
#if LLVM_VERSION_MAJOR <= 10
auto BC = llvm::parseIR(buf, Err, M.getContext(), true,
M.getDataLayout().getStringRepresentation());
#else
auto BC = llvm::parseIR(buf, Err, M.getContext(), [&](StringRef) {
return Optional<std::string>(M.getDataLayout().getStringRepresentation());
});
#endif
if (!BC)
Err.print("bcloader", llvm::errs());
assert(BC);
SmallVector<std::string, 1> toReplace;
for (auto &F : *BC) {
if (F.empty()) {
if (todo[i] != __data_fblas32 && todo[i] != __data_fblas64) {
int index = 0;
for (auto postfix : {"", "_", "_64_"}) {
std::string str;
if (strlen(postfix) == 0)
str = F.getName().str();
else if (F.getName().endswith(postfix)) {
str = "cblas_" +
F.getName()
.substr(0, F.getName().size() - strlen(postfix))
.str();
}
auto found = EnzymeBlasBC.find(str);
if (found != EnzymeBlasBC.end()) {
if (index == 1)
todo.push_back(__data_fblas32);
if (index == 2)
todo.push_back(__data_fblas64);
todo.push_back(found->second);
break;
}
index++;
}
}
continue;
}
toReplace.push_back(F.getName().str());
}
Linker L(M);
L.linkInModule(std::move(BC));
for (auto name : toReplace) {
if (auto F = M.getFunction(name)) {
F->setLinkage(Function::LinkageTypes::InternalLinkage);
F->addFnAttr(Attribute::AlwaysInline);
}
}
changed = true;
}
return changed;
}
extern "C" {
uint8_t EnzymeBitcodeReplacement(LLVMModuleRef M) {
return provideDefinitions(*unwrap(M));
}
}
namespace {
class BCLoader : public ModulePass {
public:
static char ID;
BCLoader() : ModulePass(ID) {}
bool runOnModule(Module &M) override { return provideDefinitions(M); }
};
} // namespace
char BCLoader::ID = 0;
static RegisterPass<BCLoader> X("bcloader",
"Link bitcode files for known functions");
ModulePass *createBCLoaderPass() { return new BCLoader(); }