Commit e69bd128 by Zhi Committed by Yao Wang

[relay][tensor_array] test tensor_array in vm (#4608)

* [relay] test tensor_array in vm

* add tensor_array scatter test
parent 2b916975
......@@ -26,6 +26,7 @@ import tvm
from tvm import autotvm
from tvm.relay import expr as _expr
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
from . import _vm
from . import vmobj as _obj
from .interpreter import Executor
......@@ -34,7 +35,9 @@ Tensor = _obj.Tensor
ADT = _obj.ADT
def _convert(arg, cargs):
if isinstance(arg, _obj.Object):
if isinstance(arg, _expr.Constant):
cargs.append(_obj.Tensor(arg.data))
elif isinstance(arg, _obj.Object):
cargs.append(arg)
elif isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.Tensor(arg))
......@@ -43,8 +46,12 @@ def _convert(arg, cargs):
for field in arg:
_convert(field, field_args)
cargs.append(_obj.tuple_object(field_args))
elif isinstance(arg, (_base.numeric_types, bool)):
dtype = "int32" if isinstance(arg, (int, bool)) else "float32"
value = _obj.Tensor(np.array(arg, dtype=dtype))
cargs.append(value)
else:
raise "Unsupported type: %s" % (type(arg))
raise TypeError("Unsupported type: %s" % (type(arg)))
def convert(args):
......
......@@ -33,11 +33,11 @@ class TensorArrayOps(object):
self.dtype = dtype
def get_name(self, canonical):
"""Get name corresponding to the caninical name"""
"""Get name corresponding to the canonical name"""
return self.prelude.get_name(canonical, self.dtype)
def get_var(self, canonical):
"""Get var corresponding to the caninical name"""
"""Get var corresponding to the canonical name"""
return self.prelude.get_var(canonical, self.dtype)
def define_tensor_adt(self):
......
......@@ -31,17 +31,13 @@
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <tvm/relay/attrs/memory.h>
#include <topi/tags.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <set>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "../../../runtime/vm/naive_allocator.h"
#include "../../backend/compile_engine.h"
#include "../../pass/pass_util.h"
#include "../../op/op_common.h"
......@@ -73,8 +69,6 @@ using namespace relay::transform;
// (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func);
void InstructionPrint(std::ostream& os, const Instruction& instr);
// Represent a runtime object that's going to be matched by pattern match expressions
struct MatchValue {
virtual ~MatchValue() {}
......@@ -156,12 +150,10 @@ TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data,
if (pattern.as<PatternWildcardNode>()) {
// We ignore wildcard binding since it's not producing new vars
return then_branch;
} else if (pattern.as<PatternVarNode>()) {
auto pat = pattern.as<PatternVarNode>();
auto pattern = GetRef<PatternVar>(pat);
auto cond = std::make_shared<VarBinding>(pattern->var, data);
} else if (const auto* pvn = pattern.as<PatternVarNode>()) {
auto cond = std::make_shared<VarBinding>(pvn->var, data);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else if (auto pcn = pattern.as<PatternConstructorNode>()) {
} else if (const auto* pcn = pattern.as<PatternConstructorNode>()) {
auto tag = pcn->constructor->tag;
size_t field_index = 0;
......@@ -173,13 +165,12 @@ TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data,
auto cond = std::make_shared<TagCompare>(data, tag);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else {
auto pt = pattern.as<PatternTupleNode>();
CHECK(pt) << "unhandled case: " << pattern;
const auto* pt = pattern.as<PatternTupleNode>();
CHECK(pt) << "unhandled case: " << AsText(pattern, false);
size_t field_index = 0;
for (auto& p : pt->patterns) {
auto d = std::make_shared<AccessField>(data, field_index);
auto d = std::make_shared<AccessField>(data, field_index++);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
}
return then_branch;
}
......@@ -675,16 +666,13 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
}
void CompileTreeNode(TreeObjectPtr tree) {
if (std::dynamic_pointer_cast<TreeLeafNode>(tree)) {
auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree);
if (auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree)) {
VisitExpr(node->body);
} else if (std::dynamic_pointer_cast<TreeLeafFatalNode>(tree)) {
Emit(Instruction::Fatal());
} else if (std::dynamic_pointer_cast<TreeBranchNode>(tree)) {
auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree);
if (std::dynamic_pointer_cast<TagCompare>(node->cond)) {
} else if (auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree)) {
if (auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond)) {
// For Tag compariton, generate branches
auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond);
auto r = CompileMatchValue(cond->obj);
Emit(Instruction::GetTag(r, NewRegister()));
auto operand1 = last_register_;
......@@ -707,8 +695,8 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
instructions_[goto_offset].pc_offset = else_offset - goto_offset + 1;
} else {
// For other non-branch conditions, move to then_branch directly
auto cond = std::dynamic_pointer_cast<VarBinding>(node->cond);
var_register_map_[cond->var] = CompileMatchValue(cond->val);
auto var_bind = std::dynamic_pointer_cast<VarBinding>(node->cond);
var_register_map_[var_bind->var] = CompileMatchValue(var_bind->val);
CompileTreeNode(node->then_branch);
}
}
......
......@@ -583,9 +583,9 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
break;
}
case Opcode::AllocStorage: {
os << "alloc_storage " <<
instr.dst << " " <<
instr.alloc_storage.allocation_size << " " <<
os << "alloc_storage $" <<
instr.dst << " $" <<
instr.alloc_storage.allocation_size << " $" <<
instr.alloc_storage.alignment << " " <<
TVMType2String(instr.alloc_storage.dtype_hint);
break;
......@@ -771,12 +771,14 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
for (size_t fi = 0; fi < dt_cell->size; ++fi) {
auto obj = (*dt_cell)[fi];
const auto* tensor = obj.as<TensorObj>();
CHECK(tensor != nullptr);
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< obj->GetTypeKey();
setter(idx++, tensor->data);
}
} else {
const auto* tensor = args[i].as<TensorObj>();
CHECK(tensor != nullptr);
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< args[i]->GetTypeKey();
setter(idx++, tensor->data);
}
}
......@@ -823,7 +825,8 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
int32_t result;
const auto& obj = ReadRegister(r);
const auto* tensor = obj.as<TensorObj>();
CHECK(tensor != nullptr);
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< obj->GetTypeKey();
NDArray array = tensor->data.CopyTo({kDLCPU, 0});
if (array->dtype.bits <= 8) {
......@@ -984,7 +987,8 @@ void VirtualMachine::RunLoop() {
cpu_ctx.device_id = 0;
auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register);
const auto* tensor = shape_tensor_obj.as<TensorObj>();
CHECK(tensor != nullptr);
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< shape_tensor_obj->GetTypeKey();
NDArray shape_tensor = tensor->data.CopyTo(cpu_ctx);
const DLTensor* dl_tensor = shape_tensor.operator->();
CHECK_EQ(dl_tensor->dtype.code, 0u);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment