Unverified Commit 203d2c92 by Tianqi Chen Committed by GitHub

[REFACTOR][RUNTIME] Add LibraryModule that merges systemlib and dso. (#4481)

Historically we have two variations of modules(DSOModule and SystemLibModule)
that both exposes module via symbols.

This PR creates a common implementation for both, and introduce a Library
base class that allows us to have different implementations of GetSymbol.

It paves ways for future library related module enhancements.
parent 18592c8d
......@@ -27,12 +27,12 @@
#include "../src/runtime/c_runtime_api.cc"
#include "../src/runtime/cpu_device_api.cc"
#include "../src/runtime/workspace_pool.cc"
#include "../src/runtime/module_util.cc"
#include "../src/runtime/system_lib_module.cc"
#include "../src/runtime/library_module.cc"
#include "../src/runtime/system_library.cc"
#include "../src/runtime/module.cc"
#include "../src/runtime/registry.cc"
#include "../src/runtime/file_util.cc"
#include "../src/runtime/dso_module.cc"
#include "../src/runtime/dso_library.cc"
#include "../src/runtime/thread_pool.cc"
#include "../src/runtime/object.cc"
#include "../src/runtime/threading_backend.cc"
......
......@@ -39,12 +39,12 @@
#include "../src/runtime/c_runtime_api.cc"
#include "../src/runtime/cpu_device_api.cc"
#include "../src/runtime/workspace_pool.cc"
#include "../src/runtime/module_util.cc"
#include "../src/runtime/system_lib_module.cc"
#include "../src/runtime/library_module.cc"
#include "../src/runtime/system_library.cc"
#include "../src/runtime/module.cc"
#include "../src/runtime/registry.cc"
#include "../src/runtime/file_util.cc"
#include "../src/runtime/dso_module.cc"
#include "../src/runtime/dso_library.cc"
#include "../src/runtime/rpc/rpc_session.cc"
#include "../src/runtime/rpc/rpc_event_impl.cc"
#include "../src/runtime/rpc/rpc_server_env.cc"
......
......@@ -29,4 +29,4 @@ echo "Run the cpp deployment with all in normal library..."
lib/cpp_deploy_normal
echo "Run the python deployment with all in normal library..."
python python_deploy.py
python3 python_deploy.py
......@@ -40,7 +40,7 @@
#include "../../src/runtime/c_runtime_api.cc"
#include "../../src/runtime/cpu_device_api.cc"
#include "../../src/runtime/workspace_pool.cc"
#include "../../src/runtime/module_util.cc"
#include "../../src/runtime/library_module.cc"
#include "../../src/runtime/module.cc"
#include "../../src/runtime/registry.cc"
#include "../../src/runtime/file_util.cc"
......@@ -55,8 +55,8 @@
// Likely we only need to enable one of the following
// If you use Module::Load, use dso_module
// For system packed library, use system_lib_module
#include "../../src/runtime/dso_module.cc"
#include "../../src/runtime/system_lib_module.cc"
#include "../../src/runtime/dso_library.cc"
#include "../../src/runtime/system_library.cc"
// Graph runtime
#include "../../src/runtime/graph/graph_runtime.cc"
......
......@@ -27,12 +27,12 @@
#include "../../../src/runtime/workspace_pool.cc"
#include "../../../src/runtime/thread_pool.cc"
#include "../../../src/runtime/threading_backend.cc"
#include "../../../src/runtime/module_util.cc"
#include "../../../src/runtime/system_lib_module.cc"
#include "../../../src/runtime/library_module.cc"
#include "../../../src/runtime/system_library.cc"
#include "../../../src/runtime/module.cc"
#include "../../../src/runtime/registry.cc"
#include "../../../src/runtime/file_util.cc"
#include "../../../src/runtime/dso_module.cc"
#include "../../../src/runtime/dso_library.cc"
#include "../../../src/runtime/ndarray.cc"
#include "../../../src/runtime/object.cc"
......
......@@ -24,7 +24,7 @@
#include "src/runtime/c_runtime_api.cc"
#include "src/runtime/cpu_device_api.cc"
#include "src/runtime/workspace_pool.cc"
#include "src/runtime/module_util.cc"
#include "src/runtime/library_module.cc"
#include "src/runtime/module.cc"
#include "src/runtime/registry.cc"
#include "src/runtime/file_util.cc"
......@@ -39,8 +39,8 @@
// Likely we only need to enable one of the following
// If you use Module::Load, use dso_module
// For system packed library, use system_lib_module
#include "src/runtime/dso_module.cc"
#include "src/runtime/system_lib_module.cc"
#include "src/runtime/dso_library.cc"
#include "src/runtime/system_library.cc"
// Graph runtime
#include "src/runtime/graph/graph_runtime.cc"
......
......@@ -201,6 +201,7 @@ class ModuleNode : public Object {
protected:
friend class Module;
friend class ModuleInternal;
/*! \brief The modules this module depend on */
std::vector<Module> imports_;
......
......@@ -28,7 +28,7 @@
#include "llvm_common.h"
#include "codegen_llvm.h"
#include "../../runtime/file_util.h"
#include "../../runtime/module_util.h"
#include "../../runtime/library_module.h"
namespace tvm {
namespace codegen {
......@@ -286,7 +286,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
*ctx_addr = this;
}
runtime::InitContextFunctions([this](const char *name) {
return GetGlobalAddr(name);
return reinterpret_cast<void*>(GetGlobalAddr(name));
});
}
// Get global address from execution engine.
......
......@@ -18,14 +18,14 @@
*/
/*!
* \file dso_module.cc
* \brief Module to load from dynamic shared library.
* \file dso_libary.cc
* \brief Create library module to load from dynamic shared library.
*/
#include <tvm/runtime/module.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#include "module_util.h"
#include "library_module.h"
#if defined(_WIN32)
#include <windows.h>
......@@ -36,51 +36,19 @@
namespace tvm {
namespace runtime {
// Module to load from dynamic shared libary.
// Dynamic shared libary.
// This is the default module TVM used for host-side AOT
class DSOModuleNode final : public ModuleNode {
class DSOLibrary final : public Library {
public:
~DSOModuleNode() {
~DSOLibrary() {
if (lib_handle_) Unload();
}
const char* type_key() const final {
return "dso";
}
PackedFunc GetFunction(
const std::string& name,
const ObjectPtr<Object>& sptr_to_self) final {
BackendPackedCFunc faddr;
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::tvm_module_main));
CHECK(entry_name!= nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name));
} else {
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str()));
}
if (faddr == nullptr) return PackedFunc();
return WrapPackedFunc(faddr, sptr_to_self);
}
void Init(const std::string& name) {
Load(name);
if (auto *ctx_addr =
reinterpret_cast<void**>(GetSymbol(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = this;
}
InitContextFunctions([this](const char* fname) {
return GetSymbol(fname);
});
// Load the imported modules
const char* dev_mblob =
reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::tvm_dev_mblob));
if (dev_mblob != nullptr) {
ImportModuleBlob(dev_mblob, &imports_);
}
void* GetSymbol(const char* name) final {
return GetSymbol_(name);
}
private:
......@@ -88,6 +56,12 @@ class DSOModuleNode final : public ModuleNode {
#if defined(_WIN32)
// library handle
HMODULE lib_handle_{nullptr};
void* GetSymbol_(const char* name) {
return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}
// Load the library
void Load(const std::string& name) {
// use wstring version that is needed by LLVM.
......@@ -96,12 +70,10 @@ class DSOModuleNode final : public ModuleNode {
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name;
}
void* GetSymbol(const char* name) {
return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}
void Unload() {
FreeLibrary(lib_handle_);
lib_handle_ = nullptr;
}
#else
// Library handle
......@@ -113,20 +85,23 @@ class DSOModuleNode final : public ModuleNode {
<< "Failed to load dynamic shared library " << name
<< " " << dlerror();
}
void* GetSymbol(const char* name) {
void* GetSymbol_(const char* name) {
return dlsym(lib_handle_, name);
}
void Unload() {
dlclose(lib_handle_);
lib_handle_ = nullptr;
}
#endif
};
TVM_REGISTER_GLOBAL("module.loadfile_so")
.set_body([](TVMArgs args, TVMRetValue* rv) {
auto n = make_object<DSOModuleNode>();
auto n = make_object<DSOLibrary>();
n->Init(args[0]);
*rv = runtime::Module(n);
*rv = CreateModuleFromLibrary(n);
});
} // namespace runtime
} // namespace tvm
......@@ -27,12 +27,89 @@
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <string>
#include <memory>
#include "module_util.h"
#include <vector>
#include "library_module.h"
namespace tvm {
namespace runtime {
// Library module that exposes symbols from a library.
class LibraryModuleNode final : public ModuleNode {
public:
explicit LibraryModuleNode(ObjectPtr<Library> lib)
: lib_(lib) {
}
const char* type_key() const final {
return "library";
}
PackedFunc GetFunction(
const std::string& name,
const ObjectPtr<Object>& sptr_to_self) final {
BackendPackedCFunc faddr;
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
lib_->GetSymbol(runtime::symbol::tvm_module_main));
CHECK(entry_name!= nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(lib_->GetSymbol(entry_name));
} else {
faddr = reinterpret_cast<BackendPackedCFunc>(lib_->GetSymbol(name.c_str()));
}
if (faddr == nullptr) return PackedFunc();
return WrapPackedFunc(faddr, sptr_to_self);
}
private:
ObjectPtr<Library> lib_;
};
/*!
* \brief Helper classes to get into internal of a module.
*/
class ModuleInternal {
public:
// Get mutable reference of imports.
static std::vector<Module>* GetImportsAddr(ModuleNode* node) {
return &(node->imports_);
}
};
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr,
const ObjectPtr<Object>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
int ret = (*faddr)(
const_cast<TVMValue*>(args.values),
const_cast<int*>(args.type_codes),
args.num_args);
CHECK_EQ(ret, 0) << TVMGetLastError();
});
}
void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) {
#define TVM_INIT_CONTEXT_FUNC(FuncName) \
if (auto *fp = reinterpret_cast<decltype(&FuncName)*> \
(fgetsymbol("__" #FuncName))) { \
*fp = FuncName; \
}
// Initialize the functions
TVM_INIT_CONTEXT_FUNC(TVMFuncCall);
TVM_INIT_CONTEXT_FUNC(TVMAPISetLastError);
TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv);
TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace);
TVM_INIT_CONTEXT_FUNC(TVMBackendFreeWorkspace);
TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch);
TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier);
#undef TVM_INIT_CONTEXT_FUNC
}
/*!
* \brief Load and append module blob to module list
* \param mblob The module blob.
* \param module_list The module list to append to
*/
void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
#ifndef _LIBCPP_SGX_CONFIG
CHECK(mblob != nullptr);
......@@ -62,16 +139,27 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
#endif
}
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr,
const ObjectPtr<Object>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
int ret = (*faddr)(
const_cast<TVMValue*>(args.values),
const_cast<int*>(args.type_codes),
args.num_args);
CHECK_EQ(ret, 0) << TVMGetLastError();
Module CreateModuleFromLibrary(ObjectPtr<Library> lib) {
InitContextFunctions([lib](const char* fname) {
return lib->GetSymbol(fname);
});
}
auto n = make_object<LibraryModuleNode>(lib);
// Load the imported modules
const char* dev_mblob =
reinterpret_cast<const char*>(
lib->GetSymbol(runtime::symbol::tvm_dev_mblob));
if (dev_mblob != nullptr) {
ImportModuleBlob(
dev_mblob, ModuleInternal::GetImportsAddr(n.operator->()));
}
Module root_mod = Module(n);
// allow lookup of symbol from root(so all symbols are visible).
if (auto *ctx_addr =
reinterpret_cast<void**>(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = root_mod.operator->();
}
return root_mod;
}
} // namespace runtime
} // namespace tvm
......@@ -18,17 +18,16 @@
*/
/*!
* \file module_util.h
* \brief Helper utilities for module building
* \file library_module.h
* \brief Module that builds from a libary of symbols.
*/
#ifndef TVM_RUNTIME_MODULE_UTIL_H_
#define TVM_RUNTIME_MODULE_UTIL_H_
#ifndef TVM_RUNTIME_LIBRARY_MODULE_H_
#define TVM_RUNTIME_LIBRARY_MODULE_H_
#include <tvm/runtime/module.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
#include <memory>
#include <vector>
#include <functional>
extern "C" {
// Function signature for generated packed function in shared library
......@@ -40,41 +39,47 @@ typedef int (*BackendPackedCFunc)(void* args,
namespace tvm {
namespace runtime {
/*!
* \brief Library is the common interface
* for storing data in the form of shared libaries.
*
* \sa dso_library.cc
* \sa system_library.cc
*/
class Library : public Object {
public:
/*!
* \brief Get the symbol address for a given name.
* \param name The name of the symbol.
* \return The symbol.
*/
virtual void *GetSymbol(const char* name) = 0;
// NOTE: we do not explicitly create an type index and type_key here for libary.
// This is because we do not need dynamic type downcasting.
};
/*!
* \brief Wrap a BackendPackedCFunc to packed function.
* \param faddr The function address
* \param mptr The module pointer node.
*/
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const ObjectPtr<Object>& mptr);
/*!
* \brief Load and append module blob to module list
* \param mblob The module blob.
* \param module_list The module list to append to
*/
void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list);
/*!
* \brief Utility to initialize conext function symbols during startup
* \param flookup A symbol lookup function.
* \tparam FLookup a function of signature string->void*
* \param fgetsymbol A symbol lookup function.
*/
template<typename FLookup>
void InitContextFunctions(FLookup flookup) {
#define TVM_INIT_CONTEXT_FUNC(FuncName) \
if (auto *fp = reinterpret_cast<decltype(&FuncName)*> \
(flookup("__" #FuncName))) { \
*fp = FuncName; \
}
// Initialize the functions
TVM_INIT_CONTEXT_FUNC(TVMFuncCall);
TVM_INIT_CONTEXT_FUNC(TVMAPISetLastError);
TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv);
TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace);
TVM_INIT_CONTEXT_FUNC(TVMBackendFreeWorkspace);
TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch);
TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier);
void InitContextFunctions(std::function<void*(const char*)> fgetsymbol);
#undef TVM_INIT_CONTEXT_FUNC
}
/*!
* \brief Create a module from a library.
*
* \param lib The library.
* \return The corresponding loaded module.
*
* \note This function can create multiple linked modules
* by parsing the binary blob section of the library.
*/
Module CreateModuleFromLibrary(ObjectPtr<Library> lib);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_MODULE_UTIL_H_
#endif // TVM_RUNTIME_LIBRARY_MODULE_H_
......@@ -28,7 +28,6 @@
#include <unordered_map>
#include "stackvm_module.h"
#include "../file_util.h"
#include "../module_util.h"
namespace tvm {
namespace runtime {
......
......@@ -18,73 +18,46 @@
*/
/*!
* \file system_lib_module.cc
* \brief SystemLib module.
* \file system_library.cc
* \brief Create library module that directly get symbol from the system lib.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/c_backend_api.h>
#include <mutex>
#include "module_util.h"
#include "library_module.h"
namespace tvm {
namespace runtime {
class SystemLibModuleNode : public ModuleNode {
class SystemLibrary : public Library {
public:
SystemLibModuleNode() = default;
SystemLibrary() = default;
const char* type_key() const final {
return "system_lib";
}
PackedFunc GetFunction(
const std::string& name,
const ObjectPtr<Object>& sptr_to_self) final {
void* GetSymbol(const char* name) final {
std::lock_guard<std::mutex> lock(mutex_);
if (module_blob_ != nullptr) {
// If we previously recorded submodules, load them now.
ImportModuleBlob(reinterpret_cast<const char*>(module_blob_), &imports_);
module_blob_ = nullptr;
}
auto it = tbl_.find(name);
if (it != tbl_.end()) {
return WrapPackedFunc(
reinterpret_cast<BackendPackedCFunc>(it->second), sptr_to_self);
return it->second;
} else {
return PackedFunc();
return nullptr;
}
}
void RegisterSymbol(const std::string& name, void* ptr) {
std::lock_guard<std::mutex> lock(mutex_);
if (name == symbol::tvm_module_ctx) {
void** ctx_addr = reinterpret_cast<void**>(ptr);
*ctx_addr = this;
} else if (name == symbol::tvm_dev_mblob) {
// Record pointer to content of submodules to be loaded.
// We defer loading submodules to the first call to GetFunction().
// The reason is that RegisterSymbol() gets called when initializing the
// syslib (i.e. library loading time), and the registeries aren't ready
// yet. Therefore, we might not have the functionality to load submodules
// now.
CHECK(module_blob_ == nullptr) << "Resetting mobule blob?";
module_blob_ = ptr;
} else {
auto it = tbl_.find(name);
if (it != tbl_.end() && ptr != it->second) {
LOG(WARNING) << "SystemLib symbol " << name
LOG(WARNING)
<< "SystemLib symbol " << name
<< " get overriden to a different address "
<< ptr << "->" << it->second;
}
tbl_[name] = ptr;
}
}
static const ObjectPtr<SystemLibModuleNode>& Global() {
static auto inst = make_object<SystemLibModuleNode>();
static const ObjectPtr<SystemLibrary>& Global() {
static auto inst = make_object<SystemLibrary>();
return inst;
}
......@@ -93,18 +66,18 @@ class SystemLibModuleNode : public ModuleNode {
std::mutex mutex_;
// Internal symbol table
std::unordered_map<std::string, void*> tbl_;
// Module blob to be imported
void* module_blob_{nullptr};
};
TVM_REGISTER_GLOBAL("module._GetSystemLib")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = runtime::Module(SystemLibModuleNode::Global());
static auto mod = CreateModuleFromLibrary(
SystemLibrary::Global());
*rv = mod;
});
} // namespace runtime
} // namespace tvm
int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr) {
tvm::runtime::SystemLibModuleNode::Global()->RegisterSymbol(name, ptr);
tvm::runtime::SystemLibrary::Global()->RegisterSymbol(name, ptr);
return 0;
}
......@@ -26,14 +26,14 @@
#include "../src/runtime/c_runtime_api.cc"
#include "../src/runtime/cpu_device_api.cc"
#include "../src/runtime/workspace_pool.cc"
#include "../src/runtime/module_util.cc"
#include "../src/runtime/system_lib_module.cc"
#include "../src/runtime/library_module.cc"
#include "../src/runtime/system_library.cc"
#include "../src/runtime/module.cc"
#include "../src/runtime/ndarray.cc"
#include "../src/runtime/object.cc"
#include "../src/runtime/registry.cc"
#include "../src/runtime/file_util.cc"
#include "../src/runtime/dso_module.cc"
#include "../src/runtime/dso_library.cc"
#include "../src/runtime/rpc/rpc_session.cc"
#include "../src/runtime/rpc/rpc_event_impl.cc"
#include "../src/runtime/rpc/rpc_server_env.cc"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment