Commit fc2713e5 by Wei Chen Committed by Zhi

[Relay][VM] Fix constant folding issue in VM compiler (#4077)

* [Relay][VM] Fix constant folding issue in VM compiler

1. allow pass params when compile a module
2. enhance profiler robustness

* remove dead code

* fix lint

* add get_params

* fix test

* don't pass params back

* remove get_params

* docs

* move compile function to api

* compile clashes with builtin name

* fix compilation error

* remove dead code
parent 4d875d1f
......@@ -14,43 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name, redefined-builtin
"""
The Relay Virtual Machine profiler.
Provides extra APIs for profiling vm execution.
"""
import tvm
from . import vm, _vm
def _update_target(target):
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt)
else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
"or dict of str to str/tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts
class VMCompilerProfiler(vm.VMCompiler):
"""Build Relay module to run on VM runtime."""
def __init__(self):
super().__init__()
self.mod = _vm._VMCompilerProfiler()
self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"]
def compile(self, mod, target=None, target_host=None):
def compile(mod, target=None, target_host=None, params=None):
"""
Parameters
----------
......@@ -71,14 +43,33 @@ class VMCompilerProfiler(vm.VMCompiler):
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
vm : VirtualMachineProfiler
The profile VM runtime.
"""
target = _update_target(target)
self._compile(mod, target, target_host)
return VirtualMachineProfiler(self._get_vm())
compiler = VMCompilerProfiler()
target = compiler.update_target(target)
target_host = compiler.update_target_host(target, target_host)
if params:
compiler.set_params(params)
tophub_context = compiler.tophub_context(target)
with tophub_context:
compiler._compile(mod, target, target_host)
return VirtualMachineProfiler(compiler._get_vm())
class VMCompilerProfiler(vm.VMCompiler):
"""Build Relay module to run on VM runtime."""
def __init__(self):
super().__init__()
self.mod = _vm._VMCompilerProfiler()
self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"]
self._set_params_func = self.mod["set_params"]
class VirtualMachineProfiler(vm.VirtualMachine):
"""Relay profile VM runtime."""
......
......@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name, redefined-builtin
"""
The Relay Virtual Machine.
......@@ -25,30 +25,11 @@ import numpy as np
import tvm
from tvm import autotvm
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm.relay import expr as _expr
from . import _vm
from . import vmobj as _obj
from .interpreter import Executor
def _update_target(target):
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt)
else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
"or dict of str to str/tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts
def _convert(arg, cargs):
if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.tensor_object(arg))
......@@ -144,14 +125,7 @@ class VirtualMachine(object):
return self.mod
class VMCompiler(object):
"""Build Relay module to run on VM runtime."""
def __init__(self):
self.mod = _vm._VMCompiler()
self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"]
def compile(self, mod, target=None, target_host=None):
def compile(mod, target=None, target_host=None, params=None):
"""
Parameters
----------
......@@ -172,12 +146,64 @@ class VMCompiler(object):
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
vm : VirtualMachine
The VM runtime.
"""
target = _update_target(target)
compiler = VMCompiler()
target = compiler.update_target(target)
target_host = compiler.update_target_host(target, target_host)
if params:
compiler.set_params(params)
tophub_context = compiler.tophub_context(target)
with tophub_context:
compiler._compile(mod, target, target_host)
return VirtualMachine(compiler._get_vm())
class VMCompiler(object):
"""Build Relay module to run on VM runtime."""
def __init__(self):
self.mod = _vm._VMCompiler()
self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"]
self._set_params_func = self.mod["set_params"]
def set_params(self, params):
"""Set constant parameters for the model"""
inputs = {}
for name, param in params.items():
if isinstance(param, np.ndarray):
param = _nd.array(param)
inputs[name] = _expr.const(param)
self._set_params_func(inputs)
def update_target(self, target):
"""Update target"""
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt)
else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
"or dict of str to str/tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts
def update_target_host(self, target, target_host):
"""Update target host"""
target_host = None if target_host == "" else target_host
if not target_host:
for device_type, tgt in target.items():
......@@ -186,19 +212,16 @@ class VMCompiler(object):
break
if not target_host:
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
target_host = tvm.target.create(target_host)
return tvm.target.create(target_host)
def tophub_context(self, target):
# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
tophub_context = autotvm.tophub.context(list(target.values()))
else:
tophub_context = autotvm.util.EmptyContext()
with tophub_context:
self._compile(mod, target, target_host)
return VirtualMachine(self._get_vm())
return tophub_context
class VMExecutor(Executor):
"""
......@@ -226,8 +249,7 @@ class VMExecutor(Executor):
self.mod = mod
self.ctx = ctx
self.target = target
compiler = VMCompiler()
self.vm = compiler.compile(mod, target)
self.vm = compile(mod, target)
self.vm.init(ctx)
def _make_executor(self, expr=None):
......
......@@ -780,23 +780,73 @@ PackedFunc VMCompiler::GetFunction(const std::string& name,
if (name == "compile") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3);
this->Compile(args[0], args[1], args[2]);
Module mod = args[0];
this->Compile(mod, args[1], args[2]);
});
} else if (name == "get_vm") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = runtime::Module(vm_);
});
} else if (name == "set_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Map<std::string, Constant> params = args[0];
for (const auto& kv : params) {
this->SetParam(kv.first, kv.second->data);
}
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
}
}
void VMCompiler::Compile(const Module& mod_ref,
void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
params_[name] = data_in;
}
relay::Function VMCompiler::BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, NodeHash, NodeEqual> repeat_var;
for (auto arg : func->params) {
const auto &name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(arg);
} else {
name_dict[name] = arg;
}
}
std::unordered_map<relay::Var, Expr, NodeHash, NodeEqual> bind_dict;
for (auto &kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = ConstantNode::make(kv.second);
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
CHECK(ret.defined())
<< "The returning type is expected to be a Relay Function."
<< "\n";
return ret;
}
void VMCompiler::Compile(Module mod,
const TargetsMap& targets,
const tvm::Target& target_host) {
CHECK_EQ(targets.size(), 1)
<< "Currently VM compiler doesn't support heterogeneous compilation";
if (params_.size()) {
auto f = BindParamsByName(mod->Lookup("main"), params_);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}
InitVM();
targets_ = targets;
......@@ -804,7 +854,7 @@ void VMCompiler::Compile(const Module& mod_ref,
// Run some optimizations first, this code should
// be moved to pass manager.
context_.module = OptimizeModule(mod_ref, targets_);
context_.module = OptimizeModule(mod, targets_);
// Populate the global map.
//
......
......@@ -100,11 +100,37 @@ class VMCompiler : public runtime::ModuleNode {
vm_ = std::make_shared<VirtualMachine>();
}
void Compile(const Module& mod_ref,
/*!
* \brief Set the parameters
*
* \param name name of parameter
* \param data_in input DLTensor
*/
void SetParam(const std::string& name, runtime::NDArray data_in);
/*!
* \brief Compile functions in a Module
*
* \param mod Relay Module
* \param targets For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
* \param target_host Host compilation target, if target is device.
*/
void Compile(Module mod,
const TargetsMap& targets,
const tvm::Target& target_host);
protected:
/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
relay::Function BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params);
Module OptimizeModule(const Module& mod, const TargetsMap& targets);
void PopulateGlobalMap();
......@@ -120,6 +146,8 @@ class VMCompiler : public runtime::ModuleNode {
VMCompilerContext context_;
/*! \brief Compiled virtual machine. */
std::shared_ptr<VirtualMachine> vm_;
/*! \brief parameters */
std::unordered_map<std::string, runtime::NDArray> params_;
};
} // namespace vm
......
......@@ -98,6 +98,11 @@ void VirtualMachineDebug::InvokePacked(Index packed_index,
Index output_size,
const std::vector<Object>& args) {
auto ctx = VirtualMachine::GetParamsContext();
// warmup
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
args);
TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
auto op_begin = std::chrono::high_resolution_clock::now();
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
args);
......
......@@ -47,15 +47,13 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
if isinstance(f, relay.Expr):
mod = relay.Module()
mod["main"] = f
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm = relay.vm.compile(mod, target)
vm.init(tvm.cpu())
return vm.invoke("main", *args)
else:
assert isinstance(f, relay.Module), "expected expression or module"
mod = f
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm = relay.vm.compile(mod, target)
vm.init(tvm.cpu())
ret = vm.invoke("main", *args)
return ret
......@@ -582,8 +580,7 @@ def test_set_params():
b = relay.var('b', shape=(6,))
y = relay.nn.bias_add(relay.nn.dense(x, w), b)
mod["main"] = relay.Function([x, w, b], y)
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, 'llvm')
vm = relay.vm.compile(mod, 'llvm')
vm.init(tvm.cpu())
x_np = np.random.uniform(size=(10, 5)).astype('float32')
......
......@@ -28,18 +28,16 @@ from tvm.relay.prelude import Prelude
from tvm.contrib import util
from tvm.relay import testing
def create_vm(f, ctx=tvm.cpu(), target="llvm"):
def create_vm(f, ctx=tvm.cpu(), target="llvm", params=None):
if isinstance(f, relay.Expr):
mod = relay.Module()
mod["main"] = f
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm = _vm.compile(mod, target=target, params=params)
vm.init(ctx)
return vm
else:
assert isinstance(f, relay.Module), "expected mod as relay.Module"
compiler = relay.vm.VMCompiler()
vm = compiler.compile(f, target)
vm = _vm.compile(f, target=target, params=params)
vm.init(ctx)
return vm
......@@ -61,7 +59,7 @@ def run_network(mod,
return result.asnumpy().astype(dtype)
def get_serialized_output(mod, data, params, target, ctx, dtype='float32'):
vm = create_vm(mod, ctx, target)
vm = create_vm(mod, ctx, target, params=params)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
deser = deserializer.Deserializer(code, lib)
......
......@@ -22,13 +22,11 @@ import pytest
from tvm import relay
from tvm.relay.testing import resnet
@pytest.mark.skip
def test_basic():
mod, params = resnet.get_workload()
compiler = relay.profiler_vm.VMCompilerProfiler()
target = 'llvm'
ctx = tvm.cpu()
vm = compiler.compile(mod, target)
vm = relay.profiler_vm.compile(mod, target)
vm.init(ctx)
vm.load_params(params)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment