Commit fc2713e5 by Wei Chen Committed by Zhi

[Relay][VM] Fix constant folding issue in VM compiler (#4077)

* [Relay][VM] Fix constant folding issue in VM compiler

1. allow pass params when compile a module
2. enhance profiler robustness

* remove dead code

* fix lint

* add get_params

* fix test

* don't pass params back

* remove get_params

* docs

* move compile function to api

* compile clashes with builtin name

* fix compilation error

* remove dead code
parent 4d875d1f
...@@ -14,33 +14,53 @@ ...@@ -14,33 +14,53 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name, redefined-builtin
""" """
The Relay Virtual Machine profiler. The Relay Virtual Machine profiler.
Provides extra APIs for profiling vm execution. Provides extra APIs for profiling vm execution.
""" """
import tvm
from . import vm, _vm from . import vm, _vm
def _update_target(target): def compile(mod, target=None, target_host=None, params=None):
target = target if target else tvm.target.current_target() """
if target is None: Parameters
raise ValueError("Target is not set in env or passed as argument.") ----------
mod : relay.Module
The Relay module to build.
tgts = {} target : str, :any:`tvm.target.Target`, or dict of str(i.e.
if isinstance(target, (str, tvm.target.Target)): device/context name) to str/tvm.target.Target, optional
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type) For heterogeneous compilation, it is a dictionary indicating context
tgts[dev_type] = tvm.target.create(target) to target mapping. For homogeneous compilation, it is a build target.
elif isinstance(target, dict):
for dev, tgt in target.items(): target_host : str or :any:`tvm.target.Target`, optional
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type) Host compilation target, if target is device.
tgts[dev_type] = tvm.target.create(tgt) When TVM compiles device specific program such as CUDA,
else: we also need host(CPU) side code to interact with the driver
raise TypeError("target is expected to be str, tvm.target.Target, " + to setup the dimensions and parameters correctly.
"or dict of str to str/tvm.target.Target, but received " + target_host is used to specify the host side codegen target.
"{}".format(type(target))) By default, llvm is used if it is enabled,
return tgts otherwise a stackvm intepreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
vm : VirtualMachineProfiler
The profile VM runtime.
"""
compiler = VMCompilerProfiler()
target = compiler.update_target(target)
target_host = compiler.update_target_host(target, target_host)
if params:
compiler.set_params(params)
tophub_context = compiler.tophub_context(target)
with tophub_context:
compiler._compile(mod, target, target_host)
return VirtualMachineProfiler(compiler._get_vm())
class VMCompilerProfiler(vm.VMCompiler): class VMCompilerProfiler(vm.VMCompiler):
"""Build Relay module to run on VM runtime.""" """Build Relay module to run on VM runtime."""
...@@ -49,36 +69,7 @@ class VMCompilerProfiler(vm.VMCompiler): ...@@ -49,36 +69,7 @@ class VMCompilerProfiler(vm.VMCompiler):
self.mod = _vm._VMCompilerProfiler() self.mod = _vm._VMCompilerProfiler()
self._compile = self.mod["compile"] self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"] self._get_vm = self.mod["get_vm"]
self._set_params_func = self.mod["set_params"]
def compile(self, mod, target=None, target_host=None):
"""
Parameters
----------
mod : relay.Module
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
Returns
-------
vm : VirtualMachineProfiler
The profile VM runtime.
"""
target = _update_target(target)
self._compile(mod, target, target_host)
return VirtualMachineProfiler(self._get_vm())
class VirtualMachineProfiler(vm.VirtualMachine): class VirtualMachineProfiler(vm.VirtualMachine):
"""Relay profile VM runtime.""" """Relay profile VM runtime."""
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name, redefined-builtin
""" """
The Relay Virtual Machine. The Relay Virtual Machine.
...@@ -25,30 +25,11 @@ import numpy as np ...@@ -25,30 +25,11 @@ import numpy as np
import tvm import tvm
from tvm import autotvm from tvm import autotvm
from tvm._ffi.runtime_ctypes import TVMByteArray from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm.relay import expr as _expr
from . import _vm from . import _vm
from . import vmobj as _obj from . import vmobj as _obj
from .interpreter import Executor from .interpreter import Executor
def _update_target(target):
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt)
else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
"or dict of str to str/tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts
def _convert(arg, cargs): def _convert(arg, cargs):
if isinstance(arg, (np.ndarray, tvm.nd.NDArray)): if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.tensor_object(arg)) cargs.append(_obj.tensor_object(arg))
...@@ -144,40 +125,85 @@ class VirtualMachine(object): ...@@ -144,40 +125,85 @@ class VirtualMachine(object):
return self.mod return self.mod
def compile(mod, target=None, target_host=None, params=None):
"""
Parameters
----------
mod : relay.Module
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
vm : VirtualMachine
The VM runtime.
"""
compiler = VMCompiler()
target = compiler.update_target(target)
target_host = compiler.update_target_host(target, target_host)
if params:
compiler.set_params(params)
tophub_context = compiler.tophub_context(target)
with tophub_context:
compiler._compile(mod, target, target_host)
return VirtualMachine(compiler._get_vm())
class VMCompiler(object): class VMCompiler(object):
"""Build Relay module to run on VM runtime.""" """Build Relay module to run on VM runtime."""
def __init__(self): def __init__(self):
self.mod = _vm._VMCompiler() self.mod = _vm._VMCompiler()
self._compile = self.mod["compile"] self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"] self._get_vm = self.mod["get_vm"]
self._set_params_func = self.mod["set_params"]
def set_params(self, params):
"""Set constant parameters for the model"""
inputs = {}
for name, param in params.items():
if isinstance(param, np.ndarray):
param = _nd.array(param)
inputs[name] = _expr.const(param)
self._set_params_func(inputs)
def update_target(self, target):
"""Update target"""
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt)
else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
"or dict of str to str/tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts
def compile(self, mod, target=None, target_host=None): def update_target_host(self, target, target_host):
""" """Update target host"""
Parameters
----------
mod : relay.Module
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
Returns
-------
vm : VirtualMachine
The VM runtime.
"""
target = _update_target(target)
target_host = None if target_host == "" else target_host target_host = None if target_host == "" else target_host
if not target_host: if not target_host:
for device_type, tgt in target.items(): for device_type, tgt in target.items():
...@@ -186,19 +212,16 @@ class VMCompiler(object): ...@@ -186,19 +212,16 @@ class VMCompiler(object):
break break
if not target_host: if not target_host:
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm" target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
target_host = tvm.target.create(target_host) return tvm.target.create(target_host)
def tophub_context(self, target):
# If current dispatch context is fallback context (the default root context), # If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub # then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
tophub_context = autotvm.tophub.context(list(target.values())) tophub_context = autotvm.tophub.context(list(target.values()))
else: else:
tophub_context = autotvm.util.EmptyContext() tophub_context = autotvm.util.EmptyContext()
return tophub_context
with tophub_context:
self._compile(mod, target, target_host)
return VirtualMachine(self._get_vm())
class VMExecutor(Executor): class VMExecutor(Executor):
""" """
...@@ -226,8 +249,7 @@ class VMExecutor(Executor): ...@@ -226,8 +249,7 @@ class VMExecutor(Executor):
self.mod = mod self.mod = mod
self.ctx = ctx self.ctx = ctx
self.target = target self.target = target
compiler = VMCompiler() self.vm = compile(mod, target)
self.vm = compiler.compile(mod, target)
self.vm.init(ctx) self.vm.init(ctx)
def _make_executor(self, expr=None): def _make_executor(self, expr=None):
......
...@@ -780,23 +780,73 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, ...@@ -780,23 +780,73 @@ PackedFunc VMCompiler::GetFunction(const std::string& name,
if (name == "compile") { if (name == "compile") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3); CHECK_EQ(args.num_args, 3);
this->Compile(args[0], args[1], args[2]); Module mod = args[0];
this->Compile(mod, args[1], args[2]);
}); });
} else if (name == "get_vm") { } else if (name == "get_vm") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = runtime::Module(vm_); *rv = runtime::Module(vm_);
}); });
} else if (name == "set_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Map<std::string, Constant> params = args[0];
for (const auto& kv : params) {
this->SetParam(kv.first, kv.second->data);
}
});
} else { } else {
LOG(FATAL) << "Unknown packed function: " << name; LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
} }
} }
void VMCompiler::Compile(const Module& mod_ref, void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
params_[name] = data_in;
}
relay::Function VMCompiler::BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, NodeHash, NodeEqual> repeat_var;
for (auto arg : func->params) {
const auto &name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(arg);
} else {
name_dict[name] = arg;
}
}
std::unordered_map<relay::Var, Expr, NodeHash, NodeEqual> bind_dict;
for (auto &kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = ConstantNode::make(kv.second);
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
CHECK(ret.defined())
<< "The returning type is expected to be a Relay Function."
<< "\n";
return ret;
}
void VMCompiler::Compile(Module mod,
const TargetsMap& targets, const TargetsMap& targets,
const tvm::Target& target_host) { const tvm::Target& target_host) {
CHECK_EQ(targets.size(), 1) CHECK_EQ(targets.size(), 1)
<< "Currently VM compiler doesn't support heterogeneous compilation"; << "Currently VM compiler doesn't support heterogeneous compilation";
if (params_.size()) {
auto f = BindParamsByName(mod->Lookup("main"), params_);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}
InitVM(); InitVM();
targets_ = targets; targets_ = targets;
...@@ -804,7 +854,7 @@ void VMCompiler::Compile(const Module& mod_ref, ...@@ -804,7 +854,7 @@ void VMCompiler::Compile(const Module& mod_ref,
// Run some optimizations first, this code should // Run some optimizations first, this code should
// be moved to pass manager. // be moved to pass manager.
context_.module = OptimizeModule(mod_ref, targets_); context_.module = OptimizeModule(mod, targets_);
// Populate the global map. // Populate the global map.
// //
......
...@@ -100,11 +100,37 @@ class VMCompiler : public runtime::ModuleNode { ...@@ -100,11 +100,37 @@ class VMCompiler : public runtime::ModuleNode {
vm_ = std::make_shared<VirtualMachine>(); vm_ = std::make_shared<VirtualMachine>();
} }
void Compile(const Module& mod_ref, /*!
* \brief Set the parameters
*
* \param name name of parameter
* \param data_in input DLTensor
*/
void SetParam(const std::string& name, runtime::NDArray data_in);
/*!
* \brief Compile functions in a Module
*
* \param mod Relay Module
* \param targets For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
* \param target_host Host compilation target, if target is device.
*/
void Compile(Module mod,
const TargetsMap& targets, const TargetsMap& targets,
const tvm::Target& target_host); const tvm::Target& target_host);
protected: protected:
/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
relay::Function BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params);
Module OptimizeModule(const Module& mod, const TargetsMap& targets); Module OptimizeModule(const Module& mod, const TargetsMap& targets);
void PopulateGlobalMap(); void PopulateGlobalMap();
...@@ -120,6 +146,8 @@ class VMCompiler : public runtime::ModuleNode { ...@@ -120,6 +146,8 @@ class VMCompiler : public runtime::ModuleNode {
VMCompilerContext context_; VMCompilerContext context_;
/*! \brief Compiled virtual machine. */ /*! \brief Compiled virtual machine. */
std::shared_ptr<VirtualMachine> vm_; std::shared_ptr<VirtualMachine> vm_;
/*! \brief parameters */
std::unordered_map<std::string, runtime::NDArray> params_;
}; };
} // namespace vm } // namespace vm
......
...@@ -98,6 +98,11 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, ...@@ -98,6 +98,11 @@ void VirtualMachineDebug::InvokePacked(Index packed_index,
Index output_size, Index output_size,
const std::vector<Object>& args) { const std::vector<Object>& args) {
auto ctx = VirtualMachine::GetParamsContext(); auto ctx = VirtualMachine::GetParamsContext();
// warmup
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
args);
TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
auto op_begin = std::chrono::high_resolution_clock::now(); auto op_begin = std::chrono::high_resolution_clock::now();
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
args); args);
......
...@@ -47,15 +47,13 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): ...@@ -47,15 +47,13 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
if isinstance(f, relay.Expr): if isinstance(f, relay.Expr):
mod = relay.Module() mod = relay.Module()
mod["main"] = f mod["main"] = f
compiler = relay.vm.VMCompiler() vm = relay.vm.compile(mod, target)
vm = compiler.compile(mod, target)
vm.init(tvm.cpu()) vm.init(tvm.cpu())
return vm.invoke("main", *args) return vm.invoke("main", *args)
else: else:
assert isinstance(f, relay.Module), "expected expression or module" assert isinstance(f, relay.Module), "expected expression or module"
mod = f mod = f
compiler = relay.vm.VMCompiler() vm = relay.vm.compile(mod, target)
vm = compiler.compile(mod, target)
vm.init(tvm.cpu()) vm.init(tvm.cpu())
ret = vm.invoke("main", *args) ret = vm.invoke("main", *args)
return ret return ret
...@@ -582,8 +580,7 @@ def test_set_params(): ...@@ -582,8 +580,7 @@ def test_set_params():
b = relay.var('b', shape=(6,)) b = relay.var('b', shape=(6,))
y = relay.nn.bias_add(relay.nn.dense(x, w), b) y = relay.nn.bias_add(relay.nn.dense(x, w), b)
mod["main"] = relay.Function([x, w, b], y) mod["main"] = relay.Function([x, w, b], y)
compiler = relay.vm.VMCompiler() vm = relay.vm.compile(mod, 'llvm')
vm = compiler.compile(mod, 'llvm')
vm.init(tvm.cpu()) vm.init(tvm.cpu())
x_np = np.random.uniform(size=(10, 5)).astype('float32') x_np = np.random.uniform(size=(10, 5)).astype('float32')
......
...@@ -28,18 +28,16 @@ from tvm.relay.prelude import Prelude ...@@ -28,18 +28,16 @@ from tvm.relay.prelude import Prelude
from tvm.contrib import util from tvm.contrib import util
from tvm.relay import testing from tvm.relay import testing
def create_vm(f, ctx=tvm.cpu(), target="llvm"): def create_vm(f, ctx=tvm.cpu(), target="llvm", params=None):
if isinstance(f, relay.Expr): if isinstance(f, relay.Expr):
mod = relay.Module() mod = relay.Module()
mod["main"] = f mod["main"] = f
compiler = relay.vm.VMCompiler() vm = _vm.compile(mod, target=target, params=params)
vm = compiler.compile(mod, target)
vm.init(ctx) vm.init(ctx)
return vm return vm
else: else:
assert isinstance(f, relay.Module), "expected mod as relay.Module" assert isinstance(f, relay.Module), "expected mod as relay.Module"
compiler = relay.vm.VMCompiler() vm = _vm.compile(f, target=target, params=params)
vm = compiler.compile(f, target)
vm.init(ctx) vm.init(ctx)
return vm return vm
...@@ -61,7 +59,7 @@ def run_network(mod, ...@@ -61,7 +59,7 @@ def run_network(mod,
return result.asnumpy().astype(dtype) return result.asnumpy().astype(dtype)
def get_serialized_output(mod, data, params, target, ctx, dtype='float32'): def get_serialized_output(mod, data, params, target, ctx, dtype='float32'):
vm = create_vm(mod, ctx, target) vm = create_vm(mod, ctx, target, params=params)
ser = serializer.Serializer(vm) ser = serializer.Serializer(vm)
code, lib = ser.serialize() code, lib = ser.serialize()
deser = deserializer.Deserializer(code, lib) deser = deserializer.Deserializer(code, lib)
......
...@@ -22,13 +22,11 @@ import pytest ...@@ -22,13 +22,11 @@ import pytest
from tvm import relay from tvm import relay
from tvm.relay.testing import resnet from tvm.relay.testing import resnet
@pytest.mark.skip
def test_basic(): def test_basic():
mod, params = resnet.get_workload() mod, params = resnet.get_workload()
compiler = relay.profiler_vm.VMCompilerProfiler()
target = 'llvm' target = 'llvm'
ctx = tvm.cpu() ctx = tvm.cpu()
vm = compiler.compile(mod, target) vm = relay.profiler_vm.compile(mod, target)
vm.init(ctx) vm.init(ctx)
vm.load_params(params) vm.load_params(params)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment