Commit acbf8851 by Zhi Committed by Haichen Shen

[runtime][refactor] Unify vm and interpreter objects (#4693)

* unify vm and interpreter objects

* move closure back vm

* adt/closure back to vm.adt/vm.closure

* closure base
parent 2630ffcb
...@@ -103,11 +103,9 @@ def __lldb_init_module(debugger, _): ...@@ -103,11 +103,9 @@ def __lldb_init_module(debugger, _):
"tvm::relay::Span", "tvm::relay::Span",
"tvm::relay::TempExpr", "tvm::relay::TempExpr",
"tvm::relay::TensorType", "tvm::relay::TensorType",
"tvm::relay::TensorValue",
"tvm::relay::Tuple", "tvm::relay::Tuple",
"tvm::relay::TupleGetItem", "tvm::relay::TupleGetItem",
"tvm::relay::TupleType", "tvm::relay::TupleType",
"tvm::relay::TupleValue",
"tvm::relay::Type", "tvm::relay::Type",
"tvm::relay::TypeCall", "tvm::relay::TypeCall",
"tvm::relay::TypeConstraint", "tvm::relay::TypeConstraint",
......
...@@ -25,8 +25,8 @@ ...@@ -25,8 +25,8 @@
* Given a Relay module, and a Relay expression it produces a value. * Given a Relay module, and a Relay expression it produces a value.
* *
* The interpreter's values are a naive representation of the values that * The interpreter's values are a naive representation of the values that
* can be produced by a Relay program and are exposed via tvm::Node's * can be produced by a Relay program and are exposed via TVM's object
* system to Python for introspection and debugging. * protocol to Python for introspection and debugging.
* *
* The interpreter's intent is to serve as a reference semantics for the Relay IR, * The interpreter's intent is to serve as a reference semantics for the Relay IR,
* as well as for debugging and testing. * as well as for debugging and testing.
...@@ -38,6 +38,8 @@ ...@@ -38,6 +38,8 @@
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/vm.h>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -64,11 +66,8 @@ namespace relay { ...@@ -64,11 +66,8 @@ namespace relay {
runtime::TypedPackedFunc<ObjectRef(Expr)> runtime::TypedPackedFunc<ObjectRef(Expr)>
CreateInterpreter(IRModule mod, DLContext context, Target target); CreateInterpreter(IRModule mod, DLContext context, Target target);
/*! \brief A Relay closure, i.e a scope and a function. */ /*! \brief The container type of Closures used by the interpreter. */
class Closure; class InterpreterClosureObj : public runtime::vm::ClosureObj {
/*! \brief The container type of Closures. */
class ClosureNode : public Object {
public: public:
/*! \brief The set of free variables in the closure. /*! \brief The set of free variables in the closure.
* *
...@@ -82,102 +81,69 @@ class ClosureNode : public Object { ...@@ -82,102 +81,69 @@ class ClosureNode : public Object {
*/ */
Function func; Function func;
ClosureNode() {} InterpreterClosureObj() {}
void VisitAttrs(tvm::AttrVisitor* v) { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("env", &env); v->Visit("env", &env);
v->Visit("func", &func); v->Visit("func", &func);
} }
TVM_DLL static Closure make(tvm::Map<Var, ObjectRef> env, Function func); static constexpr const char* _type_key = "interpreter.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterClosureObj, runtime::vm::ClosureObj);
static constexpr const char* _type_key = "relay.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, Object);
}; };
class Closure : public ObjectRef { class InterpreterClosure : public runtime::vm::Closure {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureNode); TVM_DLL InterpreterClosure(tvm::Map<Var, ObjectRef> env, Function func);
TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::vm::Closure,
InterpreterClosureObj);
}; };
/*! \brief A Relay Recursive Closure. A closure that has a name. */
class RecClosure;
/*! \brief The container type of RecClosure. */ /*! \brief The container type of RecClosure. */
class RecClosureNode : public Object { class RecClosureObj : public Object {
public: public:
/*! \brief The closure. */ /*! \brief The closure. */
Closure clos; InterpreterClosure clos;
/*! \brief variable the closure bind to. */ /*! \brief variable the closure bind to. */
Var bind; Var bind;
RecClosureNode() {} RecClosureObj() {}
void VisitAttrs(tvm::AttrVisitor* v) { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("clos", &clos); v->Visit("clos", &clos);
v->Visit("bind", &bind); v->Visit("bind", &bind);
} }
TVM_DLL static RecClosure make(Closure clos, Var bind); static constexpr const char* _type_key = "interpreter.RecClosure";
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureObj, Object);
static constexpr const char* _type_key = "relay.RecClosure";
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, Object);
}; };
class RecClosure : public ObjectRef { class RecClosure : public ObjectRef {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureNode); TVM_DLL RecClosure(InterpreterClosure clos, Var bind);
}; TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureObj);
/*! \brief A tuple value. */
class TupleValue;
/*! \brief Tuple (x, ... y). */
struct TupleValueNode : Object {
tvm::Array<ObjectRef> fields;
TupleValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }
TVM_DLL static TupleValue make(tvm::Array<ObjectRef> value);
static constexpr const char* _type_key = "relay.TupleValue";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, Object);
};
class TupleValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, ObjectRef, TupleValueNode);
}; };
/*! \brief A reference value. */ struct RefValueObj : Object {
class RefValue;
struct RefValueNode : Object {
mutable ObjectRef value; mutable ObjectRef value;
RefValueNode() {} RefValueObj() {}
void VisitAttrs(tvm::AttrVisitor* v) { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value); v->Visit("value", &value);
} }
TVM_DLL static RefValue make(ObjectRef val);
static constexpr const char* _type_key = "relay.RefValue"; static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(RefValueObj, Object);
}; };
class RefValue : public ObjectRef { class RefValue : public ObjectRef {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueNode); TVM_DLL RefValue(ObjectRef val);
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueObj);
}; };
/*! \brief An ADT constructor value. */ struct ConstructorValueObj : Object {
class ConstructorValue;
struct ConstructorValueNode : Object {
int32_t tag; int32_t tag;
tvm::Array<ObjectRef> fields; tvm::Array<ObjectRef> fields;
...@@ -191,17 +157,17 @@ struct ConstructorValueNode : Object { ...@@ -191,17 +157,17 @@ struct ConstructorValueNode : Object {
v->Visit("constructor", &constructor); v->Visit("constructor", &constructor);
} }
TVM_DLL static ConstructorValue make(int32_t tag,
tvm::Array<ObjectRef> fields,
Constructor construtor = {});
static constexpr const char* _type_key = "relay.ConstructorValue"; static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueObj, Object);
}; };
class ConstructorValue : public ObjectRef { class ConstructorValue : public ObjectRef {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueNode); TVM_DLL ConstructorValue(int32_t tag,
tvm::Array<ObjectRef> fields,
Constructor construtor = {});
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj);
}; };
} // namespace relay } // namespace relay
......
...@@ -50,10 +50,9 @@ namespace runtime { ...@@ -50,10 +50,9 @@ namespace runtime {
enum TypeIndex { enum TypeIndex {
/*! \brief Root object type. */ /*! \brief Root object type. */
kRoot = 0, kRoot = 0,
kVMTensor = 1, kClosure = 1,
kVMClosure = 2, kVMADT = 2,
kVMADT = 3, kRuntimeModule = 3,
kRuntimeModule = 4,
kStaticIndexEnd, kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */ /*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd kDynamic = kStaticIndexEnd
......
...@@ -25,36 +25,58 @@ ...@@ -25,36 +25,58 @@
#define TVM_RUNTIME_VM_H_ #define TVM_RUNTIME_VM_H_
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility>
#include <vector> #include <vector>
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
namespace vm { namespace vm {
/*! \brief An object representing a closure. */ /*!
* \brief An object representing a closure. This object is used by both the
* Relay VM and interpreter.
*/
class ClosureObj : public Object { class ClosureObj : public Object {
public: public:
/*! \brief The index into the VM function table. */ static constexpr const uint32_t _type_index = TypeIndex::kClosure;
static constexpr const char* _type_key = "Closure";
TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object);
};
/*! \brief reference to closure. */
class Closure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
};
/*!
* \brief An object representing a vm closure.
*/
class VMClosureObj : public ClosureObj {
public:
/*!
* \brief The index into the function list. The function could be any
* function object that is compatible to the VM runtime.
*/
size_t func_index; size_t func_index;
/*! \brief The free variables of the closure. */ /*! \brief The free variables of the closure. */
std::vector<ObjectRef> free_vars; std::vector<ObjectRef> free_vars;
static constexpr const uint32_t _type_index = TypeIndex::kVMClosure;
static constexpr const char* _type_key = "vm.Closure"; static constexpr const char* _type_key = "vm.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureObj, Object); TVM_DECLARE_FINAL_OBJECT_INFO(VMClosureObj, ClosureObj);
}; };
/*! \brief reference to closure. */ /*! \brief reference to closure. */
class Closure : public ObjectRef { class VMClosure : public Closure {
public: public:
Closure(size_t func_index, std::vector<ObjectRef> free_vars); VMClosure(size_t func_index, std::vector<ObjectRef> free_vars);
TVM_DEFINE_OBJECT_REF_METHODS(VMClosure, Closure, VMClosureObj);
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
}; };
/*! \brief Magic number for NDArray list file */ /*! \brief Magic number for NDArray list file */
......
...@@ -16,8 +16,10 @@ ...@@ -16,8 +16,10 @@
# under the License. # under the License.
"""Container data structures used in TVM DSL.""" """Container data structures used in TVM DSL."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.object import Object, register_object from tvm import ndarray as _nd
from . import _api_internal from . import _api_internal
from ._ffi.object import Object, register_object, getitem_helper
from ._ffi.function import _init_api
@register_object @register_object
class Array(Object): class Array(Object):
...@@ -114,3 +116,56 @@ class LoweredFunc(Object): ...@@ -114,3 +116,56 @@ class LoweredFunc(Object):
MixedFunc = 0 MixedFunc = 0
HostFunc = 1 HostFunc = 1
DeviceFunc = 2 DeviceFunc = 2
@register_object("vm.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
Parameters
----------
tag : int
The tag of ADT.
fields : list[Object] or tuple[Object]
The source tuple.
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)
@property
def tag(self):
return _GetADTTag(self)
def __getitem__(self, idx):
return getitem_helper(
self, _GetADTFields, len(self), idx)
def __len__(self):
return _GetADTSize(self)
def tuple_object(fields=None):
"""Create a ADT object from source tuple.
Parameters
----------
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : ADT
The created object.
"""
fields = fields if fields else []
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)
_init_api("tvm.container")
...@@ -38,7 +38,6 @@ from . import param_dict ...@@ -38,7 +38,6 @@ from . import param_dict
from . import feature from . import feature
from .backend import vm from .backend import vm
from .backend import profiler_vm from .backend import profiler_vm
from .backend import vmobj
# Root operators # Root operators
from .op import Op from .op import Op
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The VM Object FFI namespace."""
from tvm._ffi.function import _init_api
_init_api("_vmobj", __name__)
...@@ -20,6 +20,7 @@ from __future__ import absolute_import ...@@ -20,6 +20,7 @@ from __future__ import absolute_import
import numpy as np import numpy as np
from tvm import container
from . import _backend from . import _backend
from .. import _make, analysis, transform from .. import _make, analysis, transform
from .. import module from .. import module
...@@ -28,40 +29,6 @@ from ..base import Object, register_relay_node ...@@ -28,40 +29,6 @@ from ..base import Object, register_relay_node
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
from ..scope_builder import ScopeBuilder from ..scope_builder import ScopeBuilder
@register_relay_node
class TupleValue(Object):
"""A tuple value produced by the interpreter."""
def __init__(self, *fields):
self.__init_handle_by_constructor__(
_make.TupleValue, fields)
def __getitem__(self, field_no):
return self.fields[field_no]
def __len__(self):
return len(self.fields)
def __str__(self):
body = ','.join(str(f) for f in self.fields)
return '({0})'.format(body)
def __repr__(self):
body = ','.join(repr(f) for f in self.fields)
return '({0})'.format(body)
def __iter__(self):
return iter(self.fields)
@register_relay_node
class Closure(Object):
"""A closure produced by the interpreter."""
@register_relay_node
class RecClosure(Object):
"""A recursive closure produced by the interpreter."""
@register_relay_node @register_relay_node
class ConstructorValue(Object): class ConstructorValue(Object):
...@@ -80,8 +47,8 @@ class RefValue(Object): ...@@ -80,8 +47,8 @@ class RefValue(Object):
def _arg_to_ast(mod, arg): def _arg_to_ast(mod, arg):
if isinstance(arg, nd.NDArray): if isinstance(arg, nd.NDArray):
return Constant(arg.copyto(nd.cpu(0))) return Constant(arg.copyto(nd.cpu(0)))
elif isinstance(arg, TupleValue): elif isinstance(arg, container.ADT):
return Tuple([_arg_to_ast(mod, field) for field in arg.fields]) return Tuple([_arg_to_ast(mod, field) for field in arg])
elif isinstance(arg, tuple): elif isinstance(arg, tuple):
return Tuple([_arg_to_ast(mod, field) for field in arg]) return Tuple([_arg_to_ast(mod, field) for field in arg])
elif isinstance(arg, RefValue): elif isinstance(arg, RefValue):
......
...@@ -23,20 +23,18 @@ Implements a Python interface to compiling and executing on the Relay VM. ...@@ -23,20 +23,18 @@ Implements a Python interface to compiling and executing on the Relay VM.
import numpy as np import numpy as np
import tvm import tvm
from tvm import autotvm from tvm import autotvm, container
from tvm.object import Object
from tvm.relay import expr as _expr from tvm.relay import expr as _expr
from tvm._ffi.runtime_ctypes import TVMByteArray from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base from tvm._ffi import base as _base
from . import _vm from . import _vm
from . import vmobj as _obj
from .interpreter import Executor from .interpreter import Executor
ADT = _obj.ADT
def _convert(arg, cargs): def _convert(arg, cargs):
if isinstance(arg, _expr.Constant): if isinstance(arg, _expr.Constant):
cargs.append(arg.data) cargs.append(arg.data)
elif isinstance(arg, _obj.Object): elif isinstance(arg, Object):
cargs.append(arg) cargs.append(arg)
elif isinstance(arg, np.ndarray): elif isinstance(arg, np.ndarray):
nd_arr = tvm.nd.array(arg, ctx=tvm.cpu(0)) nd_arr = tvm.nd.array(arg, ctx=tvm.cpu(0))
...@@ -47,7 +45,7 @@ def _convert(arg, cargs): ...@@ -47,7 +45,7 @@ def _convert(arg, cargs):
field_args = [] field_args = []
for field in arg: for field in arg:
_convert(field, field_args) _convert(field, field_args)
cargs.append(_obj.tuple_object(field_args)) cargs.append(container.tuple_object(field_args))
elif isinstance(arg, (_base.numeric_types, bool)): elif isinstance(arg, (_base.numeric_types, bool)):
dtype = "int32" if isinstance(arg, (int, bool)) else "float32" dtype = "int32" if isinstance(arg, (int, bool)) else "float32"
value = tvm.nd.array(np.array(arg, dtype=dtype), ctx=tvm.cpu(0)) value = tvm.nd.array(np.array(arg, dtype=dtype), ctx=tvm.cpu(0))
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Runtime Object API."""
from __future__ import absolute_import as _abs
from tvm._ffi.object import Object, register_object, getitem_helper
from tvm import ndarray as _nd
from . import _vmobj
@register_object("vm.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
Parameters
----------
tag : int
The tag of ADT.
fields : list[Object] or tuple[Object]
The source tuple.
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or "
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(
_vmobj.ADT, tag, *fields)
@property
def tag(self):
return _vmobj.GetADTTag(self)
def __getitem__(self, idx):
return getitem_helper(
self, _vmobj.GetADTFields, len(self), idx)
def __len__(self):
return _vmobj.GetADTNumberOfFields(self)
def tuple_object(fields):
"""Create a ADT object from source tuple.
Parameters
----------
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : ADT
The created object.
"""
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm "
"NDArray type, but received : {0}".format(type(f))
return _vmobj.Tuple(*fields)
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
"""Common utilities""" """Common utilities"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import logging import logging
import numpy as np
import tvm import tvm
import numpy as np
from topi.util import get_const_tuple from topi.util import get_const_tuple
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module from .. import module as _module
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pylint: disable=invalid-name #pylint: disable=invalid-name
"""Utilities for testing and benchmarks""" """Utilities for testing and benchmarks"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import numpy as np
import tvm import tvm
import tvm.relay as relay import tvm.relay as relay
...@@ -24,7 +25,6 @@ import tvm.relay.op as op ...@@ -24,7 +25,6 @@ import tvm.relay.op as op
from tvm.relay import transform from tvm.relay import transform
from tvm.relay import Function, GlobalVar, ScopeBuilder, Tuple, TupleGetItem, create_executor from tvm.relay import Function, GlobalVar, ScopeBuilder, Tuple, TupleGetItem, create_executor
from tvm.relay import TensorType, TupleType from tvm.relay import TensorType, TupleType
import numpy as np
from . import mlp from . import mlp
from . import resnet from . import resnet
......
...@@ -32,18 +32,20 @@ OUTPUT_VAR_NAME = '_py_out' ...@@ -32,18 +32,20 @@ OUTPUT_VAR_NAME = '_py_out'
# import numpy # import numpy
# import tvm # import tvm
# from tvm import relay # from tvm import relay
# from tvm import import container as _container
# from tvm import nd # from tvm import nd
# from tvm.relay.backend.interpreter import RefValue, TupleValue, ConstructorValue # from tvm.relay.backend.interpreter import RefValue, ConstructorValue
PROLOGUE = [ PROLOGUE = [
ast.Import([alias('numpy', None)]), ast.Import([alias('numpy', None)]),
ast.Import([alias('tvm', None)]), ast.Import([alias('tvm', None)]),
ast.ImportFrom('tvm', [alias('relay', None)], 0), ast.ImportFrom('tvm', [alias('relay', None)], 0),
ast.ImportFrom('tvm', [alias('nd', None)], 0), ast.ImportFrom('tvm', [alias('nd', None)], 0),
ast.ImportFrom('tvm', [alias('container', '_container')],
0),
ast.ImportFrom('tvm.relay.backend.interpreter', ast.ImportFrom('tvm.relay.backend.interpreter',
[alias('RefValue', None), [alias('RefValue', None),
alias('TupleValue', None),
alias('ConstructorValue', None)], alias('ConstructorValue', None)],
0) 0),
] ]
class PythonConverter(ExprFunctor): class PythonConverter(ExprFunctor):
...@@ -253,7 +255,7 @@ class PythonConverter(ExprFunctor): ...@@ -253,7 +255,7 @@ class PythonConverter(ExprFunctor):
for i in range(len(arg_type.fields)): for i in range(len(arg_type.fields)):
ret += convert_input( ret += convert_input(
ast.Subscript( ast.Subscript(
ast.Attribute(py_input, 'fields', Load()), py_input,
ast.Index(Num(i)), Load()), ast.Index(Num(i)), Load()),
arg_type.fields[i]) arg_type.fields[i])
return ret return ret
...@@ -282,7 +284,8 @@ class PythonConverter(ExprFunctor): ...@@ -282,7 +284,8 @@ class PythonConverter(ExprFunctor):
assignments += inner_assignments assignments += inner_assignments
extra_args += inner_args extra_args += inner_args
fields.append(inner_output) fields.append(inner_output)
return (assignments, extra_args, self.create_call('TupleValue', fields)) fields = [ast.List(fields, Load())]
return (assignments, extra_args, self.create_call('_container.tuple_object', fields))
# create a function to wrap the call of the lowered op and return # create a function to wrap the call of the lowered op and return
# a call to that function # a call to that function
...@@ -444,7 +447,8 @@ class PythonConverter(ExprFunctor): ...@@ -444,7 +447,8 @@ class PythonConverter(ExprFunctor):
def visit_tuple(self, tup: Expr): def visit_tuple(self, tup: Expr):
fields, ret_defs = self.convert_fields(tup.fields) fields, ret_defs = self.convert_fields(tup.fields)
return (self.create_call('TupleValue', fields), ret_defs) fields = [ast.List(fields, Load())]
return (self.create_call('_container.tuple_object', fields), ret_defs)
def visit_tuple_getitem(self, tgi: Expr): def visit_tuple_getitem(self, tgi: Expr):
...@@ -534,7 +538,7 @@ class PythonConverter(ExprFunctor): ...@@ -534,7 +538,7 @@ class PythonConverter(ExprFunctor):
thunk_name, [], thunk_name, [],
ref_defs + val_defs + [ ref_defs + val_defs + [
Assign([ast.Attribute(ref, 'value', Store())], val), Assign([ast.Attribute(ref, 'value', Store())], val),
Return(self.create_call('TupleValue', [])) Return(self.create_call('_container.tuple_object', []))
]) ])
return (self.create_call(thunk_name, []), [thunk]) return (self.create_call(thunk_name, []), [thunk])
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <tvm/runtime/ndarray.h> #include <tvm/runtime/ndarray.h>
#include <tvm/runtime/container.h>
#include "pattern_util.h" #include "pattern_util.h"
namespace tvm { namespace tvm {
...@@ -187,10 +188,11 @@ class ConstantFolder : public ExprMutator { ...@@ -187,10 +188,11 @@ class ConstantFolder : public ExprMutator {
<< "invalid dimension after constant eval"; << "invalid dimension after constant eval";
} }
return ConstantNode::make(nd_array); return ConstantNode::make(nd_array);
} else if (const auto* val = value.as<TupleValueNode>()) { } else if (const auto* val = value.as<runtime::ADTObj>()) {
runtime::ADT adt = GetRef<runtime::ADT>(val);
Array<Expr> fields; Array<Expr> fields;
for (ObjectRef field : val->fields) { for (size_t i = 0; i < adt.size(); ++i) {
fields.push_back(ObjectToExpr(field)); fields.push_back(ObjectToExpr(adt[i]));
} }
return TupleNode::make(fields); return TupleNode::make(fields);
} else { } else {
......
...@@ -935,11 +935,12 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -935,11 +935,12 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
if (v->IsInstance<runtime::NDArray::ContainerType>()) { if (v->IsInstance<runtime::NDArray::ContainerType>()) {
auto nd_array = Downcast<runtime::NDArray>(v); auto nd_array = Downcast<runtime::NDArray>(v);
return HasStatic(MkSTensor(nd_array), ll->Push(ConstantNode::make(nd_array))); return HasStatic(MkSTensor(nd_array), ll->Push(ConstantNode::make(nd_array)));
} else if (const TupleValueNode* op = v.as<TupleValueNode>()) { } else if (const runtime::ADTObj* op = v.as<runtime::ADTObj>()) {
std::vector<PStatic> fields; std::vector<PStatic> fields;
tvm::Array<Expr> fields_dyn; tvm::Array<Expr> fields_dyn;
for (const ObjectRef& field : op->fields) { auto adt = GetRef<runtime::ADT>(op);
PStatic ps = Reify(field, ll); for (size_t i = 0; i < adt.size(); ++i) {
PStatic ps = Reify(adt[i], ll);
fields.push_back(ps); fields.push_back(ps);
fields_dyn.push_back(ps->dynamic); fields_dyn.push_back(ps->dynamic);
} }
......
...@@ -18,38 +18,28 @@ ...@@ -18,38 +18,28 @@
*/ */
/*! /*!
* \file src/runtime/vm/object.cc * \file src/runtime/container.cc
* \brief VM related objects. * \brief Implementations of common plain old data (POD) containers.
*/ */
#include <tvm/support/logging.h>
#include <tvm/runtime/container.h> #include <tvm/runtime/container.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <tvm/runtime/vm.h> #include <tvm/runtime/vm.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h>
#include "../runtime_base.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
namespace vm {
Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
auto ptr = make_object<ClosureObj>();
ptr->func_index = func_index;
ptr->free_vars = std::move(free_vars);
data_ = std::move(ptr);
}
using namespace vm;
TVM_REGISTER_GLOBAL("_vmobj.GetADTTag") TVM_REGISTER_GLOBAL("container._GetADTTag")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0]; ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj); const auto& adt = Downcast<ADT>(obj);
*rv = static_cast<int64_t>(adt.tag()); *rv = static_cast<int64_t>(adt.tag());
}); });
TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields") TVM_REGISTER_GLOBAL("container._GetADTSize")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0]; ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj); const auto& adt = Downcast<ADT>(obj);
...@@ -57,7 +47,7 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields") ...@@ -57,7 +47,7 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields")
}); });
TVM_REGISTER_GLOBAL("_vmobj.GetADTFields") TVM_REGISTER_GLOBAL("container._GetADTFields")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0]; ObjectRef obj = args[0];
int idx = args[1]; int idx = args[1];
...@@ -66,7 +56,7 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTFields") ...@@ -66,7 +56,7 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTFields")
*rv = adt[idx]; *rv = adt[idx];
}); });
TVM_REGISTER_GLOBAL("_vmobj.Tuple") TVM_REGISTER_GLOBAL("container._Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<ObjectRef> fields; std::vector<ObjectRef> fields;
for (auto i = 0; i < args.size(); ++i) { for (auto i = 0; i < args.size(); ++i) {
...@@ -75,7 +65,7 @@ TVM_REGISTER_GLOBAL("_vmobj.Tuple") ...@@ -75,7 +65,7 @@ TVM_REGISTER_GLOBAL("_vmobj.Tuple")
*rv = ADT::Tuple(fields); *rv = ADT::Tuple(fields);
}); });
TVM_REGISTER_GLOBAL("_vmobj.ADT") TVM_REGISTER_GLOBAL("container._ADT")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0]; int itag = args[0];
size_t tag = static_cast<size_t>(itag); size_t tag = static_cast<size_t>(itag);
...@@ -88,15 +78,6 @@ TVM_REGISTER_GLOBAL("_vmobj.ADT") ...@@ -88,15 +78,6 @@ TVM_REGISTER_GLOBAL("_vmobj.ADT")
TVM_REGISTER_OBJECT_TYPE(ADTObj); TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj); TVM_REGISTER_OBJECT_TYPE(ClosureObj);
} // namespace vm
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
using namespace tvm::runtime;
int TVMGetObjectTag(TVMObjectHandle handle, int* tag) {
API_BEGIN();
int res = static_cast<int>(static_cast<Object*>(handle)->type_index());
*tag = res;
API_END();
}
...@@ -45,6 +45,12 @@ namespace tvm { ...@@ -45,6 +45,12 @@ namespace tvm {
namespace runtime { namespace runtime {
namespace vm { namespace vm {
VMClosure::VMClosure(size_t func_index, std::vector<ObjectRef> free_vars) {
auto ptr = make_object<VMClosureObj>();
ptr->func_index = func_index;
ptr->free_vars = std::move(free_vars);
data_ = std::move(ptr);
}
inline Storage make_storage(size_t size, size_t alignment, DLDataType dtype_hint, TVMContext ctx) { inline Storage make_storage(size_t size, size_t alignment, DLDataType dtype_hint, TVMContext ctx) {
// We could put cache in here, from ctx to storage allocator. // We could put cache in here, from ctx to storage allocator.
...@@ -906,7 +912,7 @@ void VirtualMachine::RunLoop() { ...@@ -906,7 +912,7 @@ void VirtualMachine::RunLoop() {
} }
case Opcode::InvokeClosure: { case Opcode::InvokeClosure: {
auto object = ReadRegister(instr.closure); auto object = ReadRegister(instr.closure);
const auto* closure = object.as<ClosureObj>(); const auto* closure = object.as<VMClosureObj>();
std::vector<ObjectRef> args; std::vector<ObjectRef> args;
for (auto free_var : closure->free_vars) { for (auto free_var : closure->free_vars) {
...@@ -1008,7 +1014,7 @@ void VirtualMachine::RunLoop() { ...@@ -1008,7 +1014,7 @@ void VirtualMachine::RunLoop() {
for (Index i = 0; i < instr.num_freevar; i++) { for (Index i = 0; i < instr.num_freevar; i++) {
free_vars.push_back(ReadRegister(instr.free_vars[i])); free_vars.push_back(ReadRegister(instr.free_vars[i]));
} }
WriteRegister(instr.dst, Closure(instr.func_index, free_vars)); WriteRegister(instr.dst, VMClosure(instr.func_index, free_vars));
pc_++; pc_++;
goto main_loop; goto main_loop;
} }
......
...@@ -62,16 +62,11 @@ tf_dtypes = { ...@@ -62,16 +62,11 @@ tf_dtypes = {
def vmobj_to_list(o): def vmobj_to_list(o):
if isinstance(o, tvm.nd.NDArray): if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()] return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vmobj.ADT): elif isinstance(o, tvm.container.ADT):
result = [] result = []
for f in o: for f in o:
result.extend(vmobj_to_list(f)) result.extend(vmobj_to_list(f))
return result return result
elif isinstance(o, tvm.relay.backend.interpreter.TupleValue):
result = []
for f in o.fields:
result.append(vmobj_to_list(f))
return result
elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if o.constructor.name_hint == 'Cons': if o.constructor.name_hint == 'Cons':
tl = vmobj_to_list(o.fields[1]) tl = vmobj_to_list(o.fields[1])
......
...@@ -19,10 +19,9 @@ import numpy as np ...@@ -19,10 +19,9 @@ import numpy as np
import tvm import tvm
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm import relay from tvm import relay, container
from tvm.relay import testing from tvm.relay import testing
from tvm.relay import vm from tvm.relay import vm
from tvm.relay import vmobj as _obj
def benchmark_execution(mod, def benchmark_execution(mod,
...@@ -69,7 +68,7 @@ def benchmark_execution(mod, ...@@ -69,7 +68,7 @@ def benchmark_execution(mod,
ftimer = rly_vm.mod.time_evaluator("invoke", ctx, number=number, ftimer = rly_vm.mod.time_evaluator("invoke", ctx, number=number,
repeat=repeat) repeat=repeat)
# Measure in millisecond. # Measure in millisecond.
prof_res = np.array(ftimer("main", _obj.Tensor(data)).results) * 1000 prof_res = np.array(ftimer("main", data).results) * 1000
print("Mean vm inference time (std dev): %.2f ms (%.2f ms)" % print("Mean vm inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res))) (np.mean(prof_res), np.std(prof_res)))
......
...@@ -117,7 +117,7 @@ def tree_to_dict(t): ...@@ -117,7 +117,7 @@ def tree_to_dict(t):
def vmobj_to_list(o, dtype="float32"): def vmobj_to_list(o, dtype="float32"):
if isinstance(o, tvm.nd.NDArray): if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()] return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vmobj.ADT): elif isinstance(o, tvm.container.ADT):
if len(o) == 0: if len(o) == 0:
tensor_nil = p.get_var("tensor_nil", dtype=dtype) tensor_nil = p.get_var("tensor_nil", dtype=dtype)
if tensor_nil.tag == o.tag: if tensor_nil.tag == o.tag:
......
...@@ -18,8 +18,7 @@ import numpy as np ...@@ -18,8 +18,7 @@ import numpy as np
import tvm import tvm
import tvm.testing import tvm.testing
from tvm import nd from tvm import nd
from tvm import relay from tvm import relay, container
from tvm.relay.backend.interpreter import TupleValue
from tvm.relay.backend.interpreter import RefValue, ConstructorValue from tvm.relay.backend.interpreter import RefValue, ConstructorValue
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor from tvm.relay import testing, create_executor
...@@ -39,7 +38,8 @@ def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): ...@@ -39,7 +38,8 @@ def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
def test_tuple_value(): def test_tuple_value():
tv = TupleValue(relay.const(1), relay.const(2), relay.const(3)) tv = container.tuple_object([relay.const(1), relay.const(2),
relay.const(3)])
np.testing.assert_allclose(tv[0].data.asnumpy(), 1) np.testing.assert_allclose(tv[0].data.asnumpy(), 1)
np.testing.assert_allclose(tv[1].data.asnumpy(), 2) np.testing.assert_allclose(tv[1].data.asnumpy(), 2)
np.testing.assert_allclose(tv[2].data.asnumpy(), 3) np.testing.assert_allclose(tv[2].data.asnumpy(), 3)
...@@ -178,7 +178,7 @@ def test_function_taking_adt_ref_tuple(): ...@@ -178,7 +178,7 @@ def test_function_taking_adt_ref_tuple():
], prelude.cons) ], prelude.cons)
ref_value = RefValue(nd.array(np.random.rand(1, 10).astype('float32'))) ref_value = RefValue(nd.array(np.random.rand(1, 10).astype('float32')))
tuple_value = TupleValue(*[ tuple_value = container.tuple_object([
nd.array(np.random.rand(1, 10).astype('float32')) for _ in range(10) nd.array(np.random.rand(1, 10).astype('float32')) for _ in range(10)
]) ])
...@@ -202,8 +202,8 @@ def test_function_taking_adt_ref_tuple(): ...@@ -202,8 +202,8 @@ def test_function_taking_adt_ref_tuple():
res_tuple = id_func(tuple_value) res_tuple = id_func(tuple_value)
for i in range(10): for i in range(10):
tvm.testing.assert_allclose(res_tuple.fields[i].asnumpy(), tvm.testing.assert_allclose(res_tuple[i].asnumpy(),
tuple_value.fields[i].asnumpy()) tuple_value[i].asnumpy())
def test_tuple_passing(): def test_tuple_passing():
x = relay.var('x', type_annotation=relay.ty.TupleType([ x = relay.var('x', type_annotation=relay.ty.TupleType([
...@@ -224,7 +224,8 @@ def test_tuple_passing(): ...@@ -224,7 +224,8 @@ def test_tuple_passing():
out = f((10, 8)) out = f((10, 8))
tvm.testing.assert_allclose(out.asnumpy(), np.array(10)) tvm.testing.assert_allclose(out.asnumpy(), np.array(10))
# Second use a tuple value. # Second use a tuple value.
value_tuple = TupleValue(nd.array(np.array(11)), nd.array(np.array(12))) value_tuple = container.tuple_object([nd.array(np.array(11)),
nd.array(np.array(12))])
out = f(value_tuple) out = f(value_tuple)
tvm.testing.assert_allclose(out.asnumpy(), np.array(11)) tvm.testing.assert_allclose(out.asnumpy(), np.array(11))
......
...@@ -19,7 +19,8 @@ import tvm ...@@ -19,7 +19,8 @@ import tvm
from tvm import relay from tvm import relay
from tvm.relay.testing import to_python, run_as_python from tvm.relay.testing import to_python, run_as_python
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay.backend.interpreter import TupleValue, RefValue, ConstructorValue from tvm.container import ADT
from tvm.relay.backend.interpreter import RefValue, ConstructorValue
# helper: uses a dummy let binding to sequence a list # helper: uses a dummy let binding to sequence a list
# of expressions: expr1; expr2; expr3, etc. # of expressions: expr1; expr2; expr3, etc.
...@@ -45,10 +46,10 @@ def assert_tensor_value(candidate, val): ...@@ -45,10 +46,10 @@ def assert_tensor_value(candidate, val):
assert np.array_equal(candidate.asnumpy(), np.array(val)) assert np.array_equal(candidate.asnumpy(), np.array(val))
# assert that the candidate is a TupleValue with the indicate number of fields # assert that the candidate is an ADT with the indicated number of fields
def assert_tuple_value(candidate, fields): def assert_adt_len(candidate, fields):
assert isinstance(candidate, TupleValue) assert isinstance(candidate, ADT)
assert len(candidate.fields) == fields assert len(candidate) == fields
# assert that the candidate is a ConstructorValue with the approrpaite constructor # assert that the candidate is a ConstructorValue with the approrpaite constructor
...@@ -62,7 +63,7 @@ def assert_constructor_value(candidate, constructor, fields): ...@@ -62,7 +63,7 @@ def assert_constructor_value(candidate, constructor, fields):
def test_create_empty_tuple(): def test_create_empty_tuple():
empty = relay.Tuple([]) empty = relay.Tuple([])
tup_val = run_as_python(empty) tup_val = run_as_python(empty)
assert_tuple_value(tup_val, 0) assert_adt_len(tup_val, 0)
def test_create_scalar(): def test_create_scalar():
...@@ -87,12 +88,12 @@ def test_create_nested_tuple(): ...@@ -87,12 +88,12 @@ def test_create_nested_tuple():
]) ])
]) ])
tup_val = run_as_python(relay_tup) tup_val = run_as_python(relay_tup)
assert_tuple_value(tup_val, 3) assert_adt_len(tup_val, 3)
for i in range(2): for i in range(2):
assert_tensor_value(tup_val.fields[i], i + 1) assert_tensor_value(tup_val[i], i + 1)
assert_tuple_value(tup_val.fields[2], 2) assert_adt_len(tup_val[2], 2)
for i in range(2): for i in range(2):
assert_tensor_value(tup_val.fields[2].fields[i], i + 3) assert_tensor_value(tup_val[2][i], i + 3)
def test_tuple_get_item(): def test_tuple_get_item():
...@@ -118,23 +119,23 @@ def test_create_let(): ...@@ -118,23 +119,23 @@ def test_create_let():
v = relay.Var('v') v = relay.Var('v')
let = relay.Let(v, relay.Tuple([]), relay.Tuple([v, v])) let = relay.Let(v, relay.Tuple([]), relay.Tuple([v, v]))
tup_val = run_as_python(let) tup_val = run_as_python(let)
assert_tuple_value(tup_val, 2) assert_adt_len(tup_val, 2)
assert_tuple_value(tup_val.fields[0], 0) assert_adt_len(tup_val[0], 0)
assert_tuple_value(tup_val.fields[1], 0) assert_adt_len(tup_val[1], 0)
def test_create_ref(): def test_create_ref():
relay_ref = relay.RefCreate(relay.Tuple([])) relay_ref = relay.RefCreate(relay.Tuple([]))
ref_val = run_as_python(relay_ref) ref_val = run_as_python(relay_ref)
assert isinstance(ref_val, RefValue) assert isinstance(ref_val, RefValue)
assert_tuple_value(ref_val.value, 0) assert_adt_len(ref_val.value, 0)
def test_ref_read(): def test_ref_read():
v = relay.Var('v') v = relay.Var('v')
assign = relay.Let(v, relay.RefCreate(relay.Tuple([])), relay.RefRead(v)) assign = relay.Let(v, relay.RefCreate(relay.Tuple([])), relay.RefRead(v))
read_val = run_as_python(assign) read_val = run_as_python(assign)
assert_tuple_value(read_val, 0) assert_adt_len(read_val, 0)
def test_ref_write(): def test_ref_write():
...@@ -143,7 +144,7 @@ def test_ref_write(): ...@@ -143,7 +144,7 @@ def test_ref_write():
initial_write = relay.Let(v, relay.RefCreate(relay.Tuple([relay.const(1)])), initial_write = relay.Let(v, relay.RefCreate(relay.Tuple([relay.const(1)])),
relay.RefWrite(v, relay.Tuple([relay.const(2)]))) relay.RefWrite(v, relay.Tuple([relay.const(2)])))
write_val = run_as_python(initial_write) write_val = run_as_python(initial_write)
assert_tuple_value(write_val, 0) assert_adt_len(write_val, 0)
# now ensure that the value, once written, can be read back # now ensure that the value, once written, can be read back
# (we read the value before and after mutation) # (we read the value before and after mutation)
...@@ -155,11 +156,11 @@ def test_ref_write(): ...@@ -155,11 +156,11 @@ def test_ref_write():
seq(relay.RefWrite(v, relay.Tuple([relay.const(2)])), seq(relay.RefWrite(v, relay.Tuple([relay.const(2)])),
relay.Tuple([relay.RefRead(w), relay.RefRead(v)])))) relay.Tuple([relay.RefRead(w), relay.RefRead(v)]))))
read_val = run_as_python(read_after_write) read_val = run_as_python(read_after_write)
assert_tuple_value(read_val, 2) assert_adt_len(read_val, 2)
assert_tuple_value(read_val.fields[0], 1) assert_adt_len(read_val[0], 1)
assert_tuple_value(read_val.fields[1], 1) assert_adt_len(read_val[1], 1)
assert_tensor_value(read_val.fields[0].fields[0], 1) assert_tensor_value(read_val[0][0], 1)
assert_tensor_value(read_val.fields[1].fields[0], 2) assert_tensor_value(read_val[1][0], 2)
def test_if(): def test_if():
...@@ -191,7 +192,7 @@ def test_local_function(): ...@@ -191,7 +192,7 @@ def test_local_function():
call2 = relay.Let(f, ident, f(relay.const(2))) call2 = relay.Let(f, ident, f(relay.const(2)))
call_val1 = run_as_python(call1) call_val1 = run_as_python(call1)
assert_tuple_value(call_val1, 0) assert_adt_len(call_val1, 0)
call_val2 = run_as_python(call2) call_val2 = run_as_python(call2)
assert_tensor_value(call_val2, 2) assert_tensor_value(call_val2, 2)
...@@ -211,9 +212,9 @@ def test_global_function(): ...@@ -211,9 +212,9 @@ def test_global_function():
assert_tensor_value(call_val1, 1) assert_tensor_value(call_val1, 1)
call_val2 = run_as_python(call2, mod) call_val2 = run_as_python(call2, mod)
assert_tuple_value(call_val2, 2) assert_adt_len(call_val2, 2)
assert_tensor_value(call_val2.fields[0], 2) assert_tensor_value(call_val2[0], 2)
assert_tensor_value(call_val2.fields[1], 2) assert_tensor_value(call_val2[1], 2)
def test_constructor(): def test_constructor():
...@@ -230,7 +231,7 @@ def test_constructor(): ...@@ -230,7 +231,7 @@ def test_constructor():
box_val_tup = run_as_python(init_box_tup, mod) box_val_tup = run_as_python(init_box_tup, mod)
assert_constructor_value(box_val_tup, box_ctor, 1) assert_constructor_value(box_val_tup, box_ctor, 1)
assert_tuple_value(box_val_tup.fields[0], 0) assert_adt_len(box_val_tup.fields[0], 0)
def test_match_wildcard(): def test_match_wildcard():
...@@ -372,7 +373,7 @@ def test_global_recursion(): ...@@ -372,7 +373,7 @@ def test_global_recursion():
call2 = copy_def(p.cons(relay.Tuple([]), p.nil())) call2 = copy_def(p.cons(relay.Tuple([]), p.nil()))
val2 = run_as_python(call2, mod) val2 = run_as_python(call2, mod)
assert_constructor_value(val2, p.cons, 2) assert_constructor_value(val2, p.cons, 2)
assert_tuple_value(val2.fields[0], 0) assert_adt_len(val2.fields[0], 0)
assert_constructor_value(val2.fields[1], p.nil, 0) assert_constructor_value(val2.fields[1], p.nil, 0)
...@@ -437,10 +438,10 @@ def test_arbitrary_let_nesting(): ...@@ -437,10 +438,10 @@ def test_arbitrary_let_nesting():
]) ])
tup_val = run_as_python(expr, mod) tup_val = run_as_python(expr, mod)
assert_tuple_value(tup_val, 3) assert_adt_len(tup_val, 3)
assert_tensor_value(tup_val.fields[0], 2) assert_tensor_value(tup_val[0], 2)
assert_tensor_value(tup_val.fields[1], 3) assert_tensor_value(tup_val[1], 3)
assert_tensor_value(tup_val.fields[2], 4) assert_tensor_value(tup_val[2], 4)
def test_ref_execution_order(): def test_ref_execution_order():
...@@ -475,12 +476,12 @@ def test_ref_execution_order(): ...@@ -475,12 +476,12 @@ def test_ref_execution_order():
]))) ])))
tup_val = run_as_python(expr) tup_val = run_as_python(expr)
assert_tuple_value(tup_val, 5) assert_adt_len(tup_val, 5)
assert_tensor_value(tup_val.fields[0], 1) assert_tensor_value(tup_val[0], 1)
assert_tensor_value(tup_val.fields[1], 2) assert_tensor_value(tup_val[1], 2)
assert_tensor_value(tup_val.fields[2], 3) assert_tensor_value(tup_val[2], 3)
assert_tensor_value(tup_val.fields[3], 4) assert_tensor_value(tup_val[3], 4)
assert_tensor_value(tup_val.fields[4], 5) assert_tensor_value(tup_val[4], 5)
def test_op_add(): def test_op_add():
...@@ -501,6 +502,7 @@ def test_op_stack(): ...@@ -501,6 +502,7 @@ def test_op_stack():
args.append(relay.const(data)) args.append(relay.const(data))
call = relay.stack(relay.Tuple(args), axis) call = relay.stack(relay.Tuple(args), axis)
call_val = run_as_python(call) call_val = run_as_python(call)
type(call_val)
assert_tensor_value(call_val, ref_res) assert_tensor_value(call_val, ref_res)
verify_stack([(2,), (2,), (2,)], -1) verify_stack([(2,), (2,), (2,)], -1)
...@@ -517,9 +519,9 @@ def test_split(): ...@@ -517,9 +519,9 @@ def test_split():
ref_res = np.split(x, indices_or_sections, axis=axis) ref_res = np.split(x, indices_or_sections, axis=axis)
call = relay.split(relay.const(x), indices_or_sections, axis=axis) call = relay.split(relay.const(x), indices_or_sections, axis=axis)
call_val = run_as_python(call) call_val = run_as_python(call)
assert_tuple_value(call_val, len(ref_res)) assert_adt_len(call_val, len(ref_res))
for i in range(len(ref_res)): for i in range(len(ref_res)):
assert_tensor_value(call_val.fields[i], ref_res[i]) assert_tensor_value(call_val[i], ref_res[i])
verify_split((2, 3), 2) verify_split((2, 3), 2)
verify_split((5, 3), [3]) verify_split((5, 3), [3])
......
...@@ -58,7 +58,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): ...@@ -58,7 +58,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
def vmobj_to_list(o): def vmobj_to_list(o):
if isinstance(o, tvm.nd.NDArray): if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()] return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vm.ADT): elif isinstance(o, tvm.container.ADT):
result = [] result = []
for f in o: for f in o:
result.extend(vmobj_to_list(f)) result.extend(vmobj_to_list(f))
......
...@@ -17,18 +17,44 @@ ...@@ -17,18 +17,44 @@
import numpy as np import numpy as np
import tvm import tvm
from tvm.relay import vm from tvm import nd, relay
from tvm import container as _container
def test_adt():
arr = tvm.nd.array([1,2,3]) def test_adt_constructor():
y = vm.ADT(0, [arr, arr]) arr = nd.array([1, 2, 3])
fields = [arr, arr]
y = _container.ADT(0, [arr, arr])
assert len(y) == 2 assert len(y) == 2
assert isinstance(y, vm.ADT) assert isinstance(y, _container.ADT)
y[0:1][-1] == arr y[0:1][-1] == arr
assert y.tag == 0 assert y.tag == 0
assert isinstance(arr, tvm.nd.NDArray) assert isinstance(arr, nd.NDArray)
def test_tuple_object():
x = relay.var(
'x',
type_annotation=relay.ty.TupleType([
relay.ty.TensorType((), 'int32'),
relay.ty.TensorType((), 'int32')
]))
fn = relay.Function([x], relay.expr.TupleGetItem(x, 0))
mod = relay.Module.from_expr(fn)
exe = relay.create_executor(
kind="vm", mod=mod, ctx=nd.cpu(), target="llvm")
f = exe.evaluate()
value_tuple = _container.tuple_object(
[nd.array(np.array(11)),
nd.array(np.array(12))])
# pass an ADT object to evaluate
out = f(value_tuple)
tvm.testing.assert_allclose(out.asnumpy(), np.array(11))
if __name__ == "__main__": if __name__ == "__main__":
test_adt() test_adt_constructor()
test_tuple_object()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment