Commit 9037a4c2 by Tianqi Chen Committed by GitHub

[RUNTIME] Enable injection of some core runtime functions to avoid dynamic lookup (#260)

parent 6196cd50
......@@ -154,12 +154,12 @@ verilog: $(VER_LIBS)
# Special rules for LLVM related modules.
build/codegen/llvm/%.o: src/codegen/llvm/%.cc
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) $(CFLAGS) -MM -MT build/codegen/llvm/$*.o $< >build/codegen/llvm/$*.d
$(CXX) -c $(CFLAGS) $(LLVM_CFLAGS) -c $< -o $@
build/runtime/metal/%.o: src/runtime/metal/%.mm
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) $(CFLAGS) -MM -MT build/runtime/metal/$*.o $< >build/runtime/metal/$*.d
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
build/%.o: src/%.cc
......@@ -199,7 +199,7 @@ pylint:
pylint python/tvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc
pylint topi/python/topi --rcfile=$(ROOTDIR)/tests/lint/pylintrc
jnilint:
jnilint:
python dmlc-core/scripts/lint.py tvm4j-jni cpp jvm/native/src
lint: cpplint pylint jnilint
......
......@@ -29,7 +29,8 @@ std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) {
void CodeGenLLVM::Init(const std::string& module_name,
llvm::TargetMachine* tm,
llvm::LLVMContext* ctx,
bool system_lib) {
bool system_lib,
bool dynamic_lookup) {
InitializeLLVM();
static_assert(sizeof(TVMValue) == sizeof(double), "invariant");
// static_assert(alignof(TVMValue) == alignof(double), "invariant");
......@@ -62,7 +63,7 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_tvm_shape_index_->getPointerTo(),
t_int64_});
t_tvm_value_ = llvm::StructType::create({t_float64_});
t_f_tvm_par_for_lambda_ = llvm::FunctionType::get(
ftype_tvm_par_for_lambda_ = llvm::FunctionType::get(
t_int_, {t_int64_, t_int64_, t_void_p_}, false);
md_builder_.reset(new llvm::MDBuilder(*ctx));
md_very_likely_branch_ =
......@@ -70,45 +71,56 @@ void CodeGenLLVM::Init(const std::string& module_name,
md_tbaa_root_ = md_builder_->createTBAARoot("tvmtbaa");
md_tbaa_alias_set_ = md_builder_->createTBAAScalarTypeNode(
"alias_set", md_tbaa_root_);
md_tbaa_ctx_ptr_ = md_builder_->createTBAAScalarTypeNode(
"ctx_ptr", md_tbaa_root_);
}
ctx_ = ctx;
// initialize modules
// initialize Modules and function type
module_.reset(new llvm::Module(module_name, *ctx));
// initialize TVM runtime API
f_tvm_func_call_ = llvm::Function::Create(
llvm::FunctionType::get(t_int_, {
t_tvm_func_handle_,
t_tvm_value_->getPointerTo(),
t_int_->getPointerTo(),
t_int_,
t_tvm_value_->getPointerTo(),
t_int_->getPointerTo()}, false),
llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get());
f_tvm_get_func_from_env_ = llvm::Function::Create(
ftype_tvm_func_call_ = llvm::FunctionType::get(t_int_, {
t_tvm_func_handle_,
t_tvm_value_->getPointerTo(),
t_int_->getPointerTo(),
t_int_,
t_tvm_value_->getPointerTo(),
t_int_->getPointerTo()}, false);
ftype_tvm_get_func_from_env_ = llvm::FunctionType::get(t_int_, {
t_void_p_,
t_char_->getPointerTo(),
t_tvm_func_handle_->getPointerTo()}, false);
ftype_tvm_api_set_last_error_ = llvm::FunctionType::get(
t_void_, {t_char_->getPointerTo()}, false);
ftype_tvm_parallel_for_ =
llvm::FunctionType::get(t_int_, {
t_void_p_,
t_char_->getPointerTo(),
t_tvm_func_handle_->getPointerTo()}, false),
llvm::Function::ExternalLinkage, "TVMBackendGetFuncFromEnv", module_.get());
f_tvm_api_set_last_error_ = llvm::Function::Create(
llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false),
llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());
f_tvm_parallel_for_ = llvm::Function::Create(
llvm::FunctionType::get(t_int_, {
t_int64_, t_int64_, t_f_tvm_par_for_lambda_->getPointerTo(), t_void_p_}
, false),
llvm::Function::ExternalLinkage, "TVMBackendParallelFor", module_.get());
t_int64_, t_int64_, ftype_tvm_par_for_lambda_->getPointerTo(), t_void_p_}
, false);
// initialize TVM runtime API
if (system_lib) {
// We will need this in environment for backward registration.
f_tvm_register_system_symbol_ = llvm::Function::Create(
llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, false),
llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", module_.get());
} else {
f_tvm_register_system_symbol_ = nullptr;
}
if (dynamic_lookup || system_lib) {
f_tvm_func_call_ = llvm::Function::Create(
ftype_tvm_func_call_,
llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get());
f_tvm_get_func_from_env_ = llvm::Function::Create(
ftype_tvm_get_func_from_env_,
llvm::Function::ExternalLinkage, "TVMBackendGetFuncFromEnv", module_.get());
f_tvm_api_set_last_error_ = llvm::Function::Create(
ftype_tvm_api_set_last_error_,
llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());
f_tvm_parallel_for_ = llvm::Function::Create(
ftype_tvm_parallel_for_,
llvm::Function::ExternalLinkage, "TVMBackendParallelFor", module_.get());
}
this->InitTarget(tm);
// initialize builder
builder_.reset(new IRBuilder(*ctx));
this->InitGlobalContext();
this->InitGlobalContext(dynamic_lookup);
}
void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
......@@ -131,17 +143,48 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
}
}
void CodeGenLLVM::InitGlobalContext() {
gv_mod_ctx_ = new llvm::GlobalVariable(
*module_, t_void_p_, false,
llvm::GlobalVariable* CodeGenLLVM::InitContextPtr(
llvm::Type* p_type, std::string name) {
llvm::GlobalVariable* gv = new llvm::GlobalVariable(
*module_, p_type, false,
llvm::GlobalValue::LinkOnceAnyLinkage, 0,
tvm::runtime::symbol::tvm_module_ctx);
gv_mod_ctx_->setAlignment(data_layout_->getTypeAllocSize(t_void_p_));
gv_mod_ctx_->setInitializer(llvm::Constant::getNullValue(t_void_p_));
name);
gv->setAlignment(data_layout_->getTypeAllocSize(p_type));
gv->setInitializer(llvm::Constant::getNullValue(p_type));
return gv;
}
llvm::Value* CodeGenLLVM::GetContextPtr(llvm::GlobalVariable* gv) {
CHECK(gv != nullptr);
llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment());
faddr->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
return faddr;
}
void CodeGenLLVM::InitGlobalContext(bool dynamic_lookup) {
// Module context
gv_mod_ctx_ = InitContextPtr(t_void_p_, tvm::runtime::symbol::tvm_module_ctx);
// Register back the locations.
if (f_tvm_register_system_symbol_ != nullptr) {
export_system_symbols_.emplace_back(
std::make_pair(tvm::runtime::symbol::tvm_module_ctx, gv_mod_ctx_));
} else {
if (!dynamic_lookup) {
gv_tvm_func_call_ = InitContextPtr(
ftype_tvm_func_call_->getPointerTo(), "__TVMFuncCall");
gv_tvm_get_func_from_env_ = InitContextPtr(
ftype_tvm_get_func_from_env_->getPointerTo(), "__TVMBackendGetFuncFromEnv");
gv_tvm_api_set_last_error_ = InitContextPtr(
ftype_tvm_api_set_last_error_->getPointerTo(), "__TVMAPISetLastError");
gv_tvm_parallel_for_ = InitContextPtr(
ftype_tvm_parallel_for_->getPointerTo(), "__TVMBackendParallelFor");
// Mark as context functions
gv_func_map_["TVMBackendAllocWorkspace"] = nullptr;
gv_func_map_["TVMBackendFreeWorkspace"] = nullptr;
}
}
}
......@@ -528,9 +571,13 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
// Initialize the handle if needed.
builder_->SetInsertPoint(init_block);
llvm::Value* out = builder_->CreateAlloca(t_tvm_func_handle_);
llvm::Value* ctx = builder_->CreateLoad(gv_mod_ctx_);
llvm::LoadInst* ctx = builder_->CreateAlignedLoad(
gv_mod_ctx_, gv_mod_ctx_->getAlignment());
ctx->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
llvm::Value* retcode = builder_->CreateCall(
f_tvm_get_func_from_env_, {ctx, GetConstString(fname), out});
RuntimeTVMGetFuncFromEnv(), {ctx, GetConstString(fname), out});
init_block = CheckCallSuccess(retcode);
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
builder_->CreateBr(end_block);
......@@ -565,7 +612,7 @@ llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
Int(32), stack_tcode, ConstInt32(end));
CheckCallSuccess(
builder_->CreateCall(
f_tvm_func_call_,
RuntimeTVMFuncCall(),
{handle, arg_value, arg_tcode, ConstInt32(nargs),
ret_value, ret_tcode}));
Type r_type = op->type;
......@@ -584,17 +631,28 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
arg_values[i] = MakeValue(op->args[i]);
}
if (op->type.is_scalar()) {
llvm::Function* f = module_->getFunction(op->name);
if (f == nullptr) {
std::vector<llvm::Type*> arg_types;
for (llvm::Value* v : arg_values) {
arg_types.push_back(v->getType());
std::vector<llvm::Type*> arg_types;
for (llvm::Value* v : arg_values) {
arg_types.push_back(v->getType());
}
llvm::FunctionType* ftype = llvm::FunctionType::get(
LLVMType(op->type), arg_types, false);
// Check if it is available in global function table as injected function.
auto it = gv_func_map_.find(op->name);
if (it != gv_func_map_.end()) {
if (it->second == nullptr) {
gv_func_map_[op->name] = InitContextPtr(ftype->getPointerTo(), "__" + op->name);
it = gv_func_map_.find(op->name);
}
f = llvm::Function::Create(
llvm::FunctionType::get(LLVMType(op->type), arg_types, false),
llvm::Function::ExternalLinkage, op->name, module_.get());
return builder_->CreateCall(GetContextPtr(it->second), arg_values);
} else {
llvm::Function* f = module_->getFunction(op->name);
if (f == nullptr) {
f = llvm::Function::Create(
ftype, llvm::Function::ExternalLinkage, op->name, module_.get());
}
return builder_->CreateCall(f, arg_values);
}
return builder_->CreateCall(f, arg_values);
} else {
llvm::Function* f = module_->getFunction(op->name);
if (f) {
......@@ -603,6 +661,7 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
LOG(FATAL) << "cannot find function " << op->name;
}
}
LOG(FATAL) << "canot reach here";
return nullptr;
}
......@@ -630,6 +689,24 @@ llvm::Value* CodeGenLLVM::CreateScalarizedCall(
return value;
}
llvm::Value* CodeGenLLVM::RuntimeTVMFuncCall() {
if (f_tvm_func_call_ != nullptr) return f_tvm_func_call_;
return GetContextPtr(gv_tvm_func_call_);
}
llvm::Value* CodeGenLLVM::RuntimeTVMGetFuncFromEnv() {
if (f_tvm_get_func_from_env_ != nullptr) return f_tvm_get_func_from_env_;
return GetContextPtr(gv_tvm_get_func_from_env_);
}
llvm::Value* CodeGenLLVM::RuntimeTVMAPISetLastError() {
if (f_tvm_api_set_last_error_ != nullptr) return f_tvm_api_set_last_error_;
return GetContextPtr(gv_tvm_api_set_last_error_);
}
llvm::Value* CodeGenLLVM::RuntimeTVMParallelFor() {
if (f_tvm_parallel_for_ != nullptr) return f_tvm_parallel_for_;
return GetContextPtr(gv_tvm_parallel_for_);
}
llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
auto it = var_map_.find(v);
CHECK(it != var_map_.end())
......@@ -723,7 +800,7 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
// closure data
llvm::StructType* tcdata = llvm::StructType::create(fields);
llvm::Function* f = llvm::Function::Create(
t_f_tvm_par_for_lambda_,
ftype_tvm_par_for_lambda_,
llvm::Function::PrivateLinkage,
"__tvm_par_for_lambda", module_.get());
// allocate and setup the closure, call the closure.
......@@ -737,7 +814,7 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
}
BasicBlock* par_for_end = CheckCallSuccess(
builder_->CreateCall(
f_tvm_parallel_for_,
RuntimeTVMParallelFor(),
{min, extent, f, builder_->CreatePointerCast(cdata, t_void_p_)}));
// Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
......@@ -794,8 +871,9 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
builder_->SetInsertPoint(for_end);
}
llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
if (op->is_intrinsic("llvm_intrin")) {
CHECK_GE(op->args.size(), 1U);
std::vector<llvm::Value*> arg_values;
std::vector<llvm::Type*> arg_types;
for (size_t i = 1; i < op->args.size(); ++i) {
......@@ -808,6 +886,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
module_.get(), id, arg_types);
return builder_->CreateCall(f, arg_values);
} else if (op->is_intrinsic("llvm_builtin")) {
CHECK_GE(op->args.size(), 1U);
std::vector<llvm::Value*> arg_values;
for (size_t i = 1; i < op->args.size(); ++i) {
llvm::Value* v = MakeValue(op->args[i]);
......@@ -1391,7 +1470,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
return CreateCallPacked(op);
} else if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
return CreateIntrinstic(op);
return CreateIntrinsic(op);
} else {
CHECK(op->call_type == Call::Extern ||
op->call_type == Call::PureExtern);
......@@ -1508,7 +1587,7 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_);
// fail condition.
builder_->SetInsertPoint(fail_block);
builder_->CreateCall(f_tvm_api_set_last_error_, {msg});
builder_->CreateCall(RuntimeTVMAPISetLastError(), {msg});
builder_->CreateRet(ConstInt32(-1));
// otherwise set it to be new end.
builder_->SetInsertPoint(end_block);
......
......@@ -41,11 +41,14 @@ class CodeGenLLVM :
* \param tm Target machine model
* \param ctx The context.
* \param system_lib Whether to insert system library registration.
* \param dynamic_lookup Whether dynamically lookup runtime function
* or use the runtime function table passed by caller.
*/
void Init(const std::string& module_name,
llvm::TargetMachine* tm,
llvm::LLVMContext* ctx,
bool system_lib);
bool system_lib,
bool dynamic_lookup);
/*!
* \brief Compile and add function f to the current module.
* \param f The function to be added.
......@@ -114,7 +117,7 @@ class CodeGenLLVM :
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
// create intrinstic given call
virtual llvm::Value* CreateIntrinstic(const Call* op);
virtual llvm::Value* CreateIntrinsic(const Call* op);
// create extern function call
virtual llvm::Value* CreateCallExtern(const Call* op);
// create call into tvm packed function.
......@@ -178,6 +181,7 @@ class CodeGenLLVM :
llvm::MDNode* md_very_likely_branch_{nullptr};
llvm::MDNode* md_tbaa_root_{nullptr};
llvm::MDNode* md_tbaa_alias_set_{nullptr};
llvm::MDNode* md_tbaa_ctx_ptr_{nullptr};
// TVM related data types
llvm::Type* t_tvm_shape_index_{nullptr};
llvm::Type* t_tvm_func_handle_{nullptr};
......@@ -185,13 +189,12 @@ class CodeGenLLVM :
llvm::StructType* t_tvm_type_{nullptr};
llvm::StructType* t_tvm_array_{nullptr};
llvm::StructType* t_tvm_value_{nullptr};
llvm::FunctionType* t_f_tvm_par_for_lambda_{nullptr};
// tvm api functions
llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr};
llvm::Function* f_tvm_parallel_for_{nullptr};
llvm::Function* f_tvm_register_system_symbol_{nullptr};
llvm::FunctionType* ftype_tvm_par_for_lambda_{nullptr};
llvm::FunctionType* ftype_tvm_func_call_{nullptr};
llvm::FunctionType* ftype_tvm_get_func_from_env_{nullptr};
llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr};
llvm::FunctionType* ftype_tvm_parallel_for_{nullptr};
llvm::FunctionType* ftype_tvm_register_system_symbol_{nullptr};
// The acting body
llvm::BasicBlock* block_{nullptr};
/*! \brief native vector bits of current targetx*/
......@@ -200,6 +203,13 @@ class CodeGenLLVM :
std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_;
private:
// Get runtime functions
llvm::GlobalVariable* InitContextPtr(llvm::Type* type, std::string name);
llvm::Value* GetContextPtr(llvm::GlobalVariable* gv);
llvm::Value* RuntimeTVMFuncCall();
llvm::Value* RuntimeTVMGetFuncFromEnv();
llvm::Value* RuntimeTVMAPISetLastError();
llvm::Value* RuntimeTVMParallelFor();
// comparison op
llvm::Value* GetVarValue(const Variable* v) const;
llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
......@@ -232,7 +242,7 @@ class CodeGenLLVM :
// return the end block after the check
llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode);
// Add a function to set global module context
void InitGlobalContext();
void InitGlobalContext(bool dynamic_lookup);
// Add module startup function if needed.
void AddStartupFunction();
// add alias information.
......@@ -247,8 +257,19 @@ class CodeGenLLVM :
bool is_restricted_{true};
// set of var that are not restricted(can alias)
std::unordered_set<const Variable*> alias_var_set_;
// The local module_context
// Context for injection lookup
llvm::GlobalVariable* gv_mod_ctx_{nullptr};
llvm::GlobalVariable* gv_tvm_func_call_{nullptr};
llvm::GlobalVariable* gv_tvm_get_func_from_env_{nullptr};
llvm::GlobalVariable* gv_tvm_api_set_last_error_{nullptr};
llvm::GlobalVariable* gv_tvm_parallel_for_{nullptr};
std::unordered_map<std::string, llvm::GlobalVariable*> gv_func_map_;
// context for direct dynamic lookup
llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr};
llvm::Function* f_tvm_parallel_for_{nullptr};
llvm::Function* f_tvm_register_system_symbol_{nullptr};
// global to packed function handle
std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_;
// List of symbols to be exported to TVM system lib.
......
......@@ -113,8 +113,6 @@ GetLLVMTargetMachine(const std::string& target_str, bool allow_null) {
return tm;
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
......@@ -104,7 +104,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
ctx_ = std::make_shared<llvm::LLVMContext>();
std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_);
entry_func_ = funcs[0]->name;
cg->Init(funcs[0]->name, tm_, ctx_.get(), system_lib);
cg->Init(funcs[0]->name, tm_, ctx_.get(), system_lib, system_lib);
for (LoweredFunc f : funcs) {
cg->AddFunction(f);
}
......@@ -152,16 +152,17 @@ class LLVMModuleNode final : public runtime::ModuleNode {
<< "Failed to initialize git engine for " << mptr_->getTargetTriple();
ee_->runStaticConstructorsDestructors(false);
// setup context address.
void** ctx_addr =
reinterpret_cast<void**>(
ee_->getGlobalValueAddress(runtime::symbol::tvm_module_ctx));
// setup context address.
entry_func_ =
reinterpret_cast<const char*>(
ee_->getGlobalValueAddress(runtime::symbol::tvm_module_main));
if (ctx_addr != nullptr) {
if (void** ctx_addr = reinterpret_cast<void**>(
ee_->getGlobalValueAddress(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = this;
}
runtime::InitContextFunctions([this](const char *name) {
auto value = ee_->getGlobalValueAddress(name);
return value;
});
}
// The target configuration string
std::string target_;
......
......@@ -40,7 +40,7 @@ class DSOModuleNode final : public ModuleNode {
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name));
} else {
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name));
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str()));
}
if (faddr == nullptr) return PackedFunc();
return WrapPackedFunc(faddr, sptr_to_self);
......@@ -48,12 +48,13 @@ class DSOModuleNode final : public ModuleNode {
void Init(const std::string& name) {
Load(name);
void** ctx_addr =
reinterpret_cast<void**>(
GetSymbol(runtime::symbol::tvm_module_ctx));
if (ctx_addr != nullptr) {
if (auto *ctx_addr =
reinterpret_cast<void**>(GetSymbol(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = this;
}
InitContextFunctions([this](const char* fname) {
return GetSymbol(fname);
});
// Load the imported modules
const char* dev_mblob =
reinterpret_cast<const char*>(
......@@ -76,9 +77,9 @@ class DSOModuleNode final : public ModuleNode {
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name;
}
void* GetSymbol(const std::string& name) {
void* GetSymbol(const char* name) {
return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, (LPCSTR)name.c_str())); // NOLINT(*)
GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}
void Unload() {
FreeLibrary(lib_handle_);
......@@ -92,8 +93,8 @@ class DSOModuleNode final : public ModuleNode {
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name;
}
void* GetSymbol(const std::string& name) {
return dlsym(lib_handle_, name.c_str());
void* GetSymbol(const char* name) {
return dlsym(lib_handle_, name);
}
void Unload() {
dlclose(lib_handle_);
......
......@@ -7,6 +7,8 @@
#define TVM_RUNTIME_MODULE_UTIL_H_
#include <tvm/runtime/module.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
#include <vector>
extern "C" {
......@@ -30,6 +32,39 @@ PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr<Module
* \param module_list The module list to append to
*/
void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list);
/*!
* \brief Utility to initialize conext function symbols during startup
* \param flookup A symbol lookup function.
* \tparam FLookup a function of signature string->void*
*/
template<typename FLookup>
void InitContextFunctions(FLookup flookup) {
if (auto *fp = reinterpret_cast<decltype(&TVMFuncCall)*>
(flookup("__TVMFuncCall"))) {
*fp = TVMFuncCall;
}
if (auto *fp = reinterpret_cast<decltype(&TVMAPISetLastError)*>
(flookup("__TVMAPISetLastError"))) {
*fp = TVMAPISetLastError;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendGetFuncFromEnv)*>
(flookup("__TVMBackendGetFuncFromEnv"))) {
*fp = TVMBackendGetFuncFromEnv;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendAllocWorkspace)*>
(flookup("__TVMBackendAllocWorkspace"))) {
*fp = TVMBackendAllocWorkspace;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendFreeWorkspace)*>
(flookup("__TVMBackendFreeWorkspace"))) {
*fp = TVMBackendFreeWorkspace;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendParallelFor)*>
(flookup("__TVMBackendParallelFor"))) {
*fp = TVMBackendParallelFor;
}
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_MODULE_UTIL_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment