Commit 5357f49b by Wei Chen Committed by Haichen Shen

[Relay][VM] Support execution on devices (#3678)

* [Relay][VM] Support execution on devices

* Reduce Copy calls

* Cleanup

* Lint

* CR comments

* Merge test into test_vm.py
parent a279dd0e
......@@ -451,7 +451,10 @@ class VirtualMachine : public runtime::ModuleNode {
* \param contexts The set of TVM contexts.
*/
void Init(const std::vector<TVMContext>& contexts);
void Run();
/*! \brief Run VM dispatch loop.
*/
void RunLoop();
/*!
* \brief Load parameters from the parameter bytearray.
......@@ -475,6 +478,10 @@ class VirtualMachine : public runtime::ModuleNode {
*/
void InvokeGlobal(const VMFunction& func, const std::vector<Object>& args);
/*! \brief Get device context for params.
*/
TVMContext GetParamsContext() const;
/*! \brief The parameter name to data mapping. */
std::unordered_map<std::string, Object> params_;
};
......
......@@ -212,7 +212,6 @@ class VMExecutor(Executor):
self.vm.init(ctx)
def _make_executor(self, expr=None):
assert expr is None
main = self.mod["main"]
def _vm_wrapper(*args, **kwargs):
......
......@@ -27,6 +27,7 @@
#include <tvm/logging.h>
#include <tvm/runtime/vm.h>
#include <algorithm>
#include <chrono>
#include <iostream>
#include <sstream>
......@@ -569,20 +570,36 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) {
return os;
}
Object CopyTo(Object src, const DLContext& ctx) {
if (src->tag == ObjectTag::kTensor) {
auto tensor = ToNDArray(src);
if (tensor->ctx.device_type != ctx.device_type) {
auto copy = tensor.CopyTo(ctx);
return Object::Tensor(copy);
} else {
return src;
}
} else {
return src;
}
}
PackedFunc VirtualMachine::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
auto ctx = this->GetParamsContext();
std::vector<Object> func_args;
for (int i = 1; i < args.size(); ++i) {
Object obj = args[i];
Object obj = CopyTo(args[i], ctx);
func_args.push_back(obj);
}
auto it = std::find_if(functions.begin(), functions.end(),
[func_name](const VMFunction& func) {
return func.name == func_name;
});
CHECK(it != functions.end()) << "Cannot find function " << func_name << "\n";
CHECK_EQ(func_args.size() + params_.size(), it->params.size())
<< "The number of provided parameters doesn't match the number of arguments"
......@@ -621,6 +638,18 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
}
}
TVMContext VirtualMachine::GetParamsContext() const {
// Use the fallback device if no device index is available.
int fallback_device_type = static_cast<int>(ctxs[0].device_type);
// TODO(wweic): For heterogeneous execution, get device information from byte
const auto& cit =
std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) {
return fallback_device_type == static_cast<int>(c.device_type);
});
return (cit == ctxs.end() ? ctxs[0] : *cit);
}
void VirtualMachine::LoadParams(const std::string& params) {
dmlc::MemoryStringStream mss(const_cast<std::string*>(&params));
dmlc::Stream* strm = &mss;
......@@ -637,11 +666,13 @@ void VirtualMachine::LoadParams(const std::string& params) {
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size()) << "Invalid parameter file";
auto ctx = GetParamsContext();
for (size_t i = 0; i < size; i++) {
NDArray arr;
CHECK(arr.Load(strm)) << "Invalid parameter file";
runtime::Object obj = runtime::Object::Tensor(arr);
params_.emplace(std::make_pair(names[i], obj));
auto copy = CopyTo(obj, ctx);
params_.emplace(std::make_pair(names[i], copy));
}
}
......@@ -678,7 +709,7 @@ Object VirtualMachine::Invoke(const VMFunction& func, const std::vector<Object>&
DLOG(INFO) << "Executing Function: " << std::endl << func;
InvokeGlobal(func, args);
Run();
RunLoop();
auto alloc = MemoryManager::Global()->GetAllocator(ctxs[0]);
DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B";
return return_register;
......@@ -762,7 +793,7 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
return result;
}
void VirtualMachine::Run() {
void VirtualMachine::RunLoop() {
CHECK(this->code);
this->pc = 0;
Index frame_start = frames.size();
......@@ -786,7 +817,9 @@ void VirtualMachine::Run() {
throw std::runtime_error("VM encountered fatal error");
}
case Opcode::LoadConst: {
WriteRegister(instr.dst, this->constants[instr.const_index]);
auto constant_obj = this->constants[instr.const_index];
auto device_obj = CopyTo(constant_obj, ctxs[0]);
WriteRegister(instr.dst, device_obj);
pc++;
goto main_loop;
}
......
......@@ -21,8 +21,28 @@ import tvm
import numpy as np
from tvm import relay
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.testing.config import ctx_list
from tvm.relay.prelude import Prelude
def check_result(args, expected_result, mod=None):
"""
Check that evaluating `expr` applied to the arguments produces
`result` on Relay VM.
Parameters
----------
args: list of Expr
The arguments to supply the expr.
expected_result:
The expected result of running the expression.
"""
for target, ctx in ctx_list():
vm = relay.create_executor('vm', ctx=ctx, target=target, mod=mod)
rts_result = vm.evaluate()(*args)
tvm.testing.assert_allclose(expected_result, rts_result.asnumpy())
def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
if isinstance(f, relay.Expr):
mod = relay.Module()
......@@ -30,14 +50,14 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm.init(tvm.cpu())
return vm.run(*args)
return vm.invoke("main", *args)
else:
assert isinstance(f, relay.Module), "expected expression or module"
mod = f
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm.init(tvm.cpu())
ret = vm.run(*args)
ret = vm.invoke("main", *args)
return ret
def vmobj_to_list(o):
......@@ -76,15 +96,17 @@ def test_id():
x = relay.var('x', shape=(10, 10), dtype='float64')
f = relay.Function([x], x)
x_data = np.random.rand(10, 10).astype('float64')
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data)
mod = relay.Module()
mod["main"] = f
check_result([x_data], x_data, mod=mod)
def test_op():
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x + x)
x_data = np.random.rand(10, 10).astype('float32')
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data)
mod = relay.Module()
mod["main"] = f
check_result([x_data], 2 * x_data, mod=mod)
def any(x):
x = relay.op.nn.batch_flatten(x)
......@@ -92,20 +114,19 @@ def any(x):
def test_cond():
x = relay.var('x', shape=(10, 10))
y = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(10, 10))
# f = relay.Function([x, y], relay.op.equal(x, y))
f = relay.Function([x, y], any(relay.op.equal(x, y)))
x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(10, 10).astype('float32')
mod = relay.Module()
mod["main"] = f
# same
res = veval(f, x_data, x_data)
np.testing.assert_allclose(res.asnumpy(), True)
check_result([x_data, x_data], True, mod=mod)
# diff
res = veval(f, x_data, y_data)
tvm.testing.assert_allclose(res.asnumpy(), False)
check_result([x_data, y_data], False, mod=mod)
def test_simple_if():
x = relay.var('x', shape=(10, 10))
......@@ -115,13 +136,13 @@ def test_simple_if():
x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(10, 10).astype('float32')
mod = relay.Module()
mod["main"] = f
# same
res = veval(f, x_data, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data)
check_result([x_data, x_data], x_data, mod=mod)
# diff
res = veval(f, x_data, y_data)
tvm.testing.assert_allclose(res.asnumpy(), y_data)
check_result([x_data, y_data], y_data, mod=mod)
def test_simple_call():
mod = relay.module.Module({})
......@@ -132,10 +153,9 @@ def test_simple_call():
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
mod[sum_up] = func
i_data = np.array(0, dtype='int32')
iarg = relay.var('i', shape=[], dtype='int32')
iarg = relay.var('iarg', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg], sum_up(iarg))
result = veval(mod, i_data)
tvm.testing.assert_allclose(result.asnumpy(), i_data)
check_result([i_data], i_data, mod=mod)
def test_count_loop():
mod = relay.module.Module({})
......@@ -155,6 +175,7 @@ def test_count_loop():
mod["main"] = relay.Function([iarg], sum_up(iarg))
result = veval(mod, i_data)
tvm.testing.assert_allclose(result.asnumpy(), i_data)
check_result([i_data], i_data, mod=mod)
def test_sum_loop():
mod = relay.module.Module({})
......@@ -176,8 +197,7 @@ def test_sum_loop():
iarg = relay.var('i', shape=[], dtype='int32')
aarg = relay.var('accum', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg))
result = veval(mod, i_data, accum_data)
tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1)))
check_result([i_data, accum_data], sum(range(1, loop_bound + 1)), mod=mod)
def test_tuple_fst():
ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
......@@ -185,8 +205,9 @@ def test_tuple_fst():
f = relay.Function([tup], relay.TupleGetItem(tup, 0))
i_data = np.random.rand(41).astype('float32')
j_data = np.random.rand(10).astype('float32')
result = veval(f, (i_data, j_data))
tvm.testing.assert_allclose(result.asnumpy(), i_data)
mod = relay.Module()
mod["main"] = f
check_result([(i_data, j_data)], i_data, mod=mod)
def test_tuple_second():
ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
......@@ -194,8 +215,9 @@ def test_tuple_second():
f = relay.Function([tup], relay.TupleGetItem(tup, 1))
i_data = np.random.rand(41).astype('float32')
j_data = np.random.rand(10).astype('float32')
result = veval(f, (i_data, j_data))
tvm.testing.assert_allclose(result.asnumpy(), j_data)
mod = relay.Module()
mod["main"] = f
check_result([(i_data, j_data)], j_data, mod=mod)
def test_list_constructor():
mod = relay.Module()
......@@ -233,8 +255,9 @@ def test_let_tensor():
f = relay.Function([x], body)
x_data = np.random.rand(*shape).astype('float32')
result = veval(f, x_data)
tvm.testing.assert_allclose(result.asnumpy(), x_data + 42.0)
mod = relay.Module()
mod["main"] = f
check_result([x_data], x_data + 42.0, mod=mod)
def test_let_scalar():
sb = relay.ScopeBuilder()
......@@ -248,8 +271,9 @@ def test_let_scalar():
f = relay.Function([x], body)
x_data = np.array(np.random.rand()).astype('float32')
result = veval(f, x_data)
tvm.testing.assert_allclose(result.asnumpy(), x_data + 42.0)
mod = relay.Module()
mod["main"] = f
check_result([x_data], x_data + 42.0, mod=mod)
def test_compose():
mod = relay.Module()
......@@ -281,8 +305,7 @@ def test_compose():
mod["main"] = f
x_data = np.array(np.random.rand()).astype('float32')
result = veval(mod, x_data)
result = veval(mod, [x_data])
tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
def test_list_hd():
......@@ -504,6 +527,54 @@ def test_closure():
res = veval(main)
tvm.testing.assert_allclose(res.asnumpy(), 3.0)
def test_add_op_scalar():
"""
test_add_op_scalar:
fn (x, y) {
return x + y;
}
"""
mod = relay.Module()
x = relay.var('x', shape=())
y = relay.var('y', shape=())
func = relay.Function([x, y], relay.op.add(x, y))
x_data = np.array(10.0, dtype='float32')
y_data = np.array(1.0, dtype='float32')
mod["main"] = func
check_result([x_data, y_data], x_data + y_data, mod=mod)
def test_add_op_tensor():
"""
test_add_op_tensor:
fn (x, y) {
return x + y;
}
"""
mod = relay.Module()
x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(10, 5))
func = relay.Function([x, y], relay.op.add(x, y))
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(10, 5).astype('float32')
mod["main"] = func
check_result([x_data, y_data], x_data + y_data, mod=mod)
def test_add_op_broadcast():
"""
test_add_op_broadcast:
fn (x, y) {
return x + y;
}
"""
mod = relay.Module()
x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(1, 5))
func = relay.Function([x, y], relay.op.add(x, y))
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(1, 5).astype('float32')
mod["main"] = func
check_result([x_data, y_data], x_data + y_data, mod=mod)
if __name__ == "__main__":
test_id()
test_op()
......@@ -534,3 +605,6 @@ if __name__ == "__main__":
test_list_sum()
test_list_filter()
test_closure()
test_add_op_scalar()
test_add_op_tensor()
test_add_op_broadcast()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment