Commit acbf8851 by Zhi Committed by Haichen Shen

[runtime][refactor] Unify vm and interpreter objects (#4693)

* unify vm and interpreter objects

* move closure back vm

* adt/closure back to vm.adt/vm.closure

* closure base
parent 2630ffcb
......@@ -103,11 +103,9 @@ def __lldb_init_module(debugger, _):
"tvm::relay::Span",
"tvm::relay::TempExpr",
"tvm::relay::TensorType",
"tvm::relay::TensorValue",
"tvm::relay::Tuple",
"tvm::relay::TupleGetItem",
"tvm::relay::TupleType",
"tvm::relay::TupleValue",
"tvm::relay::Type",
"tvm::relay::TypeCall",
"tvm::relay::TypeConstraint",
......
......@@ -25,8 +25,8 @@
* Given a Relay module, and a Relay expression it produces a value.
*
* The interpreter's values are a naive representation of the values that
* can be produced by a Relay program and are exposed via tvm::Node's
* system to Python for introspection and debugging.
* can be produced by a Relay program and are exposed via TVM's object
* protocol to Python for introspection and debugging.
*
* The interpreter's intent is to serve as a reference semantics for the Relay IR,
* as well as for debugging and testing.
......@@ -38,6 +38,8 @@
#include <tvm/ir/module.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/vm.h>
namespace tvm {
namespace relay {
......@@ -64,11 +66,8 @@ namespace relay {
runtime::TypedPackedFunc<ObjectRef(Expr)>
CreateInterpreter(IRModule mod, DLContext context, Target target);
/*! \brief A Relay closure, i.e a scope and a function. */
class Closure;
/*! \brief The container type of Closures. */
class ClosureNode : public Object {
/*! \brief The container type of Closures used by the interpreter. */
class InterpreterClosureObj : public runtime::vm::ClosureObj {
public:
/*! \brief The set of free variables in the closure.
*
......@@ -82,102 +81,69 @@ class ClosureNode : public Object {
*/
Function func;
ClosureNode() {}
InterpreterClosureObj() {}
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("env", &env);
v->Visit("func", &func);
}
TVM_DLL static Closure make(tvm::Map<Var, ObjectRef> env, Function func);
static constexpr const char* _type_key = "relay.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, Object);
static constexpr const char* _type_key = "interpreter.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterClosureObj, runtime::vm::ClosureObj);
};
class Closure : public ObjectRef {
class InterpreterClosure : public runtime::vm::Closure {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureNode);
TVM_DLL InterpreterClosure(tvm::Map<Var, ObjectRef> env, Function func);
TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::vm::Closure,
InterpreterClosureObj);
};
/*! \brief A Relay Recursive Closure. A closure that has a name. */
class RecClosure;
/*! \brief The container type of RecClosure. */
class RecClosureNode : public Object {
class RecClosureObj : public Object {
public:
/*! \brief The closure. */
Closure clos;
InterpreterClosure clos;
/*! \brief variable the closure bind to. */
Var bind;
RecClosureNode() {}
RecClosureObj() {}
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("clos", &clos);
v->Visit("bind", &bind);
}
TVM_DLL static RecClosure make(Closure clos, Var bind);
static constexpr const char* _type_key = "relay.RecClosure";
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, Object);
static constexpr const char* _type_key = "interpreter.RecClosure";
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureObj, Object);
};
class RecClosure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureNode);
};
/*! \brief A tuple value. */
class TupleValue;
/*! \brief Tuple (x, ... y). */
struct TupleValueNode : Object {
tvm::Array<ObjectRef> fields;
TupleValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }
TVM_DLL static TupleValue make(tvm::Array<ObjectRef> value);
static constexpr const char* _type_key = "relay.TupleValue";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, Object);
};
class TupleValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, ObjectRef, TupleValueNode);
TVM_DLL RecClosure(InterpreterClosure clos, Var bind);
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureObj);
};
/*! \brief A reference value. */
class RefValue;
struct RefValueNode : Object {
struct RefValueObj : Object {
mutable ObjectRef value;
RefValueNode() {}
RefValueObj() {}
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
}
TVM_DLL static RefValue make(ObjectRef val);
static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueObj, Object);
};
class RefValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueNode);
TVM_DLL RefValue(ObjectRef val);
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueObj);
};
/*! \brief An ADT constructor value. */
class ConstructorValue;
struct ConstructorValueNode : Object {
struct ConstructorValueObj : Object {
int32_t tag;
tvm::Array<ObjectRef> fields;
......@@ -191,17 +157,17 @@ struct ConstructorValueNode : Object {
v->Visit("constructor", &constructor);
}
TVM_DLL static ConstructorValue make(int32_t tag,
tvm::Array<ObjectRef> fields,
Constructor construtor = {});
static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueObj, Object);
};
class ConstructorValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueNode);
TVM_DLL ConstructorValue(int32_t tag,
tvm::Array<ObjectRef> fields,
Constructor construtor = {});
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj);
};
} // namespace relay
......
......@@ -50,10 +50,9 @@ namespace runtime {
enum TypeIndex {
/*! \brief Root object type. */
kRoot = 0,
kVMTensor = 1,
kVMClosure = 2,
kVMADT = 3,
kRuntimeModule = 4,
kClosure = 1,
kVMADT = 2,
kRuntimeModule = 3,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
......
......@@ -25,36 +25,58 @@
#define TVM_RUNTIME_VM_H_
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace tvm {
namespace runtime {
namespace vm {
/*! \brief An object representing a closure. */
/*!
* \brief An object representing a closure. This object is used by both the
* Relay VM and interpreter.
*/
class ClosureObj : public Object {
public:
/*! \brief The index into the VM function table. */
static constexpr const uint32_t _type_index = TypeIndex::kClosure;
static constexpr const char* _type_key = "Closure";
TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object);
};
/*! \brief reference to closure. */
class Closure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
};
/*!
* \brief An object representing a vm closure.
*/
class VMClosureObj : public ClosureObj {
public:
/*!
* \brief The index into the function list. The function could be any
* function object that is compatible to the VM runtime.
*/
size_t func_index;
/*! \brief The free variables of the closure. */
std::vector<ObjectRef> free_vars;
static constexpr const uint32_t _type_index = TypeIndex::kVMClosure;
static constexpr const char* _type_key = "vm.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureObj, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(VMClosureObj, ClosureObj);
};
/*! \brief reference to closure. */
class Closure : public ObjectRef {
class VMClosure : public Closure {
public:
Closure(size_t func_index, std::vector<ObjectRef> free_vars);
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
VMClosure(size_t func_index, std::vector<ObjectRef> free_vars);
TVM_DEFINE_OBJECT_REF_METHODS(VMClosure, Closure, VMClosureObj);
};
/*! \brief Magic number for NDArray list file */
......
......@@ -16,8 +16,10 @@
# under the License.
"""Container data structures used in TVM DSL."""
from __future__ import absolute_import as _abs
from ._ffi.object import Object, register_object
from tvm import ndarray as _nd
from . import _api_internal
from ._ffi.object import Object, register_object, getitem_helper
from ._ffi.function import _init_api
@register_object
class Array(Object):
......@@ -114,3 +116,56 @@ class LoweredFunc(Object):
MixedFunc = 0
HostFunc = 1
DeviceFunc = 2
@register_object("vm.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
Parameters
----------
tag : int
The tag of ADT.
fields : list[Object] or tuple[Object]
The source tuple.
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)
@property
def tag(self):
return _GetADTTag(self)
def __getitem__(self, idx):
return getitem_helper(
self, _GetADTFields, len(self), idx)
def __len__(self):
return _GetADTSize(self)
def tuple_object(fields=None):
"""Create a ADT object from source tuple.
Parameters
----------
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : ADT
The created object.
"""
fields = fields if fields else []
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)
_init_api("tvm.container")
......@@ -38,7 +38,6 @@ from . import param_dict
from . import feature
from .backend import vm
from .backend import profiler_vm
from .backend import vmobj
# Root operators
from .op import Op
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The VM Object FFI namespace."""
from tvm._ffi.function import _init_api
_init_api("_vmobj", __name__)
......@@ -20,6 +20,7 @@ from __future__ import absolute_import
import numpy as np
from tvm import container
from . import _backend
from .. import _make, analysis, transform
from .. import module
......@@ -28,40 +29,6 @@ from ..base import Object, register_relay_node
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
from ..scope_builder import ScopeBuilder
@register_relay_node
class TupleValue(Object):
"""A tuple value produced by the interpreter."""
def __init__(self, *fields):
self.__init_handle_by_constructor__(
_make.TupleValue, fields)
def __getitem__(self, field_no):
return self.fields[field_no]
def __len__(self):
return len(self.fields)
def __str__(self):
body = ','.join(str(f) for f in self.fields)
return '({0})'.format(body)
def __repr__(self):
body = ','.join(repr(f) for f in self.fields)
return '({0})'.format(body)
def __iter__(self):
return iter(self.fields)
@register_relay_node
class Closure(Object):
"""A closure produced by the interpreter."""
@register_relay_node
class RecClosure(Object):
"""A recursive closure produced by the interpreter."""
@register_relay_node
class ConstructorValue(Object):
......@@ -80,8 +47,8 @@ class RefValue(Object):
def _arg_to_ast(mod, arg):
if isinstance(arg, nd.NDArray):
return Constant(arg.copyto(nd.cpu(0)))
elif isinstance(arg, TupleValue):
return Tuple([_arg_to_ast(mod, field) for field in arg.fields])
elif isinstance(arg, container.ADT):
return Tuple([_arg_to_ast(mod, field) for field in arg])
elif isinstance(arg, tuple):
return Tuple([_arg_to_ast(mod, field) for field in arg])
elif isinstance(arg, RefValue):
......
......@@ -23,20 +23,18 @@ Implements a Python interface to compiling and executing on the Relay VM.
import numpy as np
import tvm
from tvm import autotvm
from tvm import autotvm, container
from tvm.object import Object
from tvm.relay import expr as _expr
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
from . import _vm
from . import vmobj as _obj
from .interpreter import Executor
ADT = _obj.ADT
def _convert(arg, cargs):
if isinstance(arg, _expr.Constant):
cargs.append(arg.data)
elif isinstance(arg, _obj.Object):
elif isinstance(arg, Object):
cargs.append(arg)
elif isinstance(arg, np.ndarray):
nd_arr = tvm.nd.array(arg, ctx=tvm.cpu(0))
......@@ -47,7 +45,7 @@ def _convert(arg, cargs):
field_args = []
for field in arg:
_convert(field, field_args)
cargs.append(_obj.tuple_object(field_args))
cargs.append(container.tuple_object(field_args))
elif isinstance(arg, (_base.numeric_types, bool)):
dtype = "int32" if isinstance(arg, (int, bool)) else "float32"
value = tvm.nd.array(np.array(arg, dtype=dtype), ctx=tvm.cpu(0))
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Runtime Object API."""
from __future__ import absolute_import as _abs
from tvm._ffi.object import Object, register_object, getitem_helper
from tvm import ndarray as _nd
from . import _vmobj
@register_object("vm.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
Parameters
----------
tag : int
The tag of ADT.
fields : list[Object] or tuple[Object]
The source tuple.
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or "
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(
_vmobj.ADT, tag, *fields)
@property
def tag(self):
return _vmobj.GetADTTag(self)
def __getitem__(self, idx):
return getitem_helper(
self, _vmobj.GetADTFields, len(self), idx)
def __len__(self):
return _vmobj.GetADTNumberOfFields(self)
def tuple_object(fields):
"""Create a ADT object from source tuple.
Parameters
----------
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : ADT
The created object.
"""
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm "
"NDArray type, but received : {0}".format(type(f))
return _vmobj.Tuple(*fields)
......@@ -17,9 +17,9 @@
"""Common utilities"""
from __future__ import absolute_import as _abs
import logging
import numpy as np
import tvm
import numpy as np
from topi.util import get_const_tuple
from .. import expr as _expr
from .. import module as _module
......
......@@ -17,6 +17,7 @@
#pylint: disable=invalid-name
"""Utilities for testing and benchmarks"""
from __future__ import absolute_import as _abs
import numpy as np
import tvm
import tvm.relay as relay
......@@ -24,7 +25,6 @@ import tvm.relay.op as op
from tvm.relay import transform
from tvm.relay import Function, GlobalVar, ScopeBuilder, Tuple, TupleGetItem, create_executor
from tvm.relay import TensorType, TupleType
import numpy as np
from . import mlp
from . import resnet
......
......@@ -32,18 +32,20 @@ OUTPUT_VAR_NAME = '_py_out'
# import numpy
# import tvm
# from tvm import relay
# from tvm import import container as _container
# from tvm import nd
# from tvm.relay.backend.interpreter import RefValue, TupleValue, ConstructorValue
# from tvm.relay.backend.interpreter import RefValue, ConstructorValue
PROLOGUE = [
ast.Import([alias('numpy', None)]),
ast.Import([alias('tvm', None)]),
ast.ImportFrom('tvm', [alias('relay', None)], 0),
ast.ImportFrom('tvm', [alias('nd', None)], 0),
ast.ImportFrom('tvm', [alias('container', '_container')],
0),
ast.ImportFrom('tvm.relay.backend.interpreter',
[alias('RefValue', None),
alias('TupleValue', None),
alias('ConstructorValue', None)],
0)
0),
]
class PythonConverter(ExprFunctor):
......@@ -253,7 +255,7 @@ class PythonConverter(ExprFunctor):
for i in range(len(arg_type.fields)):
ret += convert_input(
ast.Subscript(
ast.Attribute(py_input, 'fields', Load()),
py_input,
ast.Index(Num(i)), Load()),
arg_type.fields[i])
return ret
......@@ -282,7 +284,8 @@ class PythonConverter(ExprFunctor):
assignments += inner_assignments
extra_args += inner_args
fields.append(inner_output)
return (assignments, extra_args, self.create_call('TupleValue', fields))
fields = [ast.List(fields, Load())]
return (assignments, extra_args, self.create_call('_container.tuple_object', fields))
# create a function to wrap the call of the lowered op and return
# a call to that function
......@@ -444,7 +447,8 @@ class PythonConverter(ExprFunctor):
def visit_tuple(self, tup: Expr):
fields, ret_defs = self.convert_fields(tup.fields)
return (self.create_call('TupleValue', fields), ret_defs)
fields = [ast.List(fields, Load())]
return (self.create_call('_container.tuple_object', fields), ret_defs)
def visit_tuple_getitem(self, tgi: Expr):
......@@ -534,7 +538,7 @@ class PythonConverter(ExprFunctor):
thunk_name, [],
ref_defs + val_defs + [
Assign([ast.Attribute(ref, 'value', Store())], val),
Return(self.create_call('TupleValue', []))
Return(self.create_call('_container.tuple_object', []))
])
return (self.create_call(thunk_name, []), [thunk])
......
......@@ -22,6 +22,7 @@
* \brief An interpreter for the Relay IR.
*/
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/object.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/interpreter.h>
......@@ -36,100 +37,82 @@ namespace relay {
using namespace runtime;
inline const PackedFunc& GetPackedFunc(const std::string& name) {
const PackedFunc* pf = tvm::runtime::Registry::Get(name);
CHECK(pf != nullptr) << "Cannot find function " << name << " in registry";
return *pf;
}
/* Object Implementation */
Closure ClosureNode::make(tvm::Map<Var, ObjectRef> env, Function func) {
ObjectPtr<ClosureNode> n = make_object<ClosureNode>();
InterpreterClosure::InterpreterClosure(tvm::Map<Var, ObjectRef> env,
Function func) {
ObjectPtr<InterpreterClosureObj> n = make_object<InterpreterClosureObj>();
n->env = std::move(env);
n->func = std::move(func);
return Closure(n);
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("relay._make.Closure")
.set_body_typed(ClosureNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ClosureNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const ClosureNode*>(ref.get());
p->stream << "ClosureNode(" << node->func << ", " << node->env << ")";
});
.set_dispatch<InterpreterClosureObj >([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const InterpreterClosureObj*>(ref.get());
p->stream << "InterpreterClosureNode(" << node->func << ", " << node->env << ")";
});
inline const PackedFunc& GetPackedFunc(const std::string& name) {
const PackedFunc* pf = tvm::runtime::Registry::Get(name);
CHECK(pf != nullptr) << "Cannot find function " << name << " in registry";
return *pf;
}
// TODO(@jroesch): this doesn't support mutual letrec
/* Object Implementation */
RecClosure RecClosureNode::make(Closure clos, Var bind) {
ObjectPtr<RecClosureNode> n = make_object<RecClosureNode>();
RecClosure::RecClosure(InterpreterClosure clos, Var bind) {
ObjectPtr<RecClosureObj> n = make_object<RecClosureObj>();
n->clos = std::move(clos);
n->bind = std::move(bind);
return RecClosure(n);
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("relay._make.RecClosure")
.set_body_typed(RecClosureNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RecClosureNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const RecClosureNode*>(ref.get());
p->stream << "RecClosureNode(" << node->clos << ")";
.set_dispatch<RecClosureObj>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const RecClosureObj*>(ref.get());
p->stream << "RecClosureObj(" << node->clos << ")";
});
TupleValue TupleValueNode::make(tvm::Array<ObjectRef> value) {
ObjectPtr<TupleValueNode> n = make_object<TupleValueNode>();
n->fields = value;
return TupleValue(n);
}
TVM_REGISTER_GLOBAL("relay._make.TupleValue")
.set_body_typed(TupleValueNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TupleValueNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TupleValueNode*>(ref.get());
p->stream << "TupleValueNode(" << node->fields << ")";
});
RefValue RefValueNode::make(ObjectRef value) {
ObjectPtr<RefValueNode> n = make_object<RefValueNode>();
RefValue::RefValue(ObjectRef value) {
ObjectPtr<RefValueObj> n = make_object<RefValueObj>();
n->value = value;
return RefValue(n);
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("relay._make.RefValue")
.set_body_typed(RefValueNode::make);
.set_body_typed([](ObjectRef value){
return RefValue(value);
});
TVM_REGISTER_NODE_TYPE(RefValueNode);
TVM_REGISTER_NODE_TYPE(RefValueObj);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RefValueNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const RefValueNode*>(ref.get());
p->stream << "RefValueNode(" << node->value << ")";
.set_dispatch<RefValueObj>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const RefValueObj*>(ref.get());
p->stream << "RefValueObj(" << node->value << ")";
});
ConstructorValue ConstructorValueNode::make(int32_t tag,
ConstructorValue::ConstructorValue(int32_t tag,
tvm::Array<ObjectRef> fields,
Constructor constructor) {
ObjectPtr<ConstructorValueNode> n = make_object<ConstructorValueNode>();
ObjectPtr<ConstructorValueObj> n = make_object<ConstructorValueObj>();
n->tag = tag;
n->fields = fields;
n->constructor = constructor;
return ConstructorValue(n);
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("relay._make.ConstructorValue")
.set_body_typed(ConstructorValueNode::make);
.set_body_typed([](int32_t tag, tvm::Array<ObjectRef> fields,
Constructor constructor) {
return ConstructorValue(tag, fields, constructor);
});
TVM_REGISTER_NODE_TYPE(ConstructorValueNode);
TVM_REGISTER_NODE_TYPE(ConstructorValueObj);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstructorValueNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const ConstructorValueNode*>(ref.get());
p->stream << "ConstructorValueNode(" << node->tag << ","
.set_dispatch<ConstructorValueObj>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const ConstructorValueObj*>(ref.get());
p->stream << "ConstructorValueObj(" << node->tag << ","
<< node->fields << ")";
});
......@@ -187,7 +170,7 @@ struct Stack {
class InterpreterState;
/*! \brief A container capturing the state of the interpreter. */
class InterpreterStateNode : public Object {
class InterpreterStateObj : public Object {
public:
using Frame = tvm::Map<Var, ObjectRef>;
using Stack = tvm::Array<Frame>;
......@@ -206,16 +189,16 @@ class InterpreterStateNode : public Object {
static InterpreterState make(Expr current_expr, Stack stack);
static constexpr const char* _type_key = "relay.InterpreterState";
TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterStateNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterStateObj, Object);
};
class InterpreterState : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(InterpreterState, ObjectRef, InterpreterStateNode);
TVM_DEFINE_OBJECT_REF_METHODS(InterpreterState, ObjectRef, InterpreterStateObj);
};
InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) {
ObjectPtr<InterpreterStateNode> n = make_object<InterpreterStateNode>();
InterpreterState InterpreterStateObj::make(Expr current_expr, Stack stack) {
ObjectPtr<InterpreterStateObj> n = make_object<InterpreterStateObj>();
n->current_expr = std::move(current_expr);
n->stack = std::move(stack);
return InterpreterState(n);
......@@ -292,7 +275,7 @@ class Interpreter :
values.push_back(field_value);
}
return TupleValueNode::make(values);
return ADT::Tuple(values);
}
ObjectRef MakeClosure(const Function& func, Var letrec_name = Var()) {
......@@ -310,9 +293,9 @@ class Interpreter :
}
// We must use mutation here to build a self referential closure.
auto closure = ClosureNode::make(captured_mod, func);
InterpreterClosure closure(captured_mod, func);
if (letrec_name.defined()) {
return RecClosureNode::make(closure, letrec_name);
return RecClosure(closure, letrec_name);
}
return std::move(closure);
}
......@@ -374,16 +357,15 @@ class Interpreter :
fset_input(arg_counter++, arg, true);
}
} else {
const TupleValueNode* tuple = arg.as<TupleValueNode>();
CHECK(tuple != nullptr);
const ADT adt = Downcast<ADT>(arg);
if (state & kNeedInputData) {
for (size_t i = 0; i < tuple->fields.size(); ++i) {
fset_input(arg_counter++, tuple->fields[i], false);
for (size_t i = 0; i < adt.size(); ++i) {
fset_input(arg_counter++, adt[i], false);
}
}
if (state & kNeedInputShape) {
for (size_t i = 0; i < tuple->fields.size(); ++i) {
fset_input(arg_counter++, tuple->fields[i], true);
for (size_t i = 0; i < adt.size(); ++i) {
fset_input(arg_counter++, adt[i], true);
}
}
}
......@@ -458,14 +440,14 @@ class Interpreter :
}
// Marshal the arguments.
// Handle tuple input/output by flattening them.
// Handle adt input/output by flattening them.
size_t arg_len = 0;
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->IsInstance<NDArray::ContainerType>()) {
++arg_len;
} else {
const auto* tvalue = args[i].as<TupleValueNode>();
arg_len += tvalue->fields.size();
auto adt = Downcast<ADT>(args[i]);
arg_len += adt.size();
}
}
size_t num_inputs = arg_len;
......@@ -495,10 +477,9 @@ class Interpreter :
if (arg->IsInstance<NDArray::ContainerType>()) {
fset_input(arg_counter++, arg);
} else {
const TupleValueNode* tuple = arg.as<TupleValueNode>();
CHECK(tuple != nullptr);
for (size_t i = 0; i < tuple->fields.size(); ++i) {
fset_input(arg_counter++, tuple->fields[i]);
auto adt = Downcast<ADT>(arg);
for (size_t i = 0; i < adt.size(); ++i) {
fset_input(arg_counter++, adt[i]);
}
}
}
......@@ -541,7 +522,7 @@ class Interpreter :
TVMRetValue rv;
if (const TupleTypeNode* rtype = func->body->checked_type().as<TupleTypeNode>()) {
CHECK(!is_dyn || out_shapes.size() == rtype->fields.size());
Array<ObjectRef> fields;
std::vector<ObjectRef> fields;
for (size_t i = 0; i < rtype->fields.size(); ++i) {
if (is_dyn) {
auto sh = out_shapes[i];
......@@ -552,7 +533,7 @@ class Interpreter :
}
}
packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
return TupleValueNode::make(fields);
return ADT::Tuple(fields);
} else {
ObjectRef out_tensor;
if (is_dyn) {
......@@ -569,7 +550,7 @@ class Interpreter :
}
// Invoke the closure
ObjectRef Invoke(const Closure& closure,
ObjectRef Invoke(const InterpreterClosure& closure,
const tvm::Array<ObjectRef>& args,
const Var& bind = Var()) {
// Get a reference to the function inside the closure.
......@@ -594,7 +575,7 @@ class Interpreter :
}
if (bind.defined()) {
locals.Set(bind, RecClosureNode::make(closure, bind));
locals.Set(bind, RecClosure(closure, bind));
}
return WithFrame<ObjectRef>(Frame(locals), [&]() { return Eval(func->body); });
......@@ -616,14 +597,14 @@ class Interpreter :
"fusing and lowering";
}
if (auto con = call->op.as<ConstructorNode>()) {
return ConstructorValueNode::make(con->tag, args, GetRef<Constructor>(con));
return ConstructorValue(con->tag, args, GetRef<Constructor>(con));
}
// Now we just evaluate and expect to find a closure.
ObjectRef fn_val = Eval(call->op);
if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) {
auto closure = GetRef<Closure>(closure_node);
if (const InterpreterClosureObj* closure_node = fn_val.as<InterpreterClosureObj>()) {
auto closure = GetRef<InterpreterClosure>(closure_node);
return this->Invoke(closure, args);
} else if (const RecClosureNode* closure_node = fn_val.as<RecClosureNode>()) {
} else if (const RecClosureObj* closure_node = fn_val.as<RecClosureObj>()) {
return this->Invoke(closure_node->clos, args, closure_node->bind);
} else {
LOG(FATAL) << "internal error: type error, expected function value in the call "
......@@ -646,12 +627,13 @@ class Interpreter :
ObjectRef VisitExpr_(const TupleGetItemNode* op) final {
ObjectRef val = Eval(op->tuple);
auto product_node = val.as<TupleValueNode>();
CHECK(product_node)
<< "interal error: when evaluating TupleGetItem expected a tuple value";
CHECK_LT(static_cast<size_t>(op->index), product_node->fields.size())
const auto* adt_obj = val.as<ADTObj>();
CHECK(adt_obj)
<< "interal error: when evaluating TupleGetItem expected an ADT value";
auto adt = GetRef<ADT>(adt_obj);
CHECK_LT(static_cast<size_t>(op->index), adt.size())
<< "internal error: index out of bounds";
return product_node->fields[op->index];
return adt[op->index];
}
ObjectRef VisitExpr_(const IfNode* op) final {
......@@ -677,9 +659,9 @@ class Interpreter :
ObjectRef VisitExpr_(const RefWriteNode* op) final {
ObjectRef r = Eval(op->ref);
if (const RefValueNode* rv = r.as<RefValueNode>()) {
if (const RefValueObj* rv = r.as<RefValueObj>()) {
rv->value = Eval(op->value);
return TupleValueNode::make({});
return ADT::Tuple(std::vector<ObjectRef>());
} else {
LOG(FATAL) << "type error, type system should have caught this";
return ObjectRef();
......@@ -687,12 +669,12 @@ class Interpreter :
}
ObjectRef VisitExpr_(const RefCreateNode* op) final {
return RefValueNode::make(Eval(op->value));
return RefValue(Eval(op->value));
}
ObjectRef VisitExpr_(const RefReadNode* op) final {
ObjectRef r = Eval(op->ref);
if (const RefValueNode* rv = r.as<RefValueNode>()) {
if (const RefValueObj* rv = r.as<RefValueObj>()) {
return rv->value;
} else {
LOG(FATAL) << "type error, type system should have caught this";
......@@ -712,7 +694,7 @@ class Interpreter :
}
bool VisitPattern_(const PatternConstructorNode* op, const ObjectRef& v) final {
const ConstructorValueNode* cvn = v.as<ConstructorValueNode>();
const ConstructorValueObj* cvn = v.as<ConstructorValueObj>();
CHECK(cvn) << "need to be a constructor for match";
CHECK_NE(op->constructor->tag, -1);
CHECK_NE(cvn->tag, -1);
......@@ -729,11 +711,10 @@ class Interpreter :
}
bool VisitPattern_(const PatternTupleNode* op, const ObjectRef& v) final {
const TupleValueNode* tvn = v.as<TupleValueNode>();
CHECK(tvn) << "need to be a tuple for match";
CHECK_EQ(op->patterns.size(), tvn->fields.size());
auto adt = Downcast<ADT>(v);
CHECK_EQ(op->patterns.size(), adt.size());
for (size_t i = 0; i < op->patterns.size(); ++i) {
if (!VisitPattern(op->patterns[i], tvn->fields[i])) {
if (!VisitPattern(op->patterns[i], adt[i])) {
return false;
}
}
......@@ -750,12 +731,12 @@ class Interpreter :
}
InterpreterState get_state(Expr e = Expr()) const {
InterpreterStateNode::Stack stack;
InterpreterStateObj::Stack stack;
for (auto fr : this->stack_.frames) {
InterpreterStateNode::Frame frame = fr.locals;
InterpreterStateObj::Frame frame = fr.locals;
stack.push_back(frame);
}
auto state = InterpreterStateNode::make(e, stack);
auto state = InterpreterStateObj::make(e, stack);
return state;
}
......@@ -804,8 +785,5 @@ CreateInterpreter(
TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter")
.set_body_typed(CreateInterpreter);
TVM_REGISTER_NODE_TYPE(ClosureNode);
TVM_REGISTER_NODE_TYPE(TupleValueNode);
} // namespace relay
} // namespace tvm
......@@ -29,6 +29,7 @@
#include <tvm/relay/transform.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/container.h>
#include "pattern_util.h"
namespace tvm {
......@@ -187,10 +188,11 @@ class ConstantFolder : public ExprMutator {
<< "invalid dimension after constant eval";
}
return ConstantNode::make(nd_array);
} else if (const auto* val = value.as<TupleValueNode>()) {
} else if (const auto* val = value.as<runtime::ADTObj>()) {
runtime::ADT adt = GetRef<runtime::ADT>(val);
Array<Expr> fields;
for (ObjectRef field : val->fields) {
fields.push_back(ObjectToExpr(field));
for (size_t i = 0; i < adt.size(); ++i) {
fields.push_back(ObjectToExpr(adt[i]));
}
return TupleNode::make(fields);
} else {
......
......@@ -935,11 +935,12 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
if (v->IsInstance<runtime::NDArray::ContainerType>()) {
auto nd_array = Downcast<runtime::NDArray>(v);
return HasStatic(MkSTensor(nd_array), ll->Push(ConstantNode::make(nd_array)));
} else if (const TupleValueNode* op = v.as<TupleValueNode>()) {
} else if (const runtime::ADTObj* op = v.as<runtime::ADTObj>()) {
std::vector<PStatic> fields;
tvm::Array<Expr> fields_dyn;
for (const ObjectRef& field : op->fields) {
PStatic ps = Reify(field, ll);
auto adt = GetRef<runtime::ADT>(op);
for (size_t i = 0; i < adt.size(); ++i) {
PStatic ps = Reify(adt[i], ll);
fields.push_back(ps);
fields_dyn.push_back(ps->dynamic);
}
......
......@@ -18,38 +18,28 @@
*/
/*!
* \file src/runtime/vm/object.cc
* \brief VM related objects.
* \file src/runtime/container.cc
* \brief Implementations of common plain old data (POD) containers.
*/
#include <tvm/support/logging.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/vm.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h>
#include "../runtime_base.h"
namespace tvm {
namespace runtime {
namespace vm {
Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
auto ptr = make_object<ClosureObj>();
ptr->func_index = func_index;
ptr->free_vars = std::move(free_vars);
data_ = std::move(ptr);
}
using namespace vm;
TVM_REGISTER_GLOBAL("_vmobj.GetADTTag")
TVM_REGISTER_GLOBAL("container._GetADTTag")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj);
*rv = static_cast<int64_t>(adt.tag());
});
TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields")
TVM_REGISTER_GLOBAL("container._GetADTSize")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj);
......@@ -57,7 +47,7 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields")
});
TVM_REGISTER_GLOBAL("_vmobj.GetADTFields")
TVM_REGISTER_GLOBAL("container._GetADTFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
int idx = args[1];
......@@ -66,7 +56,7 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTFields")
*rv = adt[idx];
});
TVM_REGISTER_GLOBAL("_vmobj.Tuple")
TVM_REGISTER_GLOBAL("container._Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<ObjectRef> fields;
for (auto i = 0; i < args.size(); ++i) {
......@@ -75,7 +65,7 @@ TVM_REGISTER_GLOBAL("_vmobj.Tuple")
*rv = ADT::Tuple(fields);
});
TVM_REGISTER_GLOBAL("_vmobj.ADT")
TVM_REGISTER_GLOBAL("container._ADT")
.set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
......@@ -88,15 +78,6 @@ TVM_REGISTER_GLOBAL("_vmobj.ADT")
TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj);
} // namespace vm
} // namespace runtime
} // namespace tvm
using namespace tvm::runtime;
int TVMGetObjectTag(TVMObjectHandle handle, int* tag) {
API_BEGIN();
int res = static_cast<int>(static_cast<Object*>(handle)->type_index());
*tag = res;
API_END();
}
......@@ -45,6 +45,12 @@ namespace tvm {
namespace runtime {
namespace vm {
VMClosure::VMClosure(size_t func_index, std::vector<ObjectRef> free_vars) {
auto ptr = make_object<VMClosureObj>();
ptr->func_index = func_index;
ptr->free_vars = std::move(free_vars);
data_ = std::move(ptr);
}
inline Storage make_storage(size_t size, size_t alignment, DLDataType dtype_hint, TVMContext ctx) {
// We could put cache in here, from ctx to storage allocator.
......@@ -906,7 +912,7 @@ void VirtualMachine::RunLoop() {
}
case Opcode::InvokeClosure: {
auto object = ReadRegister(instr.closure);
const auto* closure = object.as<ClosureObj>();
const auto* closure = object.as<VMClosureObj>();
std::vector<ObjectRef> args;
for (auto free_var : closure->free_vars) {
......@@ -1008,7 +1014,7 @@ void VirtualMachine::RunLoop() {
for (Index i = 0; i < instr.num_freevar; i++) {
free_vars.push_back(ReadRegister(instr.free_vars[i]));
}
WriteRegister(instr.dst, Closure(instr.func_index, free_vars));
WriteRegister(instr.dst, VMClosure(instr.func_index, free_vars));
pc_++;
goto main_loop;
}
......
......@@ -62,16 +62,11 @@ tf_dtypes = {
def vmobj_to_list(o):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vmobj.ADT):
elif isinstance(o, tvm.container.ADT):
result = []
for f in o:
result.extend(vmobj_to_list(f))
return result
elif isinstance(o, tvm.relay.backend.interpreter.TupleValue):
result = []
for f in o.fields:
result.append(vmobj_to_list(f))
return result
elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if o.constructor.name_hint == 'Cons':
tl = vmobj_to_list(o.fields[1])
......
......@@ -19,10 +19,9 @@ import numpy as np
import tvm
from tvm.contrib import graph_runtime
from tvm import relay
from tvm import relay, container
from tvm.relay import testing
from tvm.relay import vm
from tvm.relay import vmobj as _obj
def benchmark_execution(mod,
......@@ -69,7 +68,7 @@ def benchmark_execution(mod,
ftimer = rly_vm.mod.time_evaluator("invoke", ctx, number=number,
repeat=repeat)
# Measure in millisecond.
prof_res = np.array(ftimer("main", _obj.Tensor(data)).results) * 1000
prof_res = np.array(ftimer("main", data).results) * 1000
print("Mean vm inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
......
......@@ -117,7 +117,7 @@ def tree_to_dict(t):
def vmobj_to_list(o, dtype="float32"):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vmobj.ADT):
elif isinstance(o, tvm.container.ADT):
if len(o) == 0:
tensor_nil = p.get_var("tensor_nil", dtype=dtype)
if tensor_nil.tag == o.tag:
......
......@@ -18,8 +18,7 @@ import numpy as np
import tvm
import tvm.testing
from tvm import nd
from tvm import relay
from tvm.relay.backend.interpreter import TupleValue
from tvm import relay, container
from tvm.relay.backend.interpreter import RefValue, ConstructorValue
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor
......@@ -39,7 +38,8 @@ def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
def test_tuple_value():
tv = TupleValue(relay.const(1), relay.const(2), relay.const(3))
tv = container.tuple_object([relay.const(1), relay.const(2),
relay.const(3)])
np.testing.assert_allclose(tv[0].data.asnumpy(), 1)
np.testing.assert_allclose(tv[1].data.asnumpy(), 2)
np.testing.assert_allclose(tv[2].data.asnumpy(), 3)
......@@ -178,7 +178,7 @@ def test_function_taking_adt_ref_tuple():
], prelude.cons)
ref_value = RefValue(nd.array(np.random.rand(1, 10).astype('float32')))
tuple_value = TupleValue(*[
tuple_value = container.tuple_object([
nd.array(np.random.rand(1, 10).astype('float32')) for _ in range(10)
])
......@@ -202,8 +202,8 @@ def test_function_taking_adt_ref_tuple():
res_tuple = id_func(tuple_value)
for i in range(10):
tvm.testing.assert_allclose(res_tuple.fields[i].asnumpy(),
tuple_value.fields[i].asnumpy())
tvm.testing.assert_allclose(res_tuple[i].asnumpy(),
tuple_value[i].asnumpy())
def test_tuple_passing():
x = relay.var('x', type_annotation=relay.ty.TupleType([
......@@ -224,7 +224,8 @@ def test_tuple_passing():
out = f((10, 8))
tvm.testing.assert_allclose(out.asnumpy(), np.array(10))
# Second use a tuple value.
value_tuple = TupleValue(nd.array(np.array(11)), nd.array(np.array(12)))
value_tuple = container.tuple_object([nd.array(np.array(11)),
nd.array(np.array(12))])
out = f(value_tuple)
tvm.testing.assert_allclose(out.asnumpy(), np.array(11))
......
......@@ -19,7 +19,8 @@ import tvm
from tvm import relay
from tvm.relay.testing import to_python, run_as_python
from tvm.relay.prelude import Prelude
from tvm.relay.backend.interpreter import TupleValue, RefValue, ConstructorValue
from tvm.container import ADT
from tvm.relay.backend.interpreter import RefValue, ConstructorValue
# helper: uses a dummy let binding to sequence a list
# of expressions: expr1; expr2; expr3, etc.
......@@ -45,10 +46,10 @@ def assert_tensor_value(candidate, val):
assert np.array_equal(candidate.asnumpy(), np.array(val))
# assert that the candidate is a TupleValue with the indicate number of fields
def assert_tuple_value(candidate, fields):
assert isinstance(candidate, TupleValue)
assert len(candidate.fields) == fields
# assert that the candidate is an ADT with the indicated number of fields
def assert_adt_len(candidate, fields):
assert isinstance(candidate, ADT)
assert len(candidate) == fields
# assert that the candidate is a ConstructorValue with the approrpaite constructor
......@@ -62,7 +63,7 @@ def assert_constructor_value(candidate, constructor, fields):
def test_create_empty_tuple():
empty = relay.Tuple([])
tup_val = run_as_python(empty)
assert_tuple_value(tup_val, 0)
assert_adt_len(tup_val, 0)
def test_create_scalar():
......@@ -87,12 +88,12 @@ def test_create_nested_tuple():
])
])
tup_val = run_as_python(relay_tup)
assert_tuple_value(tup_val, 3)
assert_adt_len(tup_val, 3)
for i in range(2):
assert_tensor_value(tup_val.fields[i], i + 1)
assert_tuple_value(tup_val.fields[2], 2)
assert_tensor_value(tup_val[i], i + 1)
assert_adt_len(tup_val[2], 2)
for i in range(2):
assert_tensor_value(tup_val.fields[2].fields[i], i + 3)
assert_tensor_value(tup_val[2][i], i + 3)
def test_tuple_get_item():
......@@ -118,23 +119,23 @@ def test_create_let():
v = relay.Var('v')
let = relay.Let(v, relay.Tuple([]), relay.Tuple([v, v]))
tup_val = run_as_python(let)
assert_tuple_value(tup_val, 2)
assert_tuple_value(tup_val.fields[0], 0)
assert_tuple_value(tup_val.fields[1], 0)
assert_adt_len(tup_val, 2)
assert_adt_len(tup_val[0], 0)
assert_adt_len(tup_val[1], 0)
def test_create_ref():
relay_ref = relay.RefCreate(relay.Tuple([]))
ref_val = run_as_python(relay_ref)
assert isinstance(ref_val, RefValue)
assert_tuple_value(ref_val.value, 0)
assert_adt_len(ref_val.value, 0)
def test_ref_read():
v = relay.Var('v')
assign = relay.Let(v, relay.RefCreate(relay.Tuple([])), relay.RefRead(v))
read_val = run_as_python(assign)
assert_tuple_value(read_val, 0)
assert_adt_len(read_val, 0)
def test_ref_write():
......@@ -143,7 +144,7 @@ def test_ref_write():
initial_write = relay.Let(v, relay.RefCreate(relay.Tuple([relay.const(1)])),
relay.RefWrite(v, relay.Tuple([relay.const(2)])))
write_val = run_as_python(initial_write)
assert_tuple_value(write_val, 0)
assert_adt_len(write_val, 0)
# now ensure that the value, once written, can be read back
# (we read the value before and after mutation)
......@@ -155,11 +156,11 @@ def test_ref_write():
seq(relay.RefWrite(v, relay.Tuple([relay.const(2)])),
relay.Tuple([relay.RefRead(w), relay.RefRead(v)]))))
read_val = run_as_python(read_after_write)
assert_tuple_value(read_val, 2)
assert_tuple_value(read_val.fields[0], 1)
assert_tuple_value(read_val.fields[1], 1)
assert_tensor_value(read_val.fields[0].fields[0], 1)
assert_tensor_value(read_val.fields[1].fields[0], 2)
assert_adt_len(read_val, 2)
assert_adt_len(read_val[0], 1)
assert_adt_len(read_val[1], 1)
assert_tensor_value(read_val[0][0], 1)
assert_tensor_value(read_val[1][0], 2)
def test_if():
......@@ -191,7 +192,7 @@ def test_local_function():
call2 = relay.Let(f, ident, f(relay.const(2)))
call_val1 = run_as_python(call1)
assert_tuple_value(call_val1, 0)
assert_adt_len(call_val1, 0)
call_val2 = run_as_python(call2)
assert_tensor_value(call_val2, 2)
......@@ -211,9 +212,9 @@ def test_global_function():
assert_tensor_value(call_val1, 1)
call_val2 = run_as_python(call2, mod)
assert_tuple_value(call_val2, 2)
assert_tensor_value(call_val2.fields[0], 2)
assert_tensor_value(call_val2.fields[1], 2)
assert_adt_len(call_val2, 2)
assert_tensor_value(call_val2[0], 2)
assert_tensor_value(call_val2[1], 2)
def test_constructor():
......@@ -230,7 +231,7 @@ def test_constructor():
box_val_tup = run_as_python(init_box_tup, mod)
assert_constructor_value(box_val_tup, box_ctor, 1)
assert_tuple_value(box_val_tup.fields[0], 0)
assert_adt_len(box_val_tup.fields[0], 0)
def test_match_wildcard():
......@@ -372,7 +373,7 @@ def test_global_recursion():
call2 = copy_def(p.cons(relay.Tuple([]), p.nil()))
val2 = run_as_python(call2, mod)
assert_constructor_value(val2, p.cons, 2)
assert_tuple_value(val2.fields[0], 0)
assert_adt_len(val2.fields[0], 0)
assert_constructor_value(val2.fields[1], p.nil, 0)
......@@ -437,10 +438,10 @@ def test_arbitrary_let_nesting():
])
tup_val = run_as_python(expr, mod)
assert_tuple_value(tup_val, 3)
assert_tensor_value(tup_val.fields[0], 2)
assert_tensor_value(tup_val.fields[1], 3)
assert_tensor_value(tup_val.fields[2], 4)
assert_adt_len(tup_val, 3)
assert_tensor_value(tup_val[0], 2)
assert_tensor_value(tup_val[1], 3)
assert_tensor_value(tup_val[2], 4)
def test_ref_execution_order():
......@@ -475,12 +476,12 @@ def test_ref_execution_order():
])))
tup_val = run_as_python(expr)
assert_tuple_value(tup_val, 5)
assert_tensor_value(tup_val.fields[0], 1)
assert_tensor_value(tup_val.fields[1], 2)
assert_tensor_value(tup_val.fields[2], 3)
assert_tensor_value(tup_val.fields[3], 4)
assert_tensor_value(tup_val.fields[4], 5)
assert_adt_len(tup_val, 5)
assert_tensor_value(tup_val[0], 1)
assert_tensor_value(tup_val[1], 2)
assert_tensor_value(tup_val[2], 3)
assert_tensor_value(tup_val[3], 4)
assert_tensor_value(tup_val[4], 5)
def test_op_add():
......@@ -501,6 +502,7 @@ def test_op_stack():
args.append(relay.const(data))
call = relay.stack(relay.Tuple(args), axis)
call_val = run_as_python(call)
type(call_val)
assert_tensor_value(call_val, ref_res)
verify_stack([(2,), (2,), (2,)], -1)
......@@ -517,9 +519,9 @@ def test_split():
ref_res = np.split(x, indices_or_sections, axis=axis)
call = relay.split(relay.const(x), indices_or_sections, axis=axis)
call_val = run_as_python(call)
assert_tuple_value(call_val, len(ref_res))
assert_adt_len(call_val, len(ref_res))
for i in range(len(ref_res)):
assert_tensor_value(call_val.fields[i], ref_res[i])
assert_tensor_value(call_val[i], ref_res[i])
verify_split((2, 3), 2)
verify_split((5, 3), [3])
......
......@@ -58,7 +58,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
def vmobj_to_list(o):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vm.ADT):
elif isinstance(o, tvm.container.ADT):
result = []
for f in o:
result.extend(vmobj_to_list(f))
......
......@@ -17,18 +17,44 @@
import numpy as np
import tvm
from tvm.relay import vm
from tvm import nd, relay
from tvm import container as _container
def test_adt():
arr = tvm.nd.array([1,2,3])
y = vm.ADT(0, [arr, arr])
def test_adt_constructor():
arr = nd.array([1, 2, 3])
fields = [arr, arr]
y = _container.ADT(0, [arr, arr])
assert len(y) == 2
assert isinstance(y, vm.ADT)
assert isinstance(y, _container.ADT)
y[0:1][-1] == arr
assert y.tag == 0
assert isinstance(arr, tvm.nd.NDArray)
assert isinstance(arr, nd.NDArray)
def test_tuple_object():
x = relay.var(
'x',
type_annotation=relay.ty.TupleType([
relay.ty.TensorType((), 'int32'),
relay.ty.TensorType((), 'int32')
]))
fn = relay.Function([x], relay.expr.TupleGetItem(x, 0))
mod = relay.Module.from_expr(fn)
exe = relay.create_executor(
kind="vm", mod=mod, ctx=nd.cpu(), target="llvm")
f = exe.evaluate()
value_tuple = _container.tuple_object(
[nd.array(np.array(11)),
nd.array(np.array(12))])
# pass an ADT object to evaluate
out = f(value_tuple)
tvm.testing.assert_allclose(out.asnumpy(), np.array(11))
if __name__ == "__main__":
test_adt()
test_adt_constructor()
test_tuple_object()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment