Commit e69bd128 by Zhi Committed by Yao Wang

[relay][tensor_array] test tensor_array in vm (#4608)

* [relay] test tensor_array in vm

* add tensor_array scatter test
parent 2b916975
......@@ -26,6 +26,7 @@ import tvm
from tvm import autotvm
from tvm.relay import expr as _expr
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
from . import _vm
from . import vmobj as _obj
from .interpreter import Executor
......@@ -34,7 +35,9 @@ Tensor = _obj.Tensor
ADT = _obj.ADT
def _convert(arg, cargs):
if isinstance(arg, _obj.Object):
if isinstance(arg, _expr.Constant):
cargs.append(_obj.Tensor(arg.data))
elif isinstance(arg, _obj.Object):
cargs.append(arg)
elif isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.Tensor(arg))
......@@ -43,8 +46,12 @@ def _convert(arg, cargs):
for field in arg:
_convert(field, field_args)
cargs.append(_obj.tuple_object(field_args))
elif isinstance(arg, (_base.numeric_types, bool)):
dtype = "int32" if isinstance(arg, (int, bool)) else "float32"
value = _obj.Tensor(np.array(arg, dtype=dtype))
cargs.append(value)
else:
raise "Unsupported type: %s" % (type(arg))
raise TypeError("Unsupported type: %s" % (type(arg)))
def convert(args):
......
......@@ -33,11 +33,11 @@ class TensorArrayOps(object):
self.dtype = dtype
def get_name(self, canonical):
"""Get name corresponding to the caninical name"""
"""Get name corresponding to the canonical name"""
return self.prelude.get_name(canonical, self.dtype)
def get_var(self, canonical):
"""Get var corresponding to the caninical name"""
"""Get var corresponding to the canonical name"""
return self.prelude.get_var(canonical, self.dtype)
def define_tensor_adt(self):
......
......@@ -31,17 +31,13 @@
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <tvm/relay/attrs/memory.h>
#include <topi/tags.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <set>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "../../../runtime/vm/naive_allocator.h"
#include "../../backend/compile_engine.h"
#include "../../pass/pass_util.h"
#include "../../op/op_common.h"
......@@ -73,8 +69,6 @@ using namespace relay::transform;
// (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func);
void InstructionPrint(std::ostream& os, const Instruction& instr);
// Represent a runtime object that's going to be matched by pattern match expressions
struct MatchValue {
virtual ~MatchValue() {}
......@@ -156,12 +150,10 @@ TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data,
if (pattern.as<PatternWildcardNode>()) {
// We ignore wildcard binding since it's not producing new vars
return then_branch;
} else if (pattern.as<PatternVarNode>()) {
auto pat = pattern.as<PatternVarNode>();
auto pattern = GetRef<PatternVar>(pat);
auto cond = std::make_shared<VarBinding>(pattern->var, data);
} else if (const auto* pvn = pattern.as<PatternVarNode>()) {
auto cond = std::make_shared<VarBinding>(pvn->var, data);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else if (auto pcn = pattern.as<PatternConstructorNode>()) {
} else if (const auto* pcn = pattern.as<PatternConstructorNode>()) {
auto tag = pcn->constructor->tag;
size_t field_index = 0;
......@@ -173,13 +165,12 @@ TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data,
auto cond = std::make_shared<TagCompare>(data, tag);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else {
auto pt = pattern.as<PatternTupleNode>();
CHECK(pt) << "unhandled case: " << pattern;
const auto* pt = pattern.as<PatternTupleNode>();
CHECK(pt) << "unhandled case: " << AsText(pattern, false);
size_t field_index = 0;
for (auto& p : pt->patterns) {
auto d = std::make_shared<AccessField>(data, field_index);
auto d = std::make_shared<AccessField>(data, field_index++);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
}
return then_branch;
}
......@@ -633,7 +624,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
// and emit a call to allocate the data structure.
auto constructor = GetRef<Constructor>(constructor_node);
Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers,
NewRegister()));
NewRegister()));
} else if (auto var_node = op.as<VarNode>()) {
// If we are calling a variable, it must be the case that it is a closure so we
// emit invoke closure here.
......@@ -675,16 +666,13 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
}
void CompileTreeNode(TreeObjectPtr tree) {
if (std::dynamic_pointer_cast<TreeLeafNode>(tree)) {
auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree);
if (auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree)) {
VisitExpr(node->body);
} else if (std::dynamic_pointer_cast<TreeLeafFatalNode>(tree)) {
Emit(Instruction::Fatal());
} else if (std::dynamic_pointer_cast<TreeBranchNode>(tree)) {
auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree);
if (std::dynamic_pointer_cast<TagCompare>(node->cond)) {
} else if (auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree)) {
if (auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond)) {
// For Tag compariton, generate branches
auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond);
auto r = CompileMatchValue(cond->obj);
Emit(Instruction::GetTag(r, NewRegister()));
auto operand1 = last_register_;
......@@ -707,8 +695,8 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
instructions_[goto_offset].pc_offset = else_offset - goto_offset + 1;
} else {
// For other non-branch conditions, move to then_branch directly
auto cond = std::dynamic_pointer_cast<VarBinding>(node->cond);
var_register_map_[cond->var] = CompileMatchValue(cond->val);
auto var_bind = std::dynamic_pointer_cast<VarBinding>(node->cond);
var_register_map_[var_bind->var] = CompileMatchValue(var_bind->val);
CompileTreeNode(node->then_branch);
}
}
......
......@@ -583,9 +583,9 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
break;
}
case Opcode::AllocStorage: {
os << "alloc_storage " <<
instr.dst << " " <<
instr.alloc_storage.allocation_size << " " <<
os << "alloc_storage $" <<
instr.dst << " $" <<
instr.alloc_storage.allocation_size << " $" <<
instr.alloc_storage.alignment << " " <<
TVMType2String(instr.alloc_storage.dtype_hint);
break;
......@@ -771,12 +771,14 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
for (size_t fi = 0; fi < dt_cell->size; ++fi) {
auto obj = (*dt_cell)[fi];
const auto* tensor = obj.as<TensorObj>();
CHECK(tensor != nullptr);
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< obj->GetTypeKey();
setter(idx++, tensor->data);
}
} else {
const auto* tensor = args[i].as<TensorObj>();
CHECK(tensor != nullptr);
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< args[i]->GetTypeKey();
setter(idx++, tensor->data);
}
}
......@@ -823,7 +825,8 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
int32_t result;
const auto& obj = ReadRegister(r);
const auto* tensor = obj.as<TensorObj>();
CHECK(tensor != nullptr);
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< obj->GetTypeKey();
NDArray array = tensor->data.CopyTo({kDLCPU, 0});
if (array->dtype.bits <= 8) {
......@@ -984,7 +987,8 @@ void VirtualMachine::RunLoop() {
cpu_ctx.device_id = 0;
auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register);
const auto* tensor = shape_tensor_obj.as<TensorObj>();
CHECK(tensor != nullptr);
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< shape_tensor_obj->GetTypeKey();
NDArray shape_tensor = tensor->data.CopyTo(cpu_ctx);
const DLTensor* dl_tensor = shape_tensor.operator->();
CHECK_EQ(dl_tensor->dtype.code, 0u);
......
......@@ -114,6 +114,40 @@ def tree_to_dict(t):
return ret
def vmobj_to_list(o, dtype="float32"):
if isinstance(o, tvm.relay.backend.vmobj.Tensor):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.asnumpy()]
elif isinstance(o, tvm.relay.backend.vmobj.ADT):
if len(o) == 0:
tensor_nil = p.get_var("tensor_nil", dtype=dtype)
if tensor_nil.tag == o.tag:
return [0]
return []
result = []
for f in o:
result.extend(vmobj_to_list(f, dtype))
return result
elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if o.constructor.name_hint == 'Cons':
tl = vmobj_to_list(o.fields[1], dtype)
hd = vmobj_to_list(o.fields[0], dtype)
hd.extend(tl)
return hd
elif o.constructor.name_hint == 'Nil':
return []
elif 'tensor_nil' in o.constructor.name_hint:
return [0]
elif 'tensor' in o.constructor.name_hint:
return [o.fields[0].asnumpy()]
else:
raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint)
else:
raise RuntimeError("Unknown object type: %s" % type(o))
# turns a scalar-valued relay tensor value into a python number
def get_scalar(tv):
return tv.asnumpy().item()
......@@ -685,6 +719,16 @@ def test_iterate():
res = intrp.evaluate(relay.Function([], expr)())
assert count(res) == 12
def check_tensor_array(ta_mod, ref_res, *args, dtype="float32",
ta_ctx=tvm.cpu(), target="llvm", rtol=1e-5):
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=ta_mod, ctx=ta_ctx, target=target)
result = ex.evaluate()(*args)
got = vmobj_to_list(result, dtype)
tvm.testing.assert_allclose(ref_res, got, rtol=rtol, atol=rtol)
def test_tensor_expand_dims():
def run(dtype):
x = relay.var('x')
......@@ -693,16 +737,13 @@ def test_tensor_expand_dims():
expand_dims_func = p.get_var('tensor_expand_dims', dtype)
tensor1 = p.get_var('tensor1', dtype)
mod["main"] = relay.Function([x], expand_dims_func(tensor1(x)))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
x_np = np.random.uniform(size=(1,)).astype(dtype)
result = ex.evaluate()(x_np)
got = vmobj_to_list(result)
expected = [np.expand_dims(x_np, axis=0)]
tvm.testing.assert_allclose(expected, got)
x_np = np.random.uniform(size=(1,)).astype(dtype)
expected = [np.expand_dims(x_np, axis=0)]
check_tensor_array(mod, expected, x_np)
run('float32')
run('int32')
def test_tensor_array_constructor():
def run(dtype):
x = relay.var('x')
......@@ -710,15 +751,12 @@ def test_tensor_array_constructor():
p = Prelude(mod)
tensor_array = p.get_var('tensor_array', dtype)
mod["main"] = relay.Function([x], tensor_array(x))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(5)
got = vmobj_to_list(result)
expected = np.array([0, 0, 0, 0, 0])
tvm.testing.assert_allclose(expected, got)
expected = np.array([0, 0, 0, 0, 0])
check_tensor_array(mod, expected, 5, dtype=dtype)
run('float32')
run('int32')
def test_tensor_array_read():
def run(dtype):
mod = relay.Module()
......@@ -728,41 +766,32 @@ def test_tensor_array_read():
read_func = p.get_var('tensor_array_read', dtype)
tensor_array = p.get_var('tensor_array', dtype)
mod["main"] = relay.Function([l, i], read_func(tensor_array(l), i))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(10, 5)
got = vmobj_to_list(result)
expected = [0]
tvm.testing.assert_allclose(expected, got)
expected = [0]
check_tensor_array(mod, expected, *(1, 0), dtype=dtype)
check_tensor_array(mod, expected, *(5, 1), dtype=dtype)
run('float32')
run('int32')
def test_tensor_array_write():
def run(dtype):
mod = relay.Module()
p = Prelude(mod)
v1 = relay.var('v1')
v2 = relay.var('v2')
tensor_array = p.get_var('tensor_array', dtype)
init_tensor_array = tensor_array(relay.const(2))
write_func = p.get_var('tensor_array_write', dtype)
tensor1 = p.get_var('tensor1', dtype)
tensor_array1 = write_func(init_tensor_array, relay.const(0),
tensor1(v1))
tensor_array2 = write_func(tensor_array1, relay.const(1), tensor1(v2))
mod["main"] = relay.Function([v1, v2], tensor_array2)
expected = [3, 7]
check_tensor_array(mod, expected, *(3, 7), dtype=dtype)
run('float32')
run('int32')
def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vmobj.Tensor):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.asnumpy()]
elif isinstance(o, tvm.relay.backend.vmobj.ADT):
result = []
for f in o:
result.extend(vmobj_to_list(f))
return result
elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if o.constructor.name_hint == 'Cons':
tl = vmobj_to_list(o.fields[1])
hd = vmobj_to_list(o.fields[0])
hd.extend(tl)
return hd
elif o.constructor.name_hint == 'Nil':
return []
elif 'tensor_nil' in o.constructor.name_hint:
return [0]
elif 'tensor' in o.constructor.name_hint:
return [o.fields[0].asnumpy()]
else:
raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint)
else:
raise RuntimeError("Unknown object type: %s" % type(o))
def test_tensor_array_stack():
def run(dtype):
......@@ -772,24 +801,20 @@ def test_tensor_array_stack():
tensor1 = p.get_var('tensor1', dtype)
write = p.get_var('tensor_array_write', dtype)
stack = p.get_var('tensor_array_stack', dtype)
l = relay.var('l')
v = relay.var('v')
init_tensor_array = tensor_array(relay.const(3))
tensor_array1 = write(init_tensor_array, relay.const(0), tensor1(v))
tensor_array2 = write(tensor_array1, relay.const(1), tensor1(v))
tensor_array3 = write(tensor_array2, relay.const(2), tensor1(v))
tensor_array2 = write(tensor_array1, relay.const(1), tensor1(v))
tensor_array3 = write(tensor_array2, relay.const(2), tensor1(v))
tensor_array4 = stack(tensor_array3)
mod["main"] = relay.Function([v], tensor_array4)
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
t = np.random.uniform(size=(1,)).astype(dtype)
result = ex.evaluate()(t)
res = vmobj_to_list(result)
expected = [np.stack([t, t, t])]
tvm.testing.assert_allclose(expected, res)
t = np.random.uniform(size=(1,)).astype(dtype)
expected = [np.stack([t, t, t])]
check_tensor_array(mod, expected, t, dtype=dtype)
run('float32')
run('int32')
def test_tensor_array_unstack():
def run(dtype):
mod = relay.Module()
......@@ -797,15 +822,12 @@ def test_tensor_array_unstack():
unstack_tensor1 = p.get_var('tensor_array_unstack_tensor1', dtype)
v = relay.var('v')
mod["main"] = relay.Function([v], unstack_tensor1(v))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
t = np.random.uniform(size=(1,)).astype(dtype)
result = ex.evaluate()(t)
res = vmobj_to_list(result)
tvm.testing.assert_allclose(t, res)
t = np.random.uniform(size=(1,)).astype(dtype)
check_tensor_array(mod, t, t, dtype=dtype)
run('float32')
run('int32')
def test_tensor_take():
def run(dtype):
mod = relay.Module()
......@@ -816,16 +838,106 @@ def test_tensor_take():
lower = relay.var('lower')
upper = relay.var('upper')
mod["main"] = relay.Function([v, lower, upper], take(tensor2(v), lower, upper))
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
t = np.random.uniform(size=(10, 10)).astype(dtype)
result = ex.evaluate()(t, 2, 5)
res = vmobj_to_list(result)
expected = [np.take(t, range(2, 5), axis=0)]
tvm.testing.assert_allclose(expected, res)
v_data = np.random.uniform(size=(10, 10)).astype(dtype)
expected = [np.take(v_data, range(2, 5), axis=0)]
check_tensor_array(mod, expected, *(v_data, 2, 5), dtype=dtype)
expected = [np.take(v_data, range(0, 9), axis=0)]
check_tensor_array(mod, expected, *(v_data, 0, 9), dtype=dtype)
run('float32')
run('int32')
def test_tensor_concatenate():
def run(dtype):
mod = relay.Module()
p = Prelude(mod)
concat = p.get_var('tensor_concatenate', dtype)
tensor1 = p.get_var('tensor1', dtype)
v1 = relay.var('v1')
v2 = relay.var('v2')
mod["main"] = relay.Function([v1, v2], concat(tensor1(v1),
tensor1(v2)))
v1_data = np.random.uniform(size=(5,)).astype(dtype)
v2_data = np.random.uniform(size=(5,)).astype(dtype)
expected = [np.concatenate((v1_data, v2_data))]
check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype)
run('float32')
run('int32')
def test_tensor_array_concat():
def run(dtype):
mod = relay.Module()
p = Prelude(mod)
v1 = relay.var('v1')
v2 = relay.var('v2')
tensor_array = p.get_var('tensor_array', dtype)
tensor_array1 = tensor_array(relay.const(2))
write_func = p.get_var('tensor_array_write', dtype)
concat_func = p.get_var('tensor_array_concat', dtype)
tensor1 = p.get_var('tensor2', dtype)
tensor_array1 = write_func(tensor_array1, relay.const(0), tensor1(v1))
tensor_array1 = write_func(tensor_array1, relay.const(1), tensor1(v2))
tensor_array_concat = concat_func(tensor_array1)
mod["main"] = relay.Function([v1, v2], tensor_array_concat)
v1_data = np.random.uniform(size=(2, 3)).astype(dtype)
v2_data = np.random.uniform(size=(1, 3)).astype(dtype)
expected = [np.concatenate((v1_data, v2_data), axis=0)]
check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype)
run('float32')
run('int32')
def test_tensor_array_scatter():
def run(dtype):
mod = relay.Module()
p = Prelude(mod)
# tensor array
v1 = relay.var('v1')
v2 = relay.var('v2')
v3 = relay.var('v2')
tensor_array = p.get_var('tensor_array', dtype)
tensor_array1 = tensor_array(relay.const(3))
write_func = p.get_var('tensor_array_write', dtype)
scatter_func = p.get_var('tensor_array_scatter', dtype)
tensor2 = p.get_var('tensor2', dtype)
tensor_array1 = write_func(tensor_array1, relay.const(0), tensor2(v1))
tensor_array1 = write_func(tensor_array1, relay.const(1), tensor2(v2))
tensor_array1 = write_func(tensor_array1, relay.const(2), tensor2(v3))
# indices array
index = relay.var('index')
# values array
value_0 = relay.var('value_0')
value_1 = relay.var('value_1')
values_array = tensor_array(relay.const(2))
values_array = write_func(values_array, relay.const(0),
tensor2(value_0))
values_array = write_func(values_array, relay.const(1),
tensor2(value_1))
# create the scatter function
tensor_array_scatter = scatter_func(tensor_array1, index, values_array)
mod["main"] = relay.Function([v1, v2, v3, index, value_0, value_1],
tensor_array_scatter)
# initialize and check
v1_data = np.random.uniform(size=(2, 3)).astype(dtype)
v2_data = np.random.uniform(size=(2, 3)).astype(dtype)
v3_data = np.random.uniform(size=(2, 3)).astype(dtype)
index_data = np.array([0, 1], dtype="int32")
val1_data = np.random.uniform(size=(2, 3)).astype(dtype)
val2_data = np.random.uniform(size=(2, 3)).astype(dtype)
expected = [val1_data, val2_data, v3_data]
check_tensor_array(mod, expected, *(v1_data, v2_data, v3_data,
index_data, val1_data,
val2_data), dtype=dtype)
run('float32')
run('int32')
if __name__ == "__main__":
test_nat_constructor()
test_double()
......@@ -853,5 +965,10 @@ if __name__ == "__main__":
test_tensor_expand_dims()
test_tensor_array_constructor()
test_tensor_array_read()
test_tensor_array_write()
test_tensor_array_stack()
test_tensor_array_unstack()
test_tensor_take()
test_tensor_concatenate()
test_tensor_array_concat()
test_tensor_array_scatter()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment