Commit 6894d42b by Tianqi Chen Committed by GitHub

[CODEGEN] Allow link additional module (#559)

* [CODEGEN] Allow link additional module

* fix py3

* add register back
parent 163c4795
......@@ -30,4 +30,6 @@ from .ndarray import register_extension
from .schedule import create_schedule
from .build_module import build, lower, build_config
from .tag import tag_scope
from .contrib import rocm as _rocm
# Contrib initializers
from .contrib import rocm as _rocm, nvcc as _nvcc
......@@ -3,8 +3,12 @@
from __future__ import absolute_import as _abs
import subprocess
import os
import warnings
from . import util
from .. import ndarray as nd
from ..api import register_func
from .._ffi.base import py_str
def compile_cuda(code,
target="ptx",
......@@ -72,3 +76,60 @@ def compile_cuda(code,
raise RuntimeError(msg)
return bytearray(open(file_target, "rb").read())
def find_cuda_path():
"""Utility function to find cuda path
Returns
-------
path : str
Path to cuda root.
"""
if "CUDA_PATH" in os.environ:
return os.environ["CUDA_PATH"]
cmd = ["which", "nvcc"]
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
out = py_str(out)
if proc.returncode == 0:
return os.path.abspath(os.path.join(str(out).strip(), "../.."))
cuda_path = "/usr/local/cuda"
if os.path.exists(os.path.join(cuda_path, "bin/nvcc")):
return cuda_path
raise RuntimeError("Cannot find cuda path")
@register_func("tvm_callback_libdevice_path")
def find_libdevice_path(arch):
"""Utility function to find libdevice
Parameters
----------
arch : int
The compute architecture in int
"""
cuda_path = find_cuda_path()
lib_path = os.path.join(cuda_path, "nvvm/libdevice")
selected_ver = 0
selected_path = None
for fn in os.listdir(lib_path):
if not fn.startswith("libdevice"):
continue
ver = int(fn.split(".")[-3].split("_")[-1])
if ver > selected_ver and ver <= arch:
selected_ver = ver
selected_path = fn
if selected_path is None:
raise RuntimeError("Cannot find libdevice for arch {}".format(arch))
return os.path.join(lib_path, selected_path)
def callback_libdevice_path(arch):
try:
return find_libdevice_path(arch)
except RuntimeError:
warnings.warn("Cannot find libdevice path")
return ""
......@@ -26,6 +26,7 @@ def rocm_link(in_file, out_file):
msg += str(out)
raise RuntimeError(msg)
@register_func("tvm_callback_rocm_link")
def callback_rocm_link(obj_bin):
"""Links object file generated from LLVM to HSA Code Object
......
......@@ -147,10 +147,21 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
this->AddStartupFunction();
// link modules
for (size_t i = 0; i < link_modules_.size(); ++i) {
CHECK(!llvm::Linker::linkModules(*module_, std::move(link_modules_[i])))
<< "Failed to link modules";
}
link_modules_.clear();
// optimize
this->Optimize();
return std::move(module_);
}
void CodeGenLLVM::AddLinkModule(std::unique_ptr<llvm::Module>&& mod) {
link_modules_.emplace_back(std::move(mod));
}
void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
LOG(FATAL) << "not implemented";
}
......
......@@ -67,6 +67,11 @@ class CodeGenLLVM :
*/
virtual std::unique_ptr<llvm::Module> Finish();
/*!
* \brief Add mod to be linked with the generated module
* \param mod The module to be linked.
*/
void AddLinkModule(std::unique_ptr<llvm::Module>&& mod);
/*!
* \brief Create Value for expression e
* \param e The expression to be created value for.
* \return created value.
......@@ -227,7 +232,8 @@ class CodeGenLLVM :
llvm::MDNode* md_very_likely_branch_{nullptr};
llvm::MDNode* md_tbaa_root_{nullptr};
llvm::MDNode* md_tbaa_alias_set_{nullptr};
// modules to be linked.
std::vector<std::unique_ptr<llvm::Module> > link_modules_;
/*! \brief native vector bits of current targetx*/
int native_vector_bits_{0};
/*! \brief the storage scope of allocation */
......
......@@ -153,9 +153,10 @@ inline int DetectCUDAComputeVersion() {
runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
CHECK(target.length() >= 5 &&
target.substr(0, 5) == "nvptx");
int compute_ver = DetectCUDAComputeVersion();
std::ostringstream config;
config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_"
<< DetectCUDAComputeVersion()
<< compute_ver
<< target.substr(5, target.length() - 5);
llvm::TargetMachine* tm = GetLLVMTargetMachine(config.str());
std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX());
......@@ -164,6 +165,25 @@ runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
for (LoweredFunc f : funcs) {
cg->AddFunction(f);
}
const auto* flibdevice_path =
tvm::runtime::Registry::Get("tvm_callback_libdevice_path");
if (flibdevice_path != nullptr) {
std::string path = (*flibdevice_path)(compute_ver);
if (path.length() != 0) {
llvm::SMDiagnostic err;
std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, *ctx);
if (mlib.get() == nullptr) {
std::string msg = err.getMessage();
LOG(FATAL) << "Fail to load bitcode file " << path << "\n"
<< "line " << err.getLineNo() << ":" << msg;
}
mlib->setTargetTriple(tm->getTargetTriple().str());
mlib->setDataLayout(tm->createDataLayout());
// TODO(tqchen) libdevice linking not yet working.
// cg->AddLinkModule(std::move(mlib));
}
}
std::unique_ptr<llvm::Module> module = cg->Finish();
llvm::SmallString<8> data_ptx, data_ll;
llvm::raw_svector_ostream dest_ptx(data_ptx), dest_ll(data_ll);
......
......@@ -4,52 +4,12 @@
*/
#ifdef TVM_LLVM_VERSION
#include <tvm/ir.h>
#include <tvm/api_registry.h>
#include <tvm/codegen.h>
#include <string>
#include "./llvm_common.h"
#include "./intrin_rule_llvm.h"
namespace tvm {
namespace codegen {
namespace llvm {
using namespace ir;
// num_signature means number of arguments used to query signature
template<unsigned id, int num_signature>
inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
cargs.push_back(UIntImm::make(UInt(32), id));
cargs.push_back(UIntImm::make(UInt(32), num_signature));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = Call::make(
call->type, "llvm_intrin", cargs, Call::PureIntrinsic);
}
template<unsigned id, int num_signature>
inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
cargs.push_back(UIntImm::make(UInt(32), id));
cargs.push_back(UIntImm::make(UInt(32), num_signature));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = Call::make(
call->type, "llvm_intrin", cargs, Call::Intrinsic);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch")
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 0>);
......
/*!
* Copyright (c) 2017 by Contributors
* \file intrin_rule_llvm.h
* \brief Common utilities for llvm intrinsics.
*/
#ifndef TVM_CODEGEN_LLVM_INTRIN_RULE_LLVM_H_
#define TVM_CODEGEN_LLVM_INTRIN_RULE_LLVM_H_
#ifdef TVM_LLVM_VERSION
#include <tvm/ir.h>
#include <tvm/api_registry.h>
#include <tvm/codegen.h>
#include <string>
#include "./llvm_common.h"
namespace tvm {
namespace codegen {
// num_signature means number of arguments used to query signature
template<unsigned id, int num_signature>
inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const ir::Call* call = e.as<ir::Call>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
cargs.push_back(ir::UIntImm::make(UInt(32), id));
cargs.push_back(ir::UIntImm::make(UInt(32), num_signature));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = ir::Call::make(
call->type, "llvm_intrin", cargs, ir::Call::PureIntrinsic);
}
template<unsigned id, int num_signature>
inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const ir::Call* call = e.as<ir::Call>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
cargs.push_back(ir::UIntImm::make(UInt(32), id));
cargs.push_back(ir::UIntImm::make(UInt(32), num_signature));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = ir::Call::make(
call->type, "llvm_intrin", cargs, ir::Call::Intrinsic);
}
} // namespace codegen
} // namespace tvm
#endif // LLVM_VERSION
#endif // TVM_CODEGEN_LLVM_INTRIN_RULE_LLVM_H_
......@@ -43,6 +43,8 @@
#include <llvm/IRReader/IRReader.h>
#include <llvm/CodeGen/TargetLoweringObjectFileImpl.h>
#include <llvm/Linker/Linker.h>
#include <utility>
#include <string>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment