Commit 9037a4c2 by Tianqi Chen Committed by GitHub

[RUNTIME] Enable injection of some core runtime functions to avoid dynamic lookup (#260)

parent 6196cd50
......@@ -154,12 +154,12 @@ verilog: $(VER_LIBS)
# Special rules for LLVM related modules.
build/codegen/llvm/%.o: src/codegen/llvm/%.cc
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) $(CFLAGS) -MM -MT build/codegen/llvm/$*.o $< >build/codegen/llvm/$*.d
$(CXX) -c $(CFLAGS) $(LLVM_CFLAGS) -c $< -o $@
build/runtime/metal/%.o: src/runtime/metal/%.mm
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) $(CFLAGS) -MM -MT build/runtime/metal/$*.o $< >build/runtime/metal/$*.d
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
build/%.o: src/%.cc
......@@ -199,7 +199,7 @@ pylint:
pylint python/tvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc
pylint topi/python/topi --rcfile=$(ROOTDIR)/tests/lint/pylintrc
jnilint:
jnilint:
python dmlc-core/scripts/lint.py tvm4j-jni cpp jvm/native/src
lint: cpplint pylint jnilint
......
......@@ -41,11 +41,14 @@ class CodeGenLLVM :
* \param tm Target machine model
* \param ctx The context.
* \param system_lib Whether to insert system library registration.
* \param dynamic_lookup Whether dynamically lookup runtime function
* or use the runtime function table passed by caller.
*/
void Init(const std::string& module_name,
llvm::TargetMachine* tm,
llvm::LLVMContext* ctx,
bool system_lib);
bool system_lib,
bool dynamic_lookup);
/*!
* \brief Compile and add function f to the current module.
* \param f The function to be added.
......@@ -114,7 +117,7 @@ class CodeGenLLVM :
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
// create intrinstic given call
virtual llvm::Value* CreateIntrinstic(const Call* op);
virtual llvm::Value* CreateIntrinsic(const Call* op);
// create extern function call
virtual llvm::Value* CreateCallExtern(const Call* op);
// create call into tvm packed function.
......@@ -178,6 +181,7 @@ class CodeGenLLVM :
llvm::MDNode* md_very_likely_branch_{nullptr};
llvm::MDNode* md_tbaa_root_{nullptr};
llvm::MDNode* md_tbaa_alias_set_{nullptr};
llvm::MDNode* md_tbaa_ctx_ptr_{nullptr};
// TVM related data types
llvm::Type* t_tvm_shape_index_{nullptr};
llvm::Type* t_tvm_func_handle_{nullptr};
......@@ -185,13 +189,12 @@ class CodeGenLLVM :
llvm::StructType* t_tvm_type_{nullptr};
llvm::StructType* t_tvm_array_{nullptr};
llvm::StructType* t_tvm_value_{nullptr};
llvm::FunctionType* t_f_tvm_par_for_lambda_{nullptr};
// tvm api functions
llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr};
llvm::Function* f_tvm_parallel_for_{nullptr};
llvm::Function* f_tvm_register_system_symbol_{nullptr};
llvm::FunctionType* ftype_tvm_par_for_lambda_{nullptr};
llvm::FunctionType* ftype_tvm_func_call_{nullptr};
llvm::FunctionType* ftype_tvm_get_func_from_env_{nullptr};
llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr};
llvm::FunctionType* ftype_tvm_parallel_for_{nullptr};
llvm::FunctionType* ftype_tvm_register_system_symbol_{nullptr};
// The acting body
llvm::BasicBlock* block_{nullptr};
/*! \brief native vector bits of current targetx*/
......@@ -200,6 +203,13 @@ class CodeGenLLVM :
std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_;
private:
// Get runtime functions
llvm::GlobalVariable* InitContextPtr(llvm::Type* type, std::string name);
llvm::Value* GetContextPtr(llvm::GlobalVariable* gv);
llvm::Value* RuntimeTVMFuncCall();
llvm::Value* RuntimeTVMGetFuncFromEnv();
llvm::Value* RuntimeTVMAPISetLastError();
llvm::Value* RuntimeTVMParallelFor();
// comparison op
llvm::Value* GetVarValue(const Variable* v) const;
llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
......@@ -232,7 +242,7 @@ class CodeGenLLVM :
// return the end block after the check
llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode);
// Add a function to set global module context
void InitGlobalContext();
void InitGlobalContext(bool dynamic_lookup);
// Add module startup function if needed.
void AddStartupFunction();
// add alias information.
......@@ -247,8 +257,19 @@ class CodeGenLLVM :
bool is_restricted_{true};
// set of var that are not restricted(can alias)
std::unordered_set<const Variable*> alias_var_set_;
// The local module_context
// Context for injection lookup
llvm::GlobalVariable* gv_mod_ctx_{nullptr};
llvm::GlobalVariable* gv_tvm_func_call_{nullptr};
llvm::GlobalVariable* gv_tvm_get_func_from_env_{nullptr};
llvm::GlobalVariable* gv_tvm_api_set_last_error_{nullptr};
llvm::GlobalVariable* gv_tvm_parallel_for_{nullptr};
std::unordered_map<std::string, llvm::GlobalVariable*> gv_func_map_;
// context for direct dynamic lookup
llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr};
llvm::Function* f_tvm_parallel_for_{nullptr};
llvm::Function* f_tvm_register_system_symbol_{nullptr};
// global to packed function handle
std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_;
// List of symbols to be exported to TVM system lib.
......
......@@ -113,8 +113,6 @@ GetLLVMTargetMachine(const std::string& target_str, bool allow_null) {
return tm;
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
......@@ -104,7 +104,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
ctx_ = std::make_shared<llvm::LLVMContext>();
std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_);
entry_func_ = funcs[0]->name;
cg->Init(funcs[0]->name, tm_, ctx_.get(), system_lib);
cg->Init(funcs[0]->name, tm_, ctx_.get(), system_lib, system_lib);
for (LoweredFunc f : funcs) {
cg->AddFunction(f);
}
......@@ -152,16 +152,17 @@ class LLVMModuleNode final : public runtime::ModuleNode {
<< "Failed to initialize git engine for " << mptr_->getTargetTriple();
ee_->runStaticConstructorsDestructors(false);
// setup context address.
void** ctx_addr =
reinterpret_cast<void**>(
ee_->getGlobalValueAddress(runtime::symbol::tvm_module_ctx));
// setup context address.
entry_func_ =
reinterpret_cast<const char*>(
ee_->getGlobalValueAddress(runtime::symbol::tvm_module_main));
if (ctx_addr != nullptr) {
if (void** ctx_addr = reinterpret_cast<void**>(
ee_->getGlobalValueAddress(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = this;
}
runtime::InitContextFunctions([this](const char *name) {
auto value = ee_->getGlobalValueAddress(name);
return value;
});
}
// The target configuration string
std::string target_;
......
......@@ -40,7 +40,7 @@ class DSOModuleNode final : public ModuleNode {
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name));
} else {
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name));
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str()));
}
if (faddr == nullptr) return PackedFunc();
return WrapPackedFunc(faddr, sptr_to_self);
......@@ -48,12 +48,13 @@ class DSOModuleNode final : public ModuleNode {
void Init(const std::string& name) {
Load(name);
void** ctx_addr =
reinterpret_cast<void**>(
GetSymbol(runtime::symbol::tvm_module_ctx));
if (ctx_addr != nullptr) {
if (auto *ctx_addr =
reinterpret_cast<void**>(GetSymbol(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = this;
}
InitContextFunctions([this](const char* fname) {
return GetSymbol(fname);
});
// Load the imported modules
const char* dev_mblob =
reinterpret_cast<const char*>(
......@@ -76,9 +77,9 @@ class DSOModuleNode final : public ModuleNode {
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name;
}
void* GetSymbol(const std::string& name) {
void* GetSymbol(const char* name) {
return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, (LPCSTR)name.c_str())); // NOLINT(*)
GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}
void Unload() {
FreeLibrary(lib_handle_);
......@@ -92,8 +93,8 @@ class DSOModuleNode final : public ModuleNode {
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name;
}
void* GetSymbol(const std::string& name) {
return dlsym(lib_handle_, name.c_str());
void* GetSymbol(const char* name) {
return dlsym(lib_handle_, name);
}
void Unload() {
dlclose(lib_handle_);
......
......@@ -7,6 +7,8 @@
#define TVM_RUNTIME_MODULE_UTIL_H_
#include <tvm/runtime/module.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
#include <vector>
extern "C" {
......@@ -30,6 +32,39 @@ PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr<Module
* \param module_list The module list to append to
*/
void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list);
/*!
* \brief Utility to initialize conext function symbols during startup
* \param flookup A symbol lookup function.
* \tparam FLookup a function of signature string->void*
*/
template<typename FLookup>
void InitContextFunctions(FLookup flookup) {
if (auto *fp = reinterpret_cast<decltype(&TVMFuncCall)*>
(flookup("__TVMFuncCall"))) {
*fp = TVMFuncCall;
}
if (auto *fp = reinterpret_cast<decltype(&TVMAPISetLastError)*>
(flookup("__TVMAPISetLastError"))) {
*fp = TVMAPISetLastError;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendGetFuncFromEnv)*>
(flookup("__TVMBackendGetFuncFromEnv"))) {
*fp = TVMBackendGetFuncFromEnv;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendAllocWorkspace)*>
(flookup("__TVMBackendAllocWorkspace"))) {
*fp = TVMBackendAllocWorkspace;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendFreeWorkspace)*>
(flookup("__TVMBackendFreeWorkspace"))) {
*fp = TVMBackendFreeWorkspace;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendParallelFor)*>
(flookup("__TVMBackendParallelFor"))) {
*fp = TVMBackendParallelFor;
}
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_MODULE_UTIL_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment