Commit 3bf72469 by Tianqi Chen Committed by GitHub

[BUILD/CODEGEN] Allow combine multiple functions in build stage. (#169)

* [BUILD/CODEGEN] Allow combine multiple functions in build stage.

* Enhance code module

* fix compile
parent 5912ed03
...@@ -37,6 +37,16 @@ class LoweredFunc : public FunctionRef { ...@@ -37,6 +37,16 @@ class LoweredFunc : public FunctionRef {
using ContainerType = LoweredFuncNode; using ContainerType = LoweredFuncNode;
}; };
/*! \brief specific type of lowered function */
enum LoweredFuncType : int {
/*! \brief Function that can mix device and host calls */
kMixedFunc = 0,
/*! \brief Only contains host code */
kHostFunc = 1,
/*! \brief Only contains device code */
kDeviceFunc = 2
};
/*! \brief Node container of LoweredFunc */ /*! \brief Node container of LoweredFunc */
class LoweredFuncNode : public FunctionBaseNode { class LoweredFuncNode : public FunctionBaseNode {
public: public:
...@@ -72,6 +82,8 @@ class LoweredFuncNode : public FunctionBaseNode { ...@@ -72,6 +82,8 @@ class LoweredFuncNode : public FunctionBaseNode {
* constant Expr of given type is used. * constant Expr of given type is used.
*/ */
Map<Var, Expr> handle_data_type; Map<Var, Expr> handle_data_type;
/*! \brief The type of the function */
LoweredFuncType func_type{kMixedFunc};
/*! \brief Whether this function is packed function */ /*! \brief Whether this function is packed function */
bool is_packed_func{true}; bool is_packed_func{true};
/*! \brief The body statment of the function */ /*! \brief The body statment of the function */
...@@ -90,6 +102,7 @@ class LoweredFuncNode : public FunctionBaseNode { ...@@ -90,6 +102,7 @@ class LoweredFuncNode : public FunctionBaseNode {
v->Visit("args", &args); v->Visit("args", &args);
v->Visit("thread_axis", &thread_axis); v->Visit("thread_axis", &thread_axis);
v->Visit("handle_data_type", &handle_data_type); v->Visit("handle_data_type", &handle_data_type);
v->Visit("func_type", &func_type);
v->Visit("is_packed_func", &is_packed_func); v->Visit("is_packed_func", &is_packed_func);
v->Visit("body", &body); v->Visit("body", &body);
} }
......
...@@ -29,13 +29,15 @@ ...@@ -29,13 +29,15 @@
#define TVM_DLL #define TVM_DLL
#endif #endif
#include <stdint.h>
#include <stddef.h>
// TVM Runtime is DLPack compatible. // TVM Runtime is DLPack compatible.
#include <dlpack/dlpack.h> #include <dlpack/dlpack.h>
#ifdef __cplusplus
TVM_EXTERN_C { TVM_EXTERN_C {
#endif
#include <stdint.h>
#include <stddef.h>
/*! \brief type of array index. */ /*! \brief type of array index. */
typedef int64_t tvm_index_t; typedef int64_t tvm_index_t;
...@@ -405,6 +407,7 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, ...@@ -405,6 +407,7 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream); TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
#ifdef __cplusplus
} // TVM_EXTERN_C } // TVM_EXTERN_C
#endif
#endif // TVM_RUNTIME_C_RUNTIME_API_H_ #endif // TVM_RUNTIME_C_RUNTIME_API_H_
...@@ -4,7 +4,7 @@ from __future__ import absolute_import ...@@ -4,7 +4,7 @@ from __future__ import absolute_import
import sys import sys
import ctypes import ctypes
from .base import _LIB, check_call, py_str, c_str, _FFI_MODE from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......
...@@ -220,30 +220,57 @@ def build(sch, ...@@ -220,30 +220,57 @@ def build(sch,
if isinstance(sch, schedule.Schedule): if isinstance(sch, schedule.Schedule):
if args is None: if args is None:
raise ValueError("args must be given for build from schedule") raise ValueError("args must be given for build from schedule")
fapi = lower(sch, args, flist = lower(sch, args,
name=name, name=name,
binds=binds) binds=binds)
if isinstance(flist, collections.LoweredFunc):
flist = [flist]
elif isinstance(sch, collections.LoweredFunc): elif isinstance(sch, collections.LoweredFunc):
if args: if args:
raise ValueError("args must be done when build from LoweredFunc") raise ValueError("args must be done when build from LoweredFunc")
fapi = sch flist = [sch]
elif isinstance(sch, (list, tuple, collections.Array)):
flist = sch
else: else:
raise ValueError("sch have to be Schedule or LoweredFunc") raise ValueError("sch have to be Schedule, LoweredFunc or list of LoweredFunc")
# device related lowering fname_set = set()
for x in flist:
if not isinstance(x, collections.LoweredFunc):
raise ValueError("sch have to be Schedule, LoweredFunc or list of LoweredFunc")
if x.name in fname_set:
raise ValueError("Duplicate function name %s" % x.name)
fhost = []
fdevice = []
for func in flist:
if func.func_type == collections.LoweredFunc.MixedFunc:
if BuildConfig.current.detect_global_barrier: if BuildConfig.current.detect_global_barrier:
fapi = ir_pass.StorageSync(fapi, "global") func = ir_pass.StorageSync(func, "global")
fapi = ir_pass.StorageSync(fapi, "shared") func = ir_pass.StorageSync(func, "shared")
warp_size = 32 if target == "cuda" else 1 warp_size = 32 if target == "cuda" else 1
fapi = ir_pass.LowerThreadAllreduce(fapi, warp_size) func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(fapi)] fsplits = [s for s in ir_pass.SplitHostDevice(func)]
fsplits[0] = ir_pass.LowerPackedCall(fsplits[0]) fhost.append(fsplits[0])
if len(fsplits) > 1: for x in fsplits[1:]:
fdevice.append(x)
elif func.func_type == collections.LoweredFunc.HostFunc:
fhost.append(func)
elif func.func_type == collections.LoweredFunc.DeviceFunc:
fdevice.append(func)
else:
raise ValueError("unknown function type %d" % func.func_type)
fhost = [ir_pass.LowerPackedCall(x) for x in fhost]
if not target.startswith("llvm") and target != "stackvm" and not fdevice:
raise ValueError(
"Specified target %s, but cannot find device code, did you do bind?" % target)
if fdevice:
if not target_host: if not target_host:
target_host = "llvm" if module.enabled("llvm") else "stackvm" target_host = "llvm" if module.enabled("llvm") else "stackvm"
mhost = codegen.build_module(fsplits[0], target_host) mhost = codegen.build_module(fhost, target_host)
if target: if target:
mdev = codegen.build_module(fsplits[1:], target) mdev = codegen.build_module(fdevice, target)
mhost.import_module(mdev) mhost.import_module(mdev)
return mhost return mhost
else: else:
return codegen.build_module(fsplits[0], target) return codegen.build_module(fhost, target)
...@@ -68,4 +68,6 @@ class Range(NodeBase): ...@@ -68,4 +68,6 @@ class Range(NodeBase):
@register_node @register_node
class LoweredFunc(NodeBase): class LoweredFunc(NodeBase):
"""Represent a LoweredFunc in TVM.""" """Represent a LoweredFunc in TVM."""
pass MixedFunc = 0
HostFunc = 1
DeviceFunc = 2
...@@ -96,7 +96,7 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { ...@@ -96,7 +96,7 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
void CodeGenLLVM::InitGlobalContext() { void CodeGenLLVM::InitGlobalContext() {
gv_mod_ctx_ = new llvm::GlobalVariable( gv_mod_ctx_ = new llvm::GlobalVariable(
*module_, t_void_p_, false, *module_, t_void_p_, false,
llvm::GlobalValue::LinkOnceODRLinkage, 0, "__tvm_module_ctx"); llvm::GlobalValue::LinkOnceAnyLinkage, 0, "__tvm_module_ctx");
gv_mod_ctx_->setAlignment(data_layout_->getTypeAllocSize(t_void_p_)); gv_mod_ctx_->setAlignment(data_layout_->getTypeAllocSize(t_void_p_));
gv_mod_ctx_->setInitializer(llvm::Constant::getNullValue(t_void_p_)); gv_mod_ctx_->setInitializer(llvm::Constant::getNullValue(t_void_p_));
} }
...@@ -142,21 +142,12 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) { ...@@ -142,21 +142,12 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) { void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
llvm::Function* f = module_->getFunction(entry_func_name); llvm::Function* f = module_->getFunction(entry_func_name);
CHECK(f) << "Function " << entry_func_name << "does not in module"; CHECK(f) << "Function " << entry_func_name << "does not in module";
CHECK(!module_->getFunction(runtime::symbol::tvm_module_main)); llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 1);
llvm::FunctionType* ftype = f->getFunctionType(); llvm::GlobalVariable *global = new llvm::GlobalVariable(
function_ = llvm::cast<llvm::Function>( *module_, type, true, llvm::GlobalValue::WeakAnyLinkage, 0,
module_->getOrInsertFunction(runtime::symbol::tvm_module_main, ftype)); runtime::symbol::tvm_module_main);
function_->setCallingConv(llvm::CallingConv::C); global->setAlignment(1);
std::vector<llvm::Value*> args; global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, entry_func_name));
for (auto it = function_->arg_begin();
it != function_->arg_end(); ++it) {
args.push_back(&(*it));
}
llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(block);
llvm::CallInst* call = builder_->CreateCall(f, args);
call->setTailCall(true);
builder_->CreateRet(call);
} }
class FPassManager : public llvm::legacy::FunctionPassManager { class FPassManager : public llvm::legacy::FunctionPassManager {
...@@ -424,7 +415,7 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) { ...@@ -424,7 +415,7 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
// create the function handle // create the function handle
hptr = new llvm::GlobalVariable( hptr = new llvm::GlobalVariable(
*module_, t_tvm_func_handle_, false, *module_, t_tvm_func_handle_, false,
llvm::GlobalValue::PrivateLinkage, 0, ".tvm_func"); llvm::GlobalValue::LinkOnceAnyLinkage, 0, ".tvm_func." + fname);
hptr->setAlignment(align); hptr->setAlignment(align);
hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_)); hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_));
func_handle_map_[fname] = hptr; func_handle_map_[fname] = hptr;
......
...@@ -36,8 +36,10 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -36,8 +36,10 @@ class LLVMModuleNode final : public runtime::ModuleNode {
void PreCompile(const std::string& name, TVMContext ctx) final { void PreCompile(const std::string& name, TVMContext ctx) final {
if (ee_ == nullptr) LazyInitJIT(); if (ee_ == nullptr) LazyInitJIT();
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
const std::string& fname = (name == runtime::symbol::tvm_module_main ?
entry_func_ : name);
BackendPackedCFunc faddr = BackendPackedCFunc faddr =
reinterpret_cast<BackendPackedCFunc>(ee_->getFunctionAddress(name)); reinterpret_cast<BackendPackedCFunc>(ee_->getFunctionAddress(fname));
CHECK(faddr != nullptr) CHECK(faddr != nullptr)
<< "Failed to Precompile function " << name; << "Failed to Precompile function " << name;
} }
...@@ -47,8 +49,10 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -47,8 +49,10 @@ class LLVMModuleNode final : public runtime::ModuleNode {
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const std::shared_ptr<ModuleNode>& sptr_to_self) final {
if (ee_ == nullptr) LazyInitJIT(); if (ee_ == nullptr) LazyInitJIT();
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
const std::string& fname = (name == runtime::symbol::tvm_module_main ?
entry_func_ : name);
BackendPackedCFunc faddr = BackendPackedCFunc faddr =
reinterpret_cast<BackendPackedCFunc>(ee_->getFunctionAddress(name)); reinterpret_cast<BackendPackedCFunc>(ee_->getFunctionAddress(fname));
if (faddr == nullptr) return PackedFunc(); if (faddr == nullptr) return PackedFunc();
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
int ret = (*faddr)( int ret = (*faddr)(
...@@ -103,6 +107,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -103,6 +107,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
CHECK_NE(funcs.size(), 0U); CHECK_NE(funcs.size(), 0U);
ctx_ = std::make_shared<llvm::LLVMContext>(); ctx_ = std::make_shared<llvm::LLVMContext>();
CodeGenLLVM cg; CodeGenLLVM cg;
entry_func_ = funcs[0]->name;
cg.Init(funcs[0]->name, tm_, ctx_.get()); cg.Init(funcs[0]->name, tm_, ctx_.get());
for (LoweredFunc f : funcs) { for (LoweredFunc f : funcs) {
cg.AddFunction(f); cg.AddFunction(f);
...@@ -147,6 +152,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -147,6 +152,8 @@ class LLVMModuleNode final : public runtime::ModuleNode {
} }
// The target configuration string // The target configuration string
std::string target_; std::string target_;
// Name of entry function.
std::string entry_func_;
// JIT lock // JIT lock
std::mutex mutex_; std::mutex mutex_;
// execution engine // execution engine
......
...@@ -54,5 +54,10 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) { ...@@ -54,5 +54,10 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
std::make_shared<SourceModuleNode>(code, fmt); std::make_shared<SourceModuleNode>(code, fmt);
return runtime::Module(n); return runtime::Module(n);
} }
TVM_REGISTER_GLOBAL("module.source_module_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = SourceModuleCreate(args[0], args[1]);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -280,6 +280,7 @@ class ThreadAllreduceBuilder : public IRMutator { ...@@ -280,6 +280,7 @@ class ThreadAllreduceBuilder : public IRMutator {
LoweredFunc LoweredFunc
LowerThreadAllreduce(LoweredFunc f, int warp_size) { LowerThreadAllreduce(LoweredFunc f, int warp_size) {
CHECK_NE(f->func_type, kHostFunc);
auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body); n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body);
return LoweredFunc(n); return LoweredFunc(n);
......
...@@ -155,6 +155,7 @@ class HostDeviceSplitter : public IRMutator { ...@@ -155,6 +155,7 @@ class HostDeviceSplitter : public IRMutator {
} }
Array<LoweredFunc> Split(LoweredFunc f) { Array<LoweredFunc> Split(LoweredFunc f) {
CHECK_EQ(f->func_type, kMixedFunc);
for (auto kv : f->handle_data_type) { for (auto kv : f->handle_data_type) {
handle_data_type_[kv.first.get()] = kv.second; handle_data_type_[kv.first.get()] = kv.second;
} }
...@@ -162,6 +163,7 @@ class HostDeviceSplitter : public IRMutator { ...@@ -162,6 +163,7 @@ class HostDeviceSplitter : public IRMutator {
std::shared_ptr<LoweredFuncNode> n = std::shared_ptr<LoweredFuncNode> n =
std::make_shared<LoweredFuncNode>(*f.operator->()); std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = this->Mutate(f->body); n->body = this->Mutate(f->body);
n->func_type = kHostFunc;
Array<LoweredFunc> ret{LoweredFunc(n)}; Array<LoweredFunc> ret{LoweredFunc(n)};
for (LoweredFunc x : device_funcs_) { for (LoweredFunc x : device_funcs_) {
ret.push_back(x); ret.push_back(x);
...@@ -179,6 +181,7 @@ class HostDeviceSplitter : public IRMutator { ...@@ -179,6 +181,7 @@ class HostDeviceSplitter : public IRMutator {
m.visit_thread_extent_ = false; m.visit_thread_extent_ = false;
n->body = m.Mutate(body); n->body = m.Mutate(body);
n->name = os.str(); n->name = os.str();
n->func_type = kDeviceFunc;
n->thread_axis = m.thread_axis_; n->thread_axis = m.thread_axis_;
// Strictly order the arguments: Var pointers, positional arguments. // Strictly order the arguments: Var pointers, positional arguments.
for (Var v : m.undefined_) { for (Var v : m.undefined_) {
......
...@@ -397,6 +397,7 @@ Stmt StorageSync(Stmt stmt, std::string storage_scope) { ...@@ -397,6 +397,7 @@ Stmt StorageSync(Stmt stmt, std::string storage_scope) {
} }
LoweredFunc StorageSync(LoweredFunc f, std::string storage_scope) { LoweredFunc StorageSync(LoweredFunc f, std::string storage_scope) {
CHECK_NE(f->func_type, kHostFunc);
auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = StorageSync(f->body, storage_scope); n->body = StorageSync(f->body, storage_scope);
return LoweredFunc(n); return LoweredFunc(n);
......
...@@ -62,12 +62,18 @@ class CUDAModuleNode : public runtime::ModuleNode { ...@@ -62,12 +62,18 @@ class CUDAModuleNode : public runtime::ModuleNode {
void SaveToFile(const std::string& file_name, void SaveToFile(const std::string& file_name,
const std::string& format) final { const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format); std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
if (fmt == "cu") {
CHECK_NE(cuda_source_.length(), 0);
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, cuda_source_);
} else {
CHECK_EQ(fmt, fmt_) CHECK_EQ(fmt, fmt_)
<< "Can only save to format=" << fmt_; << "Can only save to format=" << fmt_;
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_); SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, data_); SaveBinaryToFile(file_name, data_);
} }
}
void SaveToBinary(dmlc::Stream* stream) final { void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_); stream->Write(fmt_);
......
...@@ -101,6 +101,17 @@ class DSOModuleNode final : public ModuleNode { ...@@ -101,6 +101,17 @@ class DSOModuleNode final : public ModuleNode {
} }
private: private:
BackendPackedCFunc GetFuncPtr(const std::string& name) {
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
GetGlobalVPtr(runtime::symbol::tvm_module_main));
CHECK(entry_name!= nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
return GetFuncPtr_(entry_name);
} else {
return GetFuncPtr_(name);
}
}
// Platform dependent handling. // Platform dependent handling.
#if defined(_WIN32) #if defined(_WIN32)
// library handle // library handle
...@@ -111,7 +122,7 @@ class DSOModuleNode final : public ModuleNode { ...@@ -111,7 +122,7 @@ class DSOModuleNode final : public ModuleNode {
std::wstring wname(name.begin(), name.end()); std::wstring wname(name.begin(), name.end());
lib_handle_ = LoadLibraryW(wname.c_str()); lib_handle_ = LoadLibraryW(wname.c_str());
} }
BackendPackedCFunc GetFuncPtr(const std::string& name) { BackendPackedCFunc GetFuncPtr_(const std::string& name) {
return reinterpret_cast<BackendPackedCFunc>( return reinterpret_cast<BackendPackedCFunc>(
GetProcAddress(lib_handle_, (LPCSTR)name.c_str())); // NOLINT(*) GetProcAddress(lib_handle_, (LPCSTR)name.c_str())); // NOLINT(*)
} }
...@@ -129,7 +140,7 @@ class DSOModuleNode final : public ModuleNode { ...@@ -129,7 +140,7 @@ class DSOModuleNode final : public ModuleNode {
void Load(const std::string& name) { void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
} }
BackendPackedCFunc GetFuncPtr(const std::string& name) { BackendPackedCFunc GetFuncPtr_(const std::string& name) {
return reinterpret_cast<BackendPackedCFunc>( return reinterpret_cast<BackendPackedCFunc>(
dlsym(lib_handle_, name.c_str())); dlsym(lib_handle_, name.c_str()));
} }
......
...@@ -103,7 +103,42 @@ def test_llvm_temp_space(): ...@@ -103,7 +103,42 @@ def test_llvm_temp_space():
c.asnumpy(), a.asnumpy() + 1 + 1) c.asnumpy(), a.asnumpy() + 1 + 1)
check_llvm() check_llvm()
def test_multiple_func():
nn = 1024
n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=4)
s[C].parallel(xo)
s[C].vectorize(xi)
def check_llvm():
if not tvm.module.enabled("llvm"):
return
# build two functions
f2 = tvm.lower(s, [A, B, C], name="fadd1")
f1 = tvm.lower(s, [A, B, C], name="fadd2")
m = tvm.build([f1, f2], "llvm")
fadd1 = m['fadd1']
fadd2 = m['fadd2']
ctx = tvm.cpu(0)
# launch the kernel.
n = nn
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
fadd1(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
fadd2(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
check_llvm()
if __name__ == "__main__": if __name__ == "__main__":
test_multiple_func()
test_llvm_add_pipeline() test_llvm_add_pipeline()
test_llvm_flip_pipeline() test_llvm_flip_pipeline()
test_llvm_madd_pipeline() test_llvm_madd_pipeline()
......
...@@ -101,6 +101,44 @@ def test_device_module_dump(): ...@@ -101,6 +101,44 @@ def test_device_module_dump():
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
def test_combine_module_llvm():
"""Test combine multiple module into one shared lib."""
# graph
nn = 12
n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
def check_llvm():
ctx = tvm.cpu(0)
if not tvm.module.enabled("llvm"):
print("Skip because llvm is not enabled" )
return
temp = util.tempdir()
fadd1 = tvm.build(s, [A, B], "llvm", name="myadd1")
fadd2 = tvm.build(s, [A, B], "llvm", name="myadd2")
path1 = temp.relpath("myadd1.o")
path2 = temp.relpath("myadd2.o")
path_dso = temp.relpath("mylib.so")
fadd1.save(path1)
fadd2.save(path2)
# create shared library with multiple functions
cc.create_shared(path_dso, [path1, path2])
m = tvm.module.load(path_dso)
fadd1 = m['myadd1']
fadd2 = m['myadd2']
a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx)
fadd1(a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
fadd2(a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_llvm()
if __name__ == "__main__": if __name__ == "__main__":
test_combine_module_llvm()
test_device_module_dump() test_device_module_dump()
test_dso_module_load() test_dso_module_load()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment