Commit e69bd128 by Zhi Committed by Yao Wang

[relay][tensor_array] test tensor_array in vm (#4608)

* [relay] test tensor_array in vm

* add tensor_array scatter test
parent 2b916975
...@@ -26,6 +26,7 @@ import tvm ...@@ -26,6 +26,7 @@ import tvm
from tvm import autotvm from tvm import autotvm
from tvm.relay import expr as _expr from tvm.relay import expr as _expr
from tvm._ffi.runtime_ctypes import TVMByteArray from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
from . import _vm from . import _vm
from . import vmobj as _obj from . import vmobj as _obj
from .interpreter import Executor from .interpreter import Executor
...@@ -34,7 +35,9 @@ Tensor = _obj.Tensor ...@@ -34,7 +35,9 @@ Tensor = _obj.Tensor
ADT = _obj.ADT ADT = _obj.ADT
def _convert(arg, cargs): def _convert(arg, cargs):
if isinstance(arg, _obj.Object): if isinstance(arg, _expr.Constant):
cargs.append(_obj.Tensor(arg.data))
elif isinstance(arg, _obj.Object):
cargs.append(arg) cargs.append(arg)
elif isinstance(arg, (np.ndarray, tvm.nd.NDArray)): elif isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.Tensor(arg)) cargs.append(_obj.Tensor(arg))
...@@ -43,8 +46,12 @@ def _convert(arg, cargs): ...@@ -43,8 +46,12 @@ def _convert(arg, cargs):
for field in arg: for field in arg:
_convert(field, field_args) _convert(field, field_args)
cargs.append(_obj.tuple_object(field_args)) cargs.append(_obj.tuple_object(field_args))
elif isinstance(arg, (_base.numeric_types, bool)):
dtype = "int32" if isinstance(arg, (int, bool)) else "float32"
value = _obj.Tensor(np.array(arg, dtype=dtype))
cargs.append(value)
else: else:
raise "Unsupported type: %s" % (type(arg)) raise TypeError("Unsupported type: %s" % (type(arg)))
def convert(args): def convert(args):
......
...@@ -33,11 +33,11 @@ class TensorArrayOps(object): ...@@ -33,11 +33,11 @@ class TensorArrayOps(object):
self.dtype = dtype self.dtype = dtype
def get_name(self, canonical): def get_name(self, canonical):
"""Get name corresponding to the caninical name""" """Get name corresponding to the canonical name"""
return self.prelude.get_name(canonical, self.dtype) return self.prelude.get_name(canonical, self.dtype)
def get_var(self, canonical): def get_var(self, canonical):
"""Get var corresponding to the caninical name""" """Get var corresponding to the canonical name"""
return self.prelude.get_var(canonical, self.dtype) return self.prelude.get_var(canonical, self.dtype)
def define_tensor_adt(self): def define_tensor_adt(self):
......
...@@ -31,17 +31,13 @@ ...@@ -31,17 +31,13 @@
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h> #include <tvm/runtime/vm.h>
#include <tvm/relay/attrs/memory.h> #include <tvm/relay/attrs/memory.h>
#include <topi/tags.h>
#include <algorithm>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <set>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "../../../runtime/vm/naive_allocator.h"
#include "../../backend/compile_engine.h" #include "../../backend/compile_engine.h"
#include "../../pass/pass_util.h" #include "../../pass/pass_util.h"
#include "../../op/op_common.h" #include "../../op/op_common.h"
...@@ -73,8 +69,6 @@ using namespace relay::transform; ...@@ -73,8 +69,6 @@ using namespace relay::transform;
// (@jroesch): VM passes, eventually declare as passes. // (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func); bool IsClosure(const Function& func);
void InstructionPrint(std::ostream& os, const Instruction& instr);
// Represent a runtime object that's going to be matched by pattern match expressions // Represent a runtime object that's going to be matched by pattern match expressions
struct MatchValue { struct MatchValue {
virtual ~MatchValue() {} virtual ~MatchValue() {}
...@@ -156,12 +150,10 @@ TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data, ...@@ -156,12 +150,10 @@ TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data,
if (pattern.as<PatternWildcardNode>()) { if (pattern.as<PatternWildcardNode>()) {
// We ignore wildcard binding since it's not producing new vars // We ignore wildcard binding since it's not producing new vars
return then_branch; return then_branch;
} else if (pattern.as<PatternVarNode>()) { } else if (const auto* pvn = pattern.as<PatternVarNode>()) {
auto pat = pattern.as<PatternVarNode>(); auto cond = std::make_shared<VarBinding>(pvn->var, data);
auto pattern = GetRef<PatternVar>(pat);
auto cond = std::make_shared<VarBinding>(pattern->var, data);
return TreeBranchNode::Make(cond, then_branch, else_branch); return TreeBranchNode::Make(cond, then_branch, else_branch);
} else if (auto pcn = pattern.as<PatternConstructorNode>()) { } else if (const auto* pcn = pattern.as<PatternConstructorNode>()) {
auto tag = pcn->constructor->tag; auto tag = pcn->constructor->tag;
size_t field_index = 0; size_t field_index = 0;
...@@ -173,13 +165,12 @@ TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data, ...@@ -173,13 +165,12 @@ TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data,
auto cond = std::make_shared<TagCompare>(data, tag); auto cond = std::make_shared<TagCompare>(data, tag);
return TreeBranchNode::Make(cond, then_branch, else_branch); return TreeBranchNode::Make(cond, then_branch, else_branch);
} else { } else {
auto pt = pattern.as<PatternTupleNode>(); const auto* pt = pattern.as<PatternTupleNode>();
CHECK(pt) << "unhandled case: " << pattern; CHECK(pt) << "unhandled case: " << AsText(pattern, false);
size_t field_index = 0; size_t field_index = 0;
for (auto& p : pt->patterns) { for (auto& p : pt->patterns) {
auto d = std::make_shared<AccessField>(data, field_index); auto d = std::make_shared<AccessField>(data, field_index++);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch); then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
} }
return then_branch; return then_branch;
} }
...@@ -633,7 +624,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -633,7 +624,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
// and emit a call to allocate the data structure. // and emit a call to allocate the data structure.
auto constructor = GetRef<Constructor>(constructor_node); auto constructor = GetRef<Constructor>(constructor_node);
Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers, Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers,
NewRegister())); NewRegister()));
} else if (auto var_node = op.as<VarNode>()) { } else if (auto var_node = op.as<VarNode>()) {
// If we are calling a variable, it must be the case that it is a closure so we // If we are calling a variable, it must be the case that it is a closure so we
// emit invoke closure here. // emit invoke closure here.
...@@ -675,16 +666,13 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -675,16 +666,13 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
} }
void CompileTreeNode(TreeObjectPtr tree) { void CompileTreeNode(TreeObjectPtr tree) {
if (std::dynamic_pointer_cast<TreeLeafNode>(tree)) { if (auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree)) {
auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree);
VisitExpr(node->body); VisitExpr(node->body);
} else if (std::dynamic_pointer_cast<TreeLeafFatalNode>(tree)) { } else if (std::dynamic_pointer_cast<TreeLeafFatalNode>(tree)) {
Emit(Instruction::Fatal()); Emit(Instruction::Fatal());
} else if (std::dynamic_pointer_cast<TreeBranchNode>(tree)) { } else if (auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree)) {
auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree); if (auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond)) {
if (std::dynamic_pointer_cast<TagCompare>(node->cond)) {
// For Tag compariton, generate branches // For Tag compariton, generate branches
auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond);
auto r = CompileMatchValue(cond->obj); auto r = CompileMatchValue(cond->obj);
Emit(Instruction::GetTag(r, NewRegister())); Emit(Instruction::GetTag(r, NewRegister()));
auto operand1 = last_register_; auto operand1 = last_register_;
...@@ -707,8 +695,8 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -707,8 +695,8 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
instructions_[goto_offset].pc_offset = else_offset - goto_offset + 1; instructions_[goto_offset].pc_offset = else_offset - goto_offset + 1;
} else { } else {
// For other non-branch conditions, move to then_branch directly // For other non-branch conditions, move to then_branch directly
auto cond = std::dynamic_pointer_cast<VarBinding>(node->cond); auto var_bind = std::dynamic_pointer_cast<VarBinding>(node->cond);
var_register_map_[cond->var] = CompileMatchValue(cond->val); var_register_map_[var_bind->var] = CompileMatchValue(var_bind->val);
CompileTreeNode(node->then_branch); CompileTreeNode(node->then_branch);
} }
} }
......
...@@ -583,9 +583,9 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { ...@@ -583,9 +583,9 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
break; break;
} }
case Opcode::AllocStorage: { case Opcode::AllocStorage: {
os << "alloc_storage " << os << "alloc_storage $" <<
instr.dst << " " << instr.dst << " $" <<
instr.alloc_storage.allocation_size << " " << instr.alloc_storage.allocation_size << " $" <<
instr.alloc_storage.alignment << " " << instr.alloc_storage.alignment << " " <<
TVMType2String(instr.alloc_storage.dtype_hint); TVMType2String(instr.alloc_storage.dtype_hint);
break; break;
...@@ -771,12 +771,14 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, ...@@ -771,12 +771,14 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
for (size_t fi = 0; fi < dt_cell->size; ++fi) { for (size_t fi = 0; fi < dt_cell->size; ++fi) {
auto obj = (*dt_cell)[fi]; auto obj = (*dt_cell)[fi];
const auto* tensor = obj.as<TensorObj>(); const auto* tensor = obj.as<TensorObj>();
CHECK(tensor != nullptr); CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< obj->GetTypeKey();
setter(idx++, tensor->data); setter(idx++, tensor->data);
} }
} else { } else {
const auto* tensor = args[i].as<TensorObj>(); const auto* tensor = args[i].as<TensorObj>();
CHECK(tensor != nullptr); CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< args[i]->GetTypeKey();
setter(idx++, tensor->data); setter(idx++, tensor->data);
} }
} }
...@@ -823,7 +825,8 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const { ...@@ -823,7 +825,8 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
int32_t result; int32_t result;
const auto& obj = ReadRegister(r); const auto& obj = ReadRegister(r);
const auto* tensor = obj.as<TensorObj>(); const auto* tensor = obj.as<TensorObj>();
CHECK(tensor != nullptr); CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< obj->GetTypeKey();
NDArray array = tensor->data.CopyTo({kDLCPU, 0}); NDArray array = tensor->data.CopyTo({kDLCPU, 0});
if (array->dtype.bits <= 8) { if (array->dtype.bits <= 8) {
...@@ -984,7 +987,8 @@ void VirtualMachine::RunLoop() { ...@@ -984,7 +987,8 @@ void VirtualMachine::RunLoop() {
cpu_ctx.device_id = 0; cpu_ctx.device_id = 0;
auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register);
const auto* tensor = shape_tensor_obj.as<TensorObj>(); const auto* tensor = shape_tensor_obj.as<TensorObj>();
CHECK(tensor != nullptr); CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< shape_tensor_obj->GetTypeKey();
NDArray shape_tensor = tensor->data.CopyTo(cpu_ctx); NDArray shape_tensor = tensor->data.CopyTo(cpu_ctx);
const DLTensor* dl_tensor = shape_tensor.operator->(); const DLTensor* dl_tensor = shape_tensor.operator->();
CHECK_EQ(dl_tensor->dtype.code, 0u); CHECK_EQ(dl_tensor->dtype.code, 0u);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment