Commit 5357f49b by Wei Chen Committed by Haichen Shen

[Relay][VM] Support execution on devices (#3678)

* [Relay][VM] Support execution on devices

* Reduce Copy calls

* Cleanup

* Lint

* CR comments

* Merge test into test_vm.py
parent a279dd0e
...@@ -451,7 +451,10 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -451,7 +451,10 @@ class VirtualMachine : public runtime::ModuleNode {
* \param contexts The set of TVM contexts. * \param contexts The set of TVM contexts.
*/ */
void Init(const std::vector<TVMContext>& contexts); void Init(const std::vector<TVMContext>& contexts);
void Run();
/*! \brief Run VM dispatch loop.
*/
void RunLoop();
/*! /*!
* \brief Load parameters from the parameter bytearray. * \brief Load parameters from the parameter bytearray.
...@@ -475,6 +478,10 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -475,6 +478,10 @@ class VirtualMachine : public runtime::ModuleNode {
*/ */
void InvokeGlobal(const VMFunction& func, const std::vector<Object>& args); void InvokeGlobal(const VMFunction& func, const std::vector<Object>& args);
/*! \brief Get device context for params.
*/
TVMContext GetParamsContext() const;
/*! \brief The parameter name to data mapping. */ /*! \brief The parameter name to data mapping. */
std::unordered_map<std::string, Object> params_; std::unordered_map<std::string, Object> params_;
}; };
......
...@@ -212,7 +212,6 @@ class VMExecutor(Executor): ...@@ -212,7 +212,6 @@ class VMExecutor(Executor):
self.vm.init(ctx) self.vm.init(ctx)
def _make_executor(self, expr=None): def _make_executor(self, expr=None):
assert expr is None
main = self.mod["main"] main = self.mod["main"]
def _vm_wrapper(*args, **kwargs): def _vm_wrapper(*args, **kwargs):
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <tvm/logging.h> #include <tvm/logging.h>
#include <tvm/runtime/vm.h> #include <tvm/runtime/vm.h>
#include <algorithm>
#include <chrono> #include <chrono>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -569,20 +570,36 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) { ...@@ -569,20 +570,36 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) {
return os; return os;
} }
Object CopyTo(Object src, const DLContext& ctx) {
if (src->tag == ObjectTag::kTensor) {
auto tensor = ToNDArray(src);
if (tensor->ctx.device_type != ctx.device_type) {
auto copy = tensor.CopyTo(ctx);
return Object::Tensor(copy);
} else {
return src;
}
} else {
return src;
}
}
PackedFunc VirtualMachine::GetFunction(const std::string& name, PackedFunc VirtualMachine::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "invoke") { if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0]; std::string func_name = args[0];
auto ctx = this->GetParamsContext();
std::vector<Object> func_args; std::vector<Object> func_args;
for (int i = 1; i < args.size(); ++i) { for (int i = 1; i < args.size(); ++i) {
Object obj = args[i]; Object obj = CopyTo(args[i], ctx);
func_args.push_back(obj); func_args.push_back(obj);
} }
auto it = std::find_if(functions.begin(), functions.end(), auto it = std::find_if(functions.begin(), functions.end(),
[func_name](const VMFunction& func) { [func_name](const VMFunction& func) {
return func.name == func_name; return func.name == func_name;
}); });
CHECK(it != functions.end()) << "Cannot find function " << func_name << "\n"; CHECK(it != functions.end()) << "Cannot find function " << func_name << "\n";
CHECK_EQ(func_args.size() + params_.size(), it->params.size()) CHECK_EQ(func_args.size() + params_.size(), it->params.size())
<< "The number of provided parameters doesn't match the number of arguments" << "The number of provided parameters doesn't match the number of arguments"
...@@ -621,6 +638,18 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, ...@@ -621,6 +638,18 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
} }
} }
TVMContext VirtualMachine::GetParamsContext() const {
// Use the fallback device if no device index is available.
int fallback_device_type = static_cast<int>(ctxs[0].device_type);
// TODO(wweic): For heterogeneous execution, get device information from byte
const auto& cit =
std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) {
return fallback_device_type == static_cast<int>(c.device_type);
});
return (cit == ctxs.end() ? ctxs[0] : *cit);
}
void VirtualMachine::LoadParams(const std::string& params) { void VirtualMachine::LoadParams(const std::string& params) {
dmlc::MemoryStringStream mss(const_cast<std::string*>(&params)); dmlc::MemoryStringStream mss(const_cast<std::string*>(&params));
dmlc::Stream* strm = &mss; dmlc::Stream* strm = &mss;
...@@ -637,11 +666,13 @@ void VirtualMachine::LoadParams(const std::string& params) { ...@@ -637,11 +666,13 @@ void VirtualMachine::LoadParams(const std::string& params) {
size_t size = static_cast<size_t>(sz); size_t size = static_cast<size_t>(sz);
CHECK(size == names.size()) << "Invalid parameter file"; CHECK(size == names.size()) << "Invalid parameter file";
auto ctx = GetParamsContext();
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
NDArray arr; NDArray arr;
CHECK(arr.Load(strm)) << "Invalid parameter file"; CHECK(arr.Load(strm)) << "Invalid parameter file";
runtime::Object obj = runtime::Object::Tensor(arr); runtime::Object obj = runtime::Object::Tensor(arr);
params_.emplace(std::make_pair(names[i], obj)); auto copy = CopyTo(obj, ctx);
params_.emplace(std::make_pair(names[i], copy));
} }
} }
...@@ -678,7 +709,7 @@ Object VirtualMachine::Invoke(const VMFunction& func, const std::vector<Object>& ...@@ -678,7 +709,7 @@ Object VirtualMachine::Invoke(const VMFunction& func, const std::vector<Object>&
DLOG(INFO) << "Executing Function: " << std::endl << func; DLOG(INFO) << "Executing Function: " << std::endl << func;
InvokeGlobal(func, args); InvokeGlobal(func, args);
Run(); RunLoop();
auto alloc = MemoryManager::Global()->GetAllocator(ctxs[0]); auto alloc = MemoryManager::Global()->GetAllocator(ctxs[0]);
DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B"; DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B";
return return_register; return return_register;
...@@ -762,7 +793,7 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const { ...@@ -762,7 +793,7 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
return result; return result;
} }
void VirtualMachine::Run() { void VirtualMachine::RunLoop() {
CHECK(this->code); CHECK(this->code);
this->pc = 0; this->pc = 0;
Index frame_start = frames.size(); Index frame_start = frames.size();
...@@ -786,7 +817,9 @@ void VirtualMachine::Run() { ...@@ -786,7 +817,9 @@ void VirtualMachine::Run() {
throw std::runtime_error("VM encountered fatal error"); throw std::runtime_error("VM encountered fatal error");
} }
case Opcode::LoadConst: { case Opcode::LoadConst: {
WriteRegister(instr.dst, this->constants[instr.const_index]); auto constant_obj = this->constants[instr.const_index];
auto device_obj = CopyTo(constant_obj, ctxs[0]);
WriteRegister(instr.dst, device_obj);
pc++; pc++;
goto main_loop; goto main_loop;
} }
......
...@@ -21,8 +21,28 @@ import tvm ...@@ -21,8 +21,28 @@ import tvm
import numpy as np import numpy as np
from tvm import relay from tvm import relay
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.testing.config import ctx_list
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
def check_result(args, expected_result, mod=None):
"""
Check that evaluating `expr` applied to the arguments produces
`result` on Relay VM.
Parameters
----------
args: list of Expr
The arguments to supply the expr.
expected_result:
The expected result of running the expression.
"""
for target, ctx in ctx_list():
vm = relay.create_executor('vm', ctx=ctx, target=target, mod=mod)
rts_result = vm.evaluate()(*args)
tvm.testing.assert_allclose(expected_result, rts_result.asnumpy())
def veval(f, *args, ctx=tvm.cpu(), target="llvm"): def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
if isinstance(f, relay.Expr): if isinstance(f, relay.Expr):
mod = relay.Module() mod = relay.Module()
...@@ -30,14 +50,14 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): ...@@ -30,14 +50,14 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
compiler = relay.vm.VMCompiler() compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target) vm = compiler.compile(mod, target)
vm.init(tvm.cpu()) vm.init(tvm.cpu())
return vm.run(*args) return vm.invoke("main", *args)
else: else:
assert isinstance(f, relay.Module), "expected expression or module" assert isinstance(f, relay.Module), "expected expression or module"
mod = f mod = f
compiler = relay.vm.VMCompiler() compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target) vm = compiler.compile(mod, target)
vm.init(tvm.cpu()) vm.init(tvm.cpu())
ret = vm.run(*args) ret = vm.invoke("main", *args)
return ret return ret
def vmobj_to_list(o): def vmobj_to_list(o):
...@@ -76,15 +96,17 @@ def test_id(): ...@@ -76,15 +96,17 @@ def test_id():
x = relay.var('x', shape=(10, 10), dtype='float64') x = relay.var('x', shape=(10, 10), dtype='float64')
f = relay.Function([x], x) f = relay.Function([x], x)
x_data = np.random.rand(10, 10).astype('float64') x_data = np.random.rand(10, 10).astype('float64')
res = veval(f, x_data) mod = relay.Module()
tvm.testing.assert_allclose(res.asnumpy(), x_data) mod["main"] = f
check_result([x_data], x_data, mod=mod)
def test_op(): def test_op():
x = relay.var('x', shape=(10, 10)) x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x + x) f = relay.Function([x], x + x)
x_data = np.random.rand(10, 10).astype('float32') x_data = np.random.rand(10, 10).astype('float32')
res = veval(f, x_data) mod = relay.Module()
tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data) mod["main"] = f
check_result([x_data], 2 * x_data, mod=mod)
def any(x): def any(x):
x = relay.op.nn.batch_flatten(x) x = relay.op.nn.batch_flatten(x)
...@@ -92,20 +114,19 @@ def any(x): ...@@ -92,20 +114,19 @@ def any(x):
def test_cond(): def test_cond():
x = relay.var('x', shape=(10, 10)) x = relay.var('x', shape=(10, 10))
y = relay.var('x', shape=(10, 10)) y = relay.var('y', shape=(10, 10))
# f = relay.Function([x, y], relay.op.equal(x, y)) # f = relay.Function([x, y], relay.op.equal(x, y))
f = relay.Function([x, y], any(relay.op.equal(x, y))) f = relay.Function([x, y], any(relay.op.equal(x, y)))
x_data = np.random.rand(10, 10).astype('float32') x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(10, 10).astype('float32') y_data = np.random.rand(10, 10).astype('float32')
mod = relay.Module()
mod["main"] = f
# same # same
res = veval(f, x_data, x_data) check_result([x_data, x_data], True, mod=mod)
np.testing.assert_allclose(res.asnumpy(), True)
# diff # diff
res = veval(f, x_data, y_data) check_result([x_data, y_data], False, mod=mod)
tvm.testing.assert_allclose(res.asnumpy(), False)
def test_simple_if(): def test_simple_if():
x = relay.var('x', shape=(10, 10)) x = relay.var('x', shape=(10, 10))
...@@ -115,13 +136,13 @@ def test_simple_if(): ...@@ -115,13 +136,13 @@ def test_simple_if():
x_data = np.random.rand(10, 10).astype('float32') x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(10, 10).astype('float32') y_data = np.random.rand(10, 10).astype('float32')
mod = relay.Module()
mod["main"] = f
# same # same
res = veval(f, x_data, x_data) check_result([x_data, x_data], x_data, mod=mod)
tvm.testing.assert_allclose(res.asnumpy(), x_data)
# diff # diff
res = veval(f, x_data, y_data) check_result([x_data, y_data], y_data, mod=mod)
tvm.testing.assert_allclose(res.asnumpy(), y_data)
def test_simple_call(): def test_simple_call():
mod = relay.module.Module({}) mod = relay.module.Module({})
...@@ -132,10 +153,9 @@ def test_simple_call(): ...@@ -132,10 +153,9 @@ def test_simple_call():
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32')) func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
mod[sum_up] = func mod[sum_up] = func
i_data = np.array(0, dtype='int32') i_data = np.array(0, dtype='int32')
iarg = relay.var('i', shape=[], dtype='int32') iarg = relay.var('iarg', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg], sum_up(iarg)) mod["main"] = relay.Function([iarg], sum_up(iarg))
result = veval(mod, i_data) check_result([i_data], i_data, mod=mod)
tvm.testing.assert_allclose(result.asnumpy(), i_data)
def test_count_loop(): def test_count_loop():
mod = relay.module.Module({}) mod = relay.module.Module({})
...@@ -155,6 +175,7 @@ def test_count_loop(): ...@@ -155,6 +175,7 @@ def test_count_loop():
mod["main"] = relay.Function([iarg], sum_up(iarg)) mod["main"] = relay.Function([iarg], sum_up(iarg))
result = veval(mod, i_data) result = veval(mod, i_data)
tvm.testing.assert_allclose(result.asnumpy(), i_data) tvm.testing.assert_allclose(result.asnumpy(), i_data)
check_result([i_data], i_data, mod=mod)
def test_sum_loop(): def test_sum_loop():
mod = relay.module.Module({}) mod = relay.module.Module({})
...@@ -176,8 +197,7 @@ def test_sum_loop(): ...@@ -176,8 +197,7 @@ def test_sum_loop():
iarg = relay.var('i', shape=[], dtype='int32') iarg = relay.var('i', shape=[], dtype='int32')
aarg = relay.var('accum', shape=[], dtype='int32') aarg = relay.var('accum', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg)) mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg))
result = veval(mod, i_data, accum_data) check_result([i_data, accum_data], sum(range(1, loop_bound + 1)), mod=mod)
tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1)))
def test_tuple_fst(): def test_tuple_fst():
ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))]) ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
...@@ -185,8 +205,9 @@ def test_tuple_fst(): ...@@ -185,8 +205,9 @@ def test_tuple_fst():
f = relay.Function([tup], relay.TupleGetItem(tup, 0)) f = relay.Function([tup], relay.TupleGetItem(tup, 0))
i_data = np.random.rand(41).astype('float32') i_data = np.random.rand(41).astype('float32')
j_data = np.random.rand(10).astype('float32') j_data = np.random.rand(10).astype('float32')
result = veval(f, (i_data, j_data)) mod = relay.Module()
tvm.testing.assert_allclose(result.asnumpy(), i_data) mod["main"] = f
check_result([(i_data, j_data)], i_data, mod=mod)
def test_tuple_second(): def test_tuple_second():
ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))]) ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
...@@ -194,8 +215,9 @@ def test_tuple_second(): ...@@ -194,8 +215,9 @@ def test_tuple_second():
f = relay.Function([tup], relay.TupleGetItem(tup, 1)) f = relay.Function([tup], relay.TupleGetItem(tup, 1))
i_data = np.random.rand(41).astype('float32') i_data = np.random.rand(41).astype('float32')
j_data = np.random.rand(10).astype('float32') j_data = np.random.rand(10).astype('float32')
result = veval(f, (i_data, j_data)) mod = relay.Module()
tvm.testing.assert_allclose(result.asnumpy(), j_data) mod["main"] = f
check_result([(i_data, j_data)], j_data, mod=mod)
def test_list_constructor(): def test_list_constructor():
mod = relay.Module() mod = relay.Module()
...@@ -233,8 +255,9 @@ def test_let_tensor(): ...@@ -233,8 +255,9 @@ def test_let_tensor():
f = relay.Function([x], body) f = relay.Function([x], body)
x_data = np.random.rand(*shape).astype('float32') x_data = np.random.rand(*shape).astype('float32')
result = veval(f, x_data) mod = relay.Module()
tvm.testing.assert_allclose(result.asnumpy(), x_data + 42.0) mod["main"] = f
check_result([x_data], x_data + 42.0, mod=mod)
def test_let_scalar(): def test_let_scalar():
sb = relay.ScopeBuilder() sb = relay.ScopeBuilder()
...@@ -248,8 +271,9 @@ def test_let_scalar(): ...@@ -248,8 +271,9 @@ def test_let_scalar():
f = relay.Function([x], body) f = relay.Function([x], body)
x_data = np.array(np.random.rand()).astype('float32') x_data = np.array(np.random.rand()).astype('float32')
result = veval(f, x_data) mod = relay.Module()
tvm.testing.assert_allclose(result.asnumpy(), x_data + 42.0) mod["main"] = f
check_result([x_data], x_data + 42.0, mod=mod)
def test_compose(): def test_compose():
mod = relay.Module() mod = relay.Module()
...@@ -281,8 +305,7 @@ def test_compose(): ...@@ -281,8 +305,7 @@ def test_compose():
mod["main"] = f mod["main"] = f
x_data = np.array(np.random.rand()).astype('float32') x_data = np.array(np.random.rand()).astype('float32')
result = veval(mod, x_data) result = veval(mod, [x_data])
tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0) tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
def test_list_hd(): def test_list_hd():
...@@ -504,6 +527,54 @@ def test_closure(): ...@@ -504,6 +527,54 @@ def test_closure():
res = veval(main) res = veval(main)
tvm.testing.assert_allclose(res.asnumpy(), 3.0) tvm.testing.assert_allclose(res.asnumpy(), 3.0)
def test_add_op_scalar():
"""
test_add_op_scalar:
fn (x, y) {
return x + y;
}
"""
mod = relay.Module()
x = relay.var('x', shape=())
y = relay.var('y', shape=())
func = relay.Function([x, y], relay.op.add(x, y))
x_data = np.array(10.0, dtype='float32')
y_data = np.array(1.0, dtype='float32')
mod["main"] = func
check_result([x_data, y_data], x_data + y_data, mod=mod)
def test_add_op_tensor():
"""
test_add_op_tensor:
fn (x, y) {
return x + y;
}
"""
mod = relay.Module()
x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(10, 5))
func = relay.Function([x, y], relay.op.add(x, y))
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(10, 5).astype('float32')
mod["main"] = func
check_result([x_data, y_data], x_data + y_data, mod=mod)
def test_add_op_broadcast():
"""
test_add_op_broadcast:
fn (x, y) {
return x + y;
}
"""
mod = relay.Module()
x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(1, 5))
func = relay.Function([x, y], relay.op.add(x, y))
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(1, 5).astype('float32')
mod["main"] = func
check_result([x_data, y_data], x_data + y_data, mod=mod)
if __name__ == "__main__": if __name__ == "__main__":
test_id() test_id()
test_op() test_op()
...@@ -534,3 +605,6 @@ if __name__ == "__main__": ...@@ -534,3 +605,6 @@ if __name__ == "__main__":
test_list_sum() test_list_sum()
test_list_filter() test_list_filter()
test_closure() test_closure()
test_add_op_scalar()
test_add_op_tensor()
test_add_op_broadcast()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment