Commit 326edd76 by masahi Committed by Tianqi Chen

[ROCM] Working math function support for ROCm backend, a bug fix in LLVM based codegen (#570)

* added math function support

* bug fix extern func call in llvm based codegen

lint fix

fix build

bug fix extern func call in llvm based codegen

* moved rocm bitcodes detection to python
parent ab858e3f
"""Utility for ROCm backend"""
import subprocess
from os.path import join
from . import util
from ..api import register_func
from ..api import register_func, convert
def rocm_link(in_file, out_file):
"""Link relocatable ELF object to shared ELF object using lld
......@@ -49,3 +50,32 @@ def callback_rocm_link(obj_bin):
rocm_link(tmp_obj, tmp_cobj)
cobj_bin = bytearray(open(tmp_cobj, "rb").read())
return cobj_bin
@register_func("tvm_callback_rocm_bitcode_path")
def callback_rocm_bitcode_path(rocdl_dir="/opt/rocm/lib/"):
"""Utility function to find ROCm device library bitcodes
Parameters
----------
rocdl_dir : str
The path to rocm library directory
The default value is the standard location
"""
# seems link order matters.
bitcode_files = [
"oclc_daz_opt_on.amdgcn.bc",
"ocml.amdgcn.bc",
"hc.amdgcn.bc",
"irif.amdgcn.bc",
"ockl.amdgcn.bc",
"oclc_correctly_rounded_sqrt_off.amdgcn.bc",
"oclc_correctly_rounded_sqrt_on.amdgcn.bc",
"oclc_daz_opt_off.amdgcn.bc",
"oclc_finite_only_off.amdgcn.bc",
"oclc_finite_only_on.amdgcn.bc",
"oclc_isa_version_803.amdgcn.bc",
"oclc_isa_version_900.amdgcn.bc",
"oclc_unsafe_math_off.amdgcn.bc",
"oclc_unsafe_math_on.amdgcn.bc"
]
return convert([join(rocdl_dir, bitcode) for bitcode in bitcode_files])
......@@ -161,6 +161,24 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
cg->AddFunction(f);
}
const auto *find_rocm_bitcodes =
tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path");
Array<Expr> bitcode_files = (*find_rocm_bitcodes)();
for (auto &bitcode : bitcode_files) {
std::string path = bitcode.as<StringImm>()->value;
llvm::SMDiagnostic err;
std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, *ctx);
if (mlib.get() == nullptr) {
std::string msg = err.getMessage();
LOG(FATAL) << "Fail to load bitcode file " << path << "\n"
<< "line " << err.getLineNo() << ":" << msg;
}
mlib->setTargetTriple(tm->getTargetTriple().str());
mlib->setDataLayout(tm->createDataLayout());
cg->AddLinkModule(std::move(mlib));
}
std::unique_ptr<llvm::Module> module = cg->Finish();
llvm::SmallString<8> dataObj, data_ll, dataAsm;
llvm::raw_svector_ostream destObj(dataObj), dest_ll(data_ll), destAsm(dataAsm);
......
......@@ -516,11 +516,10 @@ llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
}
llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
CHECK_GE(op->args.size(), 1U);
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
for (size_t i = 1; i < op->args.size(); ++i) {
arg_value.push_back(MakeValue(op->args[i + 1]));
for (size_t i = 0; i < op->args.size(); ++i) {
arg_value.push_back(MakeValue(op->args[i]));
arg_type.push_back(arg_value.back()->getType());
}
llvm::FunctionType* ftype = llvm::FunctionType::get(
......
/*!
* Copyright (c) 2017 by Contributors
* \file intrin_rule_llvm.cc
*/
#ifdef TVM_LLVM_VERSION
#include "./intrin_rule_llvm.h"
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/api_registry.h>
#include <sstream>
namespace tvm {
namespace codegen {
inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {
Expr e = args[0];
using namespace ir;
const Call* call = e.as<Call>();
CHECK(call != nullptr);
std::ostringstream intrinsic_name;
intrinsic_name << "__ocml_" << call->name << "_f" << call->type.bits();
*rv = Call::make(call->type, intrinsic_name.str(), call->args,
Call::PureExtern);
}
namespace llvm {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow")
.set_body(DispatchExternOCML);
} // namespace llvm
} // namespace codegen
} // namespace tvm
#endif // LLVM_VERSION
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment