Commit 72fcd4e6 by Tianqi Chen Committed by GitHub

[RUNTIME] Add System Lib (#227)

* [RUNTIME] Add System Lib

* lint

* lint

* fix compile
parent b759d0f3
......@@ -38,9 +38,10 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
* the dependency from the shared library.
*
* \param m The host module with the imports.
* \param system_lib Whether expose as system library.
* \return cstr The C string representation of the file.
*/
std::string PackImportsToC(const runtime::Module& m);
std::string PackImportsToC(const runtime::Module& m, bool system_lib);
} // namespace codegen
} // namespace tvm
......
......@@ -350,6 +350,17 @@ TVM_DLL int TVMFuncListGlobalNames(int *out_size,
TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
TVMFunctionHandle *out);
/*!
* \brief Backend function to register system-wide library symbol.
*
* \note This API is supposed to be used by backend,
* it is not supposed to be used by user.
*
* \param name The name of the symbol
* \param ptr The symbol address.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr);
/*!
* \brief Backend function for running parallel for loop.
......
......@@ -94,7 +94,7 @@ class Registry {
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __make_ ## TVMOp
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM
/*!
* \brief Register a function globally.
......
......@@ -223,23 +223,31 @@ def build(sch,
The target and option of the compilation.
When the target is llvm, you can set options like:
* **-mtriple=<target triple>** or **-target**
- **-mtriple=<target triple>** or **-target**
Specify the target triple, which is useful for cross
compilation.
* **-mcpu=<cpuname>**
- **-mcpu=<cpuname>**
Specify a specific chip in the current architecture to
generate code for. By default this is infered from the
target triple and autodetected to the current architecture.
* **-mattr=a1,+a2,-a3,...**
- **-mattr=a1,+a2,-a3,...**
Override or control specific attributes of the target,
such as whether SIMD operations are enabled or not. The
default set of attributes is set by the current CPU.
- **-system-lib**
Build TVM system library module. System lib is a global module that contains
self registered functions in program startup. User can get the module using
:any:`tvm.module.system_lib`.
It is useful in environments where dynamic loading api like dlopen is banned.
The system lib will be available as long as the result code is linked by the program.
target_host : str, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
......
......@@ -25,7 +25,7 @@ def create_shared(output,
The compile string.
"""
cmd = [cc]
cmd += ["-shared"]
cmd += ["-shared", "-fPIC"]
if sys.platform == "darwin":
cmd += ["-undefined", "dynamic_lookup"]
......
......@@ -71,16 +71,26 @@ class Module(ModuleBase):
file_name : str
The name of the shared library.
"""
if self.type_key == "stacktvm":
raise ValueError("Module[%s]: export_library requires llvm module,"
" did you build with LLVM enabled?" % self.type_key)
if self.type_key != "llvm":
raise ValueError("Module[%s]: Only llvm support export shared" % self.type_key)
temp = _util.tempdir()
path_obj = temp.relpath("lib.o")
self.save(path_obj)
files = [path_obj]
try:
self.get_function("__tvm_module_startup")
is_system_lib = True
except AttributeError:
is_system_lib = False
if self.imported_modules:
path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f:
f.write(_PackImportsToC(self))
f.write(_PackImportsToC(self, is_system_lib))
files.append(path_cc)
_cc.create_shared(file_name, files)
......@@ -116,6 +126,27 @@ class Module(ModuleBase):
raise NameError("time_evaluate is only supported when RPC is enabled")
def system_lib():
"""Get system-wide library module singleton.
System lib is a global module that contains self register functions in startup.
Unlike normal dso modules which need to be loaded explicitly.
It is useful in environments where dynamic loading api like dlopen is banned.
To build system lib function, simply specify target option ```llvm --system-lib```
The system lib will be available as long as the result code is linked by the program.
The system lib is intended to be linked and loaded during the entire life-cyle of the program.
If you want dynamic loading features, use dso modules instead.
Returns
-------
module : Module
The system-wide library module.
"""
return _GetSystemLib()
def load(path, fmt=""):
"""Load module from file
......
......@@ -23,7 +23,7 @@ TVM_REGISTER_API("codegen._Build")
TVM_REGISTER_API("module._PackImportsToC")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = PackImportsToC(args[0]);
*ret = PackImportsToC(args[0], args[1]);
});
} // namespace codegen
} // namespace tvm
......@@ -35,7 +35,7 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
return m;
}
std::string PackImportsToC(const runtime::Module& mod) {
std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
std::string bin;
dmlc::MemoryStringStream ms(&bin);
dmlc::Stream* stream = &ms;
......@@ -55,29 +55,39 @@ std::string PackImportsToC(const runtime::Module& mod) {
<< "extern \"C\" {\n"
<< "#endif\n";
os << "extern const char " << runtime::symbol::tvm_dev_mblob << "[];\n";
os << "extern const unsigned long " << runtime::symbol::tvm_dev_mblob_nbytes << ";\n";
uint64_t nbytes = bin.length();
os << "const char " << runtime::symbol::tvm_dev_mblob
<< "[" << bin.length() << "] = {\n ";
<< "[" << bin.length() + sizeof(nbytes) << "] = {\n ";
os << std::hex;
size_t nunit = 80 / 4;
for (size_t i = 0; i < bin.length(); ++i) {
for (size_t i = 0; i < sizeof(nbytes); ++i) {
// sperators
if (i != 0) {
if (i % nunit == 0) {
os << ",";
}
os << "0x" << ((nbytes >> (i * 8)) & 0xffUL);
}
for (size_t i = 0; i < bin.length(); ++i) {
// sperators
if ((i + sizeof(nbytes)) % nunit == 0) {
os << ",\n ";
} else {
os << ",";
}
}
int c = bin[i];
os << "0x" << (c & 0xff);
}
os << "\n};\n"
<< "const unsigned long " << runtime::symbol::tvm_dev_mblob_nbytes
<< " = " << std::dec << bin.length() << "UL;\n"
<< "#ifdef __cplusplus\n"
os << "\n};\n";
if (system_lib) {
os << "extern int TVMBackendRegisterSystemLibSymbol(const char*, void*);\n";
os << "static int " << runtime::symbol::tvm_dev_mblob << "_reg_ = "
<< "TVMBackendRegisterSystemLibSymbol(\"" << runtime::symbol::tvm_dev_mblob << "\", (void*)"
<< runtime::symbol::tvm_dev_mblob << ");\n";
}
os << "#ifdef __cplusplus\n"
<< "}\n"
<< "#endif\n";
return os.str();
}
} // namespace codegen
......
......@@ -28,7 +28,8 @@ std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) {
void CodeGenLLVM::Init(const std::string& module_name,
llvm::TargetMachine* tm,
llvm::LLVMContext* ctx) {
llvm::LLVMContext* ctx,
bool system_lib) {
InitializeLLVM();
static_assert(sizeof(TVMValue) == sizeof(double), "invariant");
// static_assert(alignof(TVMValue) == alignof(double), "invariant");
......@@ -36,6 +37,7 @@ void CodeGenLLVM::Init(const std::string& module_name,
var_map_.clear();
str_map_.clear();
func_handle_map_.clear();
export_system_symbols_.clear();
// initialize types.
if (ctx_ != ctx) {
t_void_ = llvm::Type::getVoidTy(*ctx);
......@@ -96,6 +98,13 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_int64_, t_int64_, t_f_tvm_par_for_lambda_->getPointerTo(), t_void_p_}
, false),
llvm::Function::ExternalLinkage, "TVMBackendParallelFor", module_.get());
if (system_lib) {
f_tvm_register_system_symbol_ = llvm::Function::Create(
llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, false),
llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", module_.get());
} else {
f_tvm_register_system_symbol_ = nullptr;
}
this->InitTarget(tm);
// initialize builder
builder_.reset(new IRBuilder(*ctx));
......@@ -125,9 +134,15 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
void CodeGenLLVM::InitGlobalContext() {
gv_mod_ctx_ = new llvm::GlobalVariable(
*module_, t_void_p_, false,
llvm::GlobalValue::LinkOnceAnyLinkage, 0, "__tvm_module_ctx");
llvm::GlobalValue::LinkOnceAnyLinkage, 0,
tvm::runtime::symbol::tvm_module_ctx);
gv_mod_ctx_->setAlignment(data_layout_->getTypeAllocSize(t_void_p_));
gv_mod_ctx_->setInitializer(llvm::Constant::getNullValue(t_void_p_));
if (f_tvm_register_system_symbol_ != nullptr) {
export_system_symbols_.emplace_back(
std::make_pair(tvm::runtime::symbol::tvm_module_ctx, gv_mod_ctx_));
}
}
void CodeGenLLVM::InitFuncState() {
......@@ -171,6 +186,11 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
builder_->SetInsertPoint(block);
this->VisitStmt(f->body);
builder_->CreateRet(ConstInt32(0));
if (f_tvm_register_system_symbol_ != nullptr) {
export_system_symbols_.emplace_back(
std::make_pair(f->name, builder_->CreatePointerCast(function_, t_void_p_)));
}
}
void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
......@@ -225,13 +245,35 @@ void CodeGenLLVM::Optimize() {
}
std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
this->AddStartupFunction();
this->Optimize();
var_map_.clear();
str_map_.clear();
func_handle_map_.clear();
export_system_symbols_.clear();
return std::move(module_);
}
void CodeGenLLVM::AddStartupFunction() {
if (export_system_symbols_.size() != 0) {
llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_, {}, false);
function_ = llvm::Function::Create(
ftype,
llvm::Function::InternalLinkage,
"__tvm_module_startup", module_.get());
llvm::BasicBlock* startup_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(startup_entry);
for (const auto& kv : export_system_symbols_) {
llvm::Value* name = GetConstString(kv.first);
builder_->CreateCall(
f_tvm_register_system_symbol_, {
name, builder_->CreateBitCast(kv.second, t_void_p_)});
}
llvm::appendToGlobalCtors(*module_, function_, 65535);
builder_->CreateRet(nullptr);
}
}
llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
llvm::Type* ret = nullptr;
if (t.is_uint() || t.is_int()) {
......
......@@ -12,6 +12,7 @@
#include <tvm/codegen.h>
#include <tvm/arithmetic.h>
#include <memory>
#include <utility>
#include <vector>
#include <string>
#include "./llvm_common.h"
......@@ -39,10 +40,12 @@ class CodeGenLLVM :
* \param module_name The name of the module.
* \param tm Target machine model
* \param ctx The context.
* \param system_lib Whether to insert system library registration.
*/
void Init(const std::string& module_name,
llvm::TargetMachine* tm,
llvm::LLVMContext* ctx);
llvm::LLVMContext* ctx,
bool system_lib);
/*!
* \brief Compile and add function f to the current module.
* \param f The function to be added.
......@@ -163,7 +166,7 @@ class CodeGenLLVM :
llvm::LLVMContext* ctx_{nullptr};
// helpful data types
llvm::Type* t_void_{nullptr};
llvm::Type* t_void_p_{nullptr};
llvm::PointerType* t_void_p_{nullptr};
llvm::Type* t_int_{nullptr};
llvm::Type* t_char_{nullptr};
llvm::Type* t_int8_{nullptr};
......@@ -188,6 +191,7 @@ class CodeGenLLVM :
llvm::Function* f_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr};
llvm::Function* f_tvm_parallel_for_{nullptr};
llvm::Function* f_tvm_register_system_symbol_{nullptr};
// The acting body
llvm::BasicBlock* block_{nullptr};
/*! \brief native vector bits of current targetx*/
......@@ -227,6 +231,8 @@ class CodeGenLLVM :
llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode);
// Add a function to set global module context
void InitGlobalContext();
// Add module startup function if needed.
void AddStartupFunction();
// add alias information.
void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, Type type);
// The definition of local variable.
......@@ -243,6 +249,8 @@ class CodeGenLLVM :
llvm::GlobalVariable* gv_mod_ctx_{nullptr};
// global to packed function handle
std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_;
// List of symbols to be exported to TVM system lib.
std::vector<std::pair<std::string, llvm::Value*> > export_system_symbols_;
};
} // namespace codegen
} // namespace tvm
......
......@@ -51,6 +51,9 @@ GetLLVMTargetMachine(const std::string& target_str) {
if (target_str.length() > 5) {
std::istringstream is(target_str.substr(5, target_str.length() - 5));
while (is >> key) {
if (key == "--system-lib" || key == "-system-lib") {
continue;
}
size_t pos = key.find('=');
if (pos != std::string::npos) {
CHECK_GE(key.length(), pos + 1)
......
......@@ -28,6 +28,7 @@
#include <llvm/IR/MDBuilder.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/Transforms/Utils/ModuleUtils.h>
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
#include <llvm/Transforms/IPO.h>
......@@ -54,6 +55,7 @@ void InitializeLLVM();
/*!
* \brief Get target machine from target_str string.
* \param target_str Target string, in format "llvm -target=xxx -mcpu=xxx"
*
* \return target machine
*/
llvm::TargetMachine*
......
......@@ -10,7 +10,7 @@
#include "./llvm_common.h"
#include "./codegen_llvm.h"
#include "../../runtime/file_util.h"
#include "../../runtime/meta_data.h"
#include "../../runtime/module_util.h"
namespace tvm {
namespace codegen {
......@@ -99,11 +99,12 @@ class LLVMModuleNode final : public runtime::ModuleNode {
void Init(const Array<LoweredFunc>& funcs, std::string target) {
InitializeLLVM();
tm_ = GetLLVMTargetMachine(target);
bool system_lib = (target.find("-system-lib") != std::string::npos);
CHECK_NE(funcs.size(), 0U);
ctx_ = std::make_shared<llvm::LLVMContext>();
std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_);
entry_func_ = funcs[0]->name;
cg->Init(funcs[0]->name, tm_, ctx_.get());
cg->Init(funcs[0]->name, tm_, ctx_.get(), system_lib);
for (LoweredFunc f : funcs) {
cg->AddFunction(f);
}
......
......@@ -113,7 +113,7 @@ class CandidateSelector final : public IRVisitor {
std::unordered_set<const Node*> candidates;
private:
bool in_likely_;
bool in_likely_{false};
bool no_split_{false};
std::unordered_map<const Variable*, VarIsUsed> record_;
};
......
/*!
* Copyright (c) 2017 by Contributors
* \file dso_module.cc
* \file dso_dll_module.cc
* \brief Module to load from dynamic shared library.
*/
#include <dmlc/memory_io.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#include "./meta_data.h"
#include "./module_util.h"
#if defined(_WIN32)
#include <windows.h>
......@@ -19,7 +18,7 @@ namespace tvm {
namespace runtime {
// Module to load from dynamic shared libary.
// This is the default module TVM used for hostside AOT
// This is the default module TVM used for host-side AOT
class DSOModuleNode final : public ModuleNode {
public:
~DSOModuleNode() {
......@@ -33,68 +32,38 @@ class DSOModuleNode final : public ModuleNode {
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
BackendPackedCFunc faddr = GetFuncPtr(name);
BackendPackedCFunc faddr;
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::tvm_module_main));
CHECK(entry_name!= nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name));
} else {
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name));
}
if (faddr == nullptr) return PackedFunc();
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
int ret = (*faddr)(
(void*)args.values, // NOLINT(*)
(int*)args.type_codes, // NOLINT(*)
args.num_args);
CHECK_EQ(ret, 0) << TVMGetLastError();
});
return WrapPackedFunc(faddr, sptr_to_self);
}
void Init(const std::string& name) {
Load(name);
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name;
void** ctx_addr =
reinterpret_cast<void**>(
GetGlobalVPtr(runtime::symbol::tvm_module_ctx));
GetSymbol(runtime::symbol::tvm_module_ctx));
if (ctx_addr != nullptr) {
*ctx_addr = this;
}
// Load the imported modules
const char* dev_mblob =
reinterpret_cast<const char*>(
GetGlobalVPtr(runtime::symbol::tvm_dev_mblob));
const unsigned long* dev_mblob_nbytes = // NOLINT(*)
reinterpret_cast<const unsigned long*>( // NOLINT(*)
GetGlobalVPtr(runtime::symbol::tvm_dev_mblob_nbytes));
GetSymbol(runtime::symbol::tvm_dev_mblob));
if (dev_mblob != nullptr) {
CHECK(dev_mblob_nbytes != nullptr);
dmlc::MemoryFixedSizeStream fs(
(void*)dev_mblob, dev_mblob_nbytes[0]); // NOLINT(*)
dmlc::Stream* stream = &fs;
uint64_t size;
CHECK(stream->Read(&size));
for (uint64_t i = 0; i < size; ++i) {
std::string tkey;
CHECK(stream->Read(&tkey));
std::string fkey = "module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr)
<< "Loader of " << tkey << "("
<< fkey << ") is not presented.";
Module m = (*f)(static_cast<void*>(stream));
this->imports_.push_back(m);
}
ImportModuleBlob(dev_mblob, &imports_);
}
}
private:
BackendPackedCFunc GetFuncPtr(const std::string& name) {
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
GetGlobalVPtr(runtime::symbol::tvm_module_main));
CHECK(entry_name!= nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
return GetFuncPtr_(entry_name);
} else {
return GetFuncPtr_(name);
}
}
// Platform dependent handling.
#if defined(_WIN32)
// library handle
......@@ -104,14 +73,12 @@ class DSOModuleNode final : public ModuleNode {
// use wstring version that is needed by LLVM.
std::wstring wname(name.begin(), name.end());
lib_handle_ = LoadLibraryW(wname.c_str());
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name;
}
BackendPackedCFunc GetFuncPtr_(const std::string& name) {
return reinterpret_cast<BackendPackedCFunc>(
GetProcAddress(lib_handle_, (LPCSTR)name.c_str())); // NOLINT(*)
}
void* GetGlobalVPtr(const std::string& name) {
void* GetSymbol(const std::string& name) {
return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, name.c_str())); // NOLINT(*)
GetProcAddress(lib_handle_, (LPCSTR)name.c_str())); // NOLINT(*)
}
void Unload() {
FreeLibrary(lib_handle_);
......@@ -122,12 +89,10 @@ class DSOModuleNode final : public ModuleNode {
// load the library
void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name;
}
BackendPackedCFunc GetFuncPtr_(const std::string& name) {
return reinterpret_cast<BackendPackedCFunc>(
dlsym(lib_handle_, name.c_str()));
}
void* GetGlobalVPtr(const std::string& name) {
void* GetSymbol(const std::string& name) {
return dlsym(lib_handle_, name.c_str());
}
void Unload() {
......
......@@ -4,8 +4,8 @@
*/
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <tvm/runtime/packed_func.h>
#include <fstream>
#include "./file_util.h"
namespace tvm {
......
......@@ -7,17 +7,12 @@
#define TVM_RUNTIME_META_DATA_H_
#include <dmlc/json.h>
#include <dmlc/io.h>
#include <tvm/runtime/packed_func.h>
#include <string>
#include <vector>
#include "./runtime_base.h"
extern "C" {
// Function signature for generated packed function in shared library
typedef int (*BackendPackedCFunc)(void* args,
int* type_codes,
int num_args);
} // extern "C"
namespace tvm {
namespace runtime {
......
/*!
* Copyright (c) 2017 by Contributors
* \file module.cc
* \brief The global registry of packed function.
* \brief TVM module system
*/
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#include <unordered_set>
#include "./file_util.h"
#include "./meta_data.h"
namespace tvm {
namespace runtime {
......
/*!
* Copyright (c) 2017 by Contributors
* \file module_util.cc
* \brief Utilities for module.
*/
#include <dmlc/memory_io.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include "./module_util.h"
namespace tvm {
namespace runtime {
void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
CHECK(mblob != nullptr);
uint64_t nbytes = 0;
for (size_t i = 0; i < sizeof(nbytes); ++i) {
uint64_t c = mblob[i];
nbytes |= (c & 0xffUL) << (i * 8);
}
dmlc::MemoryFixedSizeStream fs(
const_cast<char*>(mblob + sizeof(nbytes)), nbytes);
dmlc::Stream* stream = &fs;
uint64_t size;
CHECK(stream->Read(&size));
for (uint64_t i = 0; i < size; ++i) {
std::string tkey;
CHECK(stream->Read(&tkey));
std::string fkey = "module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr)
<< "Loader of " << tkey << "("
<< fkey << ") is not presented.";
Module m = (*f)(static_cast<void*>(stream));
mlist->push_back(m);
}
}
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
int ret = (*faddr)(
const_cast<TVMValue*>(args.values),
const_cast<int*>(args.type_codes),
args.num_args);
CHECK_EQ(ret, 0) << TVMGetLastError();
});
}
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file module_util.h
* \brief Helper utilities for module building
*/
#ifndef TVM_RUNTIME_MODULE_UTIL_H_
#define TVM_RUNTIME_MODULE_UTIL_H_
#include <tvm/runtime/module.h>
#include <vector>
extern "C" {
// Function signature for generated packed function in shared library
typedef int (*BackendPackedCFunc)(void* args,
int* type_codes,
int num_args);
} // extern "C"
namespace tvm {
namespace runtime {
/*!
* \brief Wrap a BackendPackedCFunc to packed function.
* \param faddr The function address
* \param mptr The module pointer node.
*/
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr<ModuleNode>& mptr);
/*!
* \brief Load and append module blob to module list
* \param mblob The module blob.
* \param module_list The module list to append to
*/
void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_MODULE_UTIL_H_
......@@ -153,7 +153,7 @@ class OpenCLWorkspace final : public DeviceAPI {
TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
// get the global workspace
static OpenCLWorkspace* Global();
static const std::shared_ptr<OpenCLWorkspace>& Global();
};
......
......@@ -13,9 +13,9 @@ namespace tvm {
namespace runtime {
namespace cl {
OpenCLWorkspace* OpenCLWorkspace::Global() {
static OpenCLWorkspace inst;
return &inst;
const std::shared_ptr<OpenCLWorkspace>& OpenCLWorkspace::Global() {
static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<OpenCLWorkspace>();
return inst;
}
void OpenCLWorkspace::SetDevice(TVMContext ctx) {
......@@ -210,14 +210,13 @@ void OpenCLWorkspace::Init() {
}
bool InitOpenCL(TVMArgs args, TVMRetValue* rv) {
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
w->Init();
cl::OpenCLWorkspace::Global()->Init();
return true;
}
TVM_REGISTER_GLOBAL("device_api.opencl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = OpenCLWorkspace::Global();
DeviceAPI* ptr = OpenCLWorkspace::Global().get();
*rv = static_cast<void*>(ptr);
});
......
......@@ -40,10 +40,9 @@ class OpenCLModuleNode : public ModuleNode {
~OpenCLModuleNode() {
{
// free the kernel ids in global table.
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
std::lock_guard<std::mutex> lock(w->mu);
std::lock_guard<std::mutex> lock(workspace_->mu);
for (auto& kv : kid_map_) {
w->free_kernel_ids.push_back(kv.second.kernel_id);
workspace_->free_kernel_ids.push_back(kv.second.kernel_id);
}
}
// free the kernels
......@@ -89,33 +88,33 @@ class OpenCLModuleNode : public ModuleNode {
}
// Initialize the programs
void InitProgram() {
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
w->Init();
CHECK(w->context != nullptr) << "No OpenCL device";
void Init() {
workspace_ = cl::OpenCLWorkspace::Global();
workspace_->Init();
CHECK(workspace_->context != nullptr) << "No OpenCL device";
if (fmt_ == "cl") {
const char* s = data_.c_str();
size_t len = data_.length();
cl_int err;
program_ = clCreateProgramWithSource(
w->context, 1, &s, &len, &err);
workspace_->context, 1, &s, &len, &err);
OPENCL_CHECK_ERROR(err);
} else {
LOG(FATAL) << "Unknown OpenCL format " << fmt_;
}
device_built_flag_.resize(w->devices.size(), false);
device_built_flag_.resize(workspace_->devices.size(), false);
// initialize the kernel id, need to lock global table.
std::lock_guard<std::mutex> lock(w->mu);
std::lock_guard<std::mutex> lock(workspace_->mu);
for (const auto& kv : fmap_) {
const std::string& key = kv.first;
KTRefEntry e;
if (w->free_kernel_ids.size() != 0) {
e.kernel_id = w->free_kernel_ids.back();
w->free_kernel_ids.pop_back();
if (workspace_->free_kernel_ids.size() != 0) {
e.kernel_id = workspace_->free_kernel_ids.back();
workspace_->free_kernel_ids.pop_back();
} else {
e.kernel_id = w->num_registered_kernels++;
e.kernel_id = workspace_->num_registered_kernels++;
}
e.version = w->timestamp++;
e.version = workspace_->timestamp++;
kid_map_[key] = e;
}
}
......@@ -154,6 +153,9 @@ class OpenCLModuleNode : public ModuleNode {
}
private:
// The workspace, need to keep reference to use it in destructor.
// In case of static destruction order problem.
std::shared_ptr<cl::OpenCLWorkspace> workspace_;
// the binary data
std::string data_;
// The format
......@@ -181,7 +183,7 @@ class OpenCLWrappedFunc {
std::string func_name,
std::vector<size_t> arg_size,
const std::vector<std::string>& thread_axis_tags) {
w_ = cl::OpenCLWorkspace::Global();
w_ = cl::OpenCLWorkspace::Global().get();
m_ = m;
sptr_ = sptr;
entry_ = entry;
......@@ -268,7 +270,7 @@ Module OpenCLModuleCreate(
std::unordered_map<std::string, FunctionInfo> fmap) {
std::shared_ptr<OpenCLModuleNode> n =
std::make_shared<OpenCLModuleNode>(data, fmt, fmap);
n->InitProgram();
n->Init();
return Module(n);
}
......
/*!
* Copyright (c) 2017 by Contributors
* \file system_lib_module.cc
* \brief SystemLib module.
*/
#include <tvm/runtime/registry.h>
#include <mutex>
#include "./module_util.h"
namespace tvm {
namespace runtime {
class SystemLibModuleNode : public ModuleNode {
public:
SystemLibModuleNode() {
}
const char* type_key() const final {
return "system_lib";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
std::lock_guard<std::mutex> lock(mutex_);
auto it = tbl_.find(name);
if (it != tbl_.end()) {
return WrapPackedFunc(
reinterpret_cast<BackendPackedCFunc>(it->second), sptr_to_self);
} else {
return PackedFunc();
}
}
void RegisterSymbol(const std::string& name, void* ptr) {
std::lock_guard<std::mutex> lock(mutex_);
if (name == symbol::tvm_module_ctx) {
void** ctx_addr = reinterpret_cast<void**>(ptr);
*ctx_addr = this;
} else if (name == symbol::tvm_dev_mblob) {
ImportModuleBlob(reinterpret_cast<const char*>(ptr), &imports_);
} else {
auto it = tbl_.find(name);
if (it != tbl_.end()) {
if (ptr != it->second) {
LOG(WARNING) << "SystemLib symbol " << name
<< " get overriden to a different address "
<< ptr << "->" << it->second;
tbl_[name] = ptr;
}
} else {
tbl_[name] = ptr;
}
}
}
static const std::shared_ptr<SystemLibModuleNode>& Global() {
static std::shared_ptr<SystemLibModuleNode> inst =
std::make_shared<SystemLibModuleNode>();
return inst;
}
private:
// Internal mutex
std::mutex mutex_;
// Internal symbol table
std::unordered_map<std::string, void*> tbl_;
};
TVM_REGISTER_GLOBAL("module._GetSystemLib")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = runtime::Module(SystemLibModuleNode::Global());
});
} // namespace runtime
} // namespace tvm
int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr) {
tvm::runtime::SystemLibModuleNode::Global()->RegisterSymbol(name, ptr);
return 0;
}
import tvm
from tvm.contrib import cc_compiler as cc, util
import ctypes
import os
import numpy as np
import subprocess
......@@ -86,7 +87,8 @@ def test_device_module_dump():
print("Skip because %s is not enabled" % device)
return
temp = util.tempdir()
f = tvm.build(s, [A, B], device, name="myadd")
name = "myadd_%s" % device
f = tvm.build(s, [A, B], device, "llvm -system-lib", name=name)
path_dso = temp.relpath("dev_lib.so")
f.export_library(path_dso)
......@@ -94,6 +96,9 @@ def test_device_module_dump():
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
f1(a, b)
f2 = tvm.module.system_lib()
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
f2[name](a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_device("cuda")
......@@ -134,9 +139,37 @@ def test_combine_module_llvm():
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
fadd2(a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
def check_system_lib():
ctx = tvm.cpu(0)
if not tvm.module.enabled("llvm"):
print("Skip because llvm is not enabled" )
return
temp = util.tempdir()
fadd1 = tvm.build(s, [A, B], "llvm -system-lib", name="myadd1")
fadd2 = tvm.build(s, [A, B], "llvm -system-lib", name="myadd2")
path1 = temp.relpath("myadd1.o")
path2 = temp.relpath("myadd2.o")
path_dso = temp.relpath("mylib.so")
fadd1.save(path1)
fadd2.save(path2)
cc.create_shared(path_dso, [path1, path2])
# Load dll, will trigger system library registration
dll = ctypes.CDLL(path_dso)
# Load the system wide library
mm = tvm.module.system_lib()
a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx)
mm['myadd1'](a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
mm['myadd2'](a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_system_lib()
check_llvm()
if __name__ == "__main__":
test_combine_module_llvm()
test_device_module_dump()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment