Commit baae28b2 by Haichen Shen Committed by Zhi

[Autotvm] Use VM compile to extract autotvm tasks (#4328)

* [AutoTVM] Use vm compile in extracting task from relay

* update

* restructure vm compiler to reduce task extraction time

* x

* fix

* update doc

* udpate doc

* lint
parent 2077cd57
......@@ -28,3 +28,6 @@ tvm.relay.backend
.. automodule:: tvm.relay.backend.graph_runtime_codegen
:members:
.. automodule:: tvm.relay.backend.vm
:members:
......@@ -32,7 +32,7 @@ logger = logging.getLogger('autotvm')
# TODO(moreau89) find a more elegant way to lower for VTAs
def _lower(func,
def _lower(mod,
target,
params):
""" Helper to lower VTA properly.
......@@ -45,16 +45,16 @@ def _lower(func,
with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
import vta
with vta.build_config():
mod, _ = relay.optimize(func, target, params)
mod, _ = relay.optimize(mod, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
return grc.codegen(mod["main"])
grc.codegen(mod["main"])
# default case
mod, _ = relay.optimize(func, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
return grc.codegen(mod["main"])
compiler = relay.vm.VMCompiler()
compiler.set_params(params)
compiler.lower(mod, target=target)
def extract_from_program(func, params, ops, target, target_host=None,
def extract_from_program(mod, params, ops, target, target_host=None,
template_keys=None):
""" Extract tuning tasks from a relay program.
......@@ -62,8 +62,8 @@ def extract_from_program(func, params, ops, target, target_host=None,
Parameters
----------
func: relay.expr.Function
The func to tune
mod: relay.module.Module or relay.expr.Function
The module or function to tune
params: dict of str to numpy array
The associated parameters of the program
ops: List of relay op
......@@ -81,11 +81,11 @@ def extract_from_program(func, params, ops, target, target_host=None,
task: Array of autotvm.task.Task
collected tasks
"""
return extract_from_multiple_program([func], [params], ops, target, target_host,
template_keys=template_keys)
return extract_from_multiple_program([mod], [params], ops, target, target_host,
template_keys)
def extract_from_multiple_program(funcs, params, ops, target, target_host=None,
def extract_from_multiple_program(mods, params, ops, target, target_host=None,
template_keys=None):
""" Extract tuning tasks from multiple relay programs.
......@@ -94,8 +94,8 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None,
Parameters
----------
funcs: List of relay.expr.Function
The list of functions to tune
mods: List[relay.module.Module] or List[relay.expr.Function]
The list of modules or functions to tune
params: List of dict of str to numpy array
The associated parameters of the programs
ops: List of relay op
......@@ -145,10 +145,13 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None,
old_state = logger.disabled
logger.disabled = True
for func, param in zip(funcs, params):
for mod, param in zip(mods, params):
if isinstance(mod, relay.expr.Function):
mod = relay.Module.from_expr(mod)
assert isinstance(mod, relay.module.Module), \
"only support relay Module or Function to be tuned"
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
mod = relay.Module.from_expr(func)
build_thread = threading.Thread(target=_lower,
args=(mod, target, param))
build_thread.start()
......
......@@ -363,7 +363,8 @@ class VirtualMachine(object):
def compile(mod, target=None, target_host=None, params=None):
"""
"""Compile the module to VM executable. A helper function for VMCompiler.
Parameters
----------
mod : relay.Module
......@@ -393,26 +394,31 @@ def compile(mod, target=None, target_host=None, params=None):
The VM executable that contains both library code and bytecode.
"""
compiler = VMCompiler()
target = compiler.update_target(target)
target_host = compiler.update_target_host(target, target_host)
if params:
compiler.set_params(params)
tophub_context = compiler.tophub_context(target)
with tophub_context:
compiler._compile(mod, target, target_host)
return Executable(compiler._get_exec())
compiler.lower(mod, target, target_host)
compiler.codegen()
return compiler.get_exec()
class VMCompiler(object):
"""Build Relay module to run on VM runtime."""
"""Compiler that compiles Relay module to VM executable."""
def __init__(self):
self.mod = _vm._VMCompiler()
self._compile = self.mod["compile"]
self._lower = self.mod["lower"]
self._codegen = self.mod["codegen"]
self._get_exec = self.mod["get_executable"]
self._set_params_func = self.mod["set_params"]
def set_params(self, params):
"""Set constant parameters for the model"""
"""Set constant parameters for the model.
Parameters
----------
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
"""
inputs = {}
for name, param in params.items():
if isinstance(param, np.ndarray):
......@@ -420,8 +426,50 @@ class VMCompiler(object):
inputs[name] = _expr.const(param)
self._set_params_func(inputs)
def update_target(self, target):
"""Update target"""
def lower(self, mod, target=None, target_host=None):
"""Lower the module to VM bytecode.
Parameters
----------
mod : relay.Module
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
"""
target = self._update_target(target)
target_host = self._update_target_host(target, target_host)
tophub_context = self._tophub_context(target)
with tophub_context:
self._lower(mod, target, target_host)
def codegen(self):
"""Generate the kernel library."""
self._codegen()
def get_exec(self):
"""Get the VM executable.
Returns
-------
exec : Executable
The VM executable that contains both library code and bytecode.
"""
return Executable(self._get_exec())
def _update_target(self, target):
"""Update target."""
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
......@@ -439,8 +487,8 @@ class VMCompiler(object):
"{}".format(type(target)))
return tgts
def update_target_host(self, target, target_host):
"""Update target host"""
def _update_target_host(self, target, target_host):
"""Update target host."""
target_host = None if target_host == "" else target_host
if not target_host:
for device_type, tgt in target.items():
......@@ -449,9 +497,12 @@ class VMCompiler(object):
break
if not target_host:
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
return tvm.target.create(target_host)
if isinstance(target_host, str):
target_host = tvm.target.create(target_host)
return target_host
def tophub_context(self, target):
def _tophub_context(self, target):
"""Get the autotvm context."""
# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
......
......@@ -743,11 +743,16 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
PackedFunc VMCompiler::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
if (name == "compile") {
if (name == "lower") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3);
Module mod = args[0];
this->Compile(mod, args[1], args[2]);
this->Lower(mod, args[1], args[2]);
});
} else if (name == "codegen") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 0);
this->Codegen();
});
} else if (name == "get_executable") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
......@@ -802,9 +807,9 @@ relay::Function VMCompiler::BindParamsByName(
return ret;
}
void VMCompiler::Compile(Module mod,
const TargetsMap& targets,
const tvm::Target& target_host) {
void VMCompiler::Lower(Module mod,
const TargetsMap& targets,
const tvm::Target& target_host) {
CHECK_EQ(targets.size(), 1)
<< "Currently VM compiler doesn't support heterogeneous compilation";
if (params_.size()) {
......@@ -813,7 +818,7 @@ void VMCompiler::Compile(Module mod,
mod->Add(gvar, f);
}
InitVM();
exec_ = make_object<Executable>();
targets_ = targets;
target_host_ = target_host;
......@@ -852,11 +857,20 @@ void VMCompiler::Compile(Module mod,
exec_->constants.push_back(vm::Tensor(data));
}
LibraryCodegen();
// update global function map
for (auto gv : context_.global_map) {
exec_->global_map.insert({gv.first->name_hint, gv.second});
}
// update primitive function map
size_t primitive_index = 0;
for (const auto& cfunc : context_.cached_funcs) {
if (cfunc->target->str() == "ext_dev") {
exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
} else {
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
}
}
}
Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
......@@ -942,7 +956,11 @@ void VMCompiler::PopulateGlobalMap() {
}
}
void VMCompiler::LibraryCodegen() {
void VMCompiler::Codegen() {
if (!context_.module.defined()) {
LOG(WARNING) << "Did you forget to call VMCompiler::Lower?";
return;
}
auto const &cached_funcs = context_.cached_funcs;
if (cached_funcs.size() == 0) {
return;
......@@ -980,14 +998,6 @@ void VMCompiler::LibraryCodegen() {
}
}
exec_->lib = mod;
size_t primitive_index = 0;
for (auto cfunc : cached_funcs) {
if (cfunc->target->str() == "ext_dev") {
exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
} else {
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
}
}
}
runtime::Module CreateVMCompiler() {
......
......@@ -91,10 +91,6 @@ class VMCompiler : public runtime::ModuleNode {
return "VMCompiler";
}
void InitVM() {
exec_ = make_object<Executable>();
}
/*!
* \brief Set the parameters
*
......@@ -104,16 +100,19 @@ class VMCompiler : public runtime::ModuleNode {
void SetParam(const std::string& name, runtime::NDArray data_in);
/*!
* \brief Compile functions in a Module
* \brief Lower the functions in a Module
*
* \param mod Relay Module
* \param targets For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
* \param target_host Host compilation target, if target is device.
*/
void Compile(Module mod,
const TargetsMap& targets,
const tvm::Target& target_host);
void Lower(Module mod,
const TargetsMap& targets,
const tvm::Target& target_host);
/*! \brief Generate the machine code for lowered functions. */
void Codegen();
protected:
/*!
......@@ -130,8 +129,6 @@ class VMCompiler : public runtime::ModuleNode {
void PopulateGlobalMap();
void LibraryCodegen();
protected:
/*! \brief Target devices. */
TargetsMap targets_;
......
......@@ -45,12 +45,20 @@ def test_task_extraction():
params=params,
ops=(relay.op.nn.conv2d,))
assert len(tasks) == 12
tasks = autotvm.task.extract_from_program(mod, target=target,
params=params,
ops=(relay.op.nn.conv2d,))
assert len(tasks) == 12
mod, params, _ = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params,
ops=(relay.op.nn.dense,))
assert len(tasks) == 1
tasks = autotvm.task.extract_from_program(mod, target=target,
params=params,
ops=(relay.op.nn.dense,))
assert len(tasks) == 1
mod, params, _ = get_network('resnet-18', batch_size=1)
mod_list.append(mod)
......@@ -59,22 +67,26 @@ def test_task_extraction():
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
assert len(tasks) == 13
tasks = autotvm.task.extract_from_program(mod, target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
assert len(tasks) == 13
mod, params, _ = get_network('mobilenet', batch_size=1)
mod_list.append(mod)
params_list.append(params)
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
tasks = autotvm.task.extract_from_program(mod, target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
assert len(tasks) == 20
mod, params, _ = get_network('dcgan', batch_size=1)
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
tasks = autotvm.task.extract_from_program(mod, target=target,
params=params,
ops=(relay.op.nn.conv2d_transpose,))
assert len(tasks) == 4
tasks = autotvm.task.extract_from_multiple_program([m['main'] for m in mod_list], params_list,
tasks = autotvm.task.extract_from_multiple_program(mod_list, params_list,
target=target,
ops=(relay.op.nn.conv2d,))
assert len(tasks) == 31
......
......@@ -572,4 +572,4 @@ def test_add_op_broadcast():
if __name__ == "__main__":
pytest.main()
pytest.main([__file__])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment