Commit 2386e74b by LiangHao Committed by Yao Wang

Optimizing autotvm task extraction speed (#4138)

* Optimize task extraction speed

* correct pylint errors

* Delete unused function

* remove unnecessary argument

* resolve code review comments

* corrent cpp lint errors

* remove one more graph_json return value

* fix test bugs
parent 794db370
...@@ -31,23 +31,28 @@ from .topi_integration import TaskExtractEnv ...@@ -31,23 +31,28 @@ from .topi_integration import TaskExtractEnv
logger = logging.getLogger('autotvm') logger = logging.getLogger('autotvm')
# TODO(moreau89) find a more elegant way to build for VTAs # TODO(moreau89) find a more elegant way to lower for VTAs
def _build(func, def _lower(func,
target, target,
target_host,
params): params):
""" Helper to build VTA properly. """ Helper to lower VTA properly.
""" """
from tvm import relay from tvm import relay
from tvm.relay.backend import graph_runtime_codegen
if hasattr(target, 'device_name') and target.device_name == "vta": if hasattr(target, 'device_name') and target.device_name == "vta":
with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
import vta import vta
with vta.build_config(): with vta.build_config():
return relay.build(func, target, target_host, params) mod, _ = relay.optimize(func, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
return grc.codegen(mod["main"])
# default case # default case
return relay.build(func, target, target_host, params) mod, _ = relay.optimize(func, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
return grc.codegen(mod["main"])
def extract_from_program(func, params, ops, target, target_host=None): def extract_from_program(func, params, ops, target, target_host=None):
""" Extract tuning tasks from a relay program. """ Extract tuning tasks from a relay program.
...@@ -133,8 +138,8 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None): ...@@ -133,8 +138,8 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
relay.backend.compile_engine.get().clear() relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems # wrap build call in thread to avoid multiprocessing problems
mod = relay.Module.from_expr(func) mod = relay.Module.from_expr(func)
build_thread = threading.Thread(target=_build, build_thread = threading.Thread(target=_lower,
args=(mod, target, target_host, param)) args=(mod, target, param))
build_thread.start() build_thread.start()
build_thread.join() build_thread.join()
......
...@@ -28,7 +28,7 @@ from . import module ...@@ -28,7 +28,7 @@ from . import module
from . import adt from . import adt
from . import analysis from . import analysis
from . import transform from . import transform
from .build_module import build, create_executor from .build_module import build, create_executor, optimize
from .transform import build_config from .transform import build_config
from . import prelude from . import prelude
from . import parser from . import parser
......
...@@ -36,7 +36,7 @@ contrib.graph_runtime or any other TVM runtime compatible systems. ...@@ -36,7 +36,7 @@ contrib.graph_runtime or any other TVM runtime compatible systems.
from __future__ import absolute_import from __future__ import absolute_import
from tvm.ndarray import empty from tvm.ndarray import empty
from tvm.relay import build_module from tvm.relay import _build_module
from tvm import target as _target from tvm import target as _target
from tvm import expr as _expr from tvm import expr as _expr
...@@ -44,7 +44,7 @@ class GraphRuntimeCodegen(object): ...@@ -44,7 +44,7 @@ class GraphRuntimeCodegen(object):
"""The compiler from Relay to the TVM runtime system.""" """The compiler from Relay to the TVM runtime system."""
def __init__(self, mod, target): def __init__(self, mod, target):
self._mod = build_module._GraphRuntimeCodegen() self._mod = _build_module._GraphRuntimeCodegen()
self._init = self._mod["init"] self._init = self._mod["init"]
self._codegen = self._mod["codegen"] self._codegen = self._mod["codegen"]
self._get_graph_json = self._mod["get_graph_json"] self._get_graph_json = self._mod["get_graph_json"]
......
...@@ -60,6 +60,7 @@ class BuildModule(object): ...@@ -60,6 +60,7 @@ class BuildModule(object):
self._get_graph_json = self.mod["get_graph_json"] self._get_graph_json = self.mod["get_graph_json"]
self._get_module = self.mod["get_module"] self._get_module = self.mod["get_module"]
self._build = self.mod["build"] self._build = self.mod["build"]
self._optimize = self.mod["optimize"]
self._set_params_func = self.mod["set_params"] self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"] self._get_params_func = self.mod["get_params"]
...@@ -113,6 +114,42 @@ class BuildModule(object): ...@@ -113,6 +114,42 @@ class BuildModule(object):
return graph_json, mod, params return graph_json, mod, params
def optimize(self, func, target=None, params=None):
"""
Parameters
----------
func: relay.Function
The function to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
mod : relay.Module
The optimized relay module.
params : dict
The parameters of the final graph.
"""
target = _update_target(target)
# Setup the params.
if params:
self._set_params(params)
mod = self._optimize(func, target)
# Get artifacts
params = self.get_params()
return mod, params
def _set_params(self, params): def _set_params(self, params):
inputs = {} inputs = {}
for name, param in params.items(): for name, param in params.items():
...@@ -208,6 +245,57 @@ def build(mod, target=None, target_host=None, params=None): ...@@ -208,6 +245,57 @@ def build(mod, target=None, target_host=None, params=None):
return graph_json, mod, params return graph_json, mod, params
def optimize(mod, target=None, params=None):
"""Helper function that optimizes a Relay module.
Parameters
----------
mod : relay.Module
The module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context to
target mapping. For homogeneous compilation, it is a build target.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
mod : relay.Module
The optimized relay module.
params : dict
The parameters of the final graph.
"""
if isinstance(mod, _Module):
func = mod["main"]
elif isinstance(mod, _expr.Function):
func = mod
warnings.warn(
"Please use input parameter mod (tvm.relay.module.Module) "
"instead of deprecated parameter func (tvm.relay.expr.Function)",
DeprecationWarning)
else:
raise ValueError("Type of input parameter mod must be tvm.relay.module.Module")
target = _update_target(target)
# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
tophub_context = autotvm.tophub.context(list(target.values()))
else:
tophub_context = autotvm.util.EmptyContext()
with tophub_context:
bld_mod = BuildModule()
mod, params = bld_mod.optimize(func, target, params)
return mod, params
class GraphExecutor(_interpreter.Executor): class GraphExecutor(_interpreter.Executor):
"""Wrapper around Executor interface. """Wrapper around Executor interface.
......
...@@ -148,6 +148,11 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -148,6 +148,11 @@ class RelayBuildModule : public runtime::ModuleNode {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->graph_codegen_->GetLoweredFunc(); *rv = this->graph_codegen_->GetLoweredFunc();
}); });
} else if (name == "optimize") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 2);
*rv = this->Optimize(args[0], args[1], this->params_);
});
} else { } else {
LOG(FATAL) << "Unknown packed function: " << name; LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
...@@ -273,19 +278,25 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -273,19 +278,25 @@ class RelayBuildModule : public runtime::ModuleNode {
} }
/*! /*!
* \brief Optimize a Relay module. * \brief Optimize a Relay Function.
* *
* \param relay_module The input Relay module where optmization will be * \param func The input Function where optmization will be applied on.
* applied on.
* \param targets The device type to `Target` mapping. * \param targets The device type to `Target` mapping.
* \param params The param name to value mapping. * \param params The param name to value mapping.
* *
* \return relay::Module The updated Relay module after optimization. * \return relay::Module The updated Relay module after optimization.
*/ */
relay::Module Optimize( relay::Module Optimize(
relay::Module relay_module, Function func,
const TargetsMap& targets, const TargetsMap& targets,
const std::unordered_map<std::string, runtime::NDArray>& params) { const std::unordered_map<std::string, runtime::NDArray>& params) {
if (params.size()) {
func = BindParamsByName(func, params);
}
// Perform Module->Module optimizations.
relay::Module relay_module = relay::ModuleNode::FromExpr(func);
Array<Pass> pass_seqs; Array<Pass> pass_seqs;
// Run all dialect legalization passes. // Run all dialect legalization passes.
...@@ -345,6 +356,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -345,6 +356,7 @@ class RelayBuildModule : public runtime::ModuleNode {
// Fuse the operations if it is needed. // Fuse the operations if it is needed.
relay_module = transform::FuseOps()(relay_module); relay_module = transform::FuseOps()(relay_module);
relay_module = transform::InferType()(relay_module); relay_module = transform::InferType()(relay_module);
CHECK(relay_module.defined());
return relay_module; return relay_module;
} }
...@@ -440,14 +452,8 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -440,14 +452,8 @@ class RelayBuildModule : public runtime::ModuleNode {
void BuildRelay( void BuildRelay(
Function func, Function func,
const std::unordered_map<std::string, tvm::runtime::NDArray>& params) { const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
if (params.size()) { // Optimize input Relay Function and returns Relay Module
func = BindParamsByName(func, params); relay::Module relay_module = Optimize(func, targets_, params);
}
// Perform Module->Module optimizations.
relay::Module relay_module = relay::ModuleNode::FromExpr(func);
relay_module = Optimize(relay_module, targets_, params);
CHECK(relay_module.defined());
// Get the updated function. // Get the updated function.
func = relay_module->Lookup("main"); func = relay_module->Lookup("main");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment