Commit b8c8aadf by Tianqi Chen Committed by GitHub

[BACKEND] Allow nvptx to pass ll ir to CUDAModule (#404)

parent 50c7a01b
......@@ -131,8 +131,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
};
runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
CHECK(target.length(
) >= 5 &&
CHECK(target.length() >= 5 &&
target.substr(0, 5) == "nvptx");
llvm::TargetMachine* tm = GetLLVMTargetMachine(
"-mtriple=nvptx64-nvidia-cuda -mcpu=sm_20" +
......@@ -144,16 +143,19 @@ runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
cg->AddFunction(f);
}
std::unique_ptr<llvm::Module> module = cg->Finish();
llvm::SmallString<8> data;
llvm::raw_svector_ostream dest(data);
dest.SetUnbuffered();
llvm::SmallString<8> data_ptx, data_ll;
llvm::raw_svector_ostream dest_ptx(data_ptx), dest_ll(data_ll);
dest_ptx.SetUnbuffered();
dest_ll.SetUnbuffered();
llvm::legacy::PassManager pass;
CHECK(tm->addPassesToEmitFile(
pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
pass, dest_ptx, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
<< "Cannot emit target CGFT_ObjectFile";
pass.run(*module);
std::string ptx(data.begin(), data.end());
return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(funcs), "");
module->print(dest_ll, nullptr);
std::string ptx(data_ptx.begin(), data_ptx.end());
std::string ll(data_ll.begin(), data_ll.end());
return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(funcs), ll);
}
TVM_REGISTER_API("codegen.build_nvptx")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment