Commit fc2713e5 by Wei Chen Committed by Zhi

[Relay][VM] Fix constant folding issue in VM compiler (#4077)

* [Relay][VM] Fix constant folding issue in VM compiler

1. allow pass params when compile a module
2. enhance profiler robustness

* remove dead code

* fix lint

* add get_params

* fix test

* don't pass params back

* remove get_params

* docs

* move compile function to api

* compile clashes with builtin name

* fix compilation error

* remove dead code
parent 4d875d1f
......@@ -14,33 +14,53 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name, redefined-builtin
"""
The Relay Virtual Machine profiler.
Provides extra APIs for profiling vm execution.
"""
import tvm
from . import vm, _vm
def _update_target(target):
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
def compile(mod, target=None, target_host=None, params=None):
"""
Parameters
----------
mod : relay.Module
The Relay module to build.
tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt)
else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
"or dict of str to str/tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
vm : VirtualMachineProfiler
The profile VM runtime.
"""
compiler = VMCompilerProfiler()
target = compiler.update_target(target)
target_host = compiler.update_target_host(target, target_host)
if params:
compiler.set_params(params)
tophub_context = compiler.tophub_context(target)
with tophub_context:
compiler._compile(mod, target, target_host)
return VirtualMachineProfiler(compiler._get_vm())
class VMCompilerProfiler(vm.VMCompiler):
"""Build Relay module to run on VM runtime."""
......@@ -49,36 +69,7 @@ class VMCompilerProfiler(vm.VMCompiler):
self.mod = _vm._VMCompilerProfiler()
self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"]
def compile(self, mod, target=None, target_host=None):
"""
Parameters
----------
mod : relay.Module
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
Returns
-------
vm : VirtualMachineProfiler
The profile VM runtime.
"""
target = _update_target(target)
self._compile(mod, target, target_host)
return VirtualMachineProfiler(self._get_vm())
self._set_params_func = self.mod["set_params"]
class VirtualMachineProfiler(vm.VirtualMachine):
"""Relay profile VM runtime."""
......
......@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name, redefined-builtin
"""
The Relay Virtual Machine.
......@@ -25,30 +25,11 @@ import numpy as np
import tvm
from tvm import autotvm
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm.relay import expr as _expr
from . import _vm
from . import vmobj as _obj
from .interpreter import Executor
def _update_target(target):
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt)
else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
"or dict of str to str/tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts
def _convert(arg, cargs):
if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.tensor_object(arg))
......@@ -144,40 +125,85 @@ class VirtualMachine(object):
return self.mod
def compile(mod, target=None, target_host=None, params=None):
"""
Parameters
----------
mod : relay.Module
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
vm : VirtualMachine
The VM runtime.
"""
compiler = VMCompiler()
target = compiler.update_target(target)
target_host = compiler.update_target_host(target, target_host)
if params:
compiler.set_params(params)
tophub_context = compiler.tophub_context(target)
with tophub_context:
compiler._compile(mod, target, target_host)
return VirtualMachine(compiler._get_vm())
class VMCompiler(object):
"""Build Relay module to run on VM runtime."""
def __init__(self):
self.mod = _vm._VMCompiler()
self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"]
self._set_params_func = self.mod["set_params"]
def set_params(self, params):
"""Set constant parameters for the model"""
inputs = {}
for name, param in params.items():
if isinstance(param, np.ndarray):
param = _nd.array(param)
inputs[name] = _expr.const(param)
self._set_params_func(inputs)
def update_target(self, target):
"""Update target"""
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt)
else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
"or dict of str to str/tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts
def compile(self, mod, target=None, target_host=None):
"""
Parameters
----------
mod : relay.Module
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
Returns
-------
vm : VirtualMachine
The VM runtime.
"""
target = _update_target(target)
def update_target_host(self, target, target_host):
"""Update target host"""
target_host = None if target_host == "" else target_host
if not target_host:
for device_type, tgt in target.items():
......@@ -186,19 +212,16 @@ class VMCompiler(object):
break
if not target_host:
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
target_host = tvm.target.create(target_host)
return tvm.target.create(target_host)
def tophub_context(self, target):
# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
tophub_context = autotvm.tophub.context(list(target.values()))
else:
tophub_context = autotvm.util.EmptyContext()
with tophub_context:
self._compile(mod, target, target_host)
return VirtualMachine(self._get_vm())
return tophub_context
class VMExecutor(Executor):
"""
......@@ -226,8 +249,7 @@ class VMExecutor(Executor):
self.mod = mod
self.ctx = ctx
self.target = target
compiler = VMCompiler()
self.vm = compiler.compile(mod, target)
self.vm = compile(mod, target)
self.vm.init(ctx)
def _make_executor(self, expr=None):
......
......@@ -780,23 +780,73 @@ PackedFunc VMCompiler::GetFunction(const std::string& name,
if (name == "compile") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3);
this->Compile(args[0], args[1], args[2]);
Module mod = args[0];
this->Compile(mod, args[1], args[2]);
});
} else if (name == "get_vm") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = runtime::Module(vm_);
});
} else if (name == "set_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Map<std::string, Constant> params = args[0];
for (const auto& kv : params) {
this->SetParam(kv.first, kv.second->data);
}
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
}
}
void VMCompiler::Compile(const Module& mod_ref,
void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
params_[name] = data_in;
}
relay::Function VMCompiler::BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, NodeHash, NodeEqual> repeat_var;
for (auto arg : func->params) {
const auto &name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(arg);
} else {
name_dict[name] = arg;
}
}
std::unordered_map<relay::Var, Expr, NodeHash, NodeEqual> bind_dict;
for (auto &kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = ConstantNode::make(kv.second);
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
CHECK(ret.defined())
<< "The returning type is expected to be a Relay Function."
<< "\n";
return ret;
}
void VMCompiler::Compile(Module mod,
const TargetsMap& targets,
const tvm::Target& target_host) {
CHECK_EQ(targets.size(), 1)
<< "Currently VM compiler doesn't support heterogeneous compilation";
if (params_.size()) {
auto f = BindParamsByName(mod->Lookup("main"), params_);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}
InitVM();
targets_ = targets;
......@@ -804,7 +854,7 @@ void VMCompiler::Compile(const Module& mod_ref,
// Run some optimizations first, this code should
// be moved to pass manager.
context_.module = OptimizeModule(mod_ref, targets_);
context_.module = OptimizeModule(mod, targets_);
// Populate the global map.
//
......
......@@ -100,11 +100,37 @@ class VMCompiler : public runtime::ModuleNode {
vm_ = std::make_shared<VirtualMachine>();
}
void Compile(const Module& mod_ref,
/*!
* \brief Set the parameters
*
* \param name name of parameter
* \param data_in input DLTensor
*/
void SetParam(const std::string& name, runtime::NDArray data_in);
/*!
* \brief Compile functions in a Module
*
* \param mod Relay Module
* \param targets For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
* \param target_host Host compilation target, if target is device.
*/
void Compile(Module mod,
const TargetsMap& targets,
const tvm::Target& target_host);
protected:
/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
relay::Function BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params);
Module OptimizeModule(const Module& mod, const TargetsMap& targets);
void PopulateGlobalMap();
......@@ -120,6 +146,8 @@ class VMCompiler : public runtime::ModuleNode {
VMCompilerContext context_;
/*! \brief Compiled virtual machine. */
std::shared_ptr<VirtualMachine> vm_;
/*! \brief parameters */
std::unordered_map<std::string, runtime::NDArray> params_;
};
} // namespace vm
......
......@@ -98,6 +98,11 @@ void VirtualMachineDebug::InvokePacked(Index packed_index,
Index output_size,
const std::vector<Object>& args) {
auto ctx = VirtualMachine::GetParamsContext();
// warmup
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
args);
TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
auto op_begin = std::chrono::high_resolution_clock::now();
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
args);
......
......@@ -47,15 +47,13 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
if isinstance(f, relay.Expr):
mod = relay.Module()
mod["main"] = f
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm = relay.vm.compile(mod, target)
vm.init(tvm.cpu())
return vm.invoke("main", *args)
else:
assert isinstance(f, relay.Module), "expected expression or module"
mod = f
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm = relay.vm.compile(mod, target)
vm.init(tvm.cpu())
ret = vm.invoke("main", *args)
return ret
......@@ -582,8 +580,7 @@ def test_set_params():
b = relay.var('b', shape=(6,))
y = relay.nn.bias_add(relay.nn.dense(x, w), b)
mod["main"] = relay.Function([x, w, b], y)
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, 'llvm')
vm = relay.vm.compile(mod, 'llvm')
vm.init(tvm.cpu())
x_np = np.random.uniform(size=(10, 5)).astype('float32')
......
......@@ -28,18 +28,16 @@ from tvm.relay.prelude import Prelude
from tvm.contrib import util
from tvm.relay import testing
def create_vm(f, ctx=tvm.cpu(), target="llvm"):
def create_vm(f, ctx=tvm.cpu(), target="llvm", params=None):
if isinstance(f, relay.Expr):
mod = relay.Module()
mod["main"] = f
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm = _vm.compile(mod, target=target, params=params)
vm.init(ctx)
return vm
else:
assert isinstance(f, relay.Module), "expected mod as relay.Module"
compiler = relay.vm.VMCompiler()
vm = compiler.compile(f, target)
vm = _vm.compile(f, target=target, params=params)
vm.init(ctx)
return vm
......@@ -61,7 +59,7 @@ def run_network(mod,
return result.asnumpy().astype(dtype)
def get_serialized_output(mod, data, params, target, ctx, dtype='float32'):
vm = create_vm(mod, ctx, target)
vm = create_vm(mod, ctx, target, params=params)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
deser = deserializer.Deserializer(code, lib)
......
......@@ -22,13 +22,11 @@ import pytest
from tvm import relay
from tvm.relay.testing import resnet
@pytest.mark.skip
def test_basic():
mod, params = resnet.get_workload()
compiler = relay.profiler_vm.VMCompilerProfiler()
target = 'llvm'
ctx = tvm.cpu()
vm = compiler.compile(mod, target)
vm = relay.profiler_vm.compile(mod, target)
vm.init(ctx)
vm.load_params(params)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment