Commit ef50162b by Tianqi Chen Committed by GitHub

[MODULE/RUNTIME] Remove Precompile, simplify module (#174)

parent 84aeaf48
Subproject commit a6e09b58dc00ee0065f5b7879800e646fbb01d1e
Subproject commit 36e573893fc39d324ccf2f2962300da6da5898a2
......@@ -183,21 +183,6 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
TVMFunctionHandle *out);
/*!
* \brief Precompile the function under given context.
* Many TVMFunctionHandle is initialized lazily,
* This call eagerly prepares the resources under given context.
* Useful for benchmarking purposes.
*
* \param mod The module handle.
* \param func_name The name of the function.
* \param ctx The context to be precompiled on.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMModPreCompile(TVMModuleHandle mod,
const char* func_name,
TVMContext ctx);
/*!
* \brief Free the Module
* \param mod The module to be freed.
*
......
......@@ -77,17 +77,6 @@ class ModuleNode {
/*! \return The module type key */
virtual const char* type_key() const = 0;
/*!
* \brief Eagerly compile the function under certain context,
* assuming that it is used by the current thread.
*
* This is useful for benchmarking to eliminate lazy compilation
* overhead during the first execution of the kernel.
*
* \param name The name of the function.
* \param ctx The context to be executed.
*/
virtual void PreCompile(const std::string& name, TVMContext ctx) = 0;
/*!
* \brief Get a PackedFunc from module.
*
* The PackedFunc may not be fully initialized,
......@@ -113,7 +102,7 @@ class ModuleNode {
* \param format The format of the file.
*/
virtual void SaveToFile(const std::string& file_name,
const std::string& format) = 0;
const std::string& format);
/*!
* \brief Save the module to binary stream.
* \param stream The binary stream to save to.
......@@ -121,14 +110,13 @@ class ModuleNode {
* but not necessarily host modules.
* We can use this to do AOT loading of bundled device functions.
*/
virtual void SaveToBinary(dmlc::Stream* stream) = 0;
virtual void SaveToBinary(dmlc::Stream* stream);
/*!
* \brief Get the source code of module, when available.
* \param format Format of the source code, can be empty by default.
* \return Possible source code when available.
*/
virtual std::string GetSource(
const std::string& format = "") = 0;
virtual std::string GetSource(const std::string& format = "");
/*!
* \brief Get a function from current environment
* The environment includes all the imports as well as Global functions.
......
......@@ -116,20 +116,6 @@ class ModuleBase(object):
"""
check_call(_LIB.TVMModImport(self.handle, module.handle))
def precompile(self, func_name, ctx):
"""Add module to the import list of current one.
Parameters
----------
func_name : str
The name of function to be precompiled.
ctx : Context
The context to be precompiled.
"""
check_call(_LIB.TVMModPreCompile(
self.handle, c_str(func_name), ctx))
def __getitem__(self, name):
if not isinstance(name, string_types):
raise ValueError("Can only take string as function name")
......
......@@ -33,17 +33,6 @@ class LLVMModuleNode final : public runtime::ModuleNode {
return "llvm";
}
void PreCompile(const std::string& name, TVMContext ctx) final {
if (ee_ == nullptr) LazyInitJIT();
std::lock_guard<std::mutex> lock(mutex_);
const std::string& fname = (name == runtime::symbol::tvm_module_main ?
entry_func_ : name);
BackendPackedCFunc faddr =
reinterpret_cast<BackendPackedCFunc>(ee_->getFunctionAddress(fname));
CHECK(faddr != nullptr)
<< "Failed to Precompile function " << name;
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
......
......@@ -21,8 +21,6 @@ class SourceModuleNode final : public runtime::ModuleNode {
const char* type_key() const {
return "source";
}
void PreCompile(const std::string& name, TVMContext ctx) final {
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
......@@ -31,15 +29,6 @@ class SourceModuleNode final : public runtime::ModuleNode {
return PackedFunc();
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
LOG(FATAL) << "SourceModule: SaveToFile not supported";
}
void SaveToBinary(dmlc::Stream* stream) final {
LOG(FATAL) << "SourceModule: SaveToBinary not supported";
}
std::string GetSource(const std::string& format) final {
return code_;
}
......
......@@ -16,8 +16,6 @@ class StackVMModuleNode : public runtime::ModuleNode {
return "stackvm";
}
void PreCompile(const std::string& name, TVMContext ctx) final {}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
......@@ -33,15 +31,6 @@ class StackVMModuleNode : public runtime::ModuleNode {
});
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
LOG(FATAL) << "StackVMModule: SaveToFile not supported";
}
void SaveToBinary(dmlc::Stream* stream) final {
LOG(FATAL) << "StackVMModule: SaveToBinary not supported";
}
std::string GetSource(const std::string& format) final {
std::ostringstream os;
for (const auto& kv : fmap_) {
......
......@@ -24,8 +24,6 @@ class VerilogModuleNode : public runtime::ModuleNode {
const char* type_key() const {
return "verilog";
}
void PreCompile(const std::string& name, TVMContext ctx) final {
}
PackedFunc GetFunction(
const std::string& name,
......@@ -57,15 +55,6 @@ class VerilogModuleNode : public runtime::ModuleNode {
return PackedFunc(f);
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
LOG(FATAL) << "VerilogModule: SaveToFile not supported";
}
void SaveToBinary(dmlc::Stream* stream) final {
LOG(FATAL) << "VerilogModule: SaveToBinary not supported";
}
std::string GetSource(const std::string& format) final {
return m_.code;
}
......
......@@ -196,14 +196,6 @@ int TVMModGetFunction(TVMModuleHandle mod,
API_END();
}
int TVMModPreCompile(TVMModuleHandle mod,
const char* func_name,
TVMContext ctx) {
API_BEGIN();
(*static_cast<Module*>(mod))->PreCompile(func_name, ctx);
API_END();
}
int TVMModFree(TVMModuleHandle mod) {
API_BEGIN();
delete static_cast<Module*>(mod);
......
......@@ -49,12 +49,6 @@ class CUDAModuleNode : public runtime::ModuleNode {
return "cuda";
}
void PreCompile(const std::string& name, TVMContext ctx) final {
CUDA_CALL(cudaSetDevice(ctx.device_id));
cudaFree(nullptr);
this->GetFunc(ctx.device_id, name);
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
......
......@@ -30,10 +30,6 @@ class DSOModuleNode final : public ModuleNode {
return "dso";
}
void PreCompile(const std::string& name, TVMContext ctx) final {
GetFuncPtr(name);
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
......@@ -48,19 +44,6 @@ class DSOModuleNode final : public ModuleNode {
});
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
LOG(FATAL) << "DSOModule: SaveToFile not supported";
}
void SaveToBinary(dmlc::Stream* stream) final {
LOG(FATAL) << "DSOModule: SaveToBinary not supported";
}
std::string GetSource(const std::string& format) final {
return "";
}
void Init(const std::string& name) {
Load(name);
CHECK(lib_handle_ != nullptr)
......
......@@ -37,10 +37,6 @@ class MetalModuleNode final :public runtime::ModuleNode {
return "metal";
}
void PreCompile(const std::string& name, TVMContext ctx) final {
GetPipelineState(ctx.device_id, name);
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
......
......@@ -62,6 +62,20 @@ Module Module::LoadFromFile(const std::string& file_name,
return m;
}
void ModuleNode::SaveToFile(const std::string& file_name,
const std::string& format) {
LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile";
}
void ModuleNode::SaveToBinary(dmlc::Stream* stream) {
LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToBinary";
}
std::string ModuleNode::GetSource(const std::string& format) {
LOG(FATAL) << "Module[" << type_key() << "] does not support GetSource";
return "";
}
const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
auto it = import_cache_.find(name);
if (it != import_cache_.end()) return it->second.get();
......
......@@ -59,12 +59,6 @@ class OpenCLModuleNode : public ModuleNode {
return "opencl";
}
void PreCompile(const std::string& name, TVMContext ctx) final {
InstallKernel(cl::OpenCLWorkspace::Global(),
cl::OpenCLThreadEntry::ThreadLocal(),
name, kid_map_.at(name));
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
......
......@@ -45,9 +45,6 @@ class RPCModuleNode final : public ModuleNode {
return "rpc";
}
void PreCompile(const std::string& name, TVMContext ctx) final {
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
......@@ -55,15 +52,6 @@ class RPCModuleNode final : public ModuleNode {
return WrapRemote(handle);
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
LOG(FATAL) << "RPCModule: SaveToFile not supported";
}
void SaveToBinary(dmlc::Stream* stream) final {
LOG(FATAL) << "RPCModule: SaveToBinary not supported";
}
std::string GetSource(const std::string& format) final {
if (module_handle_ != nullptr) {
std::string ret = sess_->CallRemote(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment