Unverified Commit f63b249d by Zhi Committed by GitHub

refactor build module to take IRModule (#4988)

parent fe74b37a
......@@ -332,6 +332,16 @@ TVM_DLL Pass PartitionGraph();
*/
TVM_DLL Pass Inline();
/*!
* \brief Remove the unused functions in the Relay IRModule.
*
* \param entry_functions The entry functions used to search the functions that
* are being used.
*
* \return The pass.
*/
TVM_DLL Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions);
} // namespace transform
/*!
......
......@@ -62,7 +62,7 @@ def _convert_param_map(params):
class BuildModule(object):
"""Build a Relay function to run on TVM graph runtime. This class is used
"""Build an IR module to run on TVM graph runtime. This class is used
to expose the `RelayBuildModule` APIs implemented in C++.
"""
def __init__(self):
......@@ -74,12 +74,12 @@ class BuildModule(object):
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
def build(self, func, target=None, target_host=None, params=None):
def build(self, mod, target=None, target_host=None, params=None):
"""
Parameters
----------
func: relay.Function
The function to build.
mod : :py:class:`~tvm.IRModule`
The IRModule to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
......@@ -115,8 +115,8 @@ class BuildModule(object):
# Setup the params.
if params:
self._set_params(params)
# Build the function
self._build(func, target, target_host)
# Build the IR module
self._build(mod, target, target_host)
# Get artifacts
graph_json = self.get_json()
mod = self.get_module()
......@@ -124,12 +124,12 @@ class BuildModule(object):
return graph_json, mod, params
def optimize(self, func, target=None, params=None):
def optimize(self, mod, target=None, params=None):
"""
Parameters
----------
func: relay.Function
The function to build.
mod : :py:class:`~tvm.IRModule`
The IR module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
......@@ -142,7 +142,7 @@ class BuildModule(object):
Returns
-------
mod : tvm.IRModule
mod : :py:class:`~tvm.IRModule`
The optimized relay module.
params : dict
......@@ -153,7 +153,7 @@ class BuildModule(object):
# Setup the params.
if params:
self._set_params(params)
mod = self._optimize(func, target)
mod = self._optimize(mod, target)
# Get artifacts
params = self.get_params()
......@@ -186,8 +186,8 @@ def build(mod, target=None, target_host=None, params=None):
Parameters
----------
mod : tvm.IRModule
The module to build. Using relay.Function is deprecated.
mod : :py:class:`~tvm.IRModule`
The IR module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
name) to str/tvm.target.Target, optional
......@@ -218,16 +218,15 @@ def build(mod, target=None, target_host=None, params=None):
params : dict
The parameters of the final graph.
"""
if isinstance(mod, IRModule):
func = mod["main"]
elif isinstance(mod, _expr.Function):
func = mod
if not isinstance(mod, (IRModule, _expr.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")
if isinstance(mod, _expr.Function):
mod = IRModule.from_expr(mod)
warnings.warn(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.expr.Function)",
"instead of deprecated parameter mod (tvm.relay.expr.Function)",
DeprecationWarning)
else:
raise ValueError("Type of input parameter mod must be tvm.IRModule")
target = _update_target(target)
......@@ -246,7 +245,7 @@ def build(mod, target=None, target_host=None, params=None):
with tophub_context:
bld_mod = BuildModule()
graph_json, mod, params = bld_mod.build(func, target, target_host, params)
graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
return graph_json, mod, params
......@@ -255,7 +254,7 @@ def optimize(mod, target=None, params=None):
Parameters
----------
mod : tvm.IRModule
mod : :py:class:`~tvm.IRModule`
The module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
......@@ -269,22 +268,21 @@ def optimize(mod, target=None, params=None):
Returns
-------
mod : tvm.IRModule
mod : :py:class:`~tvm.IRModule`
The optimized relay module.
params : dict
The parameters of the final graph.
"""
if isinstance(mod, IRModule):
func = mod["main"]
elif isinstance(mod, _expr.Function):
func = mod
if not isinstance(mod, (IRModule, _expr.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")
if isinstance(mod, _expr.Function):
mod = IRModule.from_expr(mod)
warnings.warn(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.expr.Function)",
DeprecationWarning)
else:
raise ValueError("Type of input parameter mod must be tvm.IRModule")
target = _update_target(target)
......@@ -297,7 +295,7 @@ def optimize(mod, target=None, params=None):
with tophub_context:
bld_mod = BuildModule()
mod, params = bld_mod.optimize(func, target, params)
mod, params = bld_mod.optimize(mod, target, params)
return mod, params
......
......@@ -233,42 +233,46 @@ class RelayBuildModule : public runtime::ModuleNode {
}
/*!
* \brief Build relay function for graph runtime
* \brief Build relay IRModule for graph runtime
*
* \param func Relay Function
* \param mod Relay IRModule
* \param target Target device
* \param target_host Host target device
*/
void Build(Function func,
void Build(IRModule mod,
const TargetsMap& targets,
const tvm::Target& target_host) {
targets_ = targets;
target_host_ = target_host;
BuildRelay(func, params_);
BuildRelay(mod, params_);
}
protected:
/*!
* \brief Optimize a Relay Function.
* \brief Optimize a Relay IRModule.
*
* \param func The input Function where optmization will be applied on.
* \param relay_module The input IRModule where optmization will be applied on.
* \param targets The device type to `Target` mapping.
* \param params The param name to value mapping.
*
* \return relay::Module The updated Relay module after optimization.
* \return relay::IRModule The updated Relay IR module after optimization.
*/
IRModule Optimize(
Function func,
IRModule relay_module,
const TargetsMap& targets,
const std::unordered_map<std::string, runtime::NDArray>& params) {
if (params.size()) {
func = BindParamsByName(func, params);
CHECK(relay_module->ContainGlobalVar("main"))
<< "Missing the main entry function";
GlobalVar main_glb_var = relay_module->GetGlobalVar("main");
Function main_func = Downcast<Function>(relay_module->Lookup(main_glb_var));
auto new_main = BindParamsByName(main_func, params);
relay_module->Update(main_glb_var, new_main);
}
// Perform Module->Module optimizations.
IRModule relay_module = IRModule::FromExpr(func);
Array<Pass> pass_seqs;
Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());
......@@ -418,18 +422,18 @@ class RelayBuildModule : public runtime::ModuleNode {
}
/*!
* \brief Compile a Relay function to runtime module.
* \brief Compile a Relay IR module to runtime module.
*
* \param func The Relay function.
* \param relay_module The Relay IR module.
* \param params The parameters.
*/
void BuildRelay(
Function func,
IRModule relay_module,
const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
// Optimize input Relay Function and returns Relay Module
IRModule relay_module = Optimize(func, targets_, params);
// Relay IRModule -> IRModule optimizations.
relay_module = Optimize(relay_module, targets_, params);
// Get the updated function.
func = Downcast<Function>(relay_module->Lookup("main"));
auto func = Downcast<Function>(relay_module->Lookup("main"));
// Generate code for the updated function.
graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
......
......@@ -51,7 +51,6 @@ namespace transform {
Pass LambdaLift();
Pass InlinePrimitives();
Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions);
Pass ManifestAlloc(Target target_host) {
auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc");
......
......@@ -29,6 +29,7 @@
#include <topi/broadcast.h>
#include <topi/generic/injective.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/ir/module.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
......@@ -115,7 +116,8 @@ TEST(Relay, BuildModule) {
Map<tvm::Integer, tvm::Target> targets;
Target llvm_tgt = Target::Create("llvm");
targets.Set(0, llvm_tgt);
build_f(func, targets, llvm_tgt);
auto relay_mod = tvm::IRModule::FromExpr(func);
build_f(relay_mod, targets, llvm_tgt);
std::string json = json_f();
tvm::runtime::Module mod = mod_f();
// run
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment