Commit baae28b2 by Haichen Shen Committed by Zhi

[Autotvm] Use VM compile to extract autotvm tasks (#4328)

* [AutoTVM] Use vm compile in extracting task from relay

* update

* restructure vm compiler to reduce task extraction time

* x

* fix

* update doc

* udpate doc

* lint
parent 2077cd57
...@@ -28,3 +28,6 @@ tvm.relay.backend ...@@ -28,3 +28,6 @@ tvm.relay.backend
.. automodule:: tvm.relay.backend.graph_runtime_codegen .. automodule:: tvm.relay.backend.graph_runtime_codegen
:members: :members:
.. automodule:: tvm.relay.backend.vm
:members:
...@@ -32,7 +32,7 @@ logger = logging.getLogger('autotvm') ...@@ -32,7 +32,7 @@ logger = logging.getLogger('autotvm')
# TODO(moreau89) find a more elegant way to lower for VTAs # TODO(moreau89) find a more elegant way to lower for VTAs
def _lower(func, def _lower(mod,
target, target,
params): params):
""" Helper to lower VTA properly. """ Helper to lower VTA properly.
...@@ -45,16 +45,16 @@ def _lower(func, ...@@ -45,16 +45,16 @@ def _lower(func,
with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
import vta import vta
with vta.build_config(): with vta.build_config():
mod, _ = relay.optimize(func, target, params) mod, _ = relay.optimize(mod, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
return grc.codegen(mod["main"]) grc.codegen(mod["main"])
# default case # default case
mod, _ = relay.optimize(func, target, params) compiler = relay.vm.VMCompiler()
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) compiler.set_params(params)
return grc.codegen(mod["main"]) compiler.lower(mod, target=target)
def extract_from_program(func, params, ops, target, target_host=None, def extract_from_program(mod, params, ops, target, target_host=None,
template_keys=None): template_keys=None):
""" Extract tuning tasks from a relay program. """ Extract tuning tasks from a relay program.
...@@ -62,8 +62,8 @@ def extract_from_program(func, params, ops, target, target_host=None, ...@@ -62,8 +62,8 @@ def extract_from_program(func, params, ops, target, target_host=None,
Parameters Parameters
---------- ----------
func: relay.expr.Function mod: relay.module.Module or relay.expr.Function
The func to tune The module or function to tune
params: dict of str to numpy array params: dict of str to numpy array
The associated parameters of the program The associated parameters of the program
ops: List of relay op ops: List of relay op
...@@ -81,11 +81,11 @@ def extract_from_program(func, params, ops, target, target_host=None, ...@@ -81,11 +81,11 @@ def extract_from_program(func, params, ops, target, target_host=None,
task: Array of autotvm.task.Task task: Array of autotvm.task.Task
collected tasks collected tasks
""" """
return extract_from_multiple_program([func], [params], ops, target, target_host, return extract_from_multiple_program([mod], [params], ops, target, target_host,
template_keys=template_keys) template_keys)
def extract_from_multiple_program(funcs, params, ops, target, target_host=None, def extract_from_multiple_program(mods, params, ops, target, target_host=None,
template_keys=None): template_keys=None):
""" Extract tuning tasks from multiple relay programs. """ Extract tuning tasks from multiple relay programs.
...@@ -94,8 +94,8 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None, ...@@ -94,8 +94,8 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None,
Parameters Parameters
---------- ----------
funcs: List of relay.expr.Function mods: List[relay.module.Module] or List[relay.expr.Function]
The list of functions to tune The list of modules or functions to tune
params: List of dict of str to numpy array params: List of dict of str to numpy array
The associated parameters of the programs The associated parameters of the programs
ops: List of relay op ops: List of relay op
...@@ -145,10 +145,13 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None, ...@@ -145,10 +145,13 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None,
old_state = logger.disabled old_state = logger.disabled
logger.disabled = True logger.disabled = True
for func, param in zip(funcs, params): for mod, param in zip(mods, params):
if isinstance(mod, relay.expr.Function):
mod = relay.Module.from_expr(mod)
assert isinstance(mod, relay.module.Module), \
"only support relay Module or Function to be tuned"
relay.backend.compile_engine.get().clear() relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems # wrap build call in thread to avoid multiprocessing problems
mod = relay.Module.from_expr(func)
build_thread = threading.Thread(target=_lower, build_thread = threading.Thread(target=_lower,
args=(mod, target, param)) args=(mod, target, param))
build_thread.start() build_thread.start()
......
...@@ -363,7 +363,8 @@ class VirtualMachine(object): ...@@ -363,7 +363,8 @@ class VirtualMachine(object):
def compile(mod, target=None, target_host=None, params=None): def compile(mod, target=None, target_host=None, params=None):
""" """Compile the module to VM executable. A helper function for VMCompiler.
Parameters Parameters
---------- ----------
mod : relay.Module mod : relay.Module
...@@ -393,26 +394,31 @@ def compile(mod, target=None, target_host=None, params=None): ...@@ -393,26 +394,31 @@ def compile(mod, target=None, target_host=None, params=None):
The VM executable that contains both library code and bytecode. The VM executable that contains both library code and bytecode.
""" """
compiler = VMCompiler() compiler = VMCompiler()
target = compiler.update_target(target)
target_host = compiler.update_target_host(target, target_host)
if params: if params:
compiler.set_params(params) compiler.set_params(params)
tophub_context = compiler.tophub_context(target) compiler.lower(mod, target, target_host)
with tophub_context: compiler.codegen()
compiler._compile(mod, target, target_host) return compiler.get_exec()
return Executable(compiler._get_exec())
class VMCompiler(object): class VMCompiler(object):
"""Build Relay module to run on VM runtime.""" """Compiler that compiles Relay module to VM executable."""
def __init__(self): def __init__(self):
self.mod = _vm._VMCompiler() self.mod = _vm._VMCompiler()
self._compile = self.mod["compile"] self._lower = self.mod["lower"]
self._codegen = self.mod["codegen"]
self._get_exec = self.mod["get_executable"] self._get_exec = self.mod["get_executable"]
self._set_params_func = self.mod["set_params"] self._set_params_func = self.mod["set_params"]
def set_params(self, params): def set_params(self, params):
"""Set constant parameters for the model""" """Set constant parameters for the model.
Parameters
----------
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
"""
inputs = {} inputs = {}
for name, param in params.items(): for name, param in params.items():
if isinstance(param, np.ndarray): if isinstance(param, np.ndarray):
...@@ -420,8 +426,50 @@ class VMCompiler(object): ...@@ -420,8 +426,50 @@ class VMCompiler(object):
inputs[name] = _expr.const(param) inputs[name] = _expr.const(param)
self._set_params_func(inputs) self._set_params_func(inputs)
def update_target(self, target): def lower(self, mod, target=None, target_host=None):
"""Update target""" """Lower the module to VM bytecode.
Parameters
----------
mod : relay.Module
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
"""
target = self._update_target(target)
target_host = self._update_target_host(target, target_host)
tophub_context = self._tophub_context(target)
with tophub_context:
self._lower(mod, target, target_host)
def codegen(self):
"""Generate the kernel library."""
self._codegen()
def get_exec(self):
"""Get the VM executable.
Returns
-------
exec : Executable
The VM executable that contains both library code and bytecode.
"""
return Executable(self._get_exec())
def _update_target(self, target):
"""Update target."""
target = target if target else tvm.target.current_target() target = target if target else tvm.target.current_target()
if target is None: if target is None:
raise ValueError("Target is not set in env or passed as argument.") raise ValueError("Target is not set in env or passed as argument.")
...@@ -439,8 +487,8 @@ class VMCompiler(object): ...@@ -439,8 +487,8 @@ class VMCompiler(object):
"{}".format(type(target))) "{}".format(type(target)))
return tgts return tgts
def update_target_host(self, target, target_host): def _update_target_host(self, target, target_host):
"""Update target host""" """Update target host."""
target_host = None if target_host == "" else target_host target_host = None if target_host == "" else target_host
if not target_host: if not target_host:
for device_type, tgt in target.items(): for device_type, tgt in target.items():
...@@ -449,9 +497,12 @@ class VMCompiler(object): ...@@ -449,9 +497,12 @@ class VMCompiler(object):
break break
if not target_host: if not target_host:
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm" target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
return tvm.target.create(target_host) if isinstance(target_host, str):
target_host = tvm.target.create(target_host)
return target_host
def tophub_context(self, target): def _tophub_context(self, target):
"""Get the autotvm context."""
# If current dispatch context is fallback context (the default root context), # If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub # then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
......
...@@ -743,11 +743,16 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -743,11 +743,16 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
PackedFunc VMCompiler::GetFunction(const std::string& name, PackedFunc VMCompiler::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
if (name == "compile") { if (name == "lower") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3); CHECK_EQ(args.num_args, 3);
Module mod = args[0]; Module mod = args[0];
this->Compile(mod, args[1], args[2]); this->Lower(mod, args[1], args[2]);
});
} else if (name == "codegen") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 0);
this->Codegen();
}); });
} else if (name == "get_executable") { } else if (name == "get_executable") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
...@@ -802,9 +807,9 @@ relay::Function VMCompiler::BindParamsByName( ...@@ -802,9 +807,9 @@ relay::Function VMCompiler::BindParamsByName(
return ret; return ret;
} }
void VMCompiler::Compile(Module mod, void VMCompiler::Lower(Module mod,
const TargetsMap& targets, const TargetsMap& targets,
const tvm::Target& target_host) { const tvm::Target& target_host) {
CHECK_EQ(targets.size(), 1) CHECK_EQ(targets.size(), 1)
<< "Currently VM compiler doesn't support heterogeneous compilation"; << "Currently VM compiler doesn't support heterogeneous compilation";
if (params_.size()) { if (params_.size()) {
...@@ -813,7 +818,7 @@ void VMCompiler::Compile(Module mod, ...@@ -813,7 +818,7 @@ void VMCompiler::Compile(Module mod,
mod->Add(gvar, f); mod->Add(gvar, f);
} }
InitVM(); exec_ = make_object<Executable>();
targets_ = targets; targets_ = targets;
target_host_ = target_host; target_host_ = target_host;
...@@ -852,11 +857,20 @@ void VMCompiler::Compile(Module mod, ...@@ -852,11 +857,20 @@ void VMCompiler::Compile(Module mod,
exec_->constants.push_back(vm::Tensor(data)); exec_->constants.push_back(vm::Tensor(data));
} }
LibraryCodegen(); // update global function map
for (auto gv : context_.global_map) { for (auto gv : context_.global_map) {
exec_->global_map.insert({gv.first->name_hint, gv.second}); exec_->global_map.insert({gv.first->name_hint, gv.second});
} }
// update primitive function map
size_t primitive_index = 0;
for (const auto& cfunc : context_.cached_funcs) {
if (cfunc->target->str() == "ext_dev") {
exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
} else {
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
}
}
} }
Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) { Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
...@@ -942,7 +956,11 @@ void VMCompiler::PopulateGlobalMap() { ...@@ -942,7 +956,11 @@ void VMCompiler::PopulateGlobalMap() {
} }
} }
void VMCompiler::LibraryCodegen() { void VMCompiler::Codegen() {
if (!context_.module.defined()) {
LOG(WARNING) << "Did you forget to call VMCompiler::Lower?";
return;
}
auto const &cached_funcs = context_.cached_funcs; auto const &cached_funcs = context_.cached_funcs;
if (cached_funcs.size() == 0) { if (cached_funcs.size() == 0) {
return; return;
...@@ -980,14 +998,6 @@ void VMCompiler::LibraryCodegen() { ...@@ -980,14 +998,6 @@ void VMCompiler::LibraryCodegen() {
} }
} }
exec_->lib = mod; exec_->lib = mod;
size_t primitive_index = 0;
for (auto cfunc : cached_funcs) {
if (cfunc->target->str() == "ext_dev") {
exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
} else {
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
}
}
} }
runtime::Module CreateVMCompiler() { runtime::Module CreateVMCompiler() {
......
...@@ -91,10 +91,6 @@ class VMCompiler : public runtime::ModuleNode { ...@@ -91,10 +91,6 @@ class VMCompiler : public runtime::ModuleNode {
return "VMCompiler"; return "VMCompiler";
} }
void InitVM() {
exec_ = make_object<Executable>();
}
/*! /*!
* \brief Set the parameters * \brief Set the parameters
* *
...@@ -104,16 +100,19 @@ class VMCompiler : public runtime::ModuleNode { ...@@ -104,16 +100,19 @@ class VMCompiler : public runtime::ModuleNode {
void SetParam(const std::string& name, runtime::NDArray data_in); void SetParam(const std::string& name, runtime::NDArray data_in);
/*! /*!
* \brief Compile functions in a Module * \brief Lower the functions in a Module
* *
* \param mod Relay Module * \param mod Relay Module
* \param targets For heterogeneous compilation, it is a dictionary indicating context * \param targets For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target. to target mapping. For homogeneous compilation, it is a build target.
* \param target_host Host compilation target, if target is device. * \param target_host Host compilation target, if target is device.
*/ */
void Compile(Module mod, void Lower(Module mod,
const TargetsMap& targets, const TargetsMap& targets,
const tvm::Target& target_host); const tvm::Target& target_host);
/*! \brief Generate the machine code for lowered functions. */
void Codegen();
protected: protected:
/*! /*!
...@@ -130,8 +129,6 @@ class VMCompiler : public runtime::ModuleNode { ...@@ -130,8 +129,6 @@ class VMCompiler : public runtime::ModuleNode {
void PopulateGlobalMap(); void PopulateGlobalMap();
void LibraryCodegen();
protected: protected:
/*! \brief Target devices. */ /*! \brief Target devices. */
TargetsMap targets_; TargetsMap targets_;
......
...@@ -45,12 +45,20 @@ def test_task_extraction(): ...@@ -45,12 +45,20 @@ def test_task_extraction():
params=params, params=params,
ops=(relay.op.nn.conv2d,)) ops=(relay.op.nn.conv2d,))
assert len(tasks) == 12 assert len(tasks) == 12
tasks = autotvm.task.extract_from_program(mod, target=target,
params=params,
ops=(relay.op.nn.conv2d,))
assert len(tasks) == 12
mod, params, _ = get_network('resnet-18', batch_size=1) mod, params, _ = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(mod["main"], target=target, tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params, params=params,
ops=(relay.op.nn.dense,)) ops=(relay.op.nn.dense,))
assert len(tasks) == 1 assert len(tasks) == 1
tasks = autotvm.task.extract_from_program(mod, target=target,
params=params,
ops=(relay.op.nn.dense,))
assert len(tasks) == 1
mod, params, _ = get_network('resnet-18', batch_size=1) mod, params, _ = get_network('resnet-18', batch_size=1)
mod_list.append(mod) mod_list.append(mod)
...@@ -59,22 +67,26 @@ def test_task_extraction(): ...@@ -59,22 +67,26 @@ def test_task_extraction():
params=params, params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense)) ops=(relay.op.nn.conv2d, relay.op.nn.dense))
assert len(tasks) == 13 assert len(tasks) == 13
tasks = autotvm.task.extract_from_program(mod, target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
assert len(tasks) == 13
mod, params, _ = get_network('mobilenet', batch_size=1) mod, params, _ = get_network('mobilenet', batch_size=1)
mod_list.append(mod) mod_list.append(mod)
params_list.append(params) params_list.append(params)
tasks = autotvm.task.extract_from_program(mod["main"], target=target, tasks = autotvm.task.extract_from_program(mod, target=target,
params=params, params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense)) ops=(relay.op.nn.conv2d, relay.op.nn.dense))
assert len(tasks) == 20 assert len(tasks) == 20
mod, params, _ = get_network('dcgan', batch_size=1) mod, params, _ = get_network('dcgan', batch_size=1)
tasks = autotvm.task.extract_from_program(mod["main"], target=target, tasks = autotvm.task.extract_from_program(mod, target=target,
params=params, params=params,
ops=(relay.op.nn.conv2d_transpose,)) ops=(relay.op.nn.conv2d_transpose,))
assert len(tasks) == 4 assert len(tasks) == 4
tasks = autotvm.task.extract_from_multiple_program([m['main'] for m in mod_list], params_list, tasks = autotvm.task.extract_from_multiple_program(mod_list, params_list,
target=target, target=target,
ops=(relay.op.nn.conv2d,)) ops=(relay.op.nn.conv2d,))
assert len(tasks) == 31 assert len(tasks) == 31
......
...@@ -572,4 +572,4 @@ def test_add_op_broadcast(): ...@@ -572,4 +572,4 @@ def test_add_op_broadcast():
if __name__ == "__main__": if __name__ == "__main__":
pytest.main() pytest.main([__file__])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment