Commit 3bf72469 by Tianqi Chen Committed by GitHub

[BUILD/CODEGEN] Allow combine multiple functions in build stage. (#169)

* [BUILD/CODEGEN] Allow combine multiple functions in build stage.

* Enhance code module

* fix compile
parent 5912ed03
......@@ -37,6 +37,16 @@ class LoweredFunc : public FunctionRef {
using ContainerType = LoweredFuncNode;
};
/*! \brief specific type of lowered function */
enum LoweredFuncType : int {
/*! \brief Function that can mix device and host calls */
kMixedFunc = 0,
/*! \brief Only contains host code */
kHostFunc = 1,
/*! \brief Only contains device code */
kDeviceFunc = 2
};
/*! \brief Node container of LoweredFunc */
class LoweredFuncNode : public FunctionBaseNode {
public:
......@@ -72,6 +82,8 @@ class LoweredFuncNode : public FunctionBaseNode {
* constant Expr of given type is used.
*/
Map<Var, Expr> handle_data_type;
/*! \brief The type of the function */
LoweredFuncType func_type{kMixedFunc};
/*! \brief Whether this function is packed function */
bool is_packed_func{true};
/*! \brief The body statment of the function */
......@@ -90,6 +102,7 @@ class LoweredFuncNode : public FunctionBaseNode {
v->Visit("args", &args);
v->Visit("thread_axis", &thread_axis);
v->Visit("handle_data_type", &handle_data_type);
v->Visit("func_type", &func_type);
v->Visit("is_packed_func", &is_packed_func);
v->Visit("body", &body);
}
......
......@@ -29,13 +29,15 @@
#define TVM_DLL
#endif
#include <stdint.h>
#include <stddef.h>
// TVM Runtime is DLPack compatible.
#include <dlpack/dlpack.h>
#ifdef __cplusplus
TVM_EXTERN_C {
#endif
#include <stdint.h>
#include <stddef.h>
/*! \brief type of array index. */
typedef int64_t tvm_index_t;
......@@ -405,6 +407,7 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
#ifdef __cplusplus
} // TVM_EXTERN_C
#endif
#endif // TVM_RUNTIME_C_RUNTIME_API_H_
......@@ -4,7 +4,7 @@ from __future__ import absolute_import
import sys
import ctypes
from .base import _LIB, check_call, py_str, c_str, _FFI_MODE
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......
......@@ -220,30 +220,57 @@ def build(sch,
if isinstance(sch, schedule.Schedule):
if args is None:
raise ValueError("args must be given for build from schedule")
fapi = lower(sch, args,
name=name,
binds=binds)
flist = lower(sch, args,
name=name,
binds=binds)
if isinstance(flist, collections.LoweredFunc):
flist = [flist]
elif isinstance(sch, collections.LoweredFunc):
if args:
raise ValueError("args must be done when build from LoweredFunc")
fapi = sch
flist = [sch]
elif isinstance(sch, (list, tuple, collections.Array)):
flist = sch
else:
raise ValueError("sch have to be Schedule or LoweredFunc")
# device related lowering
if BuildConfig.current.detect_global_barrier:
fapi = ir_pass.StorageSync(fapi, "global")
fapi = ir_pass.StorageSync(fapi, "shared")
warp_size = 32 if target == "cuda" else 1
fapi = ir_pass.LowerThreadAllreduce(fapi, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(fapi)]
fsplits[0] = ir_pass.LowerPackedCall(fsplits[0])
if len(fsplits) > 1:
raise ValueError("sch have to be Schedule, LoweredFunc or list of LoweredFunc")
fname_set = set()
for x in flist:
if not isinstance(x, collections.LoweredFunc):
raise ValueError("sch have to be Schedule, LoweredFunc or list of LoweredFunc")
if x.name in fname_set:
raise ValueError("Duplicate function name %s" % x.name)
fhost = []
fdevice = []
for func in flist:
if func.func_type == collections.LoweredFunc.MixedFunc:
if BuildConfig.current.detect_global_barrier:
func = ir_pass.StorageSync(func, "global")
func = ir_pass.StorageSync(func, "shared")
warp_size = 32 if target == "cuda" else 1
func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(func)]
fhost.append(fsplits[0])
for x in fsplits[1:]:
fdevice.append(x)
elif func.func_type == collections.LoweredFunc.HostFunc:
fhost.append(func)
elif func.func_type == collections.LoweredFunc.DeviceFunc:
fdevice.append(func)
else:
raise ValueError("unknown function type %d" % func.func_type)
fhost = [ir_pass.LowerPackedCall(x) for x in fhost]
if not target.startswith("llvm") and target != "stackvm" and not fdevice:
raise ValueError(
"Specified target %s, but cannot find device code, did you do bind?" % target)
if fdevice:
if not target_host:
target_host = "llvm" if module.enabled("llvm") else "stackvm"
mhost = codegen.build_module(fsplits[0], target_host)
mhost = codegen.build_module(fhost, target_host)
if target:
mdev = codegen.build_module(fsplits[1:], target)
mdev = codegen.build_module(fdevice, target)
mhost.import_module(mdev)
return mhost
else:
return codegen.build_module(fsplits[0], target)
return codegen.build_module(fhost, target)
......@@ -68,4 +68,6 @@ class Range(NodeBase):
@register_node
class LoweredFunc(NodeBase):
"""Represent a LoweredFunc in TVM."""
pass
MixedFunc = 0
HostFunc = 1
DeviceFunc = 2
......@@ -96,7 +96,7 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
void CodeGenLLVM::InitGlobalContext() {
gv_mod_ctx_ = new llvm::GlobalVariable(
*module_, t_void_p_, false,
llvm::GlobalValue::LinkOnceODRLinkage, 0, "__tvm_module_ctx");
llvm::GlobalValue::LinkOnceAnyLinkage, 0, "__tvm_module_ctx");
gv_mod_ctx_->setAlignment(data_layout_->getTypeAllocSize(t_void_p_));
gv_mod_ctx_->setInitializer(llvm::Constant::getNullValue(t_void_p_));
}
......@@ -142,21 +142,12 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
llvm::Function* f = module_->getFunction(entry_func_name);
CHECK(f) << "Function " << entry_func_name << "does not in module";
CHECK(!module_->getFunction(runtime::symbol::tvm_module_main));
llvm::FunctionType* ftype = f->getFunctionType();
function_ = llvm::cast<llvm::Function>(
module_->getOrInsertFunction(runtime::symbol::tvm_module_main, ftype));
function_->setCallingConv(llvm::CallingConv::C);
std::vector<llvm::Value*> args;
for (auto it = function_->arg_begin();
it != function_->arg_end(); ++it) {
args.push_back(&(*it));
}
llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(block);
llvm::CallInst* call = builder_->CreateCall(f, args);
call->setTailCall(true);
builder_->CreateRet(call);
llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 1);
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, true, llvm::GlobalValue::WeakAnyLinkage, 0,
runtime::symbol::tvm_module_main);
global->setAlignment(1);
global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, entry_func_name));
}
class FPassManager : public llvm::legacy::FunctionPassManager {
......@@ -424,7 +415,7 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
// create the function handle
hptr = new llvm::GlobalVariable(
*module_, t_tvm_func_handle_, false,
llvm::GlobalValue::PrivateLinkage, 0, ".tvm_func");
llvm::GlobalValue::LinkOnceAnyLinkage, 0, ".tvm_func." + fname);
hptr->setAlignment(align);
hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_));
func_handle_map_[fname] = hptr;
......
......@@ -36,8 +36,10 @@ class LLVMModuleNode final : public runtime::ModuleNode {
void PreCompile(const std::string& name, TVMContext ctx) final {
if (ee_ == nullptr) LazyInitJIT();
std::lock_guard<std::mutex> lock(mutex_);
const std::string& fname = (name == runtime::symbol::tvm_module_main ?
entry_func_ : name);
BackendPackedCFunc faddr =
reinterpret_cast<BackendPackedCFunc>(ee_->getFunctionAddress(name));
reinterpret_cast<BackendPackedCFunc>(ee_->getFunctionAddress(fname));
CHECK(faddr != nullptr)
<< "Failed to Precompile function " << name;
}
......@@ -47,8 +49,10 @@ class LLVMModuleNode final : public runtime::ModuleNode {
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
if (ee_ == nullptr) LazyInitJIT();
std::lock_guard<std::mutex> lock(mutex_);
const std::string& fname = (name == runtime::symbol::tvm_module_main ?
entry_func_ : name);
BackendPackedCFunc faddr =
reinterpret_cast<BackendPackedCFunc>(ee_->getFunctionAddress(name));
reinterpret_cast<BackendPackedCFunc>(ee_->getFunctionAddress(fname));
if (faddr == nullptr) return PackedFunc();
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
int ret = (*faddr)(
......@@ -103,6 +107,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
CHECK_NE(funcs.size(), 0U);
ctx_ = std::make_shared<llvm::LLVMContext>();
CodeGenLLVM cg;
entry_func_ = funcs[0]->name;
cg.Init(funcs[0]->name, tm_, ctx_.get());
for (LoweredFunc f : funcs) {
cg.AddFunction(f);
......@@ -147,6 +152,8 @@ class LLVMModuleNode final : public runtime::ModuleNode {
}
// The target configuration string
std::string target_;
// Name of entry function.
std::string entry_func_;
// JIT lock
std::mutex mutex_;
// execution engine
......
......@@ -54,5 +54,10 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
std::make_shared<SourceModuleNode>(code, fmt);
return runtime::Module(n);
}
TVM_REGISTER_GLOBAL("module.source_module_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = SourceModuleCreate(args[0], args[1]);
});
} // namespace codegen
} // namespace tvm
......@@ -280,6 +280,7 @@ class ThreadAllreduceBuilder : public IRMutator {
LoweredFunc
LowerThreadAllreduce(LoweredFunc f, int warp_size) {
CHECK_NE(f->func_type, kHostFunc);
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body);
return LoweredFunc(n);
......
......@@ -155,6 +155,7 @@ class HostDeviceSplitter : public IRMutator {
}
Array<LoweredFunc> Split(LoweredFunc f) {
CHECK_EQ(f->func_type, kMixedFunc);
for (auto kv : f->handle_data_type) {
handle_data_type_[kv.first.get()] = kv.second;
}
......@@ -162,6 +163,7 @@ class HostDeviceSplitter : public IRMutator {
std::shared_ptr<LoweredFuncNode> n =
std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = this->Mutate(f->body);
n->func_type = kHostFunc;
Array<LoweredFunc> ret{LoweredFunc(n)};
for (LoweredFunc x : device_funcs_) {
ret.push_back(x);
......@@ -179,6 +181,7 @@ class HostDeviceSplitter : public IRMutator {
m.visit_thread_extent_ = false;
n->body = m.Mutate(body);
n->name = os.str();
n->func_type = kDeviceFunc;
n->thread_axis = m.thread_axis_;
// Strictly order the arguments: Var pointers, positional arguments.
for (Var v : m.undefined_) {
......
......@@ -397,6 +397,7 @@ Stmt StorageSync(Stmt stmt, std::string storage_scope) {
}
LoweredFunc StorageSync(LoweredFunc f, std::string storage_scope) {
CHECK_NE(f->func_type, kHostFunc);
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = StorageSync(f->body, storage_scope);
return LoweredFunc(n);
......
......@@ -62,11 +62,17 @@ class CUDAModuleNode : public runtime::ModuleNode {
void SaveToFile(const std::string& file_name,
const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
CHECK_EQ(fmt, fmt_)
<< "Can only save to format=" << fmt_;
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, data_);
if (fmt == "cu") {
CHECK_NE(cuda_source_.length(), 0);
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, cuda_source_);
} else {
CHECK_EQ(fmt, fmt_)
<< "Can only save to format=" << fmt_;
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, data_);
}
}
void SaveToBinary(dmlc::Stream* stream) final {
......
......@@ -101,6 +101,17 @@ class DSOModuleNode final : public ModuleNode {
}
private:
BackendPackedCFunc GetFuncPtr(const std::string& name) {
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
GetGlobalVPtr(runtime::symbol::tvm_module_main));
CHECK(entry_name!= nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
return GetFuncPtr_(entry_name);
} else {
return GetFuncPtr_(name);
}
}
// Platform dependent handling.
#if defined(_WIN32)
// library handle
......@@ -111,7 +122,7 @@ class DSOModuleNode final : public ModuleNode {
std::wstring wname(name.begin(), name.end());
lib_handle_ = LoadLibraryW(wname.c_str());
}
BackendPackedCFunc GetFuncPtr(const std::string& name) {
BackendPackedCFunc GetFuncPtr_(const std::string& name) {
return reinterpret_cast<BackendPackedCFunc>(
GetProcAddress(lib_handle_, (LPCSTR)name.c_str())); // NOLINT(*)
}
......@@ -129,7 +140,7 @@ class DSOModuleNode final : public ModuleNode {
void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
}
BackendPackedCFunc GetFuncPtr(const std::string& name) {
BackendPackedCFunc GetFuncPtr_(const std::string& name) {
return reinterpret_cast<BackendPackedCFunc>(
dlsym(lib_handle_, name.c_str()));
}
......
......@@ -103,7 +103,42 @@ def test_llvm_temp_space():
c.asnumpy(), a.asnumpy() + 1 + 1)
check_llvm()
def test_multiple_func():
nn = 1024
n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=4)
s[C].parallel(xo)
s[C].vectorize(xi)
def check_llvm():
if not tvm.module.enabled("llvm"):
return
# build two functions
f2 = tvm.lower(s, [A, B, C], name="fadd1")
f1 = tvm.lower(s, [A, B, C], name="fadd2")
m = tvm.build([f1, f2], "llvm")
fadd1 = m['fadd1']
fadd2 = m['fadd2']
ctx = tvm.cpu(0)
# launch the kernel.
n = nn
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
fadd1(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
fadd2(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
check_llvm()
if __name__ == "__main__":
test_multiple_func()
test_llvm_add_pipeline()
test_llvm_flip_pipeline()
test_llvm_madd_pipeline()
......
......@@ -101,6 +101,44 @@ def test_device_module_dump():
check_device("opencl")
check_device("metal")
def test_combine_module_llvm():
"""Test combine multiple module into one shared lib."""
# graph
nn = 12
n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
def check_llvm():
ctx = tvm.cpu(0)
if not tvm.module.enabled("llvm"):
print("Skip because llvm is not enabled" )
return
temp = util.tempdir()
fadd1 = tvm.build(s, [A, B], "llvm", name="myadd1")
fadd2 = tvm.build(s, [A, B], "llvm", name="myadd2")
path1 = temp.relpath("myadd1.o")
path2 = temp.relpath("myadd2.o")
path_dso = temp.relpath("mylib.so")
fadd1.save(path1)
fadd2.save(path2)
# create shared library with multiple functions
cc.create_shared(path_dso, [path1, path2])
m = tvm.module.load(path_dso)
fadd1 = m['myadd1']
fadd2 = m['myadd2']
a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx)
fadd1(a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
fadd2(a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_llvm()
if __name__ == "__main__":
test_combine_module_llvm()
test_device_module_dump()
test_dso_module_load()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment