Commit 86092de0 by Zhi Committed by Jared Roesch

[REFACTOR] Replace TensorObj and TensorValue with NDArray (#4643)

* replace TensorObj and TensorValue with NDArray

* NodeBase to Object in Python

* rebase
parent dcf7fbf1
......@@ -20,17 +20,14 @@ Developer API
This page contains modules that are used by developers of TVM.
Many of these APIs are PackedFunc registered in C++ backend.
tvm.node
~~~~~~~~
.. automodule:: tvm.node
.. autoclass:: tvm.node.NodeBase
:members:
tvm.object
~~~~~~~~~~
.. automodule:: tvm.object
.. autoclass:: tvm.node.Node
.. autoclass:: tvm.object.Object
:members:
.. autofunction:: tvm.register_node
.. autofunction:: tvm.register_object
tvm.expr
~~~~~~~~
......
......@@ -55,18 +55,18 @@ We use a simple example that uses the low level TVM API directly. The example is
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``NodeBase``.
Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``.
::
@register_node
class Tensor(NodeBase, _expr.ExprOp):
@register_object
class Tensor(Object, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor"""
def __call__(self, *indices):
...
The Node system is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document <https://docs.tvm.ai/dev/runtime.html#tvm-node-and-compiler-stack>`_, and details are in ``python/tvm/_ffi/`` if you are interested.
The object protocol is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document <https://docs.tvm.ai/dev/runtime.html#tvm-node-and-compiler-stack>`_, and details are in ``python/tvm/_ffi/`` if you are interested.
``Tensor`` is created by functions in ``python/tvm/api.py``, which in turn calls into C++ functions exposed in ``src/api/api_lang.cc``. All C++ functions that are callable from Python are exposed in the ``src/api`` subdirectory. For example, the ``tvm.compute()`` function above calls into ``_ComputeOp`` API exposed in ``src/api/api_lang.cc``:
......
......@@ -37,16 +37,12 @@
#include <tvm/build_module.h>
#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>
namespace tvm {
namespace relay {
/*!
* \brief A Relay value.
*/
class Value;
/*!
*\brief Create a Interpreter function that can
* evaluate an expression and produce a value.
*
......@@ -65,39 +61,21 @@ class Value;
* \param target Compiler target flag to compile the functions on the context.
* \return A function that takes in an expression and returns a value.
*/
runtime::TypedPackedFunc<Value(Expr)>
runtime::TypedPackedFunc<ObjectRef(Expr)>
CreateInterpreter(Module mod, DLContext context, Target target);
/*! \brief The base container type of Relay values. */
class ValueNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Value";
TVM_DECLARE_BASE_OBJECT_INFO(ValueNode, RelayNode);
};
class Value : public ObjectRef {
public:
Value() {}
explicit Value(ObjectPtr<Object> n) : ObjectRef(n) {}
const ValueNode* operator->() const {
return static_cast<const ValueNode*>(get());
}
using ContainerType = ValueNode;
};
/*! \brief A Relay closure, i.e a scope and a function. */
class Closure;
/*! \brief The container type of Closures. */
class ClosureNode : public ValueNode {
class ClosureNode : public Object {
public:
/*! \brief The set of free variables in the closure.
*
* These are the captured variables which are required for
* evaluation when we call the closure.
*/
tvm::Map<Var, Value> env;
tvm::Map<Var, ObjectRef> env;
/*! \brief The function which implements the closure.
*
* \note May reference the variables contained in the env.
......@@ -111,22 +89,22 @@ class ClosureNode : public ValueNode {
v->Visit("func", &func);
}
TVM_DLL static Closure make(tvm::Map<Var, Value> env, Function func);
TVM_DLL static Closure make(tvm::Map<Var, ObjectRef> env, Function func);
static constexpr const char* _type_key = "relay.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, Object);
};
class Closure : public Value {
class Closure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Closure, Value, ClosureNode);
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureNode);
};
/*! \brief A Relay Recursive Closure. A closure that has a name. */
class RecClosure;
/*! \brief The container type of RecClosure. */
class RecClosureNode : public ValueNode {
class RecClosureNode : public Object {
public:
/*! \brief The closure. */
Closure clos;
......@@ -143,64 +121,41 @@ class RecClosureNode : public ValueNode {
TVM_DLL static RecClosure make(Closure clos, Var bind);
static constexpr const char* _type_key = "relay.RecClosure";
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, Object);
};
class RecClosure : public Value {
class RecClosure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, Value, RecClosureNode);
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureNode);
};
/*! \brief A tuple value. */
class TupleValue;
/*! \brief Tuple (x, ... y). */
struct TupleValueNode : ValueNode {
tvm::Array<Value> fields;
struct TupleValueNode : Object {
tvm::Array<ObjectRef> fields;
TupleValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }
TVM_DLL static TupleValue make(tvm::Array<Value> value);
TVM_DLL static TupleValue make(tvm::Array<ObjectRef> value);
static constexpr const char* _type_key = "relay.TupleValue";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, ValueNode);
};
class TupleValue : public Value {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, Value, TupleValueNode);
};
/*! \brief A tensor value. */
class TensorValue;
/*! \brief The tensor value container, wrapping an NDArray. */
struct TensorValueNode : ValueNode {
runtime::NDArray data;
TensorValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); }
/*! \brief Build a value from an NDArray. */
TVM_DLL static TensorValue make(runtime::NDArray data);
static constexpr const char* _type_key = "relay.TensorValue";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, Object);
};
class TensorValue : public Value {
class TupleValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TensorValue, Value, TensorValueNode);
TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, ObjectRef, TupleValueNode);
};
/*! \brief A reference value. */
class RefValue;
struct RefValueNode : ValueNode {
mutable Value value;
struct RefValueNode : Object {
mutable ObjectRef value;
RefValueNode() {}
......@@ -208,24 +163,24 @@ struct RefValueNode : ValueNode {
v->Visit("value", &value);
}
TVM_DLL static RefValue make(Value val);
TVM_DLL static RefValue make(ObjectRef val);
static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, Object);
};
class RefValue : public Value {
class RefValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, Value, RefValueNode);
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueNode);
};
/*! \brief An ADT constructor value. */
class ConstructorValue;
struct ConstructorValueNode : ValueNode {
struct ConstructorValueNode : Object {
int32_t tag;
tvm::Array<Value> fields;
tvm::Array<ObjectRef> fields;
/*! \brief Optional field tracking ADT constructor. */
Constructor constructor;
......@@ -237,16 +192,16 @@ struct ConstructorValueNode : ValueNode {
}
TVM_DLL static ConstructorValue make(int32_t tag,
tvm::Array<Value> fields,
tvm::Array<ObjectRef> fields,
Constructor construtor = {});
static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, Object);
};
class ConstructorValue : public Value {
class ConstructorValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, Value, ConstructorValueNode);
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueNode);
};
} // namespace relay
......
......@@ -36,25 +36,6 @@ namespace tvm {
namespace runtime {
namespace vm {
/*! \brief An object containing an NDArray. */
class TensorObj : public Object {
public:
/*! \brief The NDArray. */
NDArray data;
static constexpr const uint32_t _type_index = TypeIndex::kVMTensor;
static constexpr const char* _type_key = "vm.Tensor";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorObj, Object);
};
/*! \brief reference to tensor. */
class Tensor : public ObjectRef {
public:
explicit Tensor(NDArray data);
TVM_DEFINE_OBJECT_REF_METHODS(Tensor, ObjectRef, TensorObj);
};
/*! \brief An object representing a closure. */
class ClosureObj : public Object {
public:
......
......@@ -34,7 +34,7 @@ from . import codegen
from . import container
from . import schedule
from . import module
from . import node
from . import object
from . import attrs
from . import ir_builder
from . import target
......@@ -55,7 +55,7 @@ from ._ffi.base import TVMError, __version__
from .api import *
from .intrin import *
from .tensor_intrin import decl_tensor_intrin
from .node import register_node
from .object import register_object
from .ndarray import register_extension
from .schedule import create_schedule
from .build_module import build, lower, build_config
......
......@@ -25,14 +25,14 @@ from numbers import Number, Integral
from ..base import _LIB, get_last_ffi_error, py2cerror
from ..base import c_str, string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .object import ObjectBase, _set_class_node
from .object import ObjectBase, _set_class_object
from . import object as _object
FunctionHandle = ctypes.c_void_p
......@@ -144,8 +144,8 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
arg = convert_to_node(arg)
elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
arg = convert_to_object(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT_HANDLE
temp_args.append(arg)
......@@ -256,7 +256,6 @@ C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handl
_CLASS_MODULE = None
_CLASS_FUNCTION = None
_CLASS_OBJECT = None
def _set_class_module(module_class):
"""Initialize the module."""
......@@ -266,7 +265,3 @@ def _set_class_module(module_class):
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class
def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class
......@@ -30,11 +30,11 @@ __init_by_constructor__ = None
"""Maps object type to its constructor"""
OBJECT_TYPE = {}
_CLASS_NODE = None
_CLASS_OBJECT = None
def _set_class_node(node_class):
global _CLASS_NODE
_CLASS_NODE = node_class
def _set_class_object(object_class):
global _CLASS_OBJECT
_CLASS_OBJECT = object_class
def _register_object(index, cls):
......@@ -51,7 +51,7 @@ def _return_object(x):
handle = ObjectHandle(handle)
tindex = ctypes.c_uint()
check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
cls = OBJECT_TYPE.get(tindex.value, _CLASS_NODE)
cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT)
# Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
obj = cls.__new__(cls)
......
......@@ -20,7 +20,7 @@ import traceback
from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types, py2cerror
from ..node_generic import convert_to_node, NodeGeneric
from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
......@@ -149,8 +149,8 @@ cdef inline int make_arg(object arg,
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
arg = convert_to_node(arg)
elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
arg = convert_to_object(arg)
value[0].v_handle = (<ObjectBase>arg).chandle
tcode[0] = kObjectHandle
temp_args.append(arg)
......@@ -308,7 +308,6 @@ cdef class FunctionBase:
_CLASS_FUNCTION = None
_CLASS_MODULE = None
_CLASS_OBJECT = None
_CLASS_NODE = None
def _set_class_module(module_class):
"""Initialize the module."""
......@@ -322,7 +321,3 @@ def _set_class_function(func_class):
def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class
def _set_class_node(node_class):
global _CLASS_NODE
_CLASS_NODE = node_class
......@@ -32,7 +32,7 @@ def _register_object(int index, object cls):
cdef inline object make_ret_object(void* chandle):
global OBJECT_TYPE
global _CLASS_NODE
global _CLASS_OBJECT
cdef unsigned tindex
cdef object cls
cdef object handle
......@@ -44,11 +44,9 @@ cdef inline object make_ret_object(void* chandle):
if cls is not None:
obj = cls.__new__(cls)
else:
# default use node base class
# TODO(tqchen) change to object after Node unifies with Object
obj = _CLASS_NODE.__new__(_CLASS_NODE)
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
else:
obj = _CLASS_NODE.__new__(_CLASS_NODE)
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
(<ObjectBase>obj).chandle = chandle
return obj
......
......@@ -22,7 +22,7 @@ from __future__ import absolute_import
import sys
import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
from .node_generic import _set_class_objects
from .object_generic import _set_class_objects
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Node namespace"""
# pylint: disable=unused-import
from __future__ import absolute_import
import ctypes
import sys
from .. import _api_internal
from .object import Object, register_object, _set_class_node
from .node_generic import NodeGeneric, convert_to_node, const
def _new_object(cls):
"""Helper function for pickle"""
return cls.__new__(cls)
class NodeBase(Object):
"""NodeBase is the base class of all TVM language AST object."""
def __repr__(self):
return _api_internal._format_str(self)
def __dir__(self):
fnames = _api_internal._NodeListAttrNames(self)
size = fnames(-1)
return [fnames(i) for i in range(size)]
def __getattr__(self, name):
try:
return _api_internal._NodeGetAttr(self, name)
except AttributeError:
raise AttributeError(
"%s has no attribute %s" % (str(type(self)), name))
def __hash__(self):
return _api_internal._raw_ptr(self)
def __eq__(self, other):
return self.same_as(other)
def __ne__(self, other):
return not self.__eq__(other)
def __reduce__(self):
cls = type(self)
return (_new_object, (cls, ), self.__getstate__())
def __getstate__(self):
handle = self.handle
if handle is not None:
return {'handle': _api_internal._save_json(self)}
return {'handle': None}
def __setstate__(self, state):
# pylint: disable=assigning-non-slot
handle = state['handle']
if handle is not None:
json_str = handle
other = _api_internal._load_json(json_str)
self.handle = other.handle
other.handle = None
else:
self.handle = None
def same_as(self, other):
"""check object identity equality"""
if not isinstance(other, NodeBase):
return False
return self.__hash__() == other.__hash__()
# pylint: disable=invalid-name
register_node = register_object
_set_class_node(NodeBase)
......@@ -14,13 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=invalid-name, unused-import
"""Runtime Object API"""
from __future__ import absolute_import
import sys
import ctypes
from .. import _api_internal
from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
from .object_generic import ObjectGeneric, convert_to_object, const
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......@@ -29,23 +31,77 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_object, _set_class_node
from ._cy3.core import _set_class_object
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object
else:
from ._cy2.core import _set_class_object, _set_class_node
from ._cy2.core import _set_class_object
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position,unused-import
from ._ctypes.function import _set_class_object, _set_class_node
from ._ctypes.function import _set_class_object
from ._ctypes.object import ObjectBase as _ObjectBase
from ._ctypes.object import _register_object
def _new_object(cls):
"""Helper function for pickle"""
return cls.__new__(cls)
class Object(_ObjectBase):
"""Base class for all tvm's runtime objects."""
pass
def __repr__(self):
return _api_internal._format_str(self)
def __dir__(self):
fnames = _api_internal._NodeListAttrNames(self)
size = fnames(-1)
return [fnames(i) for i in range(size)]
def __getattr__(self, name):
try:
return _api_internal._NodeGetAttr(self, name)
except AttributeError:
raise AttributeError(
"%s has no attribute %s" % (str(type(self)), name))
def __hash__(self):
return _api_internal._raw_ptr(self)
def __eq__(self, other):
return self.same_as(other)
def __ne__(self, other):
return not self.__eq__(other)
def __reduce__(self):
cls = type(self)
return (_new_object, (cls, ), self.__getstate__())
def __getstate__(self):
handle = self.handle
if handle is not None:
return {'handle': _api_internal._save_json(self)}
return {'handle': None}
def __setstate__(self, state):
# pylint: disable=assigning-non-slot
handle = state['handle']
if handle is not None:
json_str = handle
other = _api_internal._load_json(json_str)
self.handle = other.handle
other.handle = None
else:
self.handle = None
def same_as(self, other):
"""check object identity equality"""
if not isinstance(other, Object):
return False
return self.__hash__() == other.__hash__()
def register_object(type_key=None):
......
......@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Common implementation of Node generic related logic"""
"""Common implementation of object generic related logic"""
# pylint: disable=unused-import
from __future__ import absolute_import
......@@ -22,7 +22,7 @@ from numbers import Number, Integral
from .. import _api_internal
from .base import string_types
# Node base class
# Object base class
_CLASS_OBJECTS = None
def _set_class_objects(cls):
......@@ -47,15 +47,15 @@ def _scalar_type_inference(value):
return dtype
class NodeGeneric(object):
"""Base class for all classes that can be converted to node."""
def asnode(self):
"""Convert value to node"""
class ObjectGeneric(object):
"""Base class for all classes that can be converted to object."""
def asobject(self):
"""Convert value to object"""
raise NotImplementedError()
def convert_to_node(value):
"""Convert a python value to corresponding node type.
def convert_to_object(value):
"""Convert a python value to corresponding object type.
Parameters
----------
......@@ -64,8 +64,8 @@ def convert_to_node(value):
Returns
-------
node : Node
The corresponding node value.
obj : Object
The corresponding object value.
"""
if isinstance(value, _CLASS_OBJECTS):
return value
......@@ -76,7 +76,7 @@ def convert_to_node(value):
if isinstance(value, string_types):
return _api_internal._str(value)
if isinstance(value, (list, tuple)):
value = [convert_to_node(x) for x in value]
value = [convert_to_object(x) for x in value]
return _api_internal._Array(*value)
if isinstance(value, dict):
vlist = []
......@@ -85,14 +85,14 @@ def convert_to_node(value):
not isinstance(item[0], string_types)):
raise ValueError("key of map must already been a container type")
vlist.append(item[0])
vlist.append(convert_to_node(item[1]))
vlist.append(convert_to_object(item[1]))
return _api_internal._Map(*vlist)
if isinstance(value, NodeGeneric):
return value.asnode()
if isinstance(value, ObjectGeneric):
return value.asobject()
if value is None:
return None
raise ValueError("don't know how to convert type %s to node" % type(value))
raise ValueError("don't know how to convert type %s to object" % type(value))
def const(value, dtype=None):
......
......@@ -22,9 +22,8 @@ from numbers import Integral as _Integral
from ._ffi.base import string_types
from ._ffi.object import register_object, Object
from ._ffi.node import register_node, NodeBase
from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.node_generic import _scalar_type_inference
from ._ffi.object import convert_to_object as _convert_to_object
from ._ffi.object_generic import _scalar_type_inference
from ._ffi.function import Function
from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
......@@ -111,7 +110,7 @@ def get_env_func(name):
Note
----
EnvFunc is a Node wrapper around
EnvFunc is a Object wrapper around
global function that can be serialized via its name.
This can be used to serialize function field in the language.
"""
......@@ -127,16 +126,16 @@ def convert(value):
Returns
-------
tvm_val : Node or Function
tvm_val : Object or Function
Converted value in TVM
"""
if isinstance(value, (Function, NodeBase)):
if isinstance(value, (Function, Object)):
return value
if callable(value):
return _convert_tvm_func(value)
return _convert_to_node(value)
return _convert_to_object(value)
def load_json(json_str):
......@@ -149,7 +148,7 @@ def load_json(json_str):
Returns
-------
node : Node
node : Object
The loaded tvm node.
"""
return _api_internal._load_json(json_str)
......@@ -160,8 +159,8 @@ def save_json(node):
Parameters
----------
node : Node
A TVM Node object to be saved.
node : Object
A TVM object to be saved.
Returns
-------
......
......@@ -17,11 +17,11 @@
"""Arithmetic data structure and utility"""
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node
from ._ffi.object import Object, register_object
from ._ffi.function import _init_api
from . import _api_internal
class IntSet(NodeBase):
class IntSet(Object):
"""Represent a set of integer in one dimension."""
def is_nothing(self):
"""Whether the set represent nothing"""
......@@ -32,7 +32,7 @@ class IntSet(NodeBase):
return _api_internal._IntSetIsEverything(self)
@register_node("arith.IntervalSet")
@register_object("arith.IntervalSet")
class IntervalSet(IntSet):
"""Represent set of continuous interval [min_value, max_value]
......@@ -49,16 +49,16 @@ class IntervalSet(IntSet):
_make_IntervalSet, min_value, max_value)
@register_node("arith.ModularSet")
class ModularSet(NodeBase):
@register_object("arith.ModularSet")
class ModularSet(Object):
"""Represent range of (coeff * x + base) for x in Z """
def __init__(self, coeff, base):
self.__init_handle_by_constructor__(
_make_ModularSet, coeff, base)
@register_node("arith.ConstIntBound")
class ConstIntBound(NodeBase):
@register_object("arith.ConstIntBound")
class ConstIntBound(Object):
"""Represent constant integer bound
Parameters
......@@ -245,7 +245,7 @@ class Analyzer:
var : tvm.Var
The variable.
info : tvm.NodeBase
info : tvm.Object
Related information.
override : bool
......
......@@ -15,13 +15,13 @@
# specific language governing permissions and limitations
# under the License.
""" TVM Attribute module, which is mainly used for defining attributes of operators"""
from ._ffi.node import NodeBase, register_node as _register_tvm_node
from ._ffi.object import Object, register_object
from ._ffi.function import _init_api
from . import _api_internal
@_register_tvm_node
class Attrs(NodeBase):
@register_object
class Attrs(Object):
"""Attribute node, which is mainly use for defining attributes of relay operators.
Used by function registered in python side, such as compute, schedule and alter_layout.
......
......@@ -23,7 +23,7 @@ from __future__ import absolute_import as _abs
import warnings
from ._ffi.function import Function
from ._ffi.node import NodeBase, register_node
from ._ffi.object import Object, register_object
from . import api
from . import _api_internal
from . import tensor
......@@ -115,22 +115,22 @@ class DumpIR(object):
DumpIR.scope_level -= 1
@register_node
class BuildConfig(NodeBase):
@register_object
class BuildConfig(Object):
"""Configuration scope to set a build config option.
Note
----
This object is backed by node system in C++, with arguments that can be
This object is backed by object protocol in C++, with arguments that can be
exchanged between python and C++.
Do not construct directly, use build_config instead.
The fields that are backed by the C++ node are immutable once an instance
is constructed. See _node_defaults for the fields.
The fields that are backed by the C++ object are immutable once an instance
is constructed. See _object_defaults for the fields.
"""
_node_defaults = {
_object_defaults = {
"auto_unroll_max_step": 0,
"auto_unroll_max_depth": 8,
"auto_unroll_max_extent": 0,
......@@ -191,7 +191,7 @@ class BuildConfig(NodeBase):
_api_internal._ExitBuildConfigScope(self)
def __setattr__(self, name, value):
if name in BuildConfig._node_defaults:
if name in BuildConfig._object_defaults:
raise AttributeError(
"'%s' object cannot set attribute '%s'" % (str(type(self)), name))
return super(BuildConfig, self).__setattr__(name, value)
......@@ -257,7 +257,7 @@ def build_config(**kwargs):
The build configuration
"""
node_args = {k: v if k not in kwargs else kwargs[k]
for k, v in BuildConfig._node_defaults.items()}
for k, v in BuildConfig._object_defaults.items()}
config = make.node("BuildConfig", **node_args)
if "add_lower_pass" in kwargs:
......
......@@ -16,11 +16,11 @@
# under the License.
"""Container data structures used in TVM DSL."""
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node
from ._ffi.object import Object, register_object
from . import _api_internal
@register_node
class Array(NodeBase):
@register_object
class Array(Object):
"""Array container of TVM.
You do not need to create Array explicitly.
......@@ -50,8 +50,8 @@ class Array(NodeBase):
return _api_internal._ArraySize(self)
@register_node
class EnvFunc(NodeBase):
@register_object
class EnvFunc(Object):
"""Environment function.
This is a global function object that can be serialized by its name.
......@@ -64,13 +64,13 @@ class EnvFunc(NodeBase):
return _api_internal._EnvFuncGetPackedFunc(self)
@register_node
class Map(NodeBase):
@register_object
class Map(Object):
"""Map container of TVM.
You do not need to create Map explicitly.
Normally python dict will be converted automaticall to Map during tvm function call.
You can use convert to create a dict[NodeBase-> NodeBase] into a Map
You can use convert to create a dict[Object-> Object] into a Map
"""
def __getitem__(self, k):
return _api_internal._MapGetItem(self, k)
......@@ -87,11 +87,11 @@ class Map(NodeBase):
return _api_internal._MapSize(self)
@register_node
@register_object
class StrMap(Map):
"""A special map container that has str as key.
You can use convert to create a dict[str->NodeBase] into a Map.
You can use convert to create a dict[str->Object] into a Map.
"""
def items(self):
"""Get the items from the map"""
......@@ -99,8 +99,8 @@ class StrMap(Map):
return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]
@register_node
class Range(NodeBase):
@register_object
class Range(Object):
"""Represent a range in TVM.
You do not need to create a Range explicitly.
......@@ -108,8 +108,8 @@ class Range(NodeBase):
"""
@register_node
class LoweredFunc(NodeBase):
@register_object
class LoweredFunc(Object):
"""Represent a LoweredFunc in TVM."""
MixedFunc = 0
HostFunc = 1
......
......@@ -32,7 +32,7 @@ For example, you can use addexp.a to get the left operand of an Add node.
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, NodeGeneric, register_node
from ._ffi.object import Object, register_object, ObjectGeneric
from ._ffi.runtime_ctypes import TVMType, TypeCode
from . import make as _make
from . import generic as _generic
......@@ -178,11 +178,11 @@ class ExprOp(object):
return _generic.cast(self, dtype)
class EqualOp(NodeGeneric, ExprOp):
class EqualOp(ObjectGeneric, ExprOp):
"""Deferred equal operator.
This is used to support sugar that a == b can either
mean NodeBase.same_as or NodeBase.equal.
mean Object.same_as or Object.equal.
Parameters
----------
......@@ -205,16 +205,16 @@ class EqualOp(NodeGeneric, ExprOp):
def __bool__(self):
return self.__nonzero__()
def asnode(self):
"""Convert node."""
def asobject(self):
"""Convert object."""
return _make._OpEQ(self.a, self.b)
class NotEqualOp(NodeGeneric, ExprOp):
class NotEqualOp(ObjectGeneric, ExprOp):
"""Deferred NE operator.
This is used to support sugar that a != b can either
mean not NodeBase.same_as or make.NE.
mean not Object.same_as or make.NE.
Parameters
----------
......@@ -237,16 +237,16 @@ class NotEqualOp(NodeGeneric, ExprOp):
def __bool__(self):
return self.__nonzero__()
def asnode(self):
"""Convert node."""
def asobject(self):
"""Convert object."""
return _make._OpNE(self.a, self.b)
class PrimExpr(ExprOp, NodeBase):
class PrimExpr(ExprOp, Object):
"""Base class of all tvm Expressions"""
# In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__
# https://docs.python.org/3.1/reference/datamodel.html#object.__hash__
__hash__ = NodeBase.__hash__
__hash__ = Object.__hash__
class ConstExpr(PrimExpr):
......@@ -261,7 +261,7 @@ class CmpExpr(PrimExpr):
class LogicalExpr(PrimExpr):
pass
@register_node("Variable")
@register_object("Variable")
class Var(PrimExpr):
"""Symbolic variable.
......@@ -278,7 +278,7 @@ class Var(PrimExpr):
_api_internal._Var, name, dtype)
@register_node
@register_object
class Reduce(PrimExpr):
"""Reduce node.
......@@ -305,7 +305,7 @@ class Reduce(PrimExpr):
condition, value_index)
@register_node
@register_object
class FloatImm(ConstExpr):
"""Float constant.
......@@ -321,7 +321,7 @@ class FloatImm(ConstExpr):
self.__init_handle_by_constructor__(
_make.FloatImm, dtype, value)
@register_node
@register_object
class IntImm(ConstExpr):
"""Int constant.
......@@ -341,7 +341,7 @@ class IntImm(ConstExpr):
return self.value
@register_node
@register_object
class UIntImm(ConstExpr):
"""UInt constant.
......@@ -358,7 +358,7 @@ class UIntImm(ConstExpr):
_make.UIntImm, dtype, value)
@register_node
@register_object
class StringImm(ConstExpr):
"""String constant.
......@@ -382,7 +382,7 @@ class StringImm(ConstExpr):
return self.value != other
@register_node
@register_object
class Cast(PrimExpr):
"""Cast expression.
......@@ -399,7 +399,7 @@ class Cast(PrimExpr):
_make.Cast, dtype, value)
@register_node
@register_object
class Add(BinaryOpExpr):
"""Add node.
......@@ -416,7 +416,7 @@ class Add(BinaryOpExpr):
_make.Add, a, b)
@register_node
@register_object
class Sub(BinaryOpExpr):
"""Sub node.
......@@ -433,7 +433,7 @@ class Sub(BinaryOpExpr):
_make.Sub, a, b)
@register_node
@register_object
class Mul(BinaryOpExpr):
"""Mul node.
......@@ -450,7 +450,7 @@ class Mul(BinaryOpExpr):
_make.Mul, a, b)
@register_node
@register_object
class Div(BinaryOpExpr):
"""Div node.
......@@ -467,7 +467,7 @@ class Div(BinaryOpExpr):
_make.Div, a, b)
@register_node
@register_object
class Mod(BinaryOpExpr):
"""Mod node.
......@@ -484,7 +484,7 @@ class Mod(BinaryOpExpr):
_make.Mod, a, b)
@register_node
@register_object
class FloorDiv(BinaryOpExpr):
"""FloorDiv node.
......@@ -501,7 +501,7 @@ class FloorDiv(BinaryOpExpr):
_make.FloorDiv, a, b)
@register_node
@register_object
class FloorMod(BinaryOpExpr):
"""FloorMod node.
......@@ -518,7 +518,7 @@ class FloorMod(BinaryOpExpr):
_make.FloorMod, a, b)
@register_node
@register_object
class Min(BinaryOpExpr):
"""Min node.
......@@ -535,7 +535,7 @@ class Min(BinaryOpExpr):
_make.Min, a, b)
@register_node
@register_object
class Max(BinaryOpExpr):
"""Max node.
......@@ -552,7 +552,7 @@ class Max(BinaryOpExpr):
_make.Max, a, b)
@register_node
@register_object
class EQ(CmpExpr):
"""EQ node.
......@@ -569,7 +569,7 @@ class EQ(CmpExpr):
_make.EQ, a, b)
@register_node
@register_object
class NE(CmpExpr):
"""NE node.
......@@ -586,7 +586,7 @@ class NE(CmpExpr):
_make.NE, a, b)
@register_node
@register_object
class LT(CmpExpr):
"""LT node.
......@@ -603,7 +603,7 @@ class LT(CmpExpr):
_make.LT, a, b)
@register_node
@register_object
class LE(CmpExpr):
"""LE node.
......@@ -620,7 +620,7 @@ class LE(CmpExpr):
_make.LE, a, b)
@register_node
@register_object
class GT(CmpExpr):
"""GT node.
......@@ -637,7 +637,7 @@ class GT(CmpExpr):
_make.GT, a, b)
@register_node
@register_object
class GE(CmpExpr):
"""GE node.
......@@ -654,7 +654,7 @@ class GE(CmpExpr):
_make.GE, a, b)
@register_node
@register_object
class And(LogicalExpr):
"""And node.
......@@ -671,7 +671,7 @@ class And(LogicalExpr):
_make.And, a, b)
@register_node
@register_object
class Or(LogicalExpr):
"""Or node.
......@@ -688,7 +688,7 @@ class Or(LogicalExpr):
_make.Or, a, b)
@register_node
@register_object
class Not(LogicalExpr):
"""Not node.
......@@ -702,7 +702,7 @@ class Not(LogicalExpr):
_make.Not, a)
@register_node
@register_object
class Select(PrimExpr):
"""Select node.
......@@ -730,7 +730,7 @@ class Select(PrimExpr):
_make.Select, condition, true_value, false_value)
@register_node
@register_object
class Load(PrimExpr):
"""Load node.
......@@ -753,7 +753,7 @@ class Load(PrimExpr):
_make.Load, dtype, buffer_var, index, predicate)
@register_node
@register_object
class Ramp(PrimExpr):
"""Ramp node.
......@@ -773,7 +773,7 @@ class Ramp(PrimExpr):
_make.Ramp, base, stride, lanes)
@register_node
@register_object
class Broadcast(PrimExpr):
"""Broadcast node.
......@@ -790,7 +790,7 @@ class Broadcast(PrimExpr):
_make.Broadcast, value, lanes)
@register_node
@register_object
class Shuffle(PrimExpr):
"""Shuffle node.
......@@ -807,7 +807,7 @@ class Shuffle(PrimExpr):
_make.Shuffle, vectors, indices)
@register_node
@register_object
class Call(PrimExpr):
"""Call node.
......@@ -842,7 +842,7 @@ class Call(PrimExpr):
_make.Call, dtype, name, args, call_type, func, value_index)
@register_node
@register_object
class Let(PrimExpr):
"""Let node.
......
......@@ -24,7 +24,7 @@ from . import make as _make
from . import ir_pass as _pass
from . import container as _container
from ._ffi.base import string_types
from ._ffi.node import NodeGeneric
from ._ffi.object import ObjectGeneric
from ._ffi.runtime_ctypes import TVMType
from .expr import Call as _Call
......@@ -41,7 +41,7 @@ class WithScope(object):
self._exit_cb()
class BufferVar(NodeGeneric):
class BufferVar(ObjectGeneric):
"""Buffer variable with content type, makes load store easily.
Do not create it directly, create use IRBuilder.
......@@ -70,7 +70,7 @@ class BufferVar(NodeGeneric):
self._buffer_var = buffer_var
self._content_type = content_type
def asnode(self):
def asobject(self):
return self._buffer_var
@property
......
......@@ -20,6 +20,4 @@ Normally user do not need to touch this api.
"""
# pylint: disable=unused-import
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node
Node = NodeBase
from ._ffi.object import Object, register_object
......@@ -16,7 +16,7 @@
# under the License.
from typing import Union, Tuple, Dict, List
from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId
from relay.ir import GlobalId, OperatorId, Item, Object, Span, FileId
from relay.ir import ShapeExtension, Operator, Defn
class Module(NodeBase): ...
class Module(Object): ...
......@@ -16,7 +16,7 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""Algebraic data types in Relay."""
from .base import RelayNode, register_relay_node, NodeBase
from .base import RelayNode, register_relay_node, Object
from . import _make
from .ty import Type
from .expr import Expr, Call
......@@ -184,7 +184,7 @@ class TypeData(Type):
@register_relay_node
class Clause(NodeBase):
class Clause(Object):
"""Clause for pattern matching in Relay."""
def __init__(self, lhs, rhs):
......
......@@ -17,19 +17,19 @@
"""Backend code generation engine."""
from __future__ import absolute_import
from ..base import register_relay_node, NodeBase
from ..base import register_relay_node, Object
from ... import target as _target
from .. import expr as _expr
from . import _backend
@register_relay_node
class CachedFunc(NodeBase):
class CachedFunc(Object):
"""Low-level tensor function to back a relay primitive function.
"""
@register_relay_node
class CCacheKey(NodeBase):
class CCacheKey(Object):
"""Key in the CompileEngine.
Parameters
......@@ -46,7 +46,7 @@ class CCacheKey(NodeBase):
@register_relay_node
class CCacheValue(NodeBase):
class CCacheValue(Object):
"""Value in the CompileEngine, including usage statistics.
"""
......@@ -64,7 +64,7 @@ def _get_cache_key(source_func, target):
@register_relay_node
class CompileEngine(NodeBase):
class CompileEngine(Object):
"""CompileEngine to get lowered code.
"""
def __init__(self):
......
......@@ -23,27 +23,13 @@ import numpy as np
from . import _backend
from .. import _make, analysis, transform
from .. import module
from ... import register_func, nd
from ..base import NodeBase, register_relay_node
from ... import nd
from ..base import Object, register_relay_node
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
from ..scope_builder import ScopeBuilder
from . import _vm
class Value(NodeBase):
"""Base class of all values.
"""
@staticmethod
@register_func("relay.from_scalar")
def from_scalar(value, dtype=None):
"""Convert a Python scalar to a Relay scalar."""
return TensorValue(const(value, dtype).data)
def to_vm(self):
return _vm._ValueToVM(self)
@register_relay_node
class TupleValue(Value):
class TupleValue(Object):
"""A tuple value produced by the interpreter."""
def __init__(self, *fields):
self.__init_handle_by_constructor__(
......@@ -68,60 +54,32 @@ class TupleValue(Value):
@register_relay_node
class Closure(Value):
class Closure(Object):
"""A closure produced by the interpreter."""
@register_relay_node
class RecClosure(Value):
class RecClosure(Object):
"""A recursive closure produced by the interpreter."""
@register_relay_node
class ConstructorValue(Value):
class ConstructorValue(Object):
def __init__(self, tag, fields, constructor):
self.__init_handle_by_constructor__(
_make.ConstructorValue, tag, fields, constructor)
@register_relay_node
class TensorValue(Value):
"""A Tensor value produced by the interpreter."""
def __init__(self, data):
"""Allocate a new TensorValue and copy the data from `array` into
the new array.
"""
if isinstance(data, np.ndarray):
data = nd.array(data)
self.__init_handle_by_constructor__(
_make.TensorValue, data)
def asnumpy(self):
"""Convert a Relay TensorValue into a numpy.ndarray."""
return self.data.asnumpy()
def __eq__(self, other):
return self.data == other.data
def __repr__(self):
return repr(self.data)
def __str__(self):
return str(self.data)
@register_relay_node
class RefValue(Value):
class RefValue(Object):
def __init__(self, value):
self.__init_handle_by_constructor__(
_make.RefValue, value)
def _arg_to_ast(mod, arg):
if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(nd.cpu(0)))
if isinstance(arg, nd.NDArray):
return Constant(arg.copyto(nd.cpu(0)))
elif isinstance(arg, TupleValue):
return Tuple([_arg_to_ast(mod, field) for field in arg.fields])
elif isinstance(arg, tuple):
......@@ -231,7 +189,7 @@ class Executor(object):
Returns
-------
val : Union[function, Value]
val : Union[function, Object]
The evaluation result.
"""
if binds:
......
......@@ -31,16 +31,18 @@ from . import _vm
from . import vmobj as _obj
from .interpreter import Executor
Tensor = _obj.Tensor
ADT = _obj.ADT
def _convert(arg, cargs):
if isinstance(arg, _expr.Constant):
cargs.append(_obj.Tensor(arg.data))
cargs.append(arg.data)
elif isinstance(arg, _obj.Object):
cargs.append(arg)
elif isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.Tensor(arg))
elif isinstance(arg, np.ndarray):
nd_arr = tvm.nd.array(arg, ctx=tvm.cpu(0))
cargs.append(nd_arr)
elif isinstance(arg, tvm.nd.NDArray):
cargs.append(arg)
elif isinstance(arg, (tuple, list)):
field_args = []
for field in arg:
......@@ -48,7 +50,7 @@ def _convert(arg, cargs):
cargs.append(_obj.tuple_object(field_args))
elif isinstance(arg, (_base.numeric_types, bool)):
dtype = "int32" if isinstance(arg, (int, bool)) else "float32"
value = _obj.Tensor(np.array(arg, dtype=dtype))
value = tvm.nd.array(np.array(arg, dtype=dtype), ctx=tvm.cpu(0))
cargs.append(value)
else:
raise TypeError("Unsupported type: %s" % (type(arg)))
......
......@@ -16,51 +16,12 @@
# under the License.
"""TVM Runtime Object API."""
from __future__ import absolute_import as _abs
import numpy as _np
from tvm._ffi.object import Object, register_object, getitem_helper
from tvm import ndarray as _nd
from . import _vmobj
@register_object("vm.Tensor")
class Tensor(Object):
"""Tensor object.
Parameters
----------
arr : numpy.ndarray or tvm.nd.NDArray
The source array.
ctx : TVMContext, optional
The device context to create the array
"""
def __init__(self, arr, ctx=None):
if isinstance(arr, _np.ndarray):
ctx = ctx if ctx else _nd.cpu(0)
self.__init_handle_by_constructor__(
_vmobj.Tensor, _nd.array(arr, ctx=ctx))
elif isinstance(arr, _nd.NDArray):
self.__init_handle_by_constructor__(
_vmobj.Tensor, arr)
else:
raise RuntimeError("Unsupported type for tensor object.")
@property
def data(self):
return _vmobj.GetTensorData(self)
def asnumpy(self):
"""Convert data to numpy array
Returns
-------
np_arr : numpy.ndarray
The corresponding numpy array.
"""
return self.data.asnumpy()
@register_object("vm.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
......@@ -75,7 +36,8 @@ class ADT(Object):
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, Object)
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or "
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(
_vmobj.ADT, tag, *fields)
......@@ -105,5 +67,6 @@ def tuple_object(fields):
The created object.
"""
for f in fields:
assert isinstance(f, Object)
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm "
"NDArray type, but received : {0}".format(type(f))
return _vmobj.Tuple(*fields)
......@@ -17,12 +17,13 @@
# pylint: disable=no-else-return, unidiomatic-typecheck
"""The base node types for the Relay language."""
from __future__ import absolute_import as _abs
from .._ffi.node import NodeBase, register_node as _register_tvm_node
from .._ffi.object import register_object as _register_tvm_node
from .._ffi.object import Object
from . import _make
from . import _expr
from . import _base
NodeBase = NodeBase
Object = Object
def register_relay_node(type_key=None):
"""Register a Relay node type.
......@@ -52,7 +53,7 @@ def register_relay_attr_node(type_key=None):
return _register_tvm_node(type_key)
class RelayNode(NodeBase):
class RelayNode(Object):
"""Base class of all Relay nodes."""
def astext(self, show_meta_data=True, annotate=None):
"""Get the text format of the expression.
......@@ -102,7 +103,7 @@ class SourceName(RelayNode):
self.__init_handle_by_constructor__(_make.SourceName, name)
@register_relay_node
class Id(NodeBase):
class Id(Object):
"""Unique identifier(name) used in Var.
Guaranteed to be stable across all passes.
"""
......
......@@ -17,12 +17,12 @@
from typing import List
import tvm
from .base import Span, NodeBase
from .base import Span, Object
from .ty import Type, TypeParam
from ._analysis import _get_checked_type
class Expr(NodeBase):
class Expr(Object):
def checked_type(self):
...
......
......@@ -22,7 +22,7 @@ from ._calibrate import calibrate
from .. import expr as _expr
from .. import transform as _transform
from ... import make as _make
from ..base import NodeBase, register_relay_node
from ..base import Object, register_relay_node
class QAnnotateKind(object):
......@@ -53,7 +53,7 @@ def _forward_op(ref_call, args):
@register_relay_node("relay.quantize.QConfig")
class QConfig(NodeBase):
class QConfig(Object):
"""Configure the quantization behavior by setting config variables.
Note
......
......@@ -32,15 +32,16 @@ OUTPUT_VAR_NAME = '_py_out'
# import numpy
# import tvm
# from tvm import relay
# from tvm.relay.backend.interpreter import RefValue, TupleValue, TensorValue, ConstructorValue
# from tvm import nd
# from tvm.relay.backend.interpreter import RefValue, TupleValue, ConstructorValue
PROLOGUE = [
ast.Import([alias('numpy', None)]),
ast.Import([alias('tvm', None)]),
ast.ImportFrom('tvm', [alias('relay', None)], 0),
ast.ImportFrom('tvm', [alias('nd', None)], 0),
ast.ImportFrom('tvm.relay.backend.interpreter',
[alias('RefValue', None),
alias('TupleValue', None),
alias('TensorValue', None),
alias('ConstructorValue', None)],
0)
]
......@@ -245,7 +246,7 @@ class PythonConverter(ExprFunctor):
a tensor or tuple (returns list of inputs to the lowered op call)"""
# equivalent: input.data
if isinstance(arg_type, relay.TensorType):
return [ast.Attribute(py_input, 'data', Load())]
return [py_input]
assert isinstance(arg_type, relay.TupleType)
# convert each input.fields[i]
ret = []
......@@ -265,15 +266,13 @@ class PythonConverter(ExprFunctor):
output_var_name = self.generate_var_name('_out')
output_var = Name(output_var_name, Load())
shape = ast.Tuple([Num(dim) for dim in ret_type.concrete_shape], Load())
# create a new TensorValue of the right shape and dtype
# create a new NDArray of the right shape and dtype
assign_output = Assign(
[Name(output_var_name, Store())],
self.create_call('TensorValue', [
self.create_call('nd.array', [
self.create_call('numpy.empty', [shape, Str(ret_type.dtype)])
]))
# we pass the data field as an argument
extra_arg = ast.Attribute(output_var, 'data', Load())
return ([assign_output], [extra_arg], output_var)
return ([assign_output], [output_var], output_var)
assert isinstance(ret_type, relay.TupleType)
assignments = []
extra_args = []
......@@ -459,7 +458,7 @@ class PythonConverter(ExprFunctor):
true_body, true_defs = self.visit(if_block.true_branch)
false_body, false_defs = self.visit(if_block.false_branch)
# need to get the value out of a TensorValue to check the condition
# need to get the value out of a NDArray to check the condition
# equvialent to: val.asnumpy()
cond_check = ast.Call(ast.Attribute(cond_body, 'asnumpy', Load()), [], [])
ret = ast.IfExp(cond_check, true_body, false_body)
......@@ -474,7 +473,7 @@ class PythonConverter(ExprFunctor):
const_expr = ast.Call(ast.Attribute(Name('numpy', Load()), 'array', Load()),
[self.parse_numpy_array(value)],
[ast.keyword('dtype', Str(constant.checked_type.dtype))])
return (self.create_call('TensorValue', [const_expr]), [])
return (self.create_call('nd.array', [const_expr]), [])
def visit_function(self, func: Expr):
......
......@@ -16,14 +16,14 @@
# under the License.
import tvm
from .base import NodeBase
from .base import Object
class PassContext(NodeBase):
class PassContext(Object):
def __init__(self):
...
class PassInfo(NodeBase):
class PassInfo(Object):
name = ... # type: str
opt_level = ... # type: int
required = ... # type: list
......@@ -32,7 +32,7 @@ class PassInfo(NodeBase):
# type: (str, int, list) -> None
class Pass(NodeBase):
class Pass(Object):
def __init__(self):
...
......
......@@ -18,11 +18,11 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language."""
from enum import IntEnum
from .base import NodeBase, register_relay_node
from .base import Object, register_relay_node
from . import _make
class Type(NodeBase):
class Type(Object):
"""The base type for all Relay types."""
def __eq__(self, other):
......
......@@ -17,8 +17,8 @@
"""The computation schedule api of TVM."""
from __future__ import absolute_import as _abs
from ._ffi.base import string_types
from ._ffi.node import NodeBase, register_node
from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.object import Object, register_object
from ._ffi.object import convert_to_object as _convert_to_object
from ._ffi.function import _init_api, Function
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
from . import _api_internal
......@@ -27,7 +27,7 @@ from . import expr as _expr
from . import container as _container
def convert(value):
"""Convert value to TVM node or function.
"""Convert value to TVM object or function.
Parameters
----------
......@@ -35,19 +35,19 @@ def convert(value):
Returns
-------
tvm_val : Node or Function
tvm_val : Object or Function
Converted value in TVM
"""
if isinstance(value, (Function, NodeBase)):
if isinstance(value, (Function, Object)):
return value
if callable(value):
return _convert_tvm_func(value)
return _convert_to_node(value)
return _convert_to_object(value)
@register_node
class Buffer(NodeBase):
@register_object
class Buffer(Object):
"""Symbolic data buffer in TVM.
Buffer provide a way to represent data layout
......@@ -156,23 +156,23 @@ class Buffer(NodeBase):
return _api_internal._BufferVStore(self, begin, value)
@register_node
class Split(NodeBase):
@register_object
class Split(Object):
"""Split operation on axis."""
@register_node
class Fuse(NodeBase):
@register_object
class Fuse(Object):
"""Fuse operation on axis."""
@register_node
class Singleton(NodeBase):
@register_object
class Singleton(Object):
"""Singleton axis."""
@register_node
class IterVar(NodeBase, _expr.ExprOp):
@register_object
class IterVar(Object, _expr.ExprOp):
"""Represent iteration variable.
IterVar is normally created by Operation, to represent
......@@ -214,8 +214,8 @@ def create_schedule(ops):
return _api_internal._CreateSchedule(ops)
@register_node
class Schedule(NodeBase):
@register_object
class Schedule(Object):
"""Schedule for all the stages."""
def __getitem__(self, k):
if isinstance(k, _tensor.Tensor):
......@@ -348,8 +348,8 @@ class Schedule(NodeBase):
return factored[0] if len(factored) == 1 else factored
@register_node
class Stage(NodeBase):
@register_object
class Stage(Object):
"""A Stage represents schedule for one operation."""
def split(self, parent, factor=None, nparts=None):
"""Split the stage either by factor providing outer scope, or both
......
......@@ -30,14 +30,14 @@ Each statement node have subfields that can be visited from python side.
assert(st.buffer_var == a)
"""
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node
from ._ffi.object import Object, register_object
from . import make as _make
class Stmt(NodeBase):
class Stmt(Object):
pass
@register_node
@register_object
class LetStmt(Stmt):
"""LetStmt node.
......@@ -57,7 +57,7 @@ class LetStmt(Stmt):
_make.LetStmt, var, value, body)
@register_node
@register_object
class AssertStmt(Stmt):
"""AssertStmt node.
......@@ -77,7 +77,7 @@ class AssertStmt(Stmt):
_make.AssertStmt, condition, message, body)
@register_node
@register_object
class ProducerConsumer(Stmt):
"""ProducerConsumer node.
......@@ -97,7 +97,7 @@ class ProducerConsumer(Stmt):
_make.ProducerConsumer, func, is_producer, body)
@register_node
@register_object
class For(Stmt):
"""For node.
......@@ -137,7 +137,7 @@ class For(Stmt):
for_type, device_api, body)
@register_node
@register_object
class Store(Stmt):
"""Store node.
......@@ -160,7 +160,7 @@ class Store(Stmt):
_make.Store, buffer_var, value, index, predicate)
@register_node
@register_object
class Provide(Stmt):
"""Provide node.
......@@ -183,7 +183,7 @@ class Provide(Stmt):
_make.Provide, func, value_index, value, args)
@register_node
@register_object
class Allocate(Stmt):
"""Allocate node.
......@@ -215,7 +215,7 @@ class Allocate(Stmt):
extents, condition, body)
@register_node
@register_object
class AttrStmt(Stmt):
"""AttrStmt node.
......@@ -238,7 +238,7 @@ class AttrStmt(Stmt):
_make.AttrStmt, node, attr_key, value, body)
@register_node
@register_object
class Free(Stmt):
"""Free node.
......@@ -252,7 +252,7 @@ class Free(Stmt):
_make.Free, buffer_var)
@register_node
@register_object
class Realize(Stmt):
"""Realize node.
......@@ -288,7 +288,7 @@ class Realize(Stmt):
bounds, condition, body)
@register_node
@register_object
class SeqStmt(Stmt):
"""Sequence of statements.
......@@ -308,7 +308,7 @@ class SeqStmt(Stmt):
return len(self.seq)
@register_node
@register_object
class IfThenElse(Stmt):
"""IfThenElse node.
......@@ -328,7 +328,7 @@ class IfThenElse(Stmt):
_make.IfThenElse, condition, then_case, else_case)
@register_node
@register_object
class Evaluate(Stmt):
"""Evaluate node.
......@@ -342,7 +342,7 @@ class Evaluate(Stmt):
_make.Evaluate, value)
@register_node
@register_object
class Prefetch(Stmt):
"""Prefetch node.
......
......@@ -59,7 +59,7 @@ from __future__ import absolute_import
import warnings
from ._ffi.base import _LIB_NAME
from ._ffi.node import NodeBase, register_node
from ._ffi.object import Object, register_object
from . import _api_internal
try:
......@@ -80,8 +80,8 @@ def _merge_opts(opts, new_opts):
return opts
@register_node
class Target(NodeBase):
@register_object
class Target(Object):
"""Target device information, use through TVM API.
Note
......@@ -97,7 +97,7 @@ class Target(NodeBase):
"""
def __new__(cls):
# Always override new to enable class
obj = NodeBase.__new__(cls)
obj = Object.__new__(cls)
obj._keys = None
obj._options = None
obj._libs = None
......@@ -146,8 +146,8 @@ class Target(NodeBase):
_api_internal._ExitTargetScope(self)
@register_node
class GenericFunc(NodeBase):
@register_object
class GenericFunc(Object):
"""GenericFunc node reference. This represents a generic function
that may be specialized for different targets. When this object is
called, a specialization is chosen based on the current target.
......
......@@ -17,13 +17,14 @@
"""Tensor and Operation class for computation declaration."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, NodeGeneric, register_node, convert_to_node
from ._ffi.object import Object, register_object, ObjectGeneric, \
convert_to_object
from . import _api_internal
from . import make as _make
from . import expr as _expr
class TensorSlice(NodeGeneric, _expr.ExprOp):
class TensorSlice(ObjectGeneric, _expr.ExprOp):
"""Auxiliary data structure for enable slicing syntax from tensor."""
def __init__(self, tensor, indices):
......@@ -37,8 +38,8 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
indices = (indices,)
return TensorSlice(self.tensor, self.indices + indices)
def asnode(self):
"""Convert slice to node."""
def asobject(self):
"""Convert slice to object."""
return self.tensor(*self.indices)
@property
......@@ -46,23 +47,23 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
"""Data content of the tensor."""
return self.tensor.dtype
@register_node
class TensorIntrinCall(NodeBase):
@register_object
class TensorIntrinCall(Object):
"""Intermediate structure for calling a tensor intrinsic."""
itervar_cls = None
@register_node
class Tensor(NodeBase, _expr.ExprOp):
@register_object
class Tensor(Object, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor"""
def __call__(self, *indices):
ndim = self.ndim
if len(indices) != ndim:
raise ValueError("Need to provide %d index in tensor slice" % ndim)
indices = convert_to_node(indices)
indices = convert_to_object(indices)
args = []
for x in indices:
if isinstance(x, _expr.PrimExpr):
......@@ -127,7 +128,7 @@ class Tensor(NodeBase, _expr.ExprOp):
class Operation(NodeBase):
class Operation(Object):
"""Represent an operation that generates a tensor"""
def output(self, index):
......@@ -156,12 +157,12 @@ class Operation(NodeBase):
return _api_internal._OpInputTensors(self)
@register_node
@register_object
class PlaceholderOp(Operation):
"""Placeholder operation."""
@register_node
@register_object
class BaseComputeOp(Operation):
"""Compute operation."""
@property
......@@ -175,18 +176,18 @@ class BaseComputeOp(Operation):
return self.__getattr__("reduce_axis")
@register_node
@register_object
class ComputeOp(BaseComputeOp):
"""Scalar operation."""
pass
@register_node
@register_object
class TensorComputeOp(BaseComputeOp):
"""Tensor operation."""
@register_node
@register_object
class ScanOp(Operation):
"""Scan operation."""
@property
......@@ -195,12 +196,12 @@ class ScanOp(Operation):
return self.__getattr__("scan_axis")
@register_node
@register_object
class ExternOp(Operation):
"""External operation."""
@register_node
@register_object
class HybridOp(Operation):
"""Hybrid operation."""
@property
......@@ -209,8 +210,8 @@ class HybridOp(Operation):
return self.__getattr__("axis")
@register_node
class Layout(NodeBase):
@register_object
class Layout(Object):
"""Layout is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
the corresponding lower case with factor size indicates the subordinate axis.
......@@ -269,8 +270,8 @@ class Layout(NodeBase):
return _api_internal._LayoutFactorOf(self, axis)
@register_node
class BijectiveLayout(NodeBase):
@register_object
class BijectiveLayout(Object):
"""Bijective mapping for two layouts (src-layout and dst-layout).
It provides shape and index conversion between each other.
......
......@@ -24,7 +24,7 @@ from . import make as _make
from . import tensor as _tensor
from . import schedule as _schedule
from .build_module import current_build_config
from ._ffi.node import NodeBase, register_node
from ._ffi.object import Object, register_object
def _get_region(tslice):
......@@ -41,8 +41,8 @@ def _get_region(tslice):
region.append(_make.range_by_min_extent(begin, 1))
return region
@register_node
class TensorIntrin(NodeBase):
@register_object
class TensorIntrin(Object):
"""Tensor intrinsic functions for certain computation.
See Also
......
......@@ -43,8 +43,8 @@ inline const PackedFunc& GetPackedFunc(const std::string& name) {
return *pf;
}
/* Value Implementation */
Closure ClosureNode::make(tvm::Map<Var, Value> env, Function func) {
/* Object Implementation */
Closure ClosureNode::make(tvm::Map<Var, ObjectRef> env, Function func) {
ObjectPtr<ClosureNode> n = make_object<ClosureNode>();
n->env = std::move(env);
n->func = std::move(func);
......@@ -62,7 +62,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
// TODO(@jroesch): this doesn't support mutual letrec
/* Value Implementation */
/* Object Implementation */
RecClosure RecClosureNode::make(Closure clos, Var bind) {
ObjectPtr<RecClosureNode> n = make_object<RecClosureNode>();
n->clos = std::move(clos);
......@@ -79,7 +79,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "RecClosureNode(" << node->clos << ")";
});
TupleValue TupleValueNode::make(tvm::Array<Value> value) {
TupleValue TupleValueNode::make(tvm::Array<ObjectRef> value) {
ObjectPtr<TupleValueNode> n = make_object<TupleValueNode>();
n->fields = value;
return TupleValue(n);
......@@ -94,24 +94,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "TupleValueNode(" << node->fields << ")";
});
TensorValue TensorValueNode::make(runtime::NDArray data) {
ObjectPtr<TensorValueNode> n = make_object<TensorValueNode>();
n->data = std::move(data);
return TensorValue(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorValueNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TensorValueNode*>(ref.get());
auto to_str = GetPackedFunc("relay._tensor_value_repr");
std::string data_str = to_str(GetRef<TensorValue>(node));
p->stream << "TensorValueNode(" << data_str << ")";
});
TVM_REGISTER_GLOBAL("relay._make.TensorValue")
.set_body_typed(TensorValueNode::make);
RefValue RefValueNode::make(Value value) {
RefValue RefValueNode::make(ObjectRef value) {
ObjectPtr<RefValueNode> n = make_object<RefValueNode>();
n->value = value;
return RefValue(n);
......@@ -129,7 +113,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
});
ConstructorValue ConstructorValueNode::make(int32_t tag,
tvm::Array<Value> fields,
tvm::Array<ObjectRef> fields,
Constructor constructor) {
ObjectPtr<ConstructorValueNode> n = make_object<ConstructorValueNode>();
n->tag = tag;
......@@ -153,13 +137,13 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
/*!
* \brief A stack frame in the Relay interpreter.
*
* Contains a mapping from relay::Var to relay::Value.
* Contains a mapping from relay::Var to relay::ObjectRef.
*/
struct Frame {
/*! \brief The set of local variables and arguments for the frame. */
tvm::Map<Var, Value> locals;
tvm::Map<Var, ObjectRef> locals;
explicit Frame(tvm::Map<Var, Value> locals) : locals(locals) {}
explicit Frame(tvm::Map<Var, ObjectRef> locals) : locals(locals) {}
};
/*!
......@@ -175,7 +159,7 @@ struct Stack {
Frame& current_frame() { return frames.back(); }
Value Lookup(const Var& local) {
ObjectRef Lookup(const Var& local) {
for (auto frame = frames.rbegin(); frame != frames.rend(); frame++) {
auto elem = frame->locals.find(local);
if (elem != frame->locals.end()) {
......@@ -185,7 +169,7 @@ struct Stack {
LOG(FATAL) << "could not find variable binding for " << local
<< "address= " << local.operator->();
return Value();
return ObjectRef();
}
/*!
* A wrapper around Frame to add RAII semantics to pushing and popping
......@@ -206,7 +190,7 @@ class InterpreterState;
/*! \brief A container capturing the state of the interpreter. */
class InterpreterStateNode : public Object {
public:
using Frame = tvm::Map<Var, Value>;
using Frame = tvm::Map<Var, ObjectRef>;
using Stack = tvm::Array<Frame>;
/*! \brief The current expression under evaluation. */
......@@ -246,8 +230,8 @@ InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) {
//
// Conversion to ANF is recommended before running the interpretation.
class Interpreter :
public ExprFunctor<Value(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const Value& v)> {
public ExprFunctor<ObjectRef(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
public:
Interpreter(Module mod, DLContext context, Target target)
: mod_(mod),
......@@ -264,56 +248,56 @@ class Interpreter :
return f();
}
void extend(const Var& id, Value v) {
void extend(const Var& id, ObjectRef v) {
stack_.current_frame().locals.Set(id, v);
}
Value Lookup(const Var& local) {
ObjectRef Lookup(const Var& local) {
return stack_.Lookup(local);
}
Value Eval(const Expr& expr) {
ObjectRef Eval(const Expr& expr) {
return VisitExpr(expr);
}
Value VisitExpr(const Expr& expr) final {
auto ret = ExprFunctor<Value(const Expr& n)>::VisitExpr(expr);
ObjectRef VisitExpr(const Expr& expr) final {
auto ret = ExprFunctor<ObjectRef(const Expr& n)>::VisitExpr(expr);
return ret;
}
Value VisitExpr_(const VarNode* var_node) final {
ObjectRef VisitExpr_(const VarNode* var_node) final {
return Lookup(GetRef<Var>(var_node));
}
Value VisitExpr_(const GlobalVarNode* op) final {
ObjectRef VisitExpr_(const GlobalVarNode* op) final {
return Eval(mod_->Lookup(GetRef<GlobalVar>(op)));
}
Value VisitExpr_(const OpNode* id) override {
ObjectRef VisitExpr_(const OpNode* id) override {
// TODO(@jroesch): Eta-expand and return in this case.
LOG(FATAL) << "internal error, need to wrap intrinsic into call synthetic call node "
<< "in "
<< "this case, eta expand";
return Value();
return ObjectRef();
}
Value VisitExpr_(const ConstantNode* op) final {
return TensorValueNode::make(op->data.CopyTo(context_));
ObjectRef VisitExpr_(const ConstantNode* op) final {
return op->data.CopyTo(context_);
}
Value VisitExpr_(const TupleNode* op) final {
std::vector<Value> values;
ObjectRef VisitExpr_(const TupleNode* op) final {
std::vector<ObjectRef> values;
for (const auto& field : op->fields) {
Value field_value = Eval(field);
ObjectRef field_value = Eval(field);
values.push_back(field_value);
}
return TupleValueNode::make(values);
}
Value MakeClosure(const Function& func, Var letrec_name = Var()) {
tvm::Map<Var, Value> captured_mod;
ObjectRef MakeClosure(const Function& func, Var letrec_name = Var()) {
tvm::Map<Var, ObjectRef> captured_mod;
Array<Var> free_vars = FreeVars(func);
for (const auto& var : free_vars) {
......@@ -334,13 +318,13 @@ class Interpreter :
return std::move(closure);
}
Value VisitExpr_(const FunctionNode* func_node) final {
ObjectRef VisitExpr_(const FunctionNode* func_node) final {
auto func = GetRef<Function>(func_node);
return MakeClosure(func);
}
Array<Shape> ComputeDynamicShape(const Function& func,
const Array<Value>& args) {
const Array<ObjectRef>& args) {
auto key = CCacheKeyNode::make(func, Target::Create("llvm"));
auto cfunc = engine_->LowerShapeFunc(key);
size_t arity = cfunc->inputs.size() + cfunc->outputs.size();
......@@ -355,11 +339,10 @@ class Interpreter :
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
auto fset_input = [&](size_t i, Value val, bool need_shape) {
const TensorValueNode* tv = val.as<TensorValueNode>();
CHECK(tv != nullptr) << "expect Tensor argument";
auto fset_input = [&](size_t i, ObjectRef val, bool need_shape) {
auto nd_array = Downcast<NDArray>(val);
if (need_shape) {
int64_t ndim = tv->data.Shape().size();
int64_t ndim = nd_array.Shape().size();
NDArray shape_arr;
if (ndim == 0) {
shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_ctx);
......@@ -367,13 +350,13 @@ class Interpreter :
shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx);
int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
for (auto j = 0; j < ndim; ++j) {
data[j] = tv->data.Shape()[j];
data[j] = nd_array.Shape()[j];
}
}
inputs[i] = shape_arr;
setter(i, shape_arr);
} else {
auto arr = tv->data.CopyTo(cpu_ctx);
auto arr = nd_array.CopyTo(cpu_ctx);
inputs[i] = arr;
setter(i, arr);
}
......@@ -384,7 +367,7 @@ class Interpreter :
auto arg = args[i];
auto param = func->params[i];
int state = cfunc->shape_func_param_states[i]->value;
if (arg.as<TensorValueNode>()) {
if (arg->IsInstance<runtime::NDArray::ContainerType>()) {
if (state & kNeedInputData) {
fset_input(arg_counter++, arg, false);
}
......@@ -457,8 +440,8 @@ class Interpreter :
return out_shapes;
}
Value InvokePrimitiveOp(const Function& func,
const Array<Value>& args) {
ObjectRef InvokePrimitiveOp(const Function& func,
const Array<ObjectRef>& args) {
const auto* call_node = func->body.as<CallNode>();
if (call_node && call_node->op == debug_op_) {
......@@ -478,7 +461,7 @@ class Interpreter :
// Handle tuple input/output by flattening them.
size_t arg_len = 0;
for (size_t i = 0; i < args.size(); ++i) {
if (args[i].as<TensorValueNode>()) {
if (args[i]->IsInstance<NDArray::ContainerType>()) {
++arg_len;
} else {
const auto* tvalue = args[i].as<TupleValueNode>();
......@@ -497,11 +480,10 @@ class Interpreter :
std::vector<int> codes(arg_len);
TVMArgsSetter setter(values.data(), codes.data());
auto fset_input = [&](size_t i, Value val) {
const TensorValueNode* tv = val.as<TensorValueNode>();
CHECK(tv != nullptr) << "expect Tensor argument";
setter(i, tv->data);
DLContext arg_ctx = tv->data->ctx;
auto fset_input = [&](size_t i, ObjectRef val) {
const auto nd_array = Downcast<NDArray>(val);
setter(i, nd_array);
DLContext arg_ctx = nd_array->ctx;
CHECK(arg_ctx.device_type == context_.device_type &&
arg_ctx.device_id == context_.device_id)
<< "Interpreter expect context to be "
......@@ -509,8 +491,8 @@ class Interpreter :
};
int arg_counter = 0;
for (Value arg : args) {
if (arg.as<TensorValueNode>()) {
for (ObjectRef arg : args) {
if (arg->IsInstance<NDArray::ContainerType>()) {
fset_input(arg_counter++, arg);
} else {
const TupleValueNode* tuple = arg.as<TupleValueNode>();
......@@ -536,10 +518,9 @@ class Interpreter :
shape.push_back(ivalue[0]);
}
DLDataType dtype = rtype->dtype;
auto out_tensor = TensorValueNode::make(
NDArray::Empty(shape, dtype, context_));
setter(num_inputs + i, out_tensor->data);
return out_tensor;
NDArray nd_array = NDArray::Empty(shape, dtype, context_);
setter(num_inputs + i, nd_array);
return nd_array;
};
Array<Shape> out_shapes;
......@@ -560,7 +541,7 @@ class Interpreter :
TVMRetValue rv;
if (const TupleTypeNode* rtype = func->body->checked_type().as<TupleTypeNode>()) {
CHECK(!is_dyn || out_shapes.size() == rtype->fields.size());
Array<Value> fields;
Array<ObjectRef> fields;
for (size_t i = 0; i < rtype->fields.size(); ++i) {
if (is_dyn) {
auto sh = out_shapes[i];
......@@ -573,7 +554,7 @@ class Interpreter :
packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
return TupleValueNode::make(fields);
} else {
Value out_tensor;
ObjectRef out_tensor;
if (is_dyn) {
CHECK_EQ(out_shapes.size(), 1);
auto sh = out_shapes[0];
......@@ -588,14 +569,16 @@ class Interpreter :
}
// Invoke the closure
Value Invoke(const Closure& closure, const tvm::Array<Value>& args, const Var& bind = Var()) {
ObjectRef Invoke(const Closure& closure,
const tvm::Array<ObjectRef>& args,
const Var& bind = Var()) {
// Get a reference to the function inside the closure.
if (closure->func->IsPrimitive()) {
return InvokePrimitiveOp(closure->func, args);
}
auto func = closure->func;
// Allocate a frame with the parameters and free variables.
tvm::Map<Var, Value> locals;
tvm::Map<Var, ObjectRef> locals;
CHECK_EQ(func->params.size(), args.size());
......@@ -614,11 +597,11 @@ class Interpreter :
locals.Set(bind, RecClosureNode::make(closure, bind));
}
return WithFrame<Value>(Frame(locals), [&]() { return Eval(func->body); });
return WithFrame<ObjectRef>(Frame(locals), [&]() { return Eval(func->body); });
}
Value VisitExpr_(const CallNode* call) final {
tvm::Array<Value> args;
ObjectRef VisitExpr_(const CallNode* call) final {
tvm::Array<ObjectRef> args;
for (auto arg : call->args) {
args.push_back(Eval(arg));
}
......@@ -636,7 +619,7 @@ class Interpreter :
return ConstructorValueNode::make(con->tag, args, GetRef<Constructor>(con));
}
// Now we just evaluate and expect to find a closure.
Value fn_val = Eval(call->op);
ObjectRef fn_val = Eval(call->op);
if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) {
auto closure = GetRef<Closure>(closure_node);
return this->Invoke(closure, args);
......@@ -645,11 +628,11 @@ class Interpreter :
} else {
LOG(FATAL) << "internal error: type error, expected function value in the call "
<< "position";
return Value();
return ObjectRef();
}
}
Value VisitExpr_(const LetNode* let) final {
ObjectRef VisitExpr_(const LetNode* let) final {
if (auto func = let->value.as<FunctionNode>()) {
auto clo = MakeClosure(GetRef<Function>(func), let->var);
this->extend(let->var, clo);
......@@ -661,8 +644,8 @@ class Interpreter :
return Eval(let->body);
}
Value VisitExpr_(const TupleGetItemNode* op) final {
Value val = Eval(op->tuple);
ObjectRef VisitExpr_(const TupleGetItemNode* op) final {
ObjectRef val = Eval(op->tuple);
auto product_node = val.as<TupleValueNode>();
CHECK(product_node)
<< "interal error: when evaluating TupleGetItem expected a tuple value";
......@@ -671,13 +654,14 @@ class Interpreter :
return product_node->fields[op->index];
}
Value VisitExpr_(const IfNode* op) final {
Value v = Eval(op->cond);
if (const TensorValueNode* bv = v.as<TensorValueNode>()) {
ObjectRef VisitExpr_(const IfNode* op) final {
ObjectRef v = Eval(op->cond);
if (v->IsInstance<NDArray::ContainerType>()) {
auto nd_array = Downcast<NDArray>(v);
DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
NDArray cpu_array = bv->data.CopyTo(cpu_ctx);
NDArray cpu_array = nd_array.CopyTo(cpu_ctx);
CHECK_EQ(DataType(cpu_array->dtype), DataType::Bool());
// TODO(@jroesch, @MK): Refactor code into helper from DCE.
if (reinterpret_cast<uint8_t*>(cpu_array->data)[0]) {
......@@ -687,47 +671,47 @@ class Interpreter :
}
} else {
LOG(FATAL) << "type error, type system should have caught this";
return Value();
return ObjectRef();
}
}
Value VisitExpr_(const RefWriteNode* op) final {
Value r = Eval(op->ref);
ObjectRef VisitExpr_(const RefWriteNode* op) final {
ObjectRef r = Eval(op->ref);
if (const RefValueNode* rv = r.as<RefValueNode>()) {
rv->value = Eval(op->value);
return TupleValueNode::make({});
} else {
LOG(FATAL) << "type error, type system should have caught this";
return Value();
return ObjectRef();
}
}
Value VisitExpr_(const RefCreateNode* op) final {
ObjectRef VisitExpr_(const RefCreateNode* op) final {
return RefValueNode::make(Eval(op->value));
}
Value VisitExpr_(const RefReadNode* op) final {
Value r = Eval(op->ref);
ObjectRef VisitExpr_(const RefReadNode* op) final {
ObjectRef r = Eval(op->ref);
if (const RefValueNode* rv = r.as<RefValueNode>()) {
return rv->value;
} else {
LOG(FATAL) << "type error, type system should have caught this";
return Value();
return ObjectRef();
}
}
Value VisitExpr_(const MatchNode* op) final {
Value v = Eval(op->data);
ObjectRef VisitExpr_(const MatchNode* op) final {
ObjectRef v = Eval(op->data);
for (const Clause& c : op->clauses) {
if (VisitPattern(c->lhs, v)) {
return VisitExpr(c->rhs);
}
}
LOG(FATAL) << "did not find any match";
return Value();
return ObjectRef();
}
bool VisitPattern_(const PatternConstructorNode* op, const Value& v) final {
bool VisitPattern_(const PatternConstructorNode* op, const ObjectRef& v) final {
const ConstructorValueNode* cvn = v.as<ConstructorValueNode>();
CHECK(cvn) << "need to be a constructor for match";
CHECK_NE(op->constructor->tag, -1);
......@@ -744,7 +728,7 @@ class Interpreter :
return false;
}
bool VisitPattern_(const PatternTupleNode* op, const Value& v) final {
bool VisitPattern_(const PatternTupleNode* op, const ObjectRef& v) final {
const TupleValueNode* tvn = v.as<TupleValueNode>();
CHECK(tvn) << "need to be a tuple for match";
CHECK_EQ(op->patterns.size(), tvn->fields.size());
......@@ -756,11 +740,11 @@ class Interpreter :
return true;
}
bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final {
bool VisitPattern_(const PatternWildcardNode* op, const ObjectRef& v) final {
return true;
}
bool VisitPattern_(const PatternVarNode* op, const Value& v) final {
bool VisitPattern_(const PatternVarNode* op, const ObjectRef& v) final {
extend(op->var, v);
return true;
}
......@@ -783,7 +767,7 @@ class Interpreter :
DLContext context_;
// Target parameter being used by the interpreter.
Target target_;
// Value stack.
// Object stack.
Stack stack_;
// Backend compile engine.
CompileEngine engine_;
......@@ -793,7 +777,7 @@ class Interpreter :
};
TypedPackedFunc<Value(Expr)>
TypedPackedFunc<ObjectRef(Expr)>
CreateInterpreter(
Module mod,
DLContext context,
......@@ -814,7 +798,7 @@ CreateInterpreter(
CHECK(f.is_subset_of(FeatureSet::All() - fGraph));
return intrp->Eval(expr);
};
return TypedPackedFunc<Value(Expr)>(packed);
return TypedPackedFunc<ObjectRef(Expr)>(packed);
}
TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter")
......@@ -822,7 +806,6 @@ TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter")
TVM_REGISTER_NODE_TYPE(ClosureNode);
TVM_REGISTER_NODE_TYPE(TupleValueNode);
TVM_REGISTER_NODE_TYPE(TensorValueNode);
} // namespace relay
} // namespace tvm
......@@ -854,7 +854,7 @@ void VMCompiler::Lower(Module mod,
// populate constants
for (auto data : context_.constants) {
exec_->constants.push_back(vm::Tensor(data));
exec_->constants.push_back(data);
}
// update global function map
......
......@@ -27,12 +27,14 @@
#include <tvm/relay/interpreter.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/transform.h>
#include "./pattern_util.h"
#include <tvm/runtime/object.h>
#include <tvm/runtime/ndarray.h>
#include "pattern_util.h"
namespace tvm {
namespace relay {
using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>;
using FInterpreter = runtime::TypedPackedFunc<ObjectRef(Expr)>;
class ConstantChecker : private ExprVisitor {
public:
......@@ -177,17 +179,18 @@ class ConstantFolder : public ExprMutator {
const Op& cast_op_;
// Convert value to expression.
Expr ValueToExpr(Value value) {
if (const auto* val = value.as<TensorValueNode>()) {
for (auto dim : val->data.Shape()) {
Expr ObjectToExpr(const ObjectRef& value) {
if (value->IsInstance<runtime::NDArray::ContainerType>()) {
auto nd_array = Downcast<runtime::NDArray>(value);
for (auto dim : nd_array.Shape()) {
CHECK_GT(dim, 0)
<< "invalid dimension after constant eval";
}
return ConstantNode::make(val->data);
return ConstantNode::make(nd_array);
} else if (const auto* val = value.as<TupleValueNode>()) {
Array<Expr> fields;
for (Value field : val->fields) {
fields.push_back(ValueToExpr(field));
for (ObjectRef field : val->fields) {
fields.push_back(ObjectToExpr(field));
}
return TupleNode::make(fields);
} else {
......@@ -216,7 +219,7 @@ class ConstantFolder : public ExprMutator {
mod = seq(mod);
auto entry_func = mod->Lookup("main");
expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
return ValueToExpr(executor_(expr));
return ObjectToExpr(executor_(expr));
}
// Evaluate a call to the shape_of operator for tensors with constant
......@@ -258,7 +261,7 @@ class ConstantFolder : public ExprMutator {
}
}
Constant shape = Downcast<Constant>(ValueToExpr(TensorValueNode::make(value)));
Constant shape = Downcast<Constant>(ObjectToExpr(value));
if (shape->data.Shape().size() == 0 && GetScalarFromConstant<int32_t>(shape) == 0) {
auto ndarray = runtime::NDArray::Empty({}, cdtype, ctx);
......@@ -283,8 +286,7 @@ Expr FoldConstant(const Expr& expr, const Module& mod) {
// in case we are already in a build context.
With<BuildConfig> fresh_build_ctx(BuildConfig::Create());
return ConstantFolder(CreateInterpreter(
mod, ctx, target), mod).Mutate(expr);
return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr);
}
namespace transform {
......
......@@ -403,7 +403,7 @@ Fuel MkFTop() {
/*!
* \brief A stack frame in the Relay interpreter.
*
* Contains a mapping from relay::Var to relay::Value.
* Contains a mapping from relay::Var to relay::Object.
*/
struct Frame {
/*! \brief The set of local variables and arguments for the frame. */
......@@ -554,7 +554,7 @@ bool StatefulOp(const Expr& e) {
return sov.stateful;
}
using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>;
using FInterpreter = runtime::TypedPackedFunc<ObjectRef(Expr)>;
DLContext CPUContext() {
DLContext ctx;
......@@ -925,13 +925,14 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}
}
PStatic Reify(const Value& v, LetList* ll) const {
if (const TensorValueNode* op = v.as<TensorValueNode>()) {
return HasStatic(MkSTensor(op->data), ll->Push(ConstantNode::make(op->data)));
PStatic Reify(const ObjectRef& v, LetList* ll) const {
if (v->IsInstance<runtime::NDArray::ContainerType>()) {
auto nd_array = Downcast<runtime::NDArray>(v);
return HasStatic(MkSTensor(nd_array), ll->Push(ConstantNode::make(nd_array)));
} else if (const TupleValueNode* op = v.as<TupleValueNode>()) {
std::vector<PStatic> fields;
tvm::Array<Expr> fields_dyn;
for (const Value& field : op->fields) {
for (const ObjectRef& field : op->fields) {
PStatic ps = Reify(field, ll);
fields.push_back(ps);
fields_dyn.push_back(ps->dynamic);
......
......@@ -150,10 +150,8 @@ std::string Executable::Stats() const {
// Get the number of constants and the shape of each of them.
oss << " Constant shapes (# " << constants.size() << "): [";
for (const auto& it : constants) {
const auto* cell = it.as<TensorObj>();
CHECK(cell);
runtime::NDArray data = cell->data;
const auto& shape = data.Shape();
const auto constant = Downcast<NDArray>(it);
const auto& shape = constant.Shape();
// Scalar
if (shape.empty()) {
......@@ -250,10 +248,8 @@ void Executable::SaveGlobalSection(dmlc::Stream* strm) {
void Executable::SaveConstantSection(dmlc::Stream* strm) {
std::vector<DLTensor*> arrays;
for (const auto& obj : this->constants) {
const auto* cell = obj.as<runtime::vm::TensorObj>();
CHECK(cell != nullptr);
runtime::NDArray data = cell->data;
arrays.push_back(const_cast<DLTensor*>(data.operator->()));
const auto cell = Downcast<runtime::NDArray>(obj);
arrays.push_back(const_cast<DLTensor*>(cell.operator->()));
}
strm->Write(static_cast<uint64_t>(this->constants.size()));
for (const auto& it : arrays) {
......@@ -513,8 +509,7 @@ void Executable::LoadConstantSection(dmlc::Stream* strm) {
for (size_t i = 0; i < size; i++) {
runtime::NDArray constant;
STREAM_CHECK(constant.Load(strm), "constant");
runtime::ObjectRef obj = runtime::vm::Tensor(constant);
this->constants.push_back(obj);
this->constants.push_back(constant);
}
}
......
......@@ -34,12 +34,6 @@ namespace tvm {
namespace runtime {
namespace vm {
Tensor::Tensor(NDArray data) {
auto ptr = make_object<TensorObj>();
ptr->data = std::move(data);
data_ = std::move(ptr);
}
Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
auto ptr = make_object<ClosureObj>();
ptr->func_index = func_index;
......@@ -48,14 +42,6 @@ Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
}
TVM_REGISTER_GLOBAL("_vmobj.GetTensorData")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto* cell = obj.as<TensorObj>();
CHECK(cell != nullptr);
*rv = cell->data;
});
TVM_REGISTER_GLOBAL("_vmobj.GetADTTag")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
......@@ -80,11 +66,6 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTFields")
*rv = adt[idx];
});
TVM_REGISTER_GLOBAL("_vmobj.Tensor")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = Tensor(args[0].operator NDArray());
});
TVM_REGISTER_GLOBAL("_vmobj.Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<ObjectRef> fields;
......@@ -105,7 +86,6 @@ TVM_REGISTER_GLOBAL("_vmobj.ADT")
*rv = ADT(tag, fields);
});
TVM_REGISTER_OBJECT_TYPE(TensorObj);
TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj);
} // namespace vm
......
......@@ -613,18 +613,14 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) {
return os;
}
ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) {
if (const TensorObj* obj = src.as<TensorObj>()) {
auto tensor = obj->data;
if (tensor->ctx.device_type != ctx.device_type) {
auto copy = tensor.CopyTo(ctx);
return Tensor(copy);
} else {
return src;
inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) {
if (src->IsInstance<NDArray::ContainerType>()) {
auto nd_array = Downcast<NDArray>(src);
if (nd_array->ctx.device_type != ctx.device_type) {
return nd_array.CopyTo(ctx);
}
} else {
return src;
}
return src;
}
PackedFunc VirtualMachine::GetFunction(const std::string& name,
......@@ -770,16 +766,12 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
if (const auto* dt_cell = args[i].as<ADTObj>()) {
for (size_t fi = 0; fi < dt_cell->size; ++fi) {
auto obj = (*dt_cell)[fi];
const auto* tensor = obj.as<TensorObj>();
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< obj->GetTypeKey();
setter(idx++, tensor->data);
auto nd_array = Downcast<NDArray>(obj);
setter(idx++, nd_array);
}
} else {
const auto* tensor = args[i].as<TensorObj>();
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< args[i]->GetTypeKey();
setter(idx++, tensor->data);
auto nd_array = Downcast<NDArray>(args[i]);
setter(idx++, nd_array);
}
}
......@@ -824,10 +816,8 @@ inline ObjectRef VirtualMachine::ReadRegister(Index r) const {
inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
int32_t result;
const auto& obj = ReadRegister(r);
const auto* tensor = obj.as<TensorObj>();
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< obj->GetTypeKey();
NDArray array = tensor->data.CopyTo({kDLCPU, 0});
auto nd_array = Downcast<NDArray>(obj);
NDArray array = nd_array.CopyTo({kDLCPU, 0});
if (array->dtype.bits <= 8) {
result = reinterpret_cast<int8_t*>(array->data)[0];
......@@ -883,7 +873,7 @@ void VirtualMachine::RunLoop() {
case Opcode::LoadConsti: {
auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0});
reinterpret_cast<int64_t*>(tensor->data)[0] = instr.load_consti.val;
WriteRegister(instr.dst, Tensor(tensor));
WriteRegister(instr.dst, tensor);
pc_++;
goto main_loop;
}
......@@ -943,7 +933,7 @@ void VirtualMachine::RunLoop() {
auto tag = adt.tag();
auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0});
reinterpret_cast<int32_t*>(tag_tensor->data)[0] = tag;
WriteRegister(instr.dst, Tensor(tag_tensor));
WriteRegister(instr.dst, tag_tensor);
pc_++;
goto main_loop;
}
......@@ -974,9 +964,8 @@ void VirtualMachine::RunLoop() {
auto storage_obj = ReadRegister(instr.alloc_tensor.storage);
auto storage = Downcast<Storage>(storage_obj);
auto data = storage->AllocNDArray(0, shape, instr.alloc_tensor.dtype);
auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor.dtype);
auto obj = Tensor(data);
WriteRegister(instr.dst, obj);
pc_++;
goto main_loop;
......@@ -986,10 +975,8 @@ void VirtualMachine::RunLoop() {
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register);
const auto* tensor = shape_tensor_obj.as<TensorObj>();
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< shape_tensor_obj->GetTypeKey();
NDArray shape_tensor = tensor->data.CopyTo(cpu_ctx);
const auto shape_arr = Downcast<NDArray>(shape_tensor_obj);
NDArray shape_tensor = shape_arr.CopyTo(cpu_ctx);
const DLTensor* dl_tensor = shape_tensor.operator->();
CHECK_EQ(dl_tensor->dtype.code, 0u);
CHECK_LE(dl_tensor->dtype.bits, 64);
......@@ -1000,9 +987,8 @@ void VirtualMachine::RunLoop() {
auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage);
auto storage = Downcast<Storage>(storage_obj);
auto data = storage->AllocNDArray(0, shape, instr.alloc_tensor_reg.dtype);
auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor_reg.dtype);
auto obj = Tensor(data);
WriteRegister(instr.dst, obj);
pc_++;
goto main_loop;
......
......@@ -18,6 +18,7 @@
import pytest
import tensorflow as tf
import numpy as np
from tvm import nd
from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow
......@@ -26,7 +27,7 @@ def check_equal(graph, tf_out):
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
ex = relay.create_executor('vm', mod=mod)
relay_out = ex.evaluate()(**params)
if isinstance(relay_out, relay.vmobj.Tensor):
if isinstance(relay_out, nd.NDArray):
np.testing.assert_allclose(tf_out, relay_out.asnumpy())
else:
if not isinstance(tf_out, list):
......
......@@ -60,7 +60,7 @@ tf_dtypes = {
}
def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vmobj.Tensor):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vmobj.ADT):
result = []
......@@ -87,8 +87,6 @@ def vmobj_to_list(o):
else:
raise RuntimeError("Unknown object type: %s" %
o.constructor.name_hint)
elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy()]
else:
raise RuntimeError("Unknown object type: %s" % type(o))
......
......@@ -115,10 +115,8 @@ def tree_to_dict(t):
def vmobj_to_list(o, dtype="float32"):
if isinstance(o, tvm.relay.backend.vmobj.Tensor):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.asnumpy()]
elif isinstance(o, tvm.relay.backend.vmobj.ADT):
if len(o) == 0:
tensor_nil = p.get_var("tensor_nil", dtype=dtype)
......
......@@ -17,8 +17,9 @@
import numpy as np
import tvm
import tvm.testing
from tvm import nd
from tvm import relay
from tvm.relay.backend.interpreter import Value, TupleValue, TensorValue
from tvm.relay.backend.interpreter import TupleValue
from tvm.relay.backend.interpreter import RefValue, ConstructorValue
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor
......@@ -37,18 +38,11 @@ def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
result.asnumpy(), expected_result, rtol=rtol)
def test_from_scalar():
np.testing.assert_allclose(Value.from_scalar(1, 'int32').asnumpy(), 1)
np.testing.assert_allclose(Value.from_scalar(10.0, 'float32').asnumpy(), 10.0)
np.testing.assert_allclose(Value.from_scalar(True).asnumpy(), True)
def test_tuple_value():
tv = TupleValue(Value.from_scalar(
1), Value.from_scalar(2), Value.from_scalar(3))
np.testing.assert_allclose(tv[0].asnumpy(), 1)
np.testing.assert_allclose(tv[1].asnumpy(), 2)
np.testing.assert_allclose(tv[2].asnumpy(), 3)
tv = TupleValue(relay.const(1), relay.const(2), relay.const(3))
np.testing.assert_allclose(tv[0].data.asnumpy(), 1)
np.testing.assert_allclose(tv[1].data.asnumpy(), 2)
np.testing.assert_allclose(tv[2].data.asnumpy(), 3)
def test_tuple_getitem():
......@@ -158,12 +152,6 @@ def test_binds():
tvm.testing.assert_allclose(xx + xx, res)
def test_tensor_value():
x = relay.var("x", shape=(1, 10))
xx = np.ones((1, 10)).astype("float32")
check_eval(relay.Function([x], x), [TensorValue(xx)], xx)
def test_kwargs_params():
x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(1, 10))
......@@ -174,7 +162,7 @@ def test_kwargs_params():
z_data = np.random.rand(1, 10).astype('float32')
params = { 'y': y_data, 'z': z_data }
intrp = create_executor("debug")
res = intrp.evaluate(f)(x_data, **params).data
res = intrp.evaluate(f)(x_data, **params)
tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data)
......@@ -185,13 +173,13 @@ def test_function_taking_adt_ref_tuple():
nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil)
cons_value = ConstructorValue(prelude.cons.tag, [
TensorValue(np.random.rand(1, 10).astype('float32')),
nd.array(np.random.rand(1, 10).astype('float32')),
nil_value
], prelude.cons)
ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32')))
ref_value = RefValue(nd.array(np.random.rand(1, 10).astype('float32')))
tuple_value = TupleValue(*[
TensorValue(np.random.rand(1, 10).astype('float32')) for _ in range(10)
nd.array(np.random.rand(1, 10).astype('float32')) for _ in range(10)
])
id_func = intrp.evaluate(prelude.id)
......@@ -236,9 +224,7 @@ def test_tuple_passing():
out = f((10, 8))
tvm.testing.assert_allclose(out.asnumpy(), np.array(10))
# Second use a tuple value.
value_tuple = TupleValue(
TensorValue(np.array(11)),
TensorValue(np.array(12)))
value_tuple = TupleValue(nd.array(np.array(11)), nd.array(np.array(12)))
out = f(value_tuple)
tvm.testing.assert_allclose(out.asnumpy(), np.array(11))
......@@ -252,7 +238,6 @@ if __name__ == "__main__":
test_binds()
test_kwargs_params()
test_ref()
test_tensor_value()
test_tuple_value()
test_tuple_getitem()
test_function_taking_adt_ref_tuple()
......
......@@ -19,7 +19,7 @@ import tvm
from tvm import relay
from tvm.relay.testing import to_python, run_as_python
from tvm.relay.prelude import Prelude
from tvm.relay.backend.interpreter import TensorValue, TupleValue, RefValue, ConstructorValue
from tvm.relay.backend.interpreter import TupleValue, RefValue, ConstructorValue
# helper: uses a dummy let binding to sequence a list
# of expressions: expr1; expr2; expr3, etc.
......@@ -39,9 +39,9 @@ def init_box_adt(mod):
return (box, box_ctor)
# assert that the candidate is a TensorValue with value val
# assert that the candidate is a NDArray with value val
def assert_tensor_value(candidate, val):
assert isinstance(candidate, TensorValue)
assert isinstance(candidate, tvm.nd.NDArray)
assert np.array_equal(candidate.asnumpy(), np.array(val))
......@@ -68,6 +68,7 @@ def test_create_empty_tuple():
def test_create_scalar():
scalar = relay.const(1)
tensor_val = run_as_python(scalar)
print(type(tensor_val))
assert_tensor_value(tensor_val, 1)
......@@ -544,7 +545,7 @@ def test_batch_norm():
# there will be a change in accuracy so we need to check
# approximate equality
assert isinstance(call_val, TensorValue)
assert isinstance(call_val, tvm.nd.NDArray)
tvm.testing.assert_allclose(call_val.asnumpy(), ref_res, atol=eps, rtol=eps)
verify_batch_norm([(10, 20), (20,), (20,), (20,), (20,)])
......
......@@ -56,7 +56,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
return vm.invoke("main", *args)
def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vm.Tensor):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vm.ADT):
result = []
......
......@@ -19,28 +19,16 @@ import numpy as np
import tvm
from tvm.relay import vm
def test_tensor():
arr = tvm.nd.array([1,2,3])
x = vm.Tensor(arr)
assert isinstance(x, vm.Tensor)
assert x.asnumpy()[0] == 1
assert x.asnumpy()[-1] == 3
assert isinstance(x.data, tvm.nd.NDArray)
def test_adt():
arr = tvm.nd.array([1,2,3])
x = vm.Tensor(arr)
y = vm.ADT(0, [x, x])
y = vm.ADT(0, [arr, arr])
assert len(y) == 2
assert isinstance(y, vm.ADT)
y[0:1][-1].data == x.data
y[0:1][-1] == arr
assert y.tag == 0
assert isinstance(x.data, tvm.nd.NDArray)
assert isinstance(arr, tvm.nd.NDArray)
if __name__ == "__main__":
test_tensor()
test_adt()
......@@ -28,7 +28,7 @@ def test_double_buffer():
with ib.for_range(0, n) as i:
B = ib.allocate("float32", m, name="B", scope="shared")
with ib.new_scope():
ib.scope_attr(B.asnode(), "double_buffer_scope", 1)
ib.scope_attr(B.asobject(), "double_buffer_scope", 1)
with ib.for_range(0, m) as j:
B[j] = A[i * 4 + j]
with ib.for_range(0, m) as j:
......@@ -39,7 +39,7 @@ def test_double_buffer():
stmt = tvm.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.stmt.Allocate)
assert stmt.body.body.extents[0].value == 2
f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asnode(), C.asnode()], 2, True)
f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
f = tvm.ir_pass.ThreadSync(f, "shared")
count = [0]
def count_sync(op):
......
......@@ -32,7 +32,7 @@ def test_vthread():
ib.scope_attr(ty, "virtual_thread", nthread)
B = ib.allocate("float32", m, name="B", scope="shared")
B[i] = A[i * nthread + tx]
bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asobject())
ib.emit(tvm.call_extern("int32", "Run",
bbuffer.access_ptr("r"),
tvm.call_pure_intrin("int32", "tvm_context_id")))
......@@ -60,9 +60,9 @@ def test_vthread_extern():
A = ib.allocate("float32", m, name="A", scope="shared")
B = ib.allocate("float32", m, name="B", scope="shared")
C = ib.allocate("float32", m, name="C", scope="shared")
cbuffer = tvm.decl_buffer((m,), dtype=C.dtype, data=C.asnode())
abuffer = tvm.decl_buffer((m,), dtype=A.dtype, data=A.asnode())
bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
cbuffer = tvm.decl_buffer((m,), dtype=C.dtype, data=C.asobject())
abuffer = tvm.decl_buffer((m,), dtype=A.dtype, data=A.asobject())
bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asobject())
A[tx] = tx + 1.0
B[ty] = ty + 1.0
ib.emit(tvm.call_extern("int32", "Run",
......
......@@ -79,7 +79,7 @@ def test_flatten_double_buffer():
with ib.for_range(0, n) as i:
B = ib.allocate("float32", m, name="B", scope="shared")
with ib.new_scope():
ib.scope_attr(B.asnode(), "double_buffer_scope", 1)
ib.scope_attr(B.asobject(), "double_buffer_scope", 1)
with ib.for_range(0, m) as j:
B[j] = A[i * 4 + j]
with ib.for_range(0, m) as j:
......@@ -91,7 +91,7 @@ def test_flatten_double_buffer():
stmt = tvm.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.stmt.Allocate)
assert stmt.body.body.extents[0].value == 2
f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asnode(), C.asnode()], 2, True)
f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
f = tvm.ir_pass.ThreadSync(f, "shared")
count = [0]
def count_sync(op):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment