Commit 9037a4c2 by Tianqi Chen Committed by GitHub

[RUNTIME] Enable injection of some core runtime functions to avoid dynamic lookup (#260)

parent 6196cd50
...@@ -154,12 +154,12 @@ verilog: $(VER_LIBS) ...@@ -154,12 +154,12 @@ verilog: $(VER_LIBS)
# Special rules for LLVM related modules. # Special rules for LLVM related modules.
build/codegen/llvm/%.o: src/codegen/llvm/%.cc build/codegen/llvm/%.o: src/codegen/llvm/%.cc
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d $(CXX) $(CFLAGS) -MM -MT build/codegen/llvm/$*.o $< >build/codegen/llvm/$*.d
$(CXX) -c $(CFLAGS) $(LLVM_CFLAGS) -c $< -o $@ $(CXX) -c $(CFLAGS) $(LLVM_CFLAGS) -c $< -o $@
build/runtime/metal/%.o: src/runtime/metal/%.mm build/runtime/metal/%.o: src/runtime/metal/%.mm
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d $(CXX) $(CFLAGS) -MM -MT build/runtime/metal/$*.o $< >build/runtime/metal/$*.d
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@ $(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
build/%.o: src/%.cc build/%.o: src/%.cc
...@@ -199,7 +199,7 @@ pylint: ...@@ -199,7 +199,7 @@ pylint:
pylint python/tvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc pylint python/tvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc
pylint topi/python/topi --rcfile=$(ROOTDIR)/tests/lint/pylintrc pylint topi/python/topi --rcfile=$(ROOTDIR)/tests/lint/pylintrc
jnilint: jnilint:
python dmlc-core/scripts/lint.py tvm4j-jni cpp jvm/native/src python dmlc-core/scripts/lint.py tvm4j-jni cpp jvm/native/src
lint: cpplint pylint jnilint lint: cpplint pylint jnilint
......
...@@ -41,11 +41,14 @@ class CodeGenLLVM : ...@@ -41,11 +41,14 @@ class CodeGenLLVM :
* \param tm Target machine model * \param tm Target machine model
* \param ctx The context. * \param ctx The context.
* \param system_lib Whether to insert system library registration. * \param system_lib Whether to insert system library registration.
* \param dynamic_lookup Whether dynamically lookup runtime function
* or use the runtime function table passed by caller.
*/ */
void Init(const std::string& module_name, void Init(const std::string& module_name,
llvm::TargetMachine* tm, llvm::TargetMachine* tm,
llvm::LLVMContext* ctx, llvm::LLVMContext* ctx,
bool system_lib); bool system_lib,
bool dynamic_lookup);
/*! /*!
* \brief Compile and add function f to the current module. * \brief Compile and add function f to the current module.
* \param f The function to be added. * \param f The function to be added.
...@@ -114,7 +117,7 @@ class CodeGenLLVM : ...@@ -114,7 +117,7 @@ class CodeGenLLVM :
void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override; void VisitStmt_(const ProducerConsumer* op) override;
// create intrinstic given call // create intrinstic given call
virtual llvm::Value* CreateIntrinstic(const Call* op); virtual llvm::Value* CreateIntrinsic(const Call* op);
// create extern function call // create extern function call
virtual llvm::Value* CreateCallExtern(const Call* op); virtual llvm::Value* CreateCallExtern(const Call* op);
// create call into tvm packed function. // create call into tvm packed function.
...@@ -178,6 +181,7 @@ class CodeGenLLVM : ...@@ -178,6 +181,7 @@ class CodeGenLLVM :
llvm::MDNode* md_very_likely_branch_{nullptr}; llvm::MDNode* md_very_likely_branch_{nullptr};
llvm::MDNode* md_tbaa_root_{nullptr}; llvm::MDNode* md_tbaa_root_{nullptr};
llvm::MDNode* md_tbaa_alias_set_{nullptr}; llvm::MDNode* md_tbaa_alias_set_{nullptr};
llvm::MDNode* md_tbaa_ctx_ptr_{nullptr};
// TVM related data types // TVM related data types
llvm::Type* t_tvm_shape_index_{nullptr}; llvm::Type* t_tvm_shape_index_{nullptr};
llvm::Type* t_tvm_func_handle_{nullptr}; llvm::Type* t_tvm_func_handle_{nullptr};
...@@ -185,13 +189,12 @@ class CodeGenLLVM : ...@@ -185,13 +189,12 @@ class CodeGenLLVM :
llvm::StructType* t_tvm_type_{nullptr}; llvm::StructType* t_tvm_type_{nullptr};
llvm::StructType* t_tvm_array_{nullptr}; llvm::StructType* t_tvm_array_{nullptr};
llvm::StructType* t_tvm_value_{nullptr}; llvm::StructType* t_tvm_value_{nullptr};
llvm::FunctionType* t_f_tvm_par_for_lambda_{nullptr}; llvm::FunctionType* ftype_tvm_par_for_lambda_{nullptr};
// tvm api functions llvm::FunctionType* ftype_tvm_func_call_{nullptr};
llvm::Function* f_tvm_func_call_{nullptr}; llvm::FunctionType* ftype_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_get_func_from_env_{nullptr}; llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr}; llvm::FunctionType* ftype_tvm_parallel_for_{nullptr};
llvm::Function* f_tvm_parallel_for_{nullptr}; llvm::FunctionType* ftype_tvm_register_system_symbol_{nullptr};
llvm::Function* f_tvm_register_system_symbol_{nullptr};
// The acting body // The acting body
llvm::BasicBlock* block_{nullptr}; llvm::BasicBlock* block_{nullptr};
/*! \brief native vector bits of current targetx*/ /*! \brief native vector bits of current targetx*/
...@@ -200,6 +203,13 @@ class CodeGenLLVM : ...@@ -200,6 +203,13 @@ class CodeGenLLVM :
std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_; std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_;
private: private:
// Get runtime functions
llvm::GlobalVariable* InitContextPtr(llvm::Type* type, std::string name);
llvm::Value* GetContextPtr(llvm::GlobalVariable* gv);
llvm::Value* RuntimeTVMFuncCall();
llvm::Value* RuntimeTVMGetFuncFromEnv();
llvm::Value* RuntimeTVMAPISetLastError();
llvm::Value* RuntimeTVMParallelFor();
// comparison op // comparison op
llvm::Value* GetVarValue(const Variable* v) const; llvm::Value* GetVarValue(const Variable* v) const;
llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
...@@ -232,7 +242,7 @@ class CodeGenLLVM : ...@@ -232,7 +242,7 @@ class CodeGenLLVM :
// return the end block after the check // return the end block after the check
llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode); llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode);
// Add a function to set global module context // Add a function to set global module context
void InitGlobalContext(); void InitGlobalContext(bool dynamic_lookup);
// Add module startup function if needed. // Add module startup function if needed.
void AddStartupFunction(); void AddStartupFunction();
// add alias information. // add alias information.
...@@ -247,8 +257,19 @@ class CodeGenLLVM : ...@@ -247,8 +257,19 @@ class CodeGenLLVM :
bool is_restricted_{true}; bool is_restricted_{true};
// set of var that are not restricted(can alias) // set of var that are not restricted(can alias)
std::unordered_set<const Variable*> alias_var_set_; std::unordered_set<const Variable*> alias_var_set_;
// The local module_context // Context for injection lookup
llvm::GlobalVariable* gv_mod_ctx_{nullptr}; llvm::GlobalVariable* gv_mod_ctx_{nullptr};
llvm::GlobalVariable* gv_tvm_func_call_{nullptr};
llvm::GlobalVariable* gv_tvm_get_func_from_env_{nullptr};
llvm::GlobalVariable* gv_tvm_api_set_last_error_{nullptr};
llvm::GlobalVariable* gv_tvm_parallel_for_{nullptr};
std::unordered_map<std::string, llvm::GlobalVariable*> gv_func_map_;
// context for direct dynamic lookup
llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr};
llvm::Function* f_tvm_parallel_for_{nullptr};
llvm::Function* f_tvm_register_system_symbol_{nullptr};
// global to packed function handle // global to packed function handle
std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_; std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_;
// List of symbols to be exported to TVM system lib. // List of symbols to be exported to TVM system lib.
......
...@@ -113,8 +113,6 @@ GetLLVMTargetMachine(const std::string& target_str, bool allow_null) { ...@@ -113,8 +113,6 @@ GetLLVMTargetMachine(const std::string& target_str, bool allow_null) {
return tm; return tm;
} }
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_LLVM_VERSION #endif // TVM_LLVM_VERSION
...@@ -104,7 +104,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -104,7 +104,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
ctx_ = std::make_shared<llvm::LLVMContext>(); ctx_ = std::make_shared<llvm::LLVMContext>();
std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_); std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_);
entry_func_ = funcs[0]->name; entry_func_ = funcs[0]->name;
cg->Init(funcs[0]->name, tm_, ctx_.get(), system_lib); cg->Init(funcs[0]->name, tm_, ctx_.get(), system_lib, system_lib);
for (LoweredFunc f : funcs) { for (LoweredFunc f : funcs) {
cg->AddFunction(f); cg->AddFunction(f);
} }
...@@ -152,16 +152,17 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -152,16 +152,17 @@ class LLVMModuleNode final : public runtime::ModuleNode {
<< "Failed to initialize git engine for " << mptr_->getTargetTriple(); << "Failed to initialize git engine for " << mptr_->getTargetTriple();
ee_->runStaticConstructorsDestructors(false); ee_->runStaticConstructorsDestructors(false);
// setup context address. // setup context address.
void** ctx_addr =
reinterpret_cast<void**>(
ee_->getGlobalValueAddress(runtime::symbol::tvm_module_ctx));
// setup context address.
entry_func_ = entry_func_ =
reinterpret_cast<const char*>( reinterpret_cast<const char*>(
ee_->getGlobalValueAddress(runtime::symbol::tvm_module_main)); ee_->getGlobalValueAddress(runtime::symbol::tvm_module_main));
if (ctx_addr != nullptr) { if (void** ctx_addr = reinterpret_cast<void**>(
ee_->getGlobalValueAddress(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = this; *ctx_addr = this;
} }
runtime::InitContextFunctions([this](const char *name) {
auto value = ee_->getGlobalValueAddress(name);
return value;
});
} }
// The target configuration string // The target configuration string
std::string target_; std::string target_;
......
...@@ -40,7 +40,7 @@ class DSOModuleNode final : public ModuleNode { ...@@ -40,7 +40,7 @@ class DSOModuleNode final : public ModuleNode {
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; << "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name)); faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name));
} else { } else {
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name)); faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str()));
} }
if (faddr == nullptr) return PackedFunc(); if (faddr == nullptr) return PackedFunc();
return WrapPackedFunc(faddr, sptr_to_self); return WrapPackedFunc(faddr, sptr_to_self);
...@@ -48,12 +48,13 @@ class DSOModuleNode final : public ModuleNode { ...@@ -48,12 +48,13 @@ class DSOModuleNode final : public ModuleNode {
void Init(const std::string& name) { void Init(const std::string& name) {
Load(name); Load(name);
void** ctx_addr = if (auto *ctx_addr =
reinterpret_cast<void**>( reinterpret_cast<void**>(GetSymbol(runtime::symbol::tvm_module_ctx))) {
GetSymbol(runtime::symbol::tvm_module_ctx));
if (ctx_addr != nullptr) {
*ctx_addr = this; *ctx_addr = this;
} }
InitContextFunctions([this](const char* fname) {
return GetSymbol(fname);
});
// Load the imported modules // Load the imported modules
const char* dev_mblob = const char* dev_mblob =
reinterpret_cast<const char*>( reinterpret_cast<const char*>(
...@@ -76,9 +77,9 @@ class DSOModuleNode final : public ModuleNode { ...@@ -76,9 +77,9 @@ class DSOModuleNode final : public ModuleNode {
CHECK(lib_handle_ != nullptr) CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name; << "Failed to load dynamic shared library " << name;
} }
void* GetSymbol(const std::string& name) { void* GetSymbol(const char* name) {
return reinterpret_cast<void*>( return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, (LPCSTR)name.c_str())); // NOLINT(*) GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
} }
void Unload() { void Unload() {
FreeLibrary(lib_handle_); FreeLibrary(lib_handle_);
...@@ -92,8 +93,8 @@ class DSOModuleNode final : public ModuleNode { ...@@ -92,8 +93,8 @@ class DSOModuleNode final : public ModuleNode {
CHECK(lib_handle_ != nullptr) CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name; << "Failed to load dynamic shared library " << name;
} }
void* GetSymbol(const std::string& name) { void* GetSymbol(const char* name) {
return dlsym(lib_handle_, name.c_str()); return dlsym(lib_handle_, name);
} }
void Unload() { void Unload() {
dlclose(lib_handle_); dlclose(lib_handle_);
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#define TVM_RUNTIME_MODULE_UTIL_H_ #define TVM_RUNTIME_MODULE_UTIL_H_
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
#include <vector> #include <vector>
extern "C" { extern "C" {
...@@ -30,6 +32,39 @@ PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr<Module ...@@ -30,6 +32,39 @@ PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr<Module
* \param module_list The module list to append to * \param module_list The module list to append to
*/ */
void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list); void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list);
/*!
* \brief Utility to initialize conext function symbols during startup
* \param flookup A symbol lookup function.
* \tparam FLookup a function of signature string->void*
*/
template<typename FLookup>
void InitContextFunctions(FLookup flookup) {
if (auto *fp = reinterpret_cast<decltype(&TVMFuncCall)*>
(flookup("__TVMFuncCall"))) {
*fp = TVMFuncCall;
}
if (auto *fp = reinterpret_cast<decltype(&TVMAPISetLastError)*>
(flookup("__TVMAPISetLastError"))) {
*fp = TVMAPISetLastError;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendGetFuncFromEnv)*>
(flookup("__TVMBackendGetFuncFromEnv"))) {
*fp = TVMBackendGetFuncFromEnv;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendAllocWorkspace)*>
(flookup("__TVMBackendAllocWorkspace"))) {
*fp = TVMBackendAllocWorkspace;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendFreeWorkspace)*>
(flookup("__TVMBackendFreeWorkspace"))) {
*fp = TVMBackendFreeWorkspace;
}
if (auto *fp = reinterpret_cast<decltype(&TVMBackendParallelFor)*>
(flookup("__TVMBackendParallelFor"))) {
*fp = TVMBackendParallelFor;
}
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_MODULE_UTIL_H_ #endif // TVM_RUNTIME_MODULE_UTIL_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment