Unverified Commit f63b249d by Zhi Committed by GitHub

refactor build module to take IRModule (#4988)

parent fe74b37a
...@@ -332,6 +332,16 @@ TVM_DLL Pass PartitionGraph(); ...@@ -332,6 +332,16 @@ TVM_DLL Pass PartitionGraph();
*/ */
TVM_DLL Pass Inline(); TVM_DLL Pass Inline();
/*!
* \brief Remove the unused functions in the Relay IRModule.
*
* \param entry_functions The entry functions used to search the functions that
* are being used.
*
* \return The pass.
*/
TVM_DLL Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions);
} // namespace transform } // namespace transform
/*! /*!
......
...@@ -62,7 +62,7 @@ def _convert_param_map(params): ...@@ -62,7 +62,7 @@ def _convert_param_map(params):
class BuildModule(object): class BuildModule(object):
"""Build a Relay function to run on TVM graph runtime. This class is used """Build an IR module to run on TVM graph runtime. This class is used
to expose the `RelayBuildModule` APIs implemented in C++. to expose the `RelayBuildModule` APIs implemented in C++.
""" """
def __init__(self): def __init__(self):
...@@ -74,12 +74,12 @@ class BuildModule(object): ...@@ -74,12 +74,12 @@ class BuildModule(object):
self._set_params_func = self.mod["set_params"] self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"] self._get_params_func = self.mod["get_params"]
def build(self, func, target=None, target_host=None, params=None): def build(self, mod, target=None, target_host=None, params=None):
""" """
Parameters Parameters
---------- ----------
func: relay.Function mod : :py:class:`~tvm.IRModule`
The function to build. The IRModule to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional device/context name) to str/tvm.target.Target, optional
...@@ -115,8 +115,8 @@ class BuildModule(object): ...@@ -115,8 +115,8 @@ class BuildModule(object):
# Setup the params. # Setup the params.
if params: if params:
self._set_params(params) self._set_params(params)
# Build the function # Build the IR module
self._build(func, target, target_host) self._build(mod, target, target_host)
# Get artifacts # Get artifacts
graph_json = self.get_json() graph_json = self.get_json()
mod = self.get_module() mod = self.get_module()
...@@ -124,12 +124,12 @@ class BuildModule(object): ...@@ -124,12 +124,12 @@ class BuildModule(object):
return graph_json, mod, params return graph_json, mod, params
def optimize(self, func, target=None, params=None): def optimize(self, mod, target=None, params=None):
""" """
Parameters Parameters
---------- ----------
func: relay.Function mod : :py:class:`~tvm.IRModule`
The function to build. The IR module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional device/context name) to str/tvm.target.Target, optional
...@@ -142,7 +142,7 @@ class BuildModule(object): ...@@ -142,7 +142,7 @@ class BuildModule(object):
Returns Returns
------- -------
mod : tvm.IRModule mod : :py:class:`~tvm.IRModule`
The optimized relay module. The optimized relay module.
params : dict params : dict
...@@ -153,7 +153,7 @@ class BuildModule(object): ...@@ -153,7 +153,7 @@ class BuildModule(object):
# Setup the params. # Setup the params.
if params: if params:
self._set_params(params) self._set_params(params)
mod = self._optimize(func, target) mod = self._optimize(mod, target)
# Get artifacts # Get artifacts
params = self.get_params() params = self.get_params()
...@@ -186,8 +186,8 @@ def build(mod, target=None, target_host=None, params=None): ...@@ -186,8 +186,8 @@ def build(mod, target=None, target_host=None, params=None):
Parameters Parameters
---------- ----------
mod : tvm.IRModule mod : :py:class:`~tvm.IRModule`
The module to build. Using relay.Function is deprecated. The IR module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
name) to str/tvm.target.Target, optional name) to str/tvm.target.Target, optional
...@@ -218,16 +218,15 @@ def build(mod, target=None, target_host=None, params=None): ...@@ -218,16 +218,15 @@ def build(mod, target=None, target_host=None, params=None):
params : dict params : dict
The parameters of the final graph. The parameters of the final graph.
""" """
if isinstance(mod, IRModule): if not isinstance(mod, (IRModule, _expr.Function)):
func = mod["main"] raise ValueError("Type of input parameter mod must be tvm.IRModule")
elif isinstance(mod, _expr.Function):
func = mod if isinstance(mod, _expr.Function):
mod = IRModule.from_expr(mod)
warnings.warn( warnings.warn(
"Please use input parameter mod (tvm.IRModule) " "Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.expr.Function)", "instead of deprecated parameter mod (tvm.relay.expr.Function)",
DeprecationWarning) DeprecationWarning)
else:
raise ValueError("Type of input parameter mod must be tvm.IRModule")
target = _update_target(target) target = _update_target(target)
...@@ -246,7 +245,7 @@ def build(mod, target=None, target_host=None, params=None): ...@@ -246,7 +245,7 @@ def build(mod, target=None, target_host=None, params=None):
with tophub_context: with tophub_context:
bld_mod = BuildModule() bld_mod = BuildModule()
graph_json, mod, params = bld_mod.build(func, target, target_host, params) graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
return graph_json, mod, params return graph_json, mod, params
...@@ -255,7 +254,7 @@ def optimize(mod, target=None, params=None): ...@@ -255,7 +254,7 @@ def optimize(mod, target=None, params=None):
Parameters Parameters
---------- ----------
mod : tvm.IRModule mod : :py:class:`~tvm.IRModule`
The module to build. Using relay.Function is deprecated. The module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
...@@ -269,22 +268,21 @@ def optimize(mod, target=None, params=None): ...@@ -269,22 +268,21 @@ def optimize(mod, target=None, params=None):
Returns Returns
------- -------
mod : tvm.IRModule mod : :py:class:`~tvm.IRModule`
The optimized relay module. The optimized relay module.
params : dict params : dict
The parameters of the final graph. The parameters of the final graph.
""" """
if isinstance(mod, IRModule): if not isinstance(mod, (IRModule, _expr.Function)):
func = mod["main"] raise ValueError("Type of input parameter mod must be tvm.IRModule")
elif isinstance(mod, _expr.Function):
func = mod if isinstance(mod, _expr.Function):
mod = IRModule.from_expr(mod)
warnings.warn( warnings.warn(
"Please use input parameter mod (tvm.IRModule) " "Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.expr.Function)", "instead of deprecated parameter func (tvm.relay.expr.Function)",
DeprecationWarning) DeprecationWarning)
else:
raise ValueError("Type of input parameter mod must be tvm.IRModule")
target = _update_target(target) target = _update_target(target)
...@@ -297,7 +295,7 @@ def optimize(mod, target=None, params=None): ...@@ -297,7 +295,7 @@ def optimize(mod, target=None, params=None):
with tophub_context: with tophub_context:
bld_mod = BuildModule() bld_mod = BuildModule()
mod, params = bld_mod.optimize(func, target, params) mod, params = bld_mod.optimize(mod, target, params)
return mod, params return mod, params
......
...@@ -233,42 +233,46 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -233,42 +233,46 @@ class RelayBuildModule : public runtime::ModuleNode {
} }
/*! /*!
* \brief Build relay function for graph runtime * \brief Build relay IRModule for graph runtime
* *
* \param func Relay Function * \param mod Relay IRModule
* \param target Target device * \param target Target device
* \param target_host Host target device * \param target_host Host target device
*/ */
void Build(Function func, void Build(IRModule mod,
const TargetsMap& targets, const TargetsMap& targets,
const tvm::Target& target_host) { const tvm::Target& target_host) {
targets_ = targets; targets_ = targets;
target_host_ = target_host; target_host_ = target_host;
BuildRelay(func, params_); BuildRelay(mod, params_);
} }
protected: protected:
/*! /*!
* \brief Optimize a Relay Function. * \brief Optimize a Relay IRModule.
* *
* \param func The input Function where optmization will be applied on. * \param relay_module The input IRModule where optmization will be applied on.
* \param targets The device type to `Target` mapping. * \param targets The device type to `Target` mapping.
* \param params The param name to value mapping. * \param params The param name to value mapping.
* *
* \return relay::Module The updated Relay module after optimization. * \return relay::IRModule The updated Relay IR module after optimization.
*/ */
IRModule Optimize( IRModule Optimize(
Function func, IRModule relay_module,
const TargetsMap& targets, const TargetsMap& targets,
const std::unordered_map<std::string, runtime::NDArray>& params) { const std::unordered_map<std::string, runtime::NDArray>& params) {
if (params.size()) { if (params.size()) {
func = BindParamsByName(func, params); CHECK(relay_module->ContainGlobalVar("main"))
<< "Missing the main entry function";
GlobalVar main_glb_var = relay_module->GetGlobalVar("main");
Function main_func = Downcast<Function>(relay_module->Lookup(main_glb_var));
auto new_main = BindParamsByName(main_func, params);
relay_module->Update(main_glb_var, new_main);
} }
// Perform Module->Module optimizations.
IRModule relay_module = IRModule::FromExpr(func);
Array<Pass> pass_seqs; Array<Pass> pass_seqs;
Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
// Run all dialect legalization passes. // Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize()); pass_seqs.push_back(relay::qnn::transform::Legalize());
...@@ -418,18 +422,18 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -418,18 +422,18 @@ class RelayBuildModule : public runtime::ModuleNode {
} }
/*! /*!
* \brief Compile a Relay function to runtime module. * \brief Compile a Relay IR module to runtime module.
* *
* \param func The Relay function. * \param relay_module The Relay IR module.
* \param params The parameters. * \param params The parameters.
*/ */
void BuildRelay( void BuildRelay(
Function func, IRModule relay_module,
const std::unordered_map<std::string, tvm::runtime::NDArray>& params) { const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
// Optimize input Relay Function and returns Relay Module // Relay IRModule -> IRModule optimizations.
IRModule relay_module = Optimize(func, targets_, params); relay_module = Optimize(relay_module, targets_, params);
// Get the updated function. // Get the updated function.
func = Downcast<Function>(relay_module->Lookup("main")); auto func = Downcast<Function>(relay_module->Lookup("main"));
// Generate code for the updated function. // Generate code for the updated function.
graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen()); graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
......
...@@ -51,7 +51,6 @@ namespace transform { ...@@ -51,7 +51,6 @@ namespace transform {
Pass LambdaLift(); Pass LambdaLift();
Pass InlinePrimitives(); Pass InlinePrimitives();
Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions);
Pass ManifestAlloc(Target target_host) { Pass ManifestAlloc(Target target_host) {
auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc"); auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc");
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <topi/broadcast.h> #include <topi/broadcast.h>
#include <topi/generic/injective.h> #include <topi/generic/injective.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/ir/module.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
...@@ -115,7 +116,8 @@ TEST(Relay, BuildModule) { ...@@ -115,7 +116,8 @@ TEST(Relay, BuildModule) {
Map<tvm::Integer, tvm::Target> targets; Map<tvm::Integer, tvm::Target> targets;
Target llvm_tgt = Target::Create("llvm"); Target llvm_tgt = Target::Create("llvm");
targets.Set(0, llvm_tgt); targets.Set(0, llvm_tgt);
build_f(func, targets, llvm_tgt); auto relay_mod = tvm::IRModule::FromExpr(func);
build_f(relay_mod, targets, llvm_tgt);
std::string json = json_f(); std::string json = json_f();
tvm::runtime::Module mod = mod_f(); tvm::runtime::Module mod = mod_f();
// run // run
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment