Commit 72fcd4e6 by Tianqi Chen Committed by GitHub

[RUNTIME] Add System Lib (#227)

* [RUNTIME] Add System Lib

* lint

* lint

* fix compile
parent b759d0f3
...@@ -38,9 +38,10 @@ runtime::Module Build(const Array<LoweredFunc>& funcs, ...@@ -38,9 +38,10 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
* the dependency from the shared library. * the dependency from the shared library.
* *
* \param m The host module with the imports. * \param m The host module with the imports.
* \param system_lib Whether expose as system library.
* \return cstr The C string representation of the file. * \return cstr The C string representation of the file.
*/ */
std::string PackImportsToC(const runtime::Module& m); std::string PackImportsToC(const runtime::Module& m, bool system_lib);
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
......
...@@ -350,6 +350,17 @@ TVM_DLL int TVMFuncListGlobalNames(int *out_size, ...@@ -350,6 +350,17 @@ TVM_DLL int TVMFuncListGlobalNames(int *out_size,
TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name, const char* func_name,
TVMFunctionHandle *out); TVMFunctionHandle *out);
/*!
* \brief Backend function to register system-wide library symbol.
*
* \note This API is supposed to be used by backend,
* it is not supposed to be used by user.
*
* \param name The name of the symbol
* \param ptr The symbol address.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr);
/*! /*!
* \brief Backend function for running parallel for loop. * \brief Backend function for running parallel for loop.
......
...@@ -94,7 +94,7 @@ class Registry { ...@@ -94,7 +94,7 @@ class Registry {
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) #define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_FUNC_REG_VAR_DEF \ #define TVM_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __make_ ## TVMOp static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM
/*! /*!
* \brief Register a function globally. * \brief Register a function globally.
......
...@@ -223,23 +223,31 @@ def build(sch, ...@@ -223,23 +223,31 @@ def build(sch,
The target and option of the compilation. The target and option of the compilation.
When the target is llvm, you can set options like: When the target is llvm, you can set options like:
* **-mtriple=<target triple>** or **-target** - **-mtriple=<target triple>** or **-target**
Specify the target triple, which is useful for cross Specify the target triple, which is useful for cross
compilation. compilation.
* **-mcpu=<cpuname>** - **-mcpu=<cpuname>**
Specify a specific chip in the current architecture to Specify a specific chip in the current architecture to
generate code for. By default this is infered from the generate code for. By default this is infered from the
target triple and autodetected to the current architecture. target triple and autodetected to the current architecture.
* **-mattr=a1,+a2,-a3,...** - **-mattr=a1,+a2,-a3,...**
Override or control specific attributes of the target, Override or control specific attributes of the target,
such as whether SIMD operations are enabled or not. The such as whether SIMD operations are enabled or not. The
default set of attributes is set by the current CPU. default set of attributes is set by the current CPU.
- **-system-lib**
Build TVM system library module. System lib is a global module that contains
self registered functions in program startup. User can get the module using
:any:`tvm.module.system_lib`.
It is useful in environments where dynamic loading api like dlopen is banned.
The system lib will be available as long as the result code is linked by the program.
target_host : str, optional target_host : str, optional
Host compilation target, if target is device. Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA, When TVM compiles device specific program such as CUDA,
......
...@@ -25,7 +25,7 @@ def create_shared(output, ...@@ -25,7 +25,7 @@ def create_shared(output,
The compile string. The compile string.
""" """
cmd = [cc] cmd = [cc]
cmd += ["-shared"] cmd += ["-shared", "-fPIC"]
if sys.platform == "darwin": if sys.platform == "darwin":
cmd += ["-undefined", "dynamic_lookup"] cmd += ["-undefined", "dynamic_lookup"]
......
...@@ -71,16 +71,26 @@ class Module(ModuleBase): ...@@ -71,16 +71,26 @@ class Module(ModuleBase):
file_name : str file_name : str
The name of the shared library. The name of the shared library.
""" """
if self.type_key == "stacktvm":
raise ValueError("Module[%s]: export_library requires llvm module,"
" did you build with LLVM enabled?" % self.type_key)
if self.type_key != "llvm": if self.type_key != "llvm":
raise ValueError("Module[%s]: Only llvm support export shared" % self.type_key) raise ValueError("Module[%s]: Only llvm support export shared" % self.type_key)
temp = _util.tempdir() temp = _util.tempdir()
path_obj = temp.relpath("lib.o") path_obj = temp.relpath("lib.o")
self.save(path_obj) self.save(path_obj)
files = [path_obj] files = [path_obj]
try:
self.get_function("__tvm_module_startup")
is_system_lib = True
except AttributeError:
is_system_lib = False
if self.imported_modules: if self.imported_modules:
path_cc = temp.relpath("devc.cc") path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f: with open(path_cc, "w") as f:
f.write(_PackImportsToC(self)) f.write(_PackImportsToC(self, is_system_lib))
files.append(path_cc) files.append(path_cc)
_cc.create_shared(file_name, files) _cc.create_shared(file_name, files)
...@@ -116,6 +126,27 @@ class Module(ModuleBase): ...@@ -116,6 +126,27 @@ class Module(ModuleBase):
raise NameError("time_evaluate is only supported when RPC is enabled") raise NameError("time_evaluate is only supported when RPC is enabled")
def system_lib():
"""Get system-wide library module singleton.
System lib is a global module that contains self register functions in startup.
Unlike normal dso modules which need to be loaded explicitly.
It is useful in environments where dynamic loading api like dlopen is banned.
To build system lib function, simply specify target option ```llvm --system-lib```
The system lib will be available as long as the result code is linked by the program.
The system lib is intended to be linked and loaded during the entire life-cyle of the program.
If you want dynamic loading features, use dso modules instead.
Returns
-------
module : Module
The system-wide library module.
"""
return _GetSystemLib()
def load(path, fmt=""): def load(path, fmt=""):
"""Load module from file """Load module from file
......
...@@ -23,7 +23,7 @@ TVM_REGISTER_API("codegen._Build") ...@@ -23,7 +23,7 @@ TVM_REGISTER_API("codegen._Build")
TVM_REGISTER_API("module._PackImportsToC") TVM_REGISTER_API("module._PackImportsToC")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = PackImportsToC(args[0]); *ret = PackImportsToC(args[0], args[1]);
}); });
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -35,7 +35,7 @@ runtime::Module Build(const Array<LoweredFunc>& funcs, ...@@ -35,7 +35,7 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
return m; return m;
} }
std::string PackImportsToC(const runtime::Module& mod) { std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
std::string bin; std::string bin;
dmlc::MemoryStringStream ms(&bin); dmlc::MemoryStringStream ms(&bin);
dmlc::Stream* stream = &ms; dmlc::Stream* stream = &ms;
...@@ -55,29 +55,39 @@ std::string PackImportsToC(const runtime::Module& mod) { ...@@ -55,29 +55,39 @@ std::string PackImportsToC(const runtime::Module& mod) {
<< "extern \"C\" {\n" << "extern \"C\" {\n"
<< "#endif\n"; << "#endif\n";
os << "extern const char " << runtime::symbol::tvm_dev_mblob << "[];\n"; os << "extern const char " << runtime::symbol::tvm_dev_mblob << "[];\n";
os << "extern const unsigned long " << runtime::symbol::tvm_dev_mblob_nbytes << ";\n"; uint64_t nbytes = bin.length();
os << "const char " << runtime::symbol::tvm_dev_mblob os << "const char " << runtime::symbol::tvm_dev_mblob
<< "[" << bin.length() << "] = {\n "; << "[" << bin.length() + sizeof(nbytes) << "] = {\n ";
os << std::hex; os << std::hex;
size_t nunit = 80 / 4; size_t nunit = 80 / 4;
for (size_t i = 0; i < bin.length(); ++i) { for (size_t i = 0; i < sizeof(nbytes); ++i) {
// sperators // sperators
if (i != 0) { if (i != 0) {
if (i % nunit == 0) { os << ",";
os << ",\n "; }
} else { os << "0x" << ((nbytes >> (i * 8)) & 0xffUL);
os << ","; }
} for (size_t i = 0; i < bin.length(); ++i) {
// sperators
if ((i + sizeof(nbytes)) % nunit == 0) {
os << ",\n ";
} else {
os << ",";
} }
int c = bin[i]; int c = bin[i];
os << "0x" << (c & 0xff); os << "0x" << (c & 0xff);
} }
os << "\n};\n" os << "\n};\n";
<< "const unsigned long " << runtime::symbol::tvm_dev_mblob_nbytes if (system_lib) {
<< " = " << std::dec << bin.length() << "UL;\n" os << "extern int TVMBackendRegisterSystemLibSymbol(const char*, void*);\n";
<< "#ifdef __cplusplus\n" os << "static int " << runtime::symbol::tvm_dev_mblob << "_reg_ = "
<< "TVMBackendRegisterSystemLibSymbol(\"" << runtime::symbol::tvm_dev_mblob << "\", (void*)"
<< runtime::symbol::tvm_dev_mblob << ");\n";
}
os << "#ifdef __cplusplus\n"
<< "}\n" << "}\n"
<< "#endif\n"; << "#endif\n";
return os.str(); return os.str();
} }
} // namespace codegen } // namespace codegen
......
...@@ -28,7 +28,8 @@ std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) { ...@@ -28,7 +28,8 @@ std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) {
void CodeGenLLVM::Init(const std::string& module_name, void CodeGenLLVM::Init(const std::string& module_name,
llvm::TargetMachine* tm, llvm::TargetMachine* tm,
llvm::LLVMContext* ctx) { llvm::LLVMContext* ctx,
bool system_lib) {
InitializeLLVM(); InitializeLLVM();
static_assert(sizeof(TVMValue) == sizeof(double), "invariant"); static_assert(sizeof(TVMValue) == sizeof(double), "invariant");
// static_assert(alignof(TVMValue) == alignof(double), "invariant"); // static_assert(alignof(TVMValue) == alignof(double), "invariant");
...@@ -36,6 +37,7 @@ void CodeGenLLVM::Init(const std::string& module_name, ...@@ -36,6 +37,7 @@ void CodeGenLLVM::Init(const std::string& module_name,
var_map_.clear(); var_map_.clear();
str_map_.clear(); str_map_.clear();
func_handle_map_.clear(); func_handle_map_.clear();
export_system_symbols_.clear();
// initialize types. // initialize types.
if (ctx_ != ctx) { if (ctx_ != ctx) {
t_void_ = llvm::Type::getVoidTy(*ctx); t_void_ = llvm::Type::getVoidTy(*ctx);
...@@ -96,6 +98,13 @@ void CodeGenLLVM::Init(const std::string& module_name, ...@@ -96,6 +98,13 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_int64_, t_int64_, t_f_tvm_par_for_lambda_->getPointerTo(), t_void_p_} t_int64_, t_int64_, t_f_tvm_par_for_lambda_->getPointerTo(), t_void_p_}
, false), , false),
llvm::Function::ExternalLinkage, "TVMBackendParallelFor", module_.get()); llvm::Function::ExternalLinkage, "TVMBackendParallelFor", module_.get());
if (system_lib) {
f_tvm_register_system_symbol_ = llvm::Function::Create(
llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, false),
llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", module_.get());
} else {
f_tvm_register_system_symbol_ = nullptr;
}
this->InitTarget(tm); this->InitTarget(tm);
// initialize builder // initialize builder
builder_.reset(new IRBuilder(*ctx)); builder_.reset(new IRBuilder(*ctx));
...@@ -125,9 +134,15 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { ...@@ -125,9 +134,15 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
void CodeGenLLVM::InitGlobalContext() { void CodeGenLLVM::InitGlobalContext() {
gv_mod_ctx_ = new llvm::GlobalVariable( gv_mod_ctx_ = new llvm::GlobalVariable(
*module_, t_void_p_, false, *module_, t_void_p_, false,
llvm::GlobalValue::LinkOnceAnyLinkage, 0, "__tvm_module_ctx"); llvm::GlobalValue::LinkOnceAnyLinkage, 0,
tvm::runtime::symbol::tvm_module_ctx);
gv_mod_ctx_->setAlignment(data_layout_->getTypeAllocSize(t_void_p_)); gv_mod_ctx_->setAlignment(data_layout_->getTypeAllocSize(t_void_p_));
gv_mod_ctx_->setInitializer(llvm::Constant::getNullValue(t_void_p_)); gv_mod_ctx_->setInitializer(llvm::Constant::getNullValue(t_void_p_));
if (f_tvm_register_system_symbol_ != nullptr) {
export_system_symbols_.emplace_back(
std::make_pair(tvm::runtime::symbol::tvm_module_ctx, gv_mod_ctx_));
}
} }
void CodeGenLLVM::InitFuncState() { void CodeGenLLVM::InitFuncState() {
...@@ -171,6 +186,11 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) { ...@@ -171,6 +186,11 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
builder_->SetInsertPoint(block); builder_->SetInsertPoint(block);
this->VisitStmt(f->body); this->VisitStmt(f->body);
builder_->CreateRet(ConstInt32(0)); builder_->CreateRet(ConstInt32(0));
if (f_tvm_register_system_symbol_ != nullptr) {
export_system_symbols_.emplace_back(
std::make_pair(f->name, builder_->CreatePointerCast(function_, t_void_p_)));
}
} }
void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) { void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
...@@ -225,13 +245,35 @@ void CodeGenLLVM::Optimize() { ...@@ -225,13 +245,35 @@ void CodeGenLLVM::Optimize() {
} }
std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() { std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
this->AddStartupFunction();
this->Optimize(); this->Optimize();
var_map_.clear(); var_map_.clear();
str_map_.clear(); str_map_.clear();
func_handle_map_.clear(); func_handle_map_.clear();
export_system_symbols_.clear();
return std::move(module_); return std::move(module_);
} }
void CodeGenLLVM::AddStartupFunction() {
if (export_system_symbols_.size() != 0) {
llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_, {}, false);
function_ = llvm::Function::Create(
ftype,
llvm::Function::InternalLinkage,
"__tvm_module_startup", module_.get());
llvm::BasicBlock* startup_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(startup_entry);
for (const auto& kv : export_system_symbols_) {
llvm::Value* name = GetConstString(kv.first);
builder_->CreateCall(
f_tvm_register_system_symbol_, {
name, builder_->CreateBitCast(kv.second, t_void_p_)});
}
llvm::appendToGlobalCtors(*module_, function_, 65535);
builder_->CreateRet(nullptr);
}
}
llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const { llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
llvm::Type* ret = nullptr; llvm::Type* ret = nullptr;
if (t.is_uint() || t.is_int()) { if (t.is_uint() || t.is_int()) {
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <memory> #include <memory>
#include <utility>
#include <vector> #include <vector>
#include <string> #include <string>
#include "./llvm_common.h" #include "./llvm_common.h"
...@@ -39,10 +40,12 @@ class CodeGenLLVM : ...@@ -39,10 +40,12 @@ class CodeGenLLVM :
* \param module_name The name of the module. * \param module_name The name of the module.
* \param tm Target machine model * \param tm Target machine model
* \param ctx The context. * \param ctx The context.
* \param system_lib Whether to insert system library registration.
*/ */
void Init(const std::string& module_name, void Init(const std::string& module_name,
llvm::TargetMachine* tm, llvm::TargetMachine* tm,
llvm::LLVMContext* ctx); llvm::LLVMContext* ctx,
bool system_lib);
/*! /*!
* \brief Compile and add function f to the current module. * \brief Compile and add function f to the current module.
* \param f The function to be added. * \param f The function to be added.
...@@ -163,7 +166,7 @@ class CodeGenLLVM : ...@@ -163,7 +166,7 @@ class CodeGenLLVM :
llvm::LLVMContext* ctx_{nullptr}; llvm::LLVMContext* ctx_{nullptr};
// helpful data types // helpful data types
llvm::Type* t_void_{nullptr}; llvm::Type* t_void_{nullptr};
llvm::Type* t_void_p_{nullptr}; llvm::PointerType* t_void_p_{nullptr};
llvm::Type* t_int_{nullptr}; llvm::Type* t_int_{nullptr};
llvm::Type* t_char_{nullptr}; llvm::Type* t_char_{nullptr};
llvm::Type* t_int8_{nullptr}; llvm::Type* t_int8_{nullptr};
...@@ -188,6 +191,7 @@ class CodeGenLLVM : ...@@ -188,6 +191,7 @@ class CodeGenLLVM :
llvm::Function* f_tvm_get_func_from_env_{nullptr}; llvm::Function* f_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr}; llvm::Function* f_tvm_api_set_last_error_{nullptr};
llvm::Function* f_tvm_parallel_for_{nullptr}; llvm::Function* f_tvm_parallel_for_{nullptr};
llvm::Function* f_tvm_register_system_symbol_{nullptr};
// The acting body // The acting body
llvm::BasicBlock* block_{nullptr}; llvm::BasicBlock* block_{nullptr};
/*! \brief native vector bits of current targetx*/ /*! \brief native vector bits of current targetx*/
...@@ -227,6 +231,8 @@ class CodeGenLLVM : ...@@ -227,6 +231,8 @@ class CodeGenLLVM :
llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode); llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode);
// Add a function to set global module context // Add a function to set global module context
void InitGlobalContext(); void InitGlobalContext();
// Add module startup function if needed.
void AddStartupFunction();
// add alias information. // add alias information.
void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, Type type); void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, Type type);
// The definition of local variable. // The definition of local variable.
...@@ -243,6 +249,8 @@ class CodeGenLLVM : ...@@ -243,6 +249,8 @@ class CodeGenLLVM :
llvm::GlobalVariable* gv_mod_ctx_{nullptr}; llvm::GlobalVariable* gv_mod_ctx_{nullptr};
// global to packed function handle // global to packed function handle
std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_; std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_;
// List of symbols to be exported to TVM system lib.
std::vector<std::pair<std::string, llvm::Value*> > export_system_symbols_;
}; };
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
......
...@@ -51,6 +51,9 @@ GetLLVMTargetMachine(const std::string& target_str) { ...@@ -51,6 +51,9 @@ GetLLVMTargetMachine(const std::string& target_str) {
if (target_str.length() > 5) { if (target_str.length() > 5) {
std::istringstream is(target_str.substr(5, target_str.length() - 5)); std::istringstream is(target_str.substr(5, target_str.length() - 5));
while (is >> key) { while (is >> key) {
if (key == "--system-lib" || key == "-system-lib") {
continue;
}
size_t pos = key.find('='); size_t pos = key.find('=');
if (pos != std::string::npos) { if (pos != std::string::npos) {
CHECK_GE(key.length(), pos + 1) CHECK_GE(key.length(), pos + 1)
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <llvm/IR/MDBuilder.h> #include <llvm/IR/MDBuilder.h>
#include <llvm/IR/LegacyPassManager.h> #include <llvm/IR/LegacyPassManager.h>
#include <llvm/Transforms/Utils/ModuleUtils.h>
#include <llvm/Transforms/IPO/PassManagerBuilder.h> #include <llvm/Transforms/IPO/PassManagerBuilder.h>
#include <llvm/Transforms/IPO.h> #include <llvm/Transforms/IPO.h>
...@@ -54,6 +55,7 @@ void InitializeLLVM(); ...@@ -54,6 +55,7 @@ void InitializeLLVM();
/*! /*!
* \brief Get target machine from target_str string. * \brief Get target machine from target_str string.
* \param target_str Target string, in format "llvm -target=xxx -mcpu=xxx" * \param target_str Target string, in format "llvm -target=xxx -mcpu=xxx"
*
* \return target machine * \return target machine
*/ */
llvm::TargetMachine* llvm::TargetMachine*
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "./llvm_common.h" #include "./llvm_common.h"
#include "./codegen_llvm.h" #include "./codegen_llvm.h"
#include "../../runtime/file_util.h" #include "../../runtime/file_util.h"
#include "../../runtime/meta_data.h" #include "../../runtime/module_util.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -99,11 +99,12 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -99,11 +99,12 @@ class LLVMModuleNode final : public runtime::ModuleNode {
void Init(const Array<LoweredFunc>& funcs, std::string target) { void Init(const Array<LoweredFunc>& funcs, std::string target) {
InitializeLLVM(); InitializeLLVM();
tm_ = GetLLVMTargetMachine(target); tm_ = GetLLVMTargetMachine(target);
bool system_lib = (target.find("-system-lib") != std::string::npos);
CHECK_NE(funcs.size(), 0U); CHECK_NE(funcs.size(), 0U);
ctx_ = std::make_shared<llvm::LLVMContext>(); ctx_ = std::make_shared<llvm::LLVMContext>();
std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_); std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_);
entry_func_ = funcs[0]->name; entry_func_ = funcs[0]->name;
cg->Init(funcs[0]->name, tm_, ctx_.get()); cg->Init(funcs[0]->name, tm_, ctx_.get(), system_lib);
for (LoweredFunc f : funcs) { for (LoweredFunc f : funcs) {
cg->AddFunction(f); cg->AddFunction(f);
} }
......
...@@ -113,7 +113,7 @@ class CandidateSelector final : public IRVisitor { ...@@ -113,7 +113,7 @@ class CandidateSelector final : public IRVisitor {
std::unordered_set<const Node*> candidates; std::unordered_set<const Node*> candidates;
private: private:
bool in_likely_; bool in_likely_{false};
bool no_split_{false}; bool no_split_{false};
std::unordered_map<const Variable*, VarIsUsed> record_; std::unordered_map<const Variable*, VarIsUsed> record_;
}; };
......
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file dso_module.cc * \file dso_dll_module.cc
* \brief Module to load from dynamic shared library. * \brief Module to load from dynamic shared library.
*/ */
#include <dmlc/memory_io.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include "./meta_data.h" #include "./module_util.h"
#if defined(_WIN32) #if defined(_WIN32)
#include <windows.h> #include <windows.h>
...@@ -19,7 +18,7 @@ namespace tvm { ...@@ -19,7 +18,7 @@ namespace tvm {
namespace runtime { namespace runtime {
// Module to load from dynamic shared libary. // Module to load from dynamic shared libary.
// This is the default module TVM used for hostside AOT // This is the default module TVM used for host-side AOT
class DSOModuleNode final : public ModuleNode { class DSOModuleNode final : public ModuleNode {
public: public:
~DSOModuleNode() { ~DSOModuleNode() {
...@@ -33,68 +32,38 @@ class DSOModuleNode final : public ModuleNode { ...@@ -33,68 +32,38 @@ class DSOModuleNode final : public ModuleNode {
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const std::shared_ptr<ModuleNode>& sptr_to_self) final {
BackendPackedCFunc faddr = GetFuncPtr(name); BackendPackedCFunc faddr;
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::tvm_module_main));
CHECK(entry_name!= nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name));
} else {
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name));
}
if (faddr == nullptr) return PackedFunc(); if (faddr == nullptr) return PackedFunc();
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { return WrapPackedFunc(faddr, sptr_to_self);
int ret = (*faddr)(
(void*)args.values, // NOLINT(*)
(int*)args.type_codes, // NOLINT(*)
args.num_args);
CHECK_EQ(ret, 0) << TVMGetLastError();
});
} }
void Init(const std::string& name) { void Init(const std::string& name) {
Load(name); Load(name);
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name;
void** ctx_addr = void** ctx_addr =
reinterpret_cast<void**>( reinterpret_cast<void**>(
GetGlobalVPtr(runtime::symbol::tvm_module_ctx)); GetSymbol(runtime::symbol::tvm_module_ctx));
if (ctx_addr != nullptr) { if (ctx_addr != nullptr) {
*ctx_addr = this; *ctx_addr = this;
} }
// Load the imported modules // Load the imported modules
const char* dev_mblob = const char* dev_mblob =
reinterpret_cast<const char*>( reinterpret_cast<const char*>(
GetGlobalVPtr(runtime::symbol::tvm_dev_mblob)); GetSymbol(runtime::symbol::tvm_dev_mblob));
const unsigned long* dev_mblob_nbytes = // NOLINT(*)
reinterpret_cast<const unsigned long*>( // NOLINT(*)
GetGlobalVPtr(runtime::symbol::tvm_dev_mblob_nbytes));
if (dev_mblob != nullptr) { if (dev_mblob != nullptr) {
CHECK(dev_mblob_nbytes != nullptr); ImportModuleBlob(dev_mblob, &imports_);
dmlc::MemoryFixedSizeStream fs(
(void*)dev_mblob, dev_mblob_nbytes[0]); // NOLINT(*)
dmlc::Stream* stream = &fs;
uint64_t size;
CHECK(stream->Read(&size));
for (uint64_t i = 0; i < size; ++i) {
std::string tkey;
CHECK(stream->Read(&tkey));
std::string fkey = "module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr)
<< "Loader of " << tkey << "("
<< fkey << ") is not presented.";
Module m = (*f)(static_cast<void*>(stream));
this->imports_.push_back(m);
}
} }
} }
private: private:
BackendPackedCFunc GetFuncPtr(const std::string& name) {
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
GetGlobalVPtr(runtime::symbol::tvm_module_main));
CHECK(entry_name!= nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
return GetFuncPtr_(entry_name);
} else {
return GetFuncPtr_(name);
}
}
// Platform dependent handling. // Platform dependent handling.
#if defined(_WIN32) #if defined(_WIN32)
// library handle // library handle
...@@ -104,14 +73,12 @@ class DSOModuleNode final : public ModuleNode { ...@@ -104,14 +73,12 @@ class DSOModuleNode final : public ModuleNode {
// use wstring version that is needed by LLVM. // use wstring version that is needed by LLVM.
std::wstring wname(name.begin(), name.end()); std::wstring wname(name.begin(), name.end());
lib_handle_ = LoadLibraryW(wname.c_str()); lib_handle_ = LoadLibraryW(wname.c_str());
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name;
} }
BackendPackedCFunc GetFuncPtr_(const std::string& name) { void* GetSymbol(const std::string& name) {
return reinterpret_cast<BackendPackedCFunc>(
GetProcAddress(lib_handle_, (LPCSTR)name.c_str())); // NOLINT(*)
}
void* GetGlobalVPtr(const std::string& name) {
return reinterpret_cast<void*>( return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, name.c_str())); // NOLINT(*) GetProcAddress(lib_handle_, (LPCSTR)name.c_str())); // NOLINT(*)
} }
void Unload() { void Unload() {
FreeLibrary(lib_handle_); FreeLibrary(lib_handle_);
...@@ -122,12 +89,10 @@ class DSOModuleNode final : public ModuleNode { ...@@ -122,12 +89,10 @@ class DSOModuleNode final : public ModuleNode {
// load the library // load the library
void Load(const std::string& name) { void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name;
} }
BackendPackedCFunc GetFuncPtr_(const std::string& name) { void* GetSymbol(const std::string& name) {
return reinterpret_cast<BackendPackedCFunc>(
dlsym(lib_handle_, name.c_str()));
}
void* GetGlobalVPtr(const std::string& name) {
return dlsym(lib_handle_, name.c_str()); return dlsym(lib_handle_, name.c_str());
} }
void Unload() { void Unload() {
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
*/ */
#include <dmlc/json.h> #include <dmlc/json.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <tvm/runtime/packed_func.h>
#include <fstream> #include <fstream>
#include "./file_util.h" #include "./file_util.h"
namespace tvm { namespace tvm {
......
...@@ -7,17 +7,12 @@ ...@@ -7,17 +7,12 @@
#define TVM_RUNTIME_META_DATA_H_ #define TVM_RUNTIME_META_DATA_H_
#include <dmlc/json.h> #include <dmlc/json.h>
#include <dmlc/io.h>
#include <tvm/runtime/packed_func.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include "./runtime_base.h" #include "./runtime_base.h"
extern "C" {
// Function signature for generated packed function in shared library
typedef int (*BackendPackedCFunc)(void* args,
int* type_codes,
int num_args);
} // extern "C"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
......
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file module.cc * \file module.cc
* \brief The global registry of packed function. * \brief TVM module system
*/ */
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <unordered_set> #include <unordered_set>
#include "./file_util.h" #include "./file_util.h"
#include "./meta_data.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
......
/*!
* Copyright (c) 2017 by Contributors
* \file module_util.cc
* \brief Utilities for module.
*/
#include <dmlc/memory_io.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include "./module_util.h"
namespace tvm {
namespace runtime {
void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
CHECK(mblob != nullptr);
uint64_t nbytes = 0;
for (size_t i = 0; i < sizeof(nbytes); ++i) {
uint64_t c = mblob[i];
nbytes |= (c & 0xffUL) << (i * 8);
}
dmlc::MemoryFixedSizeStream fs(
const_cast<char*>(mblob + sizeof(nbytes)), nbytes);
dmlc::Stream* stream = &fs;
uint64_t size;
CHECK(stream->Read(&size));
for (uint64_t i = 0; i < size; ++i) {
std::string tkey;
CHECK(stream->Read(&tkey));
std::string fkey = "module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr)
<< "Loader of " << tkey << "("
<< fkey << ") is not presented.";
Module m = (*f)(static_cast<void*>(stream));
mlist->push_back(m);
}
}
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
int ret = (*faddr)(
const_cast<TVMValue*>(args.values),
const_cast<int*>(args.type_codes),
args.num_args);
CHECK_EQ(ret, 0) << TVMGetLastError();
});
}
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file module_util.h
* \brief Helper utilities for module building
*/
#ifndef TVM_RUNTIME_MODULE_UTIL_H_
#define TVM_RUNTIME_MODULE_UTIL_H_
#include <tvm/runtime/module.h>
#include <vector>
extern "C" {
// Function signature for generated packed function in shared library
typedef int (*BackendPackedCFunc)(void* args,
int* type_codes,
int num_args);
} // extern "C"
namespace tvm {
namespace runtime {
/*!
* \brief Wrap a BackendPackedCFunc to packed function.
* \param faddr The function address
* \param mptr The module pointer node.
*/
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr<ModuleNode>& mptr);
/*!
* \brief Load and append module blob to module list
* \param mblob The module blob.
* \param module_list The module list to append to
*/
void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_MODULE_UTIL_H_
...@@ -153,7 +153,7 @@ class OpenCLWorkspace final : public DeviceAPI { ...@@ -153,7 +153,7 @@ class OpenCLWorkspace final : public DeviceAPI {
TVMStreamHandle stream) final; TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
// get the global workspace // get the global workspace
static OpenCLWorkspace* Global(); static const std::shared_ptr<OpenCLWorkspace>& Global();
}; };
......
...@@ -13,9 +13,9 @@ namespace tvm { ...@@ -13,9 +13,9 @@ namespace tvm {
namespace runtime { namespace runtime {
namespace cl { namespace cl {
OpenCLWorkspace* OpenCLWorkspace::Global() { const std::shared_ptr<OpenCLWorkspace>& OpenCLWorkspace::Global() {
static OpenCLWorkspace inst; static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<OpenCLWorkspace>();
return &inst; return inst;
} }
void OpenCLWorkspace::SetDevice(TVMContext ctx) { void OpenCLWorkspace::SetDevice(TVMContext ctx) {
...@@ -210,14 +210,13 @@ void OpenCLWorkspace::Init() { ...@@ -210,14 +210,13 @@ void OpenCLWorkspace::Init() {
} }
bool InitOpenCL(TVMArgs args, TVMRetValue* rv) { bool InitOpenCL(TVMArgs args, TVMRetValue* rv) {
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global(); cl::OpenCLWorkspace::Global()->Init();
w->Init();
return true; return true;
} }
TVM_REGISTER_GLOBAL("device_api.opencl") TVM_REGISTER_GLOBAL("device_api.opencl")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = OpenCLWorkspace::Global(); DeviceAPI* ptr = OpenCLWorkspace::Global().get();
*rv = static_cast<void*>(ptr); *rv = static_cast<void*>(ptr);
}); });
......
...@@ -40,10 +40,9 @@ class OpenCLModuleNode : public ModuleNode { ...@@ -40,10 +40,9 @@ class OpenCLModuleNode : public ModuleNode {
~OpenCLModuleNode() { ~OpenCLModuleNode() {
{ {
// free the kernel ids in global table. // free the kernel ids in global table.
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global(); std::lock_guard<std::mutex> lock(workspace_->mu);
std::lock_guard<std::mutex> lock(w->mu);
for (auto& kv : kid_map_) { for (auto& kv : kid_map_) {
w->free_kernel_ids.push_back(kv.second.kernel_id); workspace_->free_kernel_ids.push_back(kv.second.kernel_id);
} }
} }
// free the kernels // free the kernels
...@@ -89,33 +88,33 @@ class OpenCLModuleNode : public ModuleNode { ...@@ -89,33 +88,33 @@ class OpenCLModuleNode : public ModuleNode {
} }
// Initialize the programs // Initialize the programs
void InitProgram() { void Init() {
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global(); workspace_ = cl::OpenCLWorkspace::Global();
w->Init(); workspace_->Init();
CHECK(w->context != nullptr) << "No OpenCL device"; CHECK(workspace_->context != nullptr) << "No OpenCL device";
if (fmt_ == "cl") { if (fmt_ == "cl") {
const char* s = data_.c_str(); const char* s = data_.c_str();
size_t len = data_.length(); size_t len = data_.length();
cl_int err; cl_int err;
program_ = clCreateProgramWithSource( program_ = clCreateProgramWithSource(
w->context, 1, &s, &len, &err); workspace_->context, 1, &s, &len, &err);
OPENCL_CHECK_ERROR(err); OPENCL_CHECK_ERROR(err);
} else { } else {
LOG(FATAL) << "Unknown OpenCL format " << fmt_; LOG(FATAL) << "Unknown OpenCL format " << fmt_;
} }
device_built_flag_.resize(w->devices.size(), false); device_built_flag_.resize(workspace_->devices.size(), false);
// initialize the kernel id, need to lock global table. // initialize the kernel id, need to lock global table.
std::lock_guard<std::mutex> lock(w->mu); std::lock_guard<std::mutex> lock(workspace_->mu);
for (const auto& kv : fmap_) { for (const auto& kv : fmap_) {
const std::string& key = kv.first; const std::string& key = kv.first;
KTRefEntry e; KTRefEntry e;
if (w->free_kernel_ids.size() != 0) { if (workspace_->free_kernel_ids.size() != 0) {
e.kernel_id = w->free_kernel_ids.back(); e.kernel_id = workspace_->free_kernel_ids.back();
w->free_kernel_ids.pop_back(); workspace_->free_kernel_ids.pop_back();
} else { } else {
e.kernel_id = w->num_registered_kernels++; e.kernel_id = workspace_->num_registered_kernels++;
} }
e.version = w->timestamp++; e.version = workspace_->timestamp++;
kid_map_[key] = e; kid_map_[key] = e;
} }
} }
...@@ -154,6 +153,9 @@ class OpenCLModuleNode : public ModuleNode { ...@@ -154,6 +153,9 @@ class OpenCLModuleNode : public ModuleNode {
} }
private: private:
// The workspace, need to keep reference to use it in destructor.
// In case of static destruction order problem.
std::shared_ptr<cl::OpenCLWorkspace> workspace_;
// the binary data // the binary data
std::string data_; std::string data_;
// The format // The format
...@@ -181,7 +183,7 @@ class OpenCLWrappedFunc { ...@@ -181,7 +183,7 @@ class OpenCLWrappedFunc {
std::string func_name, std::string func_name,
std::vector<size_t> arg_size, std::vector<size_t> arg_size,
const std::vector<std::string>& thread_axis_tags) { const std::vector<std::string>& thread_axis_tags) {
w_ = cl::OpenCLWorkspace::Global(); w_ = cl::OpenCLWorkspace::Global().get();
m_ = m; m_ = m;
sptr_ = sptr; sptr_ = sptr;
entry_ = entry; entry_ = entry;
...@@ -268,7 +270,7 @@ Module OpenCLModuleCreate( ...@@ -268,7 +270,7 @@ Module OpenCLModuleCreate(
std::unordered_map<std::string, FunctionInfo> fmap) { std::unordered_map<std::string, FunctionInfo> fmap) {
std::shared_ptr<OpenCLModuleNode> n = std::shared_ptr<OpenCLModuleNode> n =
std::make_shared<OpenCLModuleNode>(data, fmt, fmap); std::make_shared<OpenCLModuleNode>(data, fmt, fmap);
n->InitProgram(); n->Init();
return Module(n); return Module(n);
} }
......
/*!
* Copyright (c) 2017 by Contributors
* \file system_lib_module.cc
* \brief SystemLib module.
*/
#include <tvm/runtime/registry.h>
#include <mutex>
#include "./module_util.h"
namespace tvm {
namespace runtime {
class SystemLibModuleNode : public ModuleNode {
public:
SystemLibModuleNode() {
}
const char* type_key() const final {
return "system_lib";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
std::lock_guard<std::mutex> lock(mutex_);
auto it = tbl_.find(name);
if (it != tbl_.end()) {
return WrapPackedFunc(
reinterpret_cast<BackendPackedCFunc>(it->second), sptr_to_self);
} else {
return PackedFunc();
}
}
void RegisterSymbol(const std::string& name, void* ptr) {
std::lock_guard<std::mutex> lock(mutex_);
if (name == symbol::tvm_module_ctx) {
void** ctx_addr = reinterpret_cast<void**>(ptr);
*ctx_addr = this;
} else if (name == symbol::tvm_dev_mblob) {
ImportModuleBlob(reinterpret_cast<const char*>(ptr), &imports_);
} else {
auto it = tbl_.find(name);
if (it != tbl_.end()) {
if (ptr != it->second) {
LOG(WARNING) << "SystemLib symbol " << name
<< " get overriden to a different address "
<< ptr << "->" << it->second;
tbl_[name] = ptr;
}
} else {
tbl_[name] = ptr;
}
}
}
static const std::shared_ptr<SystemLibModuleNode>& Global() {
static std::shared_ptr<SystemLibModuleNode> inst =
std::make_shared<SystemLibModuleNode>();
return inst;
}
private:
// Internal mutex
std::mutex mutex_;
// Internal symbol table
std::unordered_map<std::string, void*> tbl_;
};
TVM_REGISTER_GLOBAL("module._GetSystemLib")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = runtime::Module(SystemLibModuleNode::Global());
});
} // namespace runtime
} // namespace tvm
int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr) {
tvm::runtime::SystemLibModuleNode::Global()->RegisterSymbol(name, ptr);
return 0;
}
import tvm import tvm
from tvm.contrib import cc_compiler as cc, util from tvm.contrib import cc_compiler as cc, util
import ctypes
import os import os
import numpy as np import numpy as np
import subprocess import subprocess
...@@ -86,7 +87,8 @@ def test_device_module_dump(): ...@@ -86,7 +87,8 @@ def test_device_module_dump():
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
temp = util.tempdir() temp = util.tempdir()
f = tvm.build(s, [A, B], device, name="myadd") name = "myadd_%s" % device
f = tvm.build(s, [A, B], device, "llvm -system-lib", name=name)
path_dso = temp.relpath("dev_lib.so") path_dso = temp.relpath("dev_lib.so")
f.export_library(path_dso) f.export_library(path_dso)
...@@ -94,6 +96,9 @@ def test_device_module_dump(): ...@@ -94,6 +96,9 @@ def test_device_module_dump():
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
f1(a, b) f1(a, b)
f2 = tvm.module.system_lib()
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
f2[name](a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_device("cuda") check_device("cuda")
...@@ -134,9 +139,37 @@ def test_combine_module_llvm(): ...@@ -134,9 +139,37 @@ def test_combine_module_llvm():
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
fadd2(a, b) fadd2(a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
def check_system_lib():
ctx = tvm.cpu(0)
if not tvm.module.enabled("llvm"):
print("Skip because llvm is not enabled" )
return
temp = util.tempdir()
fadd1 = tvm.build(s, [A, B], "llvm -system-lib", name="myadd1")
fadd2 = tvm.build(s, [A, B], "llvm -system-lib", name="myadd2")
path1 = temp.relpath("myadd1.o")
path2 = temp.relpath("myadd2.o")
path_dso = temp.relpath("mylib.so")
fadd1.save(path1)
fadd2.save(path2)
cc.create_shared(path_dso, [path1, path2])
# Load dll, will trigger system library registration
dll = ctypes.CDLL(path_dso)
# Load the system wide library
mm = tvm.module.system_lib()
a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx)
mm['myadd1'](a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
mm['myadd2'](a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_system_lib()
check_llvm() check_llvm()
if __name__ == "__main__": if __name__ == "__main__":
test_combine_module_llvm() test_combine_module_llvm()
test_device_module_dump() test_device_module_dump()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment