Commit ef50162b by Tianqi Chen Committed by GitHub

[MODULE/RUNTIME] Remove Precompile, simplify module (#174)

parent 84aeaf48
Subproject commit a6e09b58dc00ee0065f5b7879800e646fbb01d1e Subproject commit 36e573893fc39d324ccf2f2962300da6da5898a2
...@@ -183,21 +183,6 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, ...@@ -183,21 +183,6 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
TVMFunctionHandle *out); TVMFunctionHandle *out);
/*! /*!
* \brief Precompile the function under given context.
* Many TVMFunctionHandle is initialized lazily,
* This call eagerly prepares the resources under given context.
* Useful for benchmarking purposes.
*
* \param mod The module handle.
* \param func_name The name of the function.
* \param ctx The context to be precompiled on.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMModPreCompile(TVMModuleHandle mod,
const char* func_name,
TVMContext ctx);
/*!
* \brief Free the Module * \brief Free the Module
* \param mod The module to be freed. * \param mod The module to be freed.
* *
......
...@@ -77,17 +77,6 @@ class ModuleNode { ...@@ -77,17 +77,6 @@ class ModuleNode {
/*! \return The module type key */ /*! \return The module type key */
virtual const char* type_key() const = 0; virtual const char* type_key() const = 0;
/*! /*!
* \brief Eagerly compile the function under certain context,
* assuming that it is used by the current thread.
*
* This is useful for benchmarking to eliminate lazy compilation
* overhead during the first execution of the kernel.
*
* \param name The name of the function.
* \param ctx The context to be executed.
*/
virtual void PreCompile(const std::string& name, TVMContext ctx) = 0;
/*!
* \brief Get a PackedFunc from module. * \brief Get a PackedFunc from module.
* *
* The PackedFunc may not be fully initialized, * The PackedFunc may not be fully initialized,
...@@ -113,7 +102,7 @@ class ModuleNode { ...@@ -113,7 +102,7 @@ class ModuleNode {
* \param format The format of the file. * \param format The format of the file.
*/ */
virtual void SaveToFile(const std::string& file_name, virtual void SaveToFile(const std::string& file_name,
const std::string& format) = 0; const std::string& format);
/*! /*!
* \brief Save the module to binary stream. * \brief Save the module to binary stream.
* \param stream The binary stream to save to. * \param stream The binary stream to save to.
...@@ -121,14 +110,13 @@ class ModuleNode { ...@@ -121,14 +110,13 @@ class ModuleNode {
* but not necessarily host modules. * but not necessarily host modules.
* We can use this to do AOT loading of bundled device functions. * We can use this to do AOT loading of bundled device functions.
*/ */
virtual void SaveToBinary(dmlc::Stream* stream) = 0; virtual void SaveToBinary(dmlc::Stream* stream);
/*! /*!
* \brief Get the source code of module, when available. * \brief Get the source code of module, when available.
* \param format Format of the source code, can be empty by default. * \param format Format of the source code, can be empty by default.
* \return Possible source code when available. * \return Possible source code when available.
*/ */
virtual std::string GetSource( virtual std::string GetSource(const std::string& format = "");
const std::string& format = "") = 0;
/*! /*!
* \brief Get a function from current environment * \brief Get a function from current environment
* The environment includes all the imports as well as Global functions. * The environment includes all the imports as well as Global functions.
......
...@@ -116,20 +116,6 @@ class ModuleBase(object): ...@@ -116,20 +116,6 @@ class ModuleBase(object):
""" """
check_call(_LIB.TVMModImport(self.handle, module.handle)) check_call(_LIB.TVMModImport(self.handle, module.handle))
def precompile(self, func_name, ctx):
"""Add module to the import list of current one.
Parameters
----------
func_name : str
The name of function to be precompiled.
ctx : Context
The context to be precompiled.
"""
check_call(_LIB.TVMModPreCompile(
self.handle, c_str(func_name), ctx))
def __getitem__(self, name): def __getitem__(self, name):
if not isinstance(name, string_types): if not isinstance(name, string_types):
raise ValueError("Can only take string as function name") raise ValueError("Can only take string as function name")
......
...@@ -33,17 +33,6 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -33,17 +33,6 @@ class LLVMModuleNode final : public runtime::ModuleNode {
return "llvm"; return "llvm";
} }
void PreCompile(const std::string& name, TVMContext ctx) final {
if (ee_ == nullptr) LazyInitJIT();
std::lock_guard<std::mutex> lock(mutex_);
const std::string& fname = (name == runtime::symbol::tvm_module_main ?
entry_func_ : name);
BackendPackedCFunc faddr =
reinterpret_cast<BackendPackedCFunc>(ee_->getFunctionAddress(fname));
CHECK(faddr != nullptr)
<< "Failed to Precompile function " << name;
}
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const std::shared_ptr<ModuleNode>& sptr_to_self) final {
......
...@@ -21,8 +21,6 @@ class SourceModuleNode final : public runtime::ModuleNode { ...@@ -21,8 +21,6 @@ class SourceModuleNode final : public runtime::ModuleNode {
const char* type_key() const { const char* type_key() const {
return "source"; return "source";
} }
void PreCompile(const std::string& name, TVMContext ctx) final {
}
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const std::shared_ptr<ModuleNode>& sptr_to_self) final {
...@@ -31,15 +29,6 @@ class SourceModuleNode final : public runtime::ModuleNode { ...@@ -31,15 +29,6 @@ class SourceModuleNode final : public runtime::ModuleNode {
return PackedFunc(); return PackedFunc();
} }
void SaveToFile(const std::string& file_name,
const std::string& format) final {
LOG(FATAL) << "SourceModule: SaveToFile not supported";
}
void SaveToBinary(dmlc::Stream* stream) final {
LOG(FATAL) << "SourceModule: SaveToBinary not supported";
}
std::string GetSource(const std::string& format) final { std::string GetSource(const std::string& format) final {
return code_; return code_;
} }
......
...@@ -16,8 +16,6 @@ class StackVMModuleNode : public runtime::ModuleNode { ...@@ -16,8 +16,6 @@ class StackVMModuleNode : public runtime::ModuleNode {
return "stackvm"; return "stackvm";
} }
void PreCompile(const std::string& name, TVMContext ctx) final {}
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const std::shared_ptr<ModuleNode>& sptr_to_self) final {
...@@ -33,15 +31,6 @@ class StackVMModuleNode : public runtime::ModuleNode { ...@@ -33,15 +31,6 @@ class StackVMModuleNode : public runtime::ModuleNode {
}); });
} }
void SaveToFile(const std::string& file_name,
const std::string& format) final {
LOG(FATAL) << "StackVMModule: SaveToFile not supported";
}
void SaveToBinary(dmlc::Stream* stream) final {
LOG(FATAL) << "StackVMModule: SaveToBinary not supported";
}
std::string GetSource(const std::string& format) final { std::string GetSource(const std::string& format) final {
std::ostringstream os; std::ostringstream os;
for (const auto& kv : fmap_) { for (const auto& kv : fmap_) {
......
...@@ -24,8 +24,6 @@ class VerilogModuleNode : public runtime::ModuleNode { ...@@ -24,8 +24,6 @@ class VerilogModuleNode : public runtime::ModuleNode {
const char* type_key() const { const char* type_key() const {
return "verilog"; return "verilog";
} }
void PreCompile(const std::string& name, TVMContext ctx) final {
}
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
...@@ -57,15 +55,6 @@ class VerilogModuleNode : public runtime::ModuleNode { ...@@ -57,15 +55,6 @@ class VerilogModuleNode : public runtime::ModuleNode {
return PackedFunc(f); return PackedFunc(f);
} }
void SaveToFile(const std::string& file_name,
const std::string& format) final {
LOG(FATAL) << "VerilogModule: SaveToFile not supported";
}
void SaveToBinary(dmlc::Stream* stream) final {
LOG(FATAL) << "VerilogModule: SaveToBinary not supported";
}
std::string GetSource(const std::string& format) final { std::string GetSource(const std::string& format) final {
return m_.code; return m_.code;
} }
......
...@@ -196,14 +196,6 @@ int TVMModGetFunction(TVMModuleHandle mod, ...@@ -196,14 +196,6 @@ int TVMModGetFunction(TVMModuleHandle mod,
API_END(); API_END();
} }
int TVMModPreCompile(TVMModuleHandle mod,
const char* func_name,
TVMContext ctx) {
API_BEGIN();
(*static_cast<Module*>(mod))->PreCompile(func_name, ctx);
API_END();
}
int TVMModFree(TVMModuleHandle mod) { int TVMModFree(TVMModuleHandle mod) {
API_BEGIN(); API_BEGIN();
delete static_cast<Module*>(mod); delete static_cast<Module*>(mod);
......
...@@ -49,12 +49,6 @@ class CUDAModuleNode : public runtime::ModuleNode { ...@@ -49,12 +49,6 @@ class CUDAModuleNode : public runtime::ModuleNode {
return "cuda"; return "cuda";
} }
void PreCompile(const std::string& name, TVMContext ctx) final {
CUDA_CALL(cudaSetDevice(ctx.device_id));
cudaFree(nullptr);
this->GetFunc(ctx.device_id, name);
}
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final; const std::shared_ptr<ModuleNode>& sptr_to_self) final;
......
...@@ -30,10 +30,6 @@ class DSOModuleNode final : public ModuleNode { ...@@ -30,10 +30,6 @@ class DSOModuleNode final : public ModuleNode {
return "dso"; return "dso";
} }
void PreCompile(const std::string& name, TVMContext ctx) final {
GetFuncPtr(name);
}
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const std::shared_ptr<ModuleNode>& sptr_to_self) final {
...@@ -48,19 +44,6 @@ class DSOModuleNode final : public ModuleNode { ...@@ -48,19 +44,6 @@ class DSOModuleNode final : public ModuleNode {
}); });
} }
void SaveToFile(const std::string& file_name,
const std::string& format) final {
LOG(FATAL) << "DSOModule: SaveToFile not supported";
}
void SaveToBinary(dmlc::Stream* stream) final {
LOG(FATAL) << "DSOModule: SaveToBinary not supported";
}
std::string GetSource(const std::string& format) final {
return "";
}
void Init(const std::string& name) { void Init(const std::string& name) {
Load(name); Load(name);
CHECK(lib_handle_ != nullptr) CHECK(lib_handle_ != nullptr)
......
...@@ -37,10 +37,6 @@ class MetalModuleNode final :public runtime::ModuleNode { ...@@ -37,10 +37,6 @@ class MetalModuleNode final :public runtime::ModuleNode {
return "metal"; return "metal";
} }
void PreCompile(const std::string& name, TVMContext ctx) final {
GetPipelineState(ctx.device_id, name);
}
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final; const std::shared_ptr<ModuleNode>& sptr_to_self) final;
......
...@@ -62,6 +62,20 @@ Module Module::LoadFromFile(const std::string& file_name, ...@@ -62,6 +62,20 @@ Module Module::LoadFromFile(const std::string& file_name,
return m; return m;
} }
void ModuleNode::SaveToFile(const std::string& file_name,
const std::string& format) {
LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile";
}
void ModuleNode::SaveToBinary(dmlc::Stream* stream) {
LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToBinary";
}
std::string ModuleNode::GetSource(const std::string& format) {
LOG(FATAL) << "Module[" << type_key() << "] does not support GetSource";
return "";
}
const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) { const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
auto it = import_cache_.find(name); auto it = import_cache_.find(name);
if (it != import_cache_.end()) return it->second.get(); if (it != import_cache_.end()) return it->second.get();
......
...@@ -59,12 +59,6 @@ class OpenCLModuleNode : public ModuleNode { ...@@ -59,12 +59,6 @@ class OpenCLModuleNode : public ModuleNode {
return "opencl"; return "opencl";
} }
void PreCompile(const std::string& name, TVMContext ctx) final {
InstallKernel(cl::OpenCLWorkspace::Global(),
cl::OpenCLThreadEntry::ThreadLocal(),
name, kid_map_.at(name));
}
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final; const std::shared_ptr<ModuleNode>& sptr_to_self) final;
......
...@@ -45,9 +45,6 @@ class RPCModuleNode final : public ModuleNode { ...@@ -45,9 +45,6 @@ class RPCModuleNode final : public ModuleNode {
return "rpc"; return "rpc";
} }
void PreCompile(const std::string& name, TVMContext ctx) final {
}
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const std::shared_ptr<ModuleNode>& sptr_to_self) final {
...@@ -55,15 +52,6 @@ class RPCModuleNode final : public ModuleNode { ...@@ -55,15 +52,6 @@ class RPCModuleNode final : public ModuleNode {
return WrapRemote(handle); return WrapRemote(handle);
} }
void SaveToFile(const std::string& file_name,
const std::string& format) final {
LOG(FATAL) << "RPCModule: SaveToFile not supported";
}
void SaveToBinary(dmlc::Stream* stream) final {
LOG(FATAL) << "RPCModule: SaveToBinary not supported";
}
std::string GetSource(const std::string& format) final { std::string GetSource(const std::string& format) final {
if (module_handle_ != nullptr) { if (module_handle_ != nullptr) {
std::string ret = sess_->CallRemote( std::string ret = sess_->CallRemote(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment