Unverified Commit 203d2c92 by Tianqi Chen Committed by GitHub

[REFACTOR][RUNTIME] Add LibraryModule that merges systemlib and dso. (#4481)

Historically we have two variations of modules(DSOModule and SystemLibModule)
that both exposes module via symbols.

This PR creates a common implementation for both, and introduce a Library
base class that allows us to have different implementations of GetSymbol.

It paves ways for future library related module enhancements.
parent 18592c8d
...@@ -27,12 +27,12 @@ ...@@ -27,12 +27,12 @@
#include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/c_runtime_api.cc"
#include "../src/runtime/cpu_device_api.cc" #include "../src/runtime/cpu_device_api.cc"
#include "../src/runtime/workspace_pool.cc" #include "../src/runtime/workspace_pool.cc"
#include "../src/runtime/module_util.cc" #include "../src/runtime/library_module.cc"
#include "../src/runtime/system_lib_module.cc" #include "../src/runtime/system_library.cc"
#include "../src/runtime/module.cc" #include "../src/runtime/module.cc"
#include "../src/runtime/registry.cc" #include "../src/runtime/registry.cc"
#include "../src/runtime/file_util.cc" #include "../src/runtime/file_util.cc"
#include "../src/runtime/dso_module.cc" #include "../src/runtime/dso_library.cc"
#include "../src/runtime/thread_pool.cc" #include "../src/runtime/thread_pool.cc"
#include "../src/runtime/object.cc" #include "../src/runtime/object.cc"
#include "../src/runtime/threading_backend.cc" #include "../src/runtime/threading_backend.cc"
......
...@@ -39,12 +39,12 @@ ...@@ -39,12 +39,12 @@
#include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/c_runtime_api.cc"
#include "../src/runtime/cpu_device_api.cc" #include "../src/runtime/cpu_device_api.cc"
#include "../src/runtime/workspace_pool.cc" #include "../src/runtime/workspace_pool.cc"
#include "../src/runtime/module_util.cc" #include "../src/runtime/library_module.cc"
#include "../src/runtime/system_lib_module.cc" #include "../src/runtime/system_library.cc"
#include "../src/runtime/module.cc" #include "../src/runtime/module.cc"
#include "../src/runtime/registry.cc" #include "../src/runtime/registry.cc"
#include "../src/runtime/file_util.cc" #include "../src/runtime/file_util.cc"
#include "../src/runtime/dso_module.cc" #include "../src/runtime/dso_library.cc"
#include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_session.cc"
#include "../src/runtime/rpc/rpc_event_impl.cc" #include "../src/runtime/rpc/rpc_event_impl.cc"
#include "../src/runtime/rpc/rpc_server_env.cc" #include "../src/runtime/rpc/rpc_server_env.cc"
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
# to you under the Apache License, Version 2.0 (the # to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance # "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at # with the License. You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -29,4 +29,4 @@ echo "Run the cpp deployment with all in normal library..." ...@@ -29,4 +29,4 @@ echo "Run the cpp deployment with all in normal library..."
lib/cpp_deploy_normal lib/cpp_deploy_normal
echo "Run the python deployment with all in normal library..." echo "Run the python deployment with all in normal library..."
python python_deploy.py python3 python_deploy.py
...@@ -40,7 +40,7 @@ ...@@ -40,7 +40,7 @@
#include "../../src/runtime/c_runtime_api.cc" #include "../../src/runtime/c_runtime_api.cc"
#include "../../src/runtime/cpu_device_api.cc" #include "../../src/runtime/cpu_device_api.cc"
#include "../../src/runtime/workspace_pool.cc" #include "../../src/runtime/workspace_pool.cc"
#include "../../src/runtime/module_util.cc" #include "../../src/runtime/library_module.cc"
#include "../../src/runtime/module.cc" #include "../../src/runtime/module.cc"
#include "../../src/runtime/registry.cc" #include "../../src/runtime/registry.cc"
#include "../../src/runtime/file_util.cc" #include "../../src/runtime/file_util.cc"
...@@ -55,8 +55,8 @@ ...@@ -55,8 +55,8 @@
// Likely we only need to enable one of the following // Likely we only need to enable one of the following
// If you use Module::Load, use dso_module // If you use Module::Load, use dso_module
// For system packed library, use system_lib_module // For system packed library, use system_lib_module
#include "../../src/runtime/dso_module.cc" #include "../../src/runtime/dso_library.cc"
#include "../../src/runtime/system_lib_module.cc" #include "../../src/runtime/system_library.cc"
// Graph runtime // Graph runtime
#include "../../src/runtime/graph/graph_runtime.cc" #include "../../src/runtime/graph/graph_runtime.cc"
......
...@@ -27,12 +27,12 @@ ...@@ -27,12 +27,12 @@
#include "../../../src/runtime/workspace_pool.cc" #include "../../../src/runtime/workspace_pool.cc"
#include "../../../src/runtime/thread_pool.cc" #include "../../../src/runtime/thread_pool.cc"
#include "../../../src/runtime/threading_backend.cc" #include "../../../src/runtime/threading_backend.cc"
#include "../../../src/runtime/module_util.cc" #include "../../../src/runtime/library_module.cc"
#include "../../../src/runtime/system_lib_module.cc" #include "../../../src/runtime/system_library.cc"
#include "../../../src/runtime/module.cc" #include "../../../src/runtime/module.cc"
#include "../../../src/runtime/registry.cc" #include "../../../src/runtime/registry.cc"
#include "../../../src/runtime/file_util.cc" #include "../../../src/runtime/file_util.cc"
#include "../../../src/runtime/dso_module.cc" #include "../../../src/runtime/dso_library.cc"
#include "../../../src/runtime/ndarray.cc" #include "../../../src/runtime/ndarray.cc"
#include "../../../src/runtime/object.cc" #include "../../../src/runtime/object.cc"
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "src/runtime/c_runtime_api.cc" #include "src/runtime/c_runtime_api.cc"
#include "src/runtime/cpu_device_api.cc" #include "src/runtime/cpu_device_api.cc"
#include "src/runtime/workspace_pool.cc" #include "src/runtime/workspace_pool.cc"
#include "src/runtime/module_util.cc" #include "src/runtime/library_module.cc"
#include "src/runtime/module.cc" #include "src/runtime/module.cc"
#include "src/runtime/registry.cc" #include "src/runtime/registry.cc"
#include "src/runtime/file_util.cc" #include "src/runtime/file_util.cc"
...@@ -39,8 +39,8 @@ ...@@ -39,8 +39,8 @@
// Likely we only need to enable one of the following // Likely we only need to enable one of the following
// If you use Module::Load, use dso_module // If you use Module::Load, use dso_module
// For system packed library, use system_lib_module // For system packed library, use system_lib_module
#include "src/runtime/dso_module.cc" #include "src/runtime/dso_library.cc"
#include "src/runtime/system_lib_module.cc" #include "src/runtime/system_library.cc"
// Graph runtime // Graph runtime
#include "src/runtime/graph/graph_runtime.cc" #include "src/runtime/graph/graph_runtime.cc"
......
...@@ -201,6 +201,7 @@ class ModuleNode : public Object { ...@@ -201,6 +201,7 @@ class ModuleNode : public Object {
protected: protected:
friend class Module; friend class Module;
friend class ModuleInternal;
/*! \brief The modules this module depend on */ /*! \brief The modules this module depend on */
std::vector<Module> imports_; std::vector<Module> imports_;
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include "llvm_common.h" #include "llvm_common.h"
#include "codegen_llvm.h" #include "codegen_llvm.h"
#include "../../runtime/file_util.h" #include "../../runtime/file_util.h"
#include "../../runtime/module_util.h" #include "../../runtime/library_module.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -286,7 +286,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -286,7 +286,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
*ctx_addr = this; *ctx_addr = this;
} }
runtime::InitContextFunctions([this](const char *name) { runtime::InitContextFunctions([this](const char *name) {
return GetGlobalAddr(name); return reinterpret_cast<void*>(GetGlobalAddr(name));
}); });
} }
// Get global address from execution engine. // Get global address from execution engine.
......
...@@ -18,14 +18,14 @@ ...@@ -18,14 +18,14 @@
*/ */
/*! /*!
* \file dso_module.cc * \file dso_libary.cc
* \brief Module to load from dynamic shared library. * \brief Create library module to load from dynamic shared library.
*/ */
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/memory.h> #include <tvm/runtime/memory.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include "module_util.h" #include "library_module.h"
#if defined(_WIN32) #if defined(_WIN32)
#include <windows.h> #include <windows.h>
...@@ -36,51 +36,19 @@ ...@@ -36,51 +36,19 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
// Module to load from dynamic shared libary. // Dynamic shared libary.
// This is the default module TVM used for host-side AOT // This is the default module TVM used for host-side AOT
class DSOModuleNode final : public ModuleNode { class DSOLibrary final : public Library {
public: public:
~DSOModuleNode() { ~DSOLibrary() {
if (lib_handle_) Unload(); if (lib_handle_) Unload();
} }
const char* type_key() const final {
return "dso";
}
PackedFunc GetFunction(
const std::string& name,
const ObjectPtr<Object>& sptr_to_self) final {
BackendPackedCFunc faddr;
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::tvm_module_main));
CHECK(entry_name!= nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name));
} else {
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str()));
}
if (faddr == nullptr) return PackedFunc();
return WrapPackedFunc(faddr, sptr_to_self);
}
void Init(const std::string& name) { void Init(const std::string& name) {
Load(name); Load(name);
if (auto *ctx_addr = }
reinterpret_cast<void**>(GetSymbol(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = this; void* GetSymbol(const char* name) final {
} return GetSymbol_(name);
InitContextFunctions([this](const char* fname) {
return GetSymbol(fname);
});
// Load the imported modules
const char* dev_mblob =
reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::tvm_dev_mblob));
if (dev_mblob != nullptr) {
ImportModuleBlob(dev_mblob, &imports_);
}
} }
private: private:
...@@ -88,6 +56,12 @@ class DSOModuleNode final : public ModuleNode { ...@@ -88,6 +56,12 @@ class DSOModuleNode final : public ModuleNode {
#if defined(_WIN32) #if defined(_WIN32)
// library handle // library handle
HMODULE lib_handle_{nullptr}; HMODULE lib_handle_{nullptr};
void* GetSymbol_(const char* name) {
return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}
// Load the library // Load the library
void Load(const std::string& name) { void Load(const std::string& name) {
// use wstring version that is needed by LLVM. // use wstring version that is needed by LLVM.
...@@ -96,12 +70,10 @@ class DSOModuleNode final : public ModuleNode { ...@@ -96,12 +70,10 @@ class DSOModuleNode final : public ModuleNode {
CHECK(lib_handle_ != nullptr) CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name; << "Failed to load dynamic shared library " << name;
} }
void* GetSymbol(const char* name) {
return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}
void Unload() { void Unload() {
FreeLibrary(lib_handle_); FreeLibrary(lib_handle_);
lib_handle_ = nullptr;
} }
#else #else
// Library handle // Library handle
...@@ -113,20 +85,23 @@ class DSOModuleNode final : public ModuleNode { ...@@ -113,20 +85,23 @@ class DSOModuleNode final : public ModuleNode {
<< "Failed to load dynamic shared library " << name << "Failed to load dynamic shared library " << name
<< " " << dlerror(); << " " << dlerror();
} }
void* GetSymbol(const char* name) {
void* GetSymbol_(const char* name) {
return dlsym(lib_handle_, name); return dlsym(lib_handle_, name);
} }
void Unload() { void Unload() {
dlclose(lib_handle_); dlclose(lib_handle_);
lib_handle_ = nullptr;
} }
#endif #endif
}; };
TVM_REGISTER_GLOBAL("module.loadfile_so") TVM_REGISTER_GLOBAL("module.loadfile_so")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
auto n = make_object<DSOModuleNode>(); auto n = make_object<DSOLibrary>();
n->Init(args[0]); n->Init(args[0]);
*rv = runtime::Module(n); *rv = CreateModuleFromLibrary(n);
}); });
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -27,12 +27,89 @@ ...@@ -27,12 +27,89 @@
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <string> #include <string>
#include <memory> #include <vector>
#include "module_util.h" #include "library_module.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
// Library module that exposes symbols from a library.
class LibraryModuleNode final : public ModuleNode {
public:
explicit LibraryModuleNode(ObjectPtr<Library> lib)
: lib_(lib) {
}
const char* type_key() const final {
return "library";
}
PackedFunc GetFunction(
const std::string& name,
const ObjectPtr<Object>& sptr_to_self) final {
BackendPackedCFunc faddr;
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
lib_->GetSymbol(runtime::symbol::tvm_module_main));
CHECK(entry_name!= nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(lib_->GetSymbol(entry_name));
} else {
faddr = reinterpret_cast<BackendPackedCFunc>(lib_->GetSymbol(name.c_str()));
}
if (faddr == nullptr) return PackedFunc();
return WrapPackedFunc(faddr, sptr_to_self);
}
private:
ObjectPtr<Library> lib_;
};
/*!
* \brief Helper classes to get into internal of a module.
*/
class ModuleInternal {
public:
// Get mutable reference of imports.
static std::vector<Module>* GetImportsAddr(ModuleNode* node) {
return &(node->imports_);
}
};
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr,
const ObjectPtr<Object>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
int ret = (*faddr)(
const_cast<TVMValue*>(args.values),
const_cast<int*>(args.type_codes),
args.num_args);
CHECK_EQ(ret, 0) << TVMGetLastError();
});
}
void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) {
#define TVM_INIT_CONTEXT_FUNC(FuncName) \
if (auto *fp = reinterpret_cast<decltype(&FuncName)*> \
(fgetsymbol("__" #FuncName))) { \
*fp = FuncName; \
}
// Initialize the functions
TVM_INIT_CONTEXT_FUNC(TVMFuncCall);
TVM_INIT_CONTEXT_FUNC(TVMAPISetLastError);
TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv);
TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace);
TVM_INIT_CONTEXT_FUNC(TVMBackendFreeWorkspace);
TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch);
TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier);
#undef TVM_INIT_CONTEXT_FUNC
}
/*!
* \brief Load and append module blob to module list
* \param mblob The module blob.
* \param module_list The module list to append to
*/
void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) { void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
#ifndef _LIBCPP_SGX_CONFIG #ifndef _LIBCPP_SGX_CONFIG
CHECK(mblob != nullptr); CHECK(mblob != nullptr);
...@@ -62,16 +139,27 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) { ...@@ -62,16 +139,27 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
#endif #endif
} }
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, Module CreateModuleFromLibrary(ObjectPtr<Library> lib) {
const ObjectPtr<Object>& sptr_to_self) { InitContextFunctions([lib](const char* fname) {
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { return lib->GetSymbol(fname);
int ret = (*faddr)(
const_cast<TVMValue*>(args.values),
const_cast<int*>(args.type_codes),
args.num_args);
CHECK_EQ(ret, 0) << TVMGetLastError();
}); });
} auto n = make_object<LibraryModuleNode>(lib);
// Load the imported modules
const char* dev_mblob =
reinterpret_cast<const char*>(
lib->GetSymbol(runtime::symbol::tvm_dev_mblob));
if (dev_mblob != nullptr) {
ImportModuleBlob(
dev_mblob, ModuleInternal::GetImportsAddr(n.operator->()));
}
Module root_mod = Module(n);
// allow lookup of symbol from root(so all symbols are visible).
if (auto *ctx_addr =
reinterpret_cast<void**>(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = root_mod.operator->();
}
return root_mod;
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -18,17 +18,16 @@ ...@@ -18,17 +18,16 @@
*/ */
/*! /*!
* \file module_util.h * \file library_module.h
* \brief Helper utilities for module building * \brief Module that builds from a libary of symbols.
*/ */
#ifndef TVM_RUNTIME_MODULE_UTIL_H_ #ifndef TVM_RUNTIME_LIBRARY_MODULE_H_
#define TVM_RUNTIME_MODULE_UTIL_H_ #define TVM_RUNTIME_LIBRARY_MODULE_H_
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h> #include <tvm/runtime/c_backend_api.h>
#include <memory> #include <functional>
#include <vector>
extern "C" { extern "C" {
// Function signature for generated packed function in shared library // Function signature for generated packed function in shared library
...@@ -40,41 +39,47 @@ typedef int (*BackendPackedCFunc)(void* args, ...@@ -40,41 +39,47 @@ typedef int (*BackendPackedCFunc)(void* args,
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
/*! /*!
* \brief Library is the common interface
* for storing data in the form of shared libaries.
*
* \sa dso_library.cc
* \sa system_library.cc
*/
class Library : public Object {
public:
/*!
* \brief Get the symbol address for a given name.
* \param name The name of the symbol.
* \return The symbol.
*/
virtual void *GetSymbol(const char* name) = 0;
// NOTE: we do not explicitly create an type index and type_key here for libary.
// This is because we do not need dynamic type downcasting.
};
/*!
* \brief Wrap a BackendPackedCFunc to packed function. * \brief Wrap a BackendPackedCFunc to packed function.
* \param faddr The function address * \param faddr The function address
* \param mptr The module pointer node. * \param mptr The module pointer node.
*/ */
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const ObjectPtr<Object>& mptr); PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const ObjectPtr<Object>& mptr);
/*!
* \brief Load and append module blob to module list
* \param mblob The module blob.
* \param module_list The module list to append to
*/
void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list);
/*! /*!
* \brief Utility to initialize conext function symbols during startup * \brief Utility to initialize conext function symbols during startup
* \param flookup A symbol lookup function. * \param fgetsymbol A symbol lookup function.
* \tparam FLookup a function of signature string->void*
*/ */
template<typename FLookup> void InitContextFunctions(std::function<void*(const char*)> fgetsymbol);
void InitContextFunctions(FLookup flookup) {
#define TVM_INIT_CONTEXT_FUNC(FuncName) \
if (auto *fp = reinterpret_cast<decltype(&FuncName)*> \
(flookup("__" #FuncName))) { \
*fp = FuncName; \
}
// Initialize the functions
TVM_INIT_CONTEXT_FUNC(TVMFuncCall);
TVM_INIT_CONTEXT_FUNC(TVMAPISetLastError);
TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv);
TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace);
TVM_INIT_CONTEXT_FUNC(TVMBackendFreeWorkspace);
TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch);
TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier);
#undef TVM_INIT_CONTEXT_FUNC /*!
} * \brief Create a module from a library.
*
* \param lib The library.
* \return The corresponding loaded module.
*
* \note This function can create multiple linked modules
* by parsing the binary blob section of the library.
*/
Module CreateModuleFromLibrary(ObjectPtr<Library> lib);
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_MODULE_UTIL_H_ #endif // TVM_RUNTIME_LIBRARY_MODULE_H_
...@@ -28,7 +28,6 @@ ...@@ -28,7 +28,6 @@
#include <unordered_map> #include <unordered_map>
#include "stackvm_module.h" #include "stackvm_module.h"
#include "../file_util.h" #include "../file_util.h"
#include "../module_util.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
......
...@@ -18,73 +18,46 @@ ...@@ -18,73 +18,46 @@
*/ */
/*! /*!
* \file system_lib_module.cc * \file system_library.cc
* \brief SystemLib module. * \brief Create library module that directly get symbol from the system lib.
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/memory.h> #include <tvm/runtime/memory.h>
#include <tvm/runtime/c_backend_api.h> #include <tvm/runtime/c_backend_api.h>
#include <mutex> #include <mutex>
#include "module_util.h" #include "library_module.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
class SystemLibModuleNode : public ModuleNode { class SystemLibrary : public Library {
public: public:
SystemLibModuleNode() = default; SystemLibrary() = default;
const char* type_key() const final { void* GetSymbol(const char* name) final {
return "system_lib";
}
PackedFunc GetFunction(
const std::string& name,
const ObjectPtr<Object>& sptr_to_self) final {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (module_blob_ != nullptr) {
// If we previously recorded submodules, load them now.
ImportModuleBlob(reinterpret_cast<const char*>(module_blob_), &imports_);
module_blob_ = nullptr;
}
auto it = tbl_.find(name); auto it = tbl_.find(name);
if (it != tbl_.end()) { if (it != tbl_.end()) {
return WrapPackedFunc( return it->second;
reinterpret_cast<BackendPackedCFunc>(it->second), sptr_to_self);
} else { } else {
return PackedFunc(); return nullptr;
} }
} }
void RegisterSymbol(const std::string& name, void* ptr) { void RegisterSymbol(const std::string& name, void* ptr) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (name == symbol::tvm_module_ctx) { auto it = tbl_.find(name);
void** ctx_addr = reinterpret_cast<void**>(ptr); if (it != tbl_.end() && ptr != it->second) {
*ctx_addr = this; LOG(WARNING)
} else if (name == symbol::tvm_dev_mblob) { << "SystemLib symbol " << name
// Record pointer to content of submodules to be loaded. << " get overriden to a different address "
// We defer loading submodules to the first call to GetFunction(). << ptr << "->" << it->second;
// The reason is that RegisterSymbol() gets called when initializing the
// syslib (i.e. library loading time), and the registeries aren't ready
// yet. Therefore, we might not have the functionality to load submodules
// now.
CHECK(module_blob_ == nullptr) << "Resetting mobule blob?";
module_blob_ = ptr;
} else {
auto it = tbl_.find(name);
if (it != tbl_.end() && ptr != it->second) {
LOG(WARNING) << "SystemLib symbol " << name
<< " get overriden to a different address "
<< ptr << "->" << it->second;
}
tbl_[name] = ptr;
} }
tbl_[name] = ptr;
} }
static const ObjectPtr<SystemLibModuleNode>& Global() { static const ObjectPtr<SystemLibrary>& Global() {
static auto inst = make_object<SystemLibModuleNode>(); static auto inst = make_object<SystemLibrary>();
return inst; return inst;
} }
...@@ -93,18 +66,18 @@ class SystemLibModuleNode : public ModuleNode { ...@@ -93,18 +66,18 @@ class SystemLibModuleNode : public ModuleNode {
std::mutex mutex_; std::mutex mutex_;
// Internal symbol table // Internal symbol table
std::unordered_map<std::string, void*> tbl_; std::unordered_map<std::string, void*> tbl_;
// Module blob to be imported
void* module_blob_{nullptr};
}; };
TVM_REGISTER_GLOBAL("module._GetSystemLib") TVM_REGISTER_GLOBAL("module._GetSystemLib")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = runtime::Module(SystemLibModuleNode::Global()); static auto mod = CreateModuleFromLibrary(
SystemLibrary::Global());
*rv = mod;
}); });
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr) { int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr) {
tvm::runtime::SystemLibModuleNode::Global()->RegisterSymbol(name, ptr); tvm::runtime::SystemLibrary::Global()->RegisterSymbol(name, ptr);
return 0; return 0;
} }
...@@ -26,14 +26,14 @@ ...@@ -26,14 +26,14 @@
#include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/c_runtime_api.cc"
#include "../src/runtime/cpu_device_api.cc" #include "../src/runtime/cpu_device_api.cc"
#include "../src/runtime/workspace_pool.cc" #include "../src/runtime/workspace_pool.cc"
#include "../src/runtime/module_util.cc" #include "../src/runtime/library_module.cc"
#include "../src/runtime/system_lib_module.cc" #include "../src/runtime/system_library.cc"
#include "../src/runtime/module.cc" #include "../src/runtime/module.cc"
#include "../src/runtime/ndarray.cc" #include "../src/runtime/ndarray.cc"
#include "../src/runtime/object.cc" #include "../src/runtime/object.cc"
#include "../src/runtime/registry.cc" #include "../src/runtime/registry.cc"
#include "../src/runtime/file_util.cc" #include "../src/runtime/file_util.cc"
#include "../src/runtime/dso_module.cc" #include "../src/runtime/dso_library.cc"
#include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_session.cc"
#include "../src/runtime/rpc/rpc_event_impl.cc" #include "../src/runtime/rpc/rpc_event_impl.cc"
#include "../src/runtime/rpc/rpc_server_env.cc" #include "../src/runtime/rpc/rpc_server_env.cc"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment