Commit 86092de0 by Zhi Committed by Jared Roesch

[REFACTOR] Replace TensorObj and TensorValue with NDArray (#4643)

* replace TensorObj and TensorValue with NDArray

* NodeBase to Object in Python

* rebase
parent dcf7fbf1
...@@ -20,17 +20,14 @@ Developer API ...@@ -20,17 +20,14 @@ Developer API
This page contains modules that are used by developers of TVM. This page contains modules that are used by developers of TVM.
Many of these APIs are PackedFunc registered in C++ backend. Many of these APIs are PackedFunc registered in C++ backend.
tvm.node tvm.object
~~~~~~~~ ~~~~~~~~~~
.. automodule:: tvm.node .. automodule:: tvm.object
.. autoclass:: tvm.node.NodeBase
:members:
.. autoclass:: tvm.node.Node .. autoclass:: tvm.object.Object
:members: :members:
.. autofunction:: tvm.register_node .. autofunction:: tvm.register_object
tvm.expr tvm.expr
~~~~~~~~ ~~~~~~~~
......
...@@ -55,18 +55,18 @@ We use a simple example that uses the low level TVM API directly. The example is ...@@ -55,18 +55,18 @@ We use a simple example that uses the low level TVM API directly. The example is
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C") C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``NodeBase``. Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``.
:: ::
@register_node @register_object
class Tensor(NodeBase, _expr.ExprOp): class Tensor(Object, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor""" """Tensor object, to construct, see function.Tensor"""
def __call__(self, *indices): def __call__(self, *indices):
... ...
The Node system is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document <https://docs.tvm.ai/dev/runtime.html#tvm-node-and-compiler-stack>`_, and details are in ``python/tvm/_ffi/`` if you are interested. The object protocol is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document <https://docs.tvm.ai/dev/runtime.html#tvm-node-and-compiler-stack>`_, and details are in ``python/tvm/_ffi/`` if you are interested.
``Tensor`` is created by functions in ``python/tvm/api.py``, which in turn calls into C++ functions exposed in ``src/api/api_lang.cc``. All C++ functions that are callable from Python are exposed in the ``src/api`` subdirectory. For example, the ``tvm.compute()`` function above calls into ``_ComputeOp`` API exposed in ``src/api/api_lang.cc``: ``Tensor`` is created by functions in ``python/tvm/api.py``, which in turn calls into C++ functions exposed in ``src/api/api_lang.cc``. All C++ functions that are callable from Python are exposed in the ``src/api`` subdirectory. For example, the ``tvm.compute()`` function above calls into ``_ComputeOp`` API exposed in ``src/api/api_lang.cc``:
......
...@@ -37,16 +37,12 @@ ...@@ -37,16 +37,12 @@
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <tvm/relay/module.h> #include <tvm/relay/module.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
/*! /*!
* \brief A Relay value.
*/
class Value;
/*!
*\brief Create a Interpreter function that can *\brief Create a Interpreter function that can
* evaluate an expression and produce a value. * evaluate an expression and produce a value.
* *
...@@ -65,39 +61,21 @@ class Value; ...@@ -65,39 +61,21 @@ class Value;
* \param target Compiler target flag to compile the functions on the context. * \param target Compiler target flag to compile the functions on the context.
* \return A function that takes in an expression and returns a value. * \return A function that takes in an expression and returns a value.
*/ */
runtime::TypedPackedFunc<Value(Expr)> runtime::TypedPackedFunc<ObjectRef(Expr)>
CreateInterpreter(Module mod, DLContext context, Target target); CreateInterpreter(Module mod, DLContext context, Target target);
/*! \brief The base container type of Relay values. */
class ValueNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Value";
TVM_DECLARE_BASE_OBJECT_INFO(ValueNode, RelayNode);
};
class Value : public ObjectRef {
public:
Value() {}
explicit Value(ObjectPtr<Object> n) : ObjectRef(n) {}
const ValueNode* operator->() const {
return static_cast<const ValueNode*>(get());
}
using ContainerType = ValueNode;
};
/*! \brief A Relay closure, i.e a scope and a function. */ /*! \brief A Relay closure, i.e a scope and a function. */
class Closure; class Closure;
/*! \brief The container type of Closures. */ /*! \brief The container type of Closures. */
class ClosureNode : public ValueNode { class ClosureNode : public Object {
public: public:
/*! \brief The set of free variables in the closure. /*! \brief The set of free variables in the closure.
* *
* These are the captured variables which are required for * These are the captured variables which are required for
* evaluation when we call the closure. * evaluation when we call the closure.
*/ */
tvm::Map<Var, Value> env; tvm::Map<Var, ObjectRef> env;
/*! \brief The function which implements the closure. /*! \brief The function which implements the closure.
* *
* \note May reference the variables contained in the env. * \note May reference the variables contained in the env.
...@@ -111,22 +89,22 @@ class ClosureNode : public ValueNode { ...@@ -111,22 +89,22 @@ class ClosureNode : public ValueNode {
v->Visit("func", &func); v->Visit("func", &func);
} }
TVM_DLL static Closure make(tvm::Map<Var, Value> env, Function func); TVM_DLL static Closure make(tvm::Map<Var, ObjectRef> env, Function func);
static constexpr const char* _type_key = "relay.Closure"; static constexpr const char* _type_key = "relay.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, ValueNode); TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, Object);
}; };
class Closure : public Value { class Closure : public ObjectRef {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(Closure, Value, ClosureNode); TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureNode);
}; };
/*! \brief A Relay Recursive Closure. A closure that has a name. */ /*! \brief A Relay Recursive Closure. A closure that has a name. */
class RecClosure; class RecClosure;
/*! \brief The container type of RecClosure. */ /*! \brief The container type of RecClosure. */
class RecClosureNode : public ValueNode { class RecClosureNode : public Object {
public: public:
/*! \brief The closure. */ /*! \brief The closure. */
Closure clos; Closure clos;
...@@ -143,64 +121,41 @@ class RecClosureNode : public ValueNode { ...@@ -143,64 +121,41 @@ class RecClosureNode : public ValueNode {
TVM_DLL static RecClosure make(Closure clos, Var bind); TVM_DLL static RecClosure make(Closure clos, Var bind);
static constexpr const char* _type_key = "relay.RecClosure"; static constexpr const char* _type_key = "relay.RecClosure";
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, ValueNode); TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, Object);
}; };
class RecClosure : public Value { class RecClosure : public ObjectRef {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, Value, RecClosureNode); TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureNode);
}; };
/*! \brief A tuple value. */ /*! \brief A tuple value. */
class TupleValue; class TupleValue;
/*! \brief Tuple (x, ... y). */ /*! \brief Tuple (x, ... y). */
struct TupleValueNode : ValueNode { struct TupleValueNode : Object {
tvm::Array<Value> fields; tvm::Array<ObjectRef> fields;
TupleValueNode() {} TupleValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }
TVM_DLL static TupleValue make(tvm::Array<Value> value); TVM_DLL static TupleValue make(tvm::Array<ObjectRef> value);
static constexpr const char* _type_key = "relay.TupleValue"; static constexpr const char* _type_key = "relay.TupleValue";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, ValueNode); TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, Object);
};
class TupleValue : public Value {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, Value, TupleValueNode);
};
/*! \brief A tensor value. */
class TensorValue;
/*! \brief The tensor value container, wrapping an NDArray. */
struct TensorValueNode : ValueNode {
runtime::NDArray data;
TensorValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); }
/*! \brief Build a value from an NDArray. */
TVM_DLL static TensorValue make(runtime::NDArray data);
static constexpr const char* _type_key = "relay.TensorValue";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorValueNode, ValueNode);
}; };
class TensorValue : public Value { class TupleValue : public ObjectRef {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(TensorValue, Value, TensorValueNode); TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, ObjectRef, TupleValueNode);
}; };
/*! \brief A reference value. */ /*! \brief A reference value. */
class RefValue; class RefValue;
struct RefValueNode : ValueNode { struct RefValueNode : Object {
mutable Value value; mutable ObjectRef value;
RefValueNode() {} RefValueNode() {}
...@@ -208,24 +163,24 @@ struct RefValueNode : ValueNode { ...@@ -208,24 +163,24 @@ struct RefValueNode : ValueNode {
v->Visit("value", &value); v->Visit("value", &value);
} }
TVM_DLL static RefValue make(Value val); TVM_DLL static RefValue make(ObjectRef val);
static constexpr const char* _type_key = "relay.RefValue"; static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, ValueNode); TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, Object);
}; };
class RefValue : public Value { class RefValue : public ObjectRef {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, Value, RefValueNode); TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueNode);
}; };
/*! \brief An ADT constructor value. */ /*! \brief An ADT constructor value. */
class ConstructorValue; class ConstructorValue;
struct ConstructorValueNode : ValueNode { struct ConstructorValueNode : Object {
int32_t tag; int32_t tag;
tvm::Array<Value> fields; tvm::Array<ObjectRef> fields;
/*! \brief Optional field tracking ADT constructor. */ /*! \brief Optional field tracking ADT constructor. */
Constructor constructor; Constructor constructor;
...@@ -237,16 +192,16 @@ struct ConstructorValueNode : ValueNode { ...@@ -237,16 +192,16 @@ struct ConstructorValueNode : ValueNode {
} }
TVM_DLL static ConstructorValue make(int32_t tag, TVM_DLL static ConstructorValue make(int32_t tag,
tvm::Array<Value> fields, tvm::Array<ObjectRef> fields,
Constructor construtor = {}); Constructor construtor = {});
static constexpr const char* _type_key = "relay.ConstructorValue"; static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, ValueNode); TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, Object);
}; };
class ConstructorValue : public Value { class ConstructorValue : public ObjectRef {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, Value, ConstructorValueNode); TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueNode);
}; };
} // namespace relay } // namespace relay
......
...@@ -36,25 +36,6 @@ namespace tvm { ...@@ -36,25 +36,6 @@ namespace tvm {
namespace runtime { namespace runtime {
namespace vm { namespace vm {
/*! \brief An object containing an NDArray. */
class TensorObj : public Object {
public:
/*! \brief The NDArray. */
NDArray data;
static constexpr const uint32_t _type_index = TypeIndex::kVMTensor;
static constexpr const char* _type_key = "vm.Tensor";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorObj, Object);
};
/*! \brief reference to tensor. */
class Tensor : public ObjectRef {
public:
explicit Tensor(NDArray data);
TVM_DEFINE_OBJECT_REF_METHODS(Tensor, ObjectRef, TensorObj);
};
/*! \brief An object representing a closure. */ /*! \brief An object representing a closure. */
class ClosureObj : public Object { class ClosureObj : public Object {
public: public:
......
...@@ -34,7 +34,7 @@ from . import codegen ...@@ -34,7 +34,7 @@ from . import codegen
from . import container from . import container
from . import schedule from . import schedule
from . import module from . import module
from . import node from . import object
from . import attrs from . import attrs
from . import ir_builder from . import ir_builder
from . import target from . import target
...@@ -55,7 +55,7 @@ from ._ffi.base import TVMError, __version__ ...@@ -55,7 +55,7 @@ from ._ffi.base import TVMError, __version__
from .api import * from .api import *
from .intrin import * from .intrin import *
from .tensor_intrin import decl_tensor_intrin from .tensor_intrin import decl_tensor_intrin
from .node import register_node from .object import register_object
from .ndarray import register_extension from .ndarray import register_extension
from .schedule import create_schedule from .schedule import create_schedule
from .build_module import build, lower, build_config from .build_module import build, lower, build_config
......
...@@ -25,14 +25,14 @@ from numbers import Number, Integral ...@@ -25,14 +25,14 @@ from numbers import Number, Integral
from ..base import _LIB, get_last_ffi_error, py2cerror from ..base import _LIB, get_last_ffi_error, py2cerror
from ..base import c_str, string_types from ..base import c_str, string_types
from ..node_generic import convert_to_node, NodeGeneric from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
from . import ndarray as _nd from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64 from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .object import ObjectBase, _set_class_node from .object import ObjectBase, _set_class_object
from . import object as _object from . import object as _object
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
...@@ -144,8 +144,8 @@ def _make_tvm_args(args, temp_args): ...@@ -144,8 +144,8 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, string_types): elif isinstance(arg, string_types):
values[i].v_str = c_str(arg) values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR type_codes[i] = TypeCode.STR
elif isinstance(arg, (list, tuple, dict, NodeGeneric)): elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
arg = convert_to_node(arg) arg = convert_to_object(arg)
values[i].v_handle = arg.handle values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT_HANDLE type_codes[i] = TypeCode.OBJECT_HANDLE
temp_args.append(arg) temp_args.append(arg)
...@@ -256,7 +256,6 @@ C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handl ...@@ -256,7 +256,6 @@ C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handl
_CLASS_MODULE = None _CLASS_MODULE = None
_CLASS_FUNCTION = None _CLASS_FUNCTION = None
_CLASS_OBJECT = None
def _set_class_module(module_class): def _set_class_module(module_class):
"""Initialize the module.""" """Initialize the module."""
...@@ -266,7 +265,3 @@ def _set_class_module(module_class): ...@@ -266,7 +265,3 @@ def _set_class_module(module_class):
def _set_class_function(func_class): def _set_class_function(func_class):
global _CLASS_FUNCTION global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class _CLASS_FUNCTION = func_class
def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class
...@@ -30,11 +30,11 @@ __init_by_constructor__ = None ...@@ -30,11 +30,11 @@ __init_by_constructor__ = None
"""Maps object type to its constructor""" """Maps object type to its constructor"""
OBJECT_TYPE = {} OBJECT_TYPE = {}
_CLASS_NODE = None _CLASS_OBJECT = None
def _set_class_node(node_class): def _set_class_object(object_class):
global _CLASS_NODE global _CLASS_OBJECT
_CLASS_NODE = node_class _CLASS_OBJECT = object_class
def _register_object(index, cls): def _register_object(index, cls):
...@@ -51,7 +51,7 @@ def _return_object(x): ...@@ -51,7 +51,7 @@ def _return_object(x):
handle = ObjectHandle(handle) handle = ObjectHandle(handle)
tindex = ctypes.c_uint() tindex = ctypes.c_uint()
check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
cls = OBJECT_TYPE.get(tindex.value, _CLASS_NODE) cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT)
# Avoid calling __init__ of cls, instead directly call __new__ # Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__ # This allows child class to implement their own __init__
obj = cls.__new__(cls) obj = cls.__new__(cls)
......
...@@ -20,7 +20,7 @@ import traceback ...@@ -20,7 +20,7 @@ import traceback
from cpython cimport Py_INCREF, Py_DECREF from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral from numbers import Number, Integral
from ..base import string_types, py2cerror from ..base import string_types, py2cerror
from ..node_generic import convert_to_node, NodeGeneric from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
...@@ -149,8 +149,8 @@ cdef inline int make_arg(object arg, ...@@ -149,8 +149,8 @@ cdef inline int make_arg(object arg,
value[0].v_str = tstr value[0].v_str = tstr
tcode[0] = kStr tcode[0] = kStr
temp_args.append(tstr) temp_args.append(tstr)
elif isinstance(arg, (list, tuple, dict, NodeGeneric)): elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
arg = convert_to_node(arg) arg = convert_to_object(arg)
value[0].v_handle = (<ObjectBase>arg).chandle value[0].v_handle = (<ObjectBase>arg).chandle
tcode[0] = kObjectHandle tcode[0] = kObjectHandle
temp_args.append(arg) temp_args.append(arg)
...@@ -308,7 +308,6 @@ cdef class FunctionBase: ...@@ -308,7 +308,6 @@ cdef class FunctionBase:
_CLASS_FUNCTION = None _CLASS_FUNCTION = None
_CLASS_MODULE = None _CLASS_MODULE = None
_CLASS_OBJECT = None _CLASS_OBJECT = None
_CLASS_NODE = None
def _set_class_module(module_class): def _set_class_module(module_class):
"""Initialize the module.""" """Initialize the module."""
...@@ -322,7 +321,3 @@ def _set_class_function(func_class): ...@@ -322,7 +321,3 @@ def _set_class_function(func_class):
def _set_class_object(obj_class): def _set_class_object(obj_class):
global _CLASS_OBJECT global _CLASS_OBJECT
_CLASS_OBJECT = obj_class _CLASS_OBJECT = obj_class
def _set_class_node(node_class):
global _CLASS_NODE
_CLASS_NODE = node_class
...@@ -32,7 +32,7 @@ def _register_object(int index, object cls): ...@@ -32,7 +32,7 @@ def _register_object(int index, object cls):
cdef inline object make_ret_object(void* chandle): cdef inline object make_ret_object(void* chandle):
global OBJECT_TYPE global OBJECT_TYPE
global _CLASS_NODE global _CLASS_OBJECT
cdef unsigned tindex cdef unsigned tindex
cdef object cls cdef object cls
cdef object handle cdef object handle
...@@ -44,11 +44,9 @@ cdef inline object make_ret_object(void* chandle): ...@@ -44,11 +44,9 @@ cdef inline object make_ret_object(void* chandle):
if cls is not None: if cls is not None:
obj = cls.__new__(cls) obj = cls.__new__(cls)
else: else:
# default use node base class obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
# TODO(tqchen) change to object after Node unifies with Object
obj = _CLASS_NODE.__new__(_CLASS_NODE)
else: else:
obj = _CLASS_NODE.__new__(_CLASS_NODE) obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
(<ObjectBase>obj).chandle = chandle (<ObjectBase>obj).chandle = chandle
return obj return obj
......
...@@ -22,7 +22,7 @@ from __future__ import absolute_import ...@@ -22,7 +22,7 @@ from __future__ import absolute_import
import sys import sys
import ctypes import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
from .node_generic import _set_class_objects from .object_generic import _set_class_objects
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Node namespace"""
# pylint: disable=unused-import
from __future__ import absolute_import
import ctypes
import sys
from .. import _api_internal
from .object import Object, register_object, _set_class_node
from .node_generic import NodeGeneric, convert_to_node, const
def _new_object(cls):
"""Helper function for pickle"""
return cls.__new__(cls)
class NodeBase(Object):
"""NodeBase is the base class of all TVM language AST object."""
def __repr__(self):
return _api_internal._format_str(self)
def __dir__(self):
fnames = _api_internal._NodeListAttrNames(self)
size = fnames(-1)
return [fnames(i) for i in range(size)]
def __getattr__(self, name):
try:
return _api_internal._NodeGetAttr(self, name)
except AttributeError:
raise AttributeError(
"%s has no attribute %s" % (str(type(self)), name))
def __hash__(self):
return _api_internal._raw_ptr(self)
def __eq__(self, other):
return self.same_as(other)
def __ne__(self, other):
return not self.__eq__(other)
def __reduce__(self):
cls = type(self)
return (_new_object, (cls, ), self.__getstate__())
def __getstate__(self):
handle = self.handle
if handle is not None:
return {'handle': _api_internal._save_json(self)}
return {'handle': None}
def __setstate__(self, state):
# pylint: disable=assigning-non-slot
handle = state['handle']
if handle is not None:
json_str = handle
other = _api_internal._load_json(json_str)
self.handle = other.handle
other.handle = None
else:
self.handle = None
def same_as(self, other):
"""check object identity equality"""
if not isinstance(other, NodeBase):
return False
return self.__hash__() == other.__hash__()
# pylint: disable=invalid-name
register_node = register_object
_set_class_node(NodeBase)
...@@ -14,13 +14,15 @@ ...@@ -14,13 +14,15 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name # pylint: disable=invalid-name, unused-import
"""Runtime Object API""" """Runtime Object API"""
from __future__ import absolute_import from __future__ import absolute_import
import sys import sys
import ctypes import ctypes
from .. import _api_internal
from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
from .object_generic import ObjectGeneric, convert_to_object, const
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
...@@ -29,23 +31,77 @@ try: ...@@ -29,23 +31,77 @@ try:
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_object, _set_class_node from ._cy3.core import _set_class_object
from ._cy3.core import ObjectBase as _ObjectBase from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object from ._cy3.core import _register_object
else: else:
from ._cy2.core import _set_class_object, _set_class_node from ._cy2.core import _set_class_object
from ._cy2.core import ObjectBase as _ObjectBase from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object from ._cy2.core import _register_object
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position,unused-import # pylint: disable=wrong-import-position,unused-import
from ._ctypes.function import _set_class_object, _set_class_node from ._ctypes.function import _set_class_object
from ._ctypes.object import ObjectBase as _ObjectBase from ._ctypes.object import ObjectBase as _ObjectBase
from ._ctypes.object import _register_object from ._ctypes.object import _register_object
def _new_object(cls):
"""Helper function for pickle"""
return cls.__new__(cls)
class Object(_ObjectBase): class Object(_ObjectBase):
"""Base class for all tvm's runtime objects.""" """Base class for all tvm's runtime objects."""
pass def __repr__(self):
return _api_internal._format_str(self)
def __dir__(self):
fnames = _api_internal._NodeListAttrNames(self)
size = fnames(-1)
return [fnames(i) for i in range(size)]
def __getattr__(self, name):
try:
return _api_internal._NodeGetAttr(self, name)
except AttributeError:
raise AttributeError(
"%s has no attribute %s" % (str(type(self)), name))
def __hash__(self):
return _api_internal._raw_ptr(self)
def __eq__(self, other):
return self.same_as(other)
def __ne__(self, other):
return not self.__eq__(other)
def __reduce__(self):
cls = type(self)
return (_new_object, (cls, ), self.__getstate__())
def __getstate__(self):
handle = self.handle
if handle is not None:
return {'handle': _api_internal._save_json(self)}
return {'handle': None}
def __setstate__(self, state):
# pylint: disable=assigning-non-slot
handle = state['handle']
if handle is not None:
json_str = handle
other = _api_internal._load_json(json_str)
self.handle = other.handle
other.handle = None
else:
self.handle = None
def same_as(self, other):
"""check object identity equality"""
if not isinstance(other, Object):
return False
return self.__hash__() == other.__hash__()
def register_object(type_key=None): def register_object(type_key=None):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Common implementation of Node generic related logic""" """Common implementation of object generic related logic"""
# pylint: disable=unused-import # pylint: disable=unused-import
from __future__ import absolute_import from __future__ import absolute_import
...@@ -22,7 +22,7 @@ from numbers import Number, Integral ...@@ -22,7 +22,7 @@ from numbers import Number, Integral
from .. import _api_internal from .. import _api_internal
from .base import string_types from .base import string_types
# Node base class # Object base class
_CLASS_OBJECTS = None _CLASS_OBJECTS = None
def _set_class_objects(cls): def _set_class_objects(cls):
...@@ -47,15 +47,15 @@ def _scalar_type_inference(value): ...@@ -47,15 +47,15 @@ def _scalar_type_inference(value):
return dtype return dtype
class NodeGeneric(object): class ObjectGeneric(object):
"""Base class for all classes that can be converted to node.""" """Base class for all classes that can be converted to object."""
def asnode(self): def asobject(self):
"""Convert value to node""" """Convert value to object"""
raise NotImplementedError() raise NotImplementedError()
def convert_to_node(value): def convert_to_object(value):
"""Convert a python value to corresponding node type. """Convert a python value to corresponding object type.
Parameters Parameters
---------- ----------
...@@ -64,8 +64,8 @@ def convert_to_node(value): ...@@ -64,8 +64,8 @@ def convert_to_node(value):
Returns Returns
------- -------
node : Node obj : Object
The corresponding node value. The corresponding object value.
""" """
if isinstance(value, _CLASS_OBJECTS): if isinstance(value, _CLASS_OBJECTS):
return value return value
...@@ -76,7 +76,7 @@ def convert_to_node(value): ...@@ -76,7 +76,7 @@ def convert_to_node(value):
if isinstance(value, string_types): if isinstance(value, string_types):
return _api_internal._str(value) return _api_internal._str(value)
if isinstance(value, (list, tuple)): if isinstance(value, (list, tuple)):
value = [convert_to_node(x) for x in value] value = [convert_to_object(x) for x in value]
return _api_internal._Array(*value) return _api_internal._Array(*value)
if isinstance(value, dict): if isinstance(value, dict):
vlist = [] vlist = []
...@@ -85,14 +85,14 @@ def convert_to_node(value): ...@@ -85,14 +85,14 @@ def convert_to_node(value):
not isinstance(item[0], string_types)): not isinstance(item[0], string_types)):
raise ValueError("key of map must already been a container type") raise ValueError("key of map must already been a container type")
vlist.append(item[0]) vlist.append(item[0])
vlist.append(convert_to_node(item[1])) vlist.append(convert_to_object(item[1]))
return _api_internal._Map(*vlist) return _api_internal._Map(*vlist)
if isinstance(value, NodeGeneric): if isinstance(value, ObjectGeneric):
return value.asnode() return value.asobject()
if value is None: if value is None:
return None return None
raise ValueError("don't know how to convert type %s to node" % type(value)) raise ValueError("don't know how to convert type %s to object" % type(value))
def const(value, dtype=None): def const(value, dtype=None):
......
...@@ -22,9 +22,8 @@ from numbers import Integral as _Integral ...@@ -22,9 +22,8 @@ from numbers import Integral as _Integral
from ._ffi.base import string_types from ._ffi.base import string_types
from ._ffi.object import register_object, Object from ._ffi.object import register_object, Object
from ._ffi.node import register_node, NodeBase from ._ffi.object import convert_to_object as _convert_to_object
from ._ffi.node import convert_to_node as _convert_to_node from ._ffi.object_generic import _scalar_type_inference
from ._ffi.node_generic import _scalar_type_inference
from ._ffi.function import Function from ._ffi.function import Function
from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
...@@ -111,7 +110,7 @@ def get_env_func(name): ...@@ -111,7 +110,7 @@ def get_env_func(name):
Note Note
---- ----
EnvFunc is a Node wrapper around EnvFunc is a Object wrapper around
global function that can be serialized via its name. global function that can be serialized via its name.
This can be used to serialize function field in the language. This can be used to serialize function field in the language.
""" """
...@@ -127,16 +126,16 @@ def convert(value): ...@@ -127,16 +126,16 @@ def convert(value):
Returns Returns
------- -------
tvm_val : Node or Function tvm_val : Object or Function
Converted value in TVM Converted value in TVM
""" """
if isinstance(value, (Function, NodeBase)): if isinstance(value, (Function, Object)):
return value return value
if callable(value): if callable(value):
return _convert_tvm_func(value) return _convert_tvm_func(value)
return _convert_to_node(value) return _convert_to_object(value)
def load_json(json_str): def load_json(json_str):
...@@ -149,7 +148,7 @@ def load_json(json_str): ...@@ -149,7 +148,7 @@ def load_json(json_str):
Returns Returns
------- -------
node : Node node : Object
The loaded tvm node. The loaded tvm node.
""" """
return _api_internal._load_json(json_str) return _api_internal._load_json(json_str)
...@@ -160,8 +159,8 @@ def save_json(node): ...@@ -160,8 +159,8 @@ def save_json(node):
Parameters Parameters
---------- ----------
node : Node node : Object
A TVM Node object to be saved. A TVM object to be saved.
Returns Returns
------- -------
......
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
"""Arithmetic data structure and utility""" """Arithmetic data structure and utility"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node from ._ffi.object import Object, register_object
from ._ffi.function import _init_api from ._ffi.function import _init_api
from . import _api_internal from . import _api_internal
class IntSet(NodeBase): class IntSet(Object):
"""Represent a set of integer in one dimension.""" """Represent a set of integer in one dimension."""
def is_nothing(self): def is_nothing(self):
"""Whether the set represent nothing""" """Whether the set represent nothing"""
...@@ -32,7 +32,7 @@ class IntSet(NodeBase): ...@@ -32,7 +32,7 @@ class IntSet(NodeBase):
return _api_internal._IntSetIsEverything(self) return _api_internal._IntSetIsEverything(self)
@register_node("arith.IntervalSet") @register_object("arith.IntervalSet")
class IntervalSet(IntSet): class IntervalSet(IntSet):
"""Represent set of continuous interval [min_value, max_value] """Represent set of continuous interval [min_value, max_value]
...@@ -49,16 +49,16 @@ class IntervalSet(IntSet): ...@@ -49,16 +49,16 @@ class IntervalSet(IntSet):
_make_IntervalSet, min_value, max_value) _make_IntervalSet, min_value, max_value)
@register_node("arith.ModularSet") @register_object("arith.ModularSet")
class ModularSet(NodeBase): class ModularSet(Object):
"""Represent range of (coeff * x + base) for x in Z """ """Represent range of (coeff * x + base) for x in Z """
def __init__(self, coeff, base): def __init__(self, coeff, base):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make_ModularSet, coeff, base) _make_ModularSet, coeff, base)
@register_node("arith.ConstIntBound") @register_object("arith.ConstIntBound")
class ConstIntBound(NodeBase): class ConstIntBound(Object):
"""Represent constant integer bound """Represent constant integer bound
Parameters Parameters
...@@ -245,7 +245,7 @@ class Analyzer: ...@@ -245,7 +245,7 @@ class Analyzer:
var : tvm.Var var : tvm.Var
The variable. The variable.
info : tvm.NodeBase info : tvm.Object
Related information. Related information.
override : bool override : bool
......
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
""" TVM Attribute module, which is mainly used for defining attributes of operators""" """ TVM Attribute module, which is mainly used for defining attributes of operators"""
from ._ffi.node import NodeBase, register_node as _register_tvm_node from ._ffi.object import Object, register_object
from ._ffi.function import _init_api from ._ffi.function import _init_api
from . import _api_internal from . import _api_internal
@_register_tvm_node @register_object
class Attrs(NodeBase): class Attrs(Object):
"""Attribute node, which is mainly use for defining attributes of relay operators. """Attribute node, which is mainly use for defining attributes of relay operators.
Used by function registered in python side, such as compute, schedule and alter_layout. Used by function registered in python side, such as compute, schedule and alter_layout.
......
...@@ -23,7 +23,7 @@ from __future__ import absolute_import as _abs ...@@ -23,7 +23,7 @@ from __future__ import absolute_import as _abs
import warnings import warnings
from ._ffi.function import Function from ._ffi.function import Function
from ._ffi.node import NodeBase, register_node from ._ffi.object import Object, register_object
from . import api from . import api
from . import _api_internal from . import _api_internal
from . import tensor from . import tensor
...@@ -115,22 +115,22 @@ class DumpIR(object): ...@@ -115,22 +115,22 @@ class DumpIR(object):
DumpIR.scope_level -= 1 DumpIR.scope_level -= 1
@register_node @register_object
class BuildConfig(NodeBase): class BuildConfig(Object):
"""Configuration scope to set a build config option. """Configuration scope to set a build config option.
Note Note
---- ----
This object is backed by node system in C++, with arguments that can be This object is backed by object protocol in C++, with arguments that can be
exchanged between python and C++. exchanged between python and C++.
Do not construct directly, use build_config instead. Do not construct directly, use build_config instead.
The fields that are backed by the C++ node are immutable once an instance The fields that are backed by the C++ object are immutable once an instance
is constructed. See _node_defaults for the fields. is constructed. See _object_defaults for the fields.
""" """
_node_defaults = { _object_defaults = {
"auto_unroll_max_step": 0, "auto_unroll_max_step": 0,
"auto_unroll_max_depth": 8, "auto_unroll_max_depth": 8,
"auto_unroll_max_extent": 0, "auto_unroll_max_extent": 0,
...@@ -191,7 +191,7 @@ class BuildConfig(NodeBase): ...@@ -191,7 +191,7 @@ class BuildConfig(NodeBase):
_api_internal._ExitBuildConfigScope(self) _api_internal._ExitBuildConfigScope(self)
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name in BuildConfig._node_defaults: if name in BuildConfig._object_defaults:
raise AttributeError( raise AttributeError(
"'%s' object cannot set attribute '%s'" % (str(type(self)), name)) "'%s' object cannot set attribute '%s'" % (str(type(self)), name))
return super(BuildConfig, self).__setattr__(name, value) return super(BuildConfig, self).__setattr__(name, value)
...@@ -257,7 +257,7 @@ def build_config(**kwargs): ...@@ -257,7 +257,7 @@ def build_config(**kwargs):
The build configuration The build configuration
""" """
node_args = {k: v if k not in kwargs else kwargs[k] node_args = {k: v if k not in kwargs else kwargs[k]
for k, v in BuildConfig._node_defaults.items()} for k, v in BuildConfig._object_defaults.items()}
config = make.node("BuildConfig", **node_args) config = make.node("BuildConfig", **node_args)
if "add_lower_pass" in kwargs: if "add_lower_pass" in kwargs:
......
...@@ -16,11 +16,11 @@ ...@@ -16,11 +16,11 @@
# under the License. # under the License.
"""Container data structures used in TVM DSL.""" """Container data structures used in TVM DSL."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node from ._ffi.object import Object, register_object
from . import _api_internal from . import _api_internal
@register_node @register_object
class Array(NodeBase): class Array(Object):
"""Array container of TVM. """Array container of TVM.
You do not need to create Array explicitly. You do not need to create Array explicitly.
...@@ -50,8 +50,8 @@ class Array(NodeBase): ...@@ -50,8 +50,8 @@ class Array(NodeBase):
return _api_internal._ArraySize(self) return _api_internal._ArraySize(self)
@register_node @register_object
class EnvFunc(NodeBase): class EnvFunc(Object):
"""Environment function. """Environment function.
This is a global function object that can be serialized by its name. This is a global function object that can be serialized by its name.
...@@ -64,13 +64,13 @@ class EnvFunc(NodeBase): ...@@ -64,13 +64,13 @@ class EnvFunc(NodeBase):
return _api_internal._EnvFuncGetPackedFunc(self) return _api_internal._EnvFuncGetPackedFunc(self)
@register_node @register_object
class Map(NodeBase): class Map(Object):
"""Map container of TVM. """Map container of TVM.
You do not need to create Map explicitly. You do not need to create Map explicitly.
Normally python dict will be converted automaticall to Map during tvm function call. Normally python dict will be converted automaticall to Map during tvm function call.
You can use convert to create a dict[NodeBase-> NodeBase] into a Map You can use convert to create a dict[Object-> Object] into a Map
""" """
def __getitem__(self, k): def __getitem__(self, k):
return _api_internal._MapGetItem(self, k) return _api_internal._MapGetItem(self, k)
...@@ -87,11 +87,11 @@ class Map(NodeBase): ...@@ -87,11 +87,11 @@ class Map(NodeBase):
return _api_internal._MapSize(self) return _api_internal._MapSize(self)
@register_node @register_object
class StrMap(Map): class StrMap(Map):
"""A special map container that has str as key. """A special map container that has str as key.
You can use convert to create a dict[str->NodeBase] into a Map. You can use convert to create a dict[str->Object] into a Map.
""" """
def items(self): def items(self):
"""Get the items from the map""" """Get the items from the map"""
...@@ -99,8 +99,8 @@ class StrMap(Map): ...@@ -99,8 +99,8 @@ class StrMap(Map):
return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)] return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]
@register_node @register_object
class Range(NodeBase): class Range(Object):
"""Represent a range in TVM. """Represent a range in TVM.
You do not need to create a Range explicitly. You do not need to create a Range explicitly.
...@@ -108,8 +108,8 @@ class Range(NodeBase): ...@@ -108,8 +108,8 @@ class Range(NodeBase):
""" """
@register_node @register_object
class LoweredFunc(NodeBase): class LoweredFunc(Object):
"""Represent a LoweredFunc in TVM.""" """Represent a LoweredFunc in TVM."""
MixedFunc = 0 MixedFunc = 0
HostFunc = 1 HostFunc = 1
......
...@@ -32,7 +32,7 @@ For example, you can use addexp.a to get the left operand of an Add node. ...@@ -32,7 +32,7 @@ For example, you can use addexp.a to get the left operand of an Add node.
""" """
# pylint: disable=missing-docstring # pylint: disable=missing-docstring
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, NodeGeneric, register_node from ._ffi.object import Object, register_object, ObjectGeneric
from ._ffi.runtime_ctypes import TVMType, TypeCode from ._ffi.runtime_ctypes import TVMType, TypeCode
from . import make as _make from . import make as _make
from . import generic as _generic from . import generic as _generic
...@@ -178,11 +178,11 @@ class ExprOp(object): ...@@ -178,11 +178,11 @@ class ExprOp(object):
return _generic.cast(self, dtype) return _generic.cast(self, dtype)
class EqualOp(NodeGeneric, ExprOp): class EqualOp(ObjectGeneric, ExprOp):
"""Deferred equal operator. """Deferred equal operator.
This is used to support sugar that a == b can either This is used to support sugar that a == b can either
mean NodeBase.same_as or NodeBase.equal. mean Object.same_as or Object.equal.
Parameters Parameters
---------- ----------
...@@ -205,16 +205,16 @@ class EqualOp(NodeGeneric, ExprOp): ...@@ -205,16 +205,16 @@ class EqualOp(NodeGeneric, ExprOp):
def __bool__(self): def __bool__(self):
return self.__nonzero__() return self.__nonzero__()
def asnode(self): def asobject(self):
"""Convert node.""" """Convert object."""
return _make._OpEQ(self.a, self.b) return _make._OpEQ(self.a, self.b)
class NotEqualOp(NodeGeneric, ExprOp): class NotEqualOp(ObjectGeneric, ExprOp):
"""Deferred NE operator. """Deferred NE operator.
This is used to support sugar that a != b can either This is used to support sugar that a != b can either
mean not NodeBase.same_as or make.NE. mean not Object.same_as or make.NE.
Parameters Parameters
---------- ----------
...@@ -237,16 +237,16 @@ class NotEqualOp(NodeGeneric, ExprOp): ...@@ -237,16 +237,16 @@ class NotEqualOp(NodeGeneric, ExprOp):
def __bool__(self): def __bool__(self):
return self.__nonzero__() return self.__nonzero__()
def asnode(self): def asobject(self):
"""Convert node.""" """Convert object."""
return _make._OpNE(self.a, self.b) return _make._OpNE(self.a, self.b)
class PrimExpr(ExprOp, NodeBase): class PrimExpr(ExprOp, Object):
"""Base class of all tvm Expressions""" """Base class of all tvm Expressions"""
# In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__ # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__
# https://docs.python.org/3.1/reference/datamodel.html#object.__hash__ # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__
__hash__ = NodeBase.__hash__ __hash__ = Object.__hash__
class ConstExpr(PrimExpr): class ConstExpr(PrimExpr):
...@@ -261,7 +261,7 @@ class CmpExpr(PrimExpr): ...@@ -261,7 +261,7 @@ class CmpExpr(PrimExpr):
class LogicalExpr(PrimExpr): class LogicalExpr(PrimExpr):
pass pass
@register_node("Variable") @register_object("Variable")
class Var(PrimExpr): class Var(PrimExpr):
"""Symbolic variable. """Symbolic variable.
...@@ -278,7 +278,7 @@ class Var(PrimExpr): ...@@ -278,7 +278,7 @@ class Var(PrimExpr):
_api_internal._Var, name, dtype) _api_internal._Var, name, dtype)
@register_node @register_object
class Reduce(PrimExpr): class Reduce(PrimExpr):
"""Reduce node. """Reduce node.
...@@ -305,7 +305,7 @@ class Reduce(PrimExpr): ...@@ -305,7 +305,7 @@ class Reduce(PrimExpr):
condition, value_index) condition, value_index)
@register_node @register_object
class FloatImm(ConstExpr): class FloatImm(ConstExpr):
"""Float constant. """Float constant.
...@@ -321,7 +321,7 @@ class FloatImm(ConstExpr): ...@@ -321,7 +321,7 @@ class FloatImm(ConstExpr):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.FloatImm, dtype, value) _make.FloatImm, dtype, value)
@register_node @register_object
class IntImm(ConstExpr): class IntImm(ConstExpr):
"""Int constant. """Int constant.
...@@ -341,7 +341,7 @@ class IntImm(ConstExpr): ...@@ -341,7 +341,7 @@ class IntImm(ConstExpr):
return self.value return self.value
@register_node @register_object
class UIntImm(ConstExpr): class UIntImm(ConstExpr):
"""UInt constant. """UInt constant.
...@@ -358,7 +358,7 @@ class UIntImm(ConstExpr): ...@@ -358,7 +358,7 @@ class UIntImm(ConstExpr):
_make.UIntImm, dtype, value) _make.UIntImm, dtype, value)
@register_node @register_object
class StringImm(ConstExpr): class StringImm(ConstExpr):
"""String constant. """String constant.
...@@ -382,7 +382,7 @@ class StringImm(ConstExpr): ...@@ -382,7 +382,7 @@ class StringImm(ConstExpr):
return self.value != other return self.value != other
@register_node @register_object
class Cast(PrimExpr): class Cast(PrimExpr):
"""Cast expression. """Cast expression.
...@@ -399,7 +399,7 @@ class Cast(PrimExpr): ...@@ -399,7 +399,7 @@ class Cast(PrimExpr):
_make.Cast, dtype, value) _make.Cast, dtype, value)
@register_node @register_object
class Add(BinaryOpExpr): class Add(BinaryOpExpr):
"""Add node. """Add node.
...@@ -416,7 +416,7 @@ class Add(BinaryOpExpr): ...@@ -416,7 +416,7 @@ class Add(BinaryOpExpr):
_make.Add, a, b) _make.Add, a, b)
@register_node @register_object
class Sub(BinaryOpExpr): class Sub(BinaryOpExpr):
"""Sub node. """Sub node.
...@@ -433,7 +433,7 @@ class Sub(BinaryOpExpr): ...@@ -433,7 +433,7 @@ class Sub(BinaryOpExpr):
_make.Sub, a, b) _make.Sub, a, b)
@register_node @register_object
class Mul(BinaryOpExpr): class Mul(BinaryOpExpr):
"""Mul node. """Mul node.
...@@ -450,7 +450,7 @@ class Mul(BinaryOpExpr): ...@@ -450,7 +450,7 @@ class Mul(BinaryOpExpr):
_make.Mul, a, b) _make.Mul, a, b)
@register_node @register_object
class Div(BinaryOpExpr): class Div(BinaryOpExpr):
"""Div node. """Div node.
...@@ -467,7 +467,7 @@ class Div(BinaryOpExpr): ...@@ -467,7 +467,7 @@ class Div(BinaryOpExpr):
_make.Div, a, b) _make.Div, a, b)
@register_node @register_object
class Mod(BinaryOpExpr): class Mod(BinaryOpExpr):
"""Mod node. """Mod node.
...@@ -484,7 +484,7 @@ class Mod(BinaryOpExpr): ...@@ -484,7 +484,7 @@ class Mod(BinaryOpExpr):
_make.Mod, a, b) _make.Mod, a, b)
@register_node @register_object
class FloorDiv(BinaryOpExpr): class FloorDiv(BinaryOpExpr):
"""FloorDiv node. """FloorDiv node.
...@@ -501,7 +501,7 @@ class FloorDiv(BinaryOpExpr): ...@@ -501,7 +501,7 @@ class FloorDiv(BinaryOpExpr):
_make.FloorDiv, a, b) _make.FloorDiv, a, b)
@register_node @register_object
class FloorMod(BinaryOpExpr): class FloorMod(BinaryOpExpr):
"""FloorMod node. """FloorMod node.
...@@ -518,7 +518,7 @@ class FloorMod(BinaryOpExpr): ...@@ -518,7 +518,7 @@ class FloorMod(BinaryOpExpr):
_make.FloorMod, a, b) _make.FloorMod, a, b)
@register_node @register_object
class Min(BinaryOpExpr): class Min(BinaryOpExpr):
"""Min node. """Min node.
...@@ -535,7 +535,7 @@ class Min(BinaryOpExpr): ...@@ -535,7 +535,7 @@ class Min(BinaryOpExpr):
_make.Min, a, b) _make.Min, a, b)
@register_node @register_object
class Max(BinaryOpExpr): class Max(BinaryOpExpr):
"""Max node. """Max node.
...@@ -552,7 +552,7 @@ class Max(BinaryOpExpr): ...@@ -552,7 +552,7 @@ class Max(BinaryOpExpr):
_make.Max, a, b) _make.Max, a, b)
@register_node @register_object
class EQ(CmpExpr): class EQ(CmpExpr):
"""EQ node. """EQ node.
...@@ -569,7 +569,7 @@ class EQ(CmpExpr): ...@@ -569,7 +569,7 @@ class EQ(CmpExpr):
_make.EQ, a, b) _make.EQ, a, b)
@register_node @register_object
class NE(CmpExpr): class NE(CmpExpr):
"""NE node. """NE node.
...@@ -586,7 +586,7 @@ class NE(CmpExpr): ...@@ -586,7 +586,7 @@ class NE(CmpExpr):
_make.NE, a, b) _make.NE, a, b)
@register_node @register_object
class LT(CmpExpr): class LT(CmpExpr):
"""LT node. """LT node.
...@@ -603,7 +603,7 @@ class LT(CmpExpr): ...@@ -603,7 +603,7 @@ class LT(CmpExpr):
_make.LT, a, b) _make.LT, a, b)
@register_node @register_object
class LE(CmpExpr): class LE(CmpExpr):
"""LE node. """LE node.
...@@ -620,7 +620,7 @@ class LE(CmpExpr): ...@@ -620,7 +620,7 @@ class LE(CmpExpr):
_make.LE, a, b) _make.LE, a, b)
@register_node @register_object
class GT(CmpExpr): class GT(CmpExpr):
"""GT node. """GT node.
...@@ -637,7 +637,7 @@ class GT(CmpExpr): ...@@ -637,7 +637,7 @@ class GT(CmpExpr):
_make.GT, a, b) _make.GT, a, b)
@register_node @register_object
class GE(CmpExpr): class GE(CmpExpr):
"""GE node. """GE node.
...@@ -654,7 +654,7 @@ class GE(CmpExpr): ...@@ -654,7 +654,7 @@ class GE(CmpExpr):
_make.GE, a, b) _make.GE, a, b)
@register_node @register_object
class And(LogicalExpr): class And(LogicalExpr):
"""And node. """And node.
...@@ -671,7 +671,7 @@ class And(LogicalExpr): ...@@ -671,7 +671,7 @@ class And(LogicalExpr):
_make.And, a, b) _make.And, a, b)
@register_node @register_object
class Or(LogicalExpr): class Or(LogicalExpr):
"""Or node. """Or node.
...@@ -688,7 +688,7 @@ class Or(LogicalExpr): ...@@ -688,7 +688,7 @@ class Or(LogicalExpr):
_make.Or, a, b) _make.Or, a, b)
@register_node @register_object
class Not(LogicalExpr): class Not(LogicalExpr):
"""Not node. """Not node.
...@@ -702,7 +702,7 @@ class Not(LogicalExpr): ...@@ -702,7 +702,7 @@ class Not(LogicalExpr):
_make.Not, a) _make.Not, a)
@register_node @register_object
class Select(PrimExpr): class Select(PrimExpr):
"""Select node. """Select node.
...@@ -730,7 +730,7 @@ class Select(PrimExpr): ...@@ -730,7 +730,7 @@ class Select(PrimExpr):
_make.Select, condition, true_value, false_value) _make.Select, condition, true_value, false_value)
@register_node @register_object
class Load(PrimExpr): class Load(PrimExpr):
"""Load node. """Load node.
...@@ -753,7 +753,7 @@ class Load(PrimExpr): ...@@ -753,7 +753,7 @@ class Load(PrimExpr):
_make.Load, dtype, buffer_var, index, predicate) _make.Load, dtype, buffer_var, index, predicate)
@register_node @register_object
class Ramp(PrimExpr): class Ramp(PrimExpr):
"""Ramp node. """Ramp node.
...@@ -773,7 +773,7 @@ class Ramp(PrimExpr): ...@@ -773,7 +773,7 @@ class Ramp(PrimExpr):
_make.Ramp, base, stride, lanes) _make.Ramp, base, stride, lanes)
@register_node @register_object
class Broadcast(PrimExpr): class Broadcast(PrimExpr):
"""Broadcast node. """Broadcast node.
...@@ -790,7 +790,7 @@ class Broadcast(PrimExpr): ...@@ -790,7 +790,7 @@ class Broadcast(PrimExpr):
_make.Broadcast, value, lanes) _make.Broadcast, value, lanes)
@register_node @register_object
class Shuffle(PrimExpr): class Shuffle(PrimExpr):
"""Shuffle node. """Shuffle node.
...@@ -807,7 +807,7 @@ class Shuffle(PrimExpr): ...@@ -807,7 +807,7 @@ class Shuffle(PrimExpr):
_make.Shuffle, vectors, indices) _make.Shuffle, vectors, indices)
@register_node @register_object
class Call(PrimExpr): class Call(PrimExpr):
"""Call node. """Call node.
...@@ -842,7 +842,7 @@ class Call(PrimExpr): ...@@ -842,7 +842,7 @@ class Call(PrimExpr):
_make.Call, dtype, name, args, call_type, func, value_index) _make.Call, dtype, name, args, call_type, func, value_index)
@register_node @register_object
class Let(PrimExpr): class Let(PrimExpr):
"""Let node. """Let node.
......
...@@ -24,7 +24,7 @@ from . import make as _make ...@@ -24,7 +24,7 @@ from . import make as _make
from . import ir_pass as _pass from . import ir_pass as _pass
from . import container as _container from . import container as _container
from ._ffi.base import string_types from ._ffi.base import string_types
from ._ffi.node import NodeGeneric from ._ffi.object import ObjectGeneric
from ._ffi.runtime_ctypes import TVMType from ._ffi.runtime_ctypes import TVMType
from .expr import Call as _Call from .expr import Call as _Call
...@@ -41,7 +41,7 @@ class WithScope(object): ...@@ -41,7 +41,7 @@ class WithScope(object):
self._exit_cb() self._exit_cb()
class BufferVar(NodeGeneric): class BufferVar(ObjectGeneric):
"""Buffer variable with content type, makes load store easily. """Buffer variable with content type, makes load store easily.
Do not create it directly, create use IRBuilder. Do not create it directly, create use IRBuilder.
...@@ -70,7 +70,7 @@ class BufferVar(NodeGeneric): ...@@ -70,7 +70,7 @@ class BufferVar(NodeGeneric):
self._buffer_var = buffer_var self._buffer_var = buffer_var
self._content_type = content_type self._content_type = content_type
def asnode(self): def asobject(self):
return self._buffer_var return self._buffer_var
@property @property
......
...@@ -20,6 +20,4 @@ Normally user do not need to touch this api. ...@@ -20,6 +20,4 @@ Normally user do not need to touch this api.
""" """
# pylint: disable=unused-import # pylint: disable=unused-import
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node from ._ffi.object import Object, register_object
Node = NodeBase
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# under the License. # under the License.
from typing import Union, Tuple, Dict, List from typing import Union, Tuple, Dict, List
from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId from relay.ir import GlobalId, OperatorId, Item, Object, Span, FileId
from relay.ir import ShapeExtension, Operator, Defn from relay.ir import ShapeExtension, Operator, Defn
class Module(NodeBase): ... class Module(Object): ...
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""Algebraic data types in Relay.""" """Algebraic data types in Relay."""
from .base import RelayNode, register_relay_node, NodeBase from .base import RelayNode, register_relay_node, Object
from . import _make from . import _make
from .ty import Type from .ty import Type
from .expr import Expr, Call from .expr import Expr, Call
...@@ -184,7 +184,7 @@ class TypeData(Type): ...@@ -184,7 +184,7 @@ class TypeData(Type):
@register_relay_node @register_relay_node
class Clause(NodeBase): class Clause(Object):
"""Clause for pattern matching in Relay.""" """Clause for pattern matching in Relay."""
def __init__(self, lhs, rhs): def __init__(self, lhs, rhs):
......
...@@ -17,19 +17,19 @@ ...@@ -17,19 +17,19 @@
"""Backend code generation engine.""" """Backend code generation engine."""
from __future__ import absolute_import from __future__ import absolute_import
from ..base import register_relay_node, NodeBase from ..base import register_relay_node, Object
from ... import target as _target from ... import target as _target
from .. import expr as _expr from .. import expr as _expr
from . import _backend from . import _backend
@register_relay_node @register_relay_node
class CachedFunc(NodeBase): class CachedFunc(Object):
"""Low-level tensor function to back a relay primitive function. """Low-level tensor function to back a relay primitive function.
""" """
@register_relay_node @register_relay_node
class CCacheKey(NodeBase): class CCacheKey(Object):
"""Key in the CompileEngine. """Key in the CompileEngine.
Parameters Parameters
...@@ -46,7 +46,7 @@ class CCacheKey(NodeBase): ...@@ -46,7 +46,7 @@ class CCacheKey(NodeBase):
@register_relay_node @register_relay_node
class CCacheValue(NodeBase): class CCacheValue(Object):
"""Value in the CompileEngine, including usage statistics. """Value in the CompileEngine, including usage statistics.
""" """
...@@ -64,7 +64,7 @@ def _get_cache_key(source_func, target): ...@@ -64,7 +64,7 @@ def _get_cache_key(source_func, target):
@register_relay_node @register_relay_node
class CompileEngine(NodeBase): class CompileEngine(Object):
"""CompileEngine to get lowered code. """CompileEngine to get lowered code.
""" """
def __init__(self): def __init__(self):
......
...@@ -23,27 +23,13 @@ import numpy as np ...@@ -23,27 +23,13 @@ import numpy as np
from . import _backend from . import _backend
from .. import _make, analysis, transform from .. import _make, analysis, transform
from .. import module from .. import module
from ... import register_func, nd from ... import nd
from ..base import NodeBase, register_relay_node from ..base import Object, register_relay_node
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
from ..scope_builder import ScopeBuilder from ..scope_builder import ScopeBuilder
from . import _vm
class Value(NodeBase):
"""Base class of all values.
"""
@staticmethod
@register_func("relay.from_scalar")
def from_scalar(value, dtype=None):
"""Convert a Python scalar to a Relay scalar."""
return TensorValue(const(value, dtype).data)
def to_vm(self):
return _vm._ValueToVM(self)
@register_relay_node @register_relay_node
class TupleValue(Value): class TupleValue(Object):
"""A tuple value produced by the interpreter.""" """A tuple value produced by the interpreter."""
def __init__(self, *fields): def __init__(self, *fields):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
...@@ -68,60 +54,32 @@ class TupleValue(Value): ...@@ -68,60 +54,32 @@ class TupleValue(Value):
@register_relay_node @register_relay_node
class Closure(Value): class Closure(Object):
"""A closure produced by the interpreter.""" """A closure produced by the interpreter."""
@register_relay_node @register_relay_node
class RecClosure(Value): class RecClosure(Object):
"""A recursive closure produced by the interpreter.""" """A recursive closure produced by the interpreter."""
@register_relay_node @register_relay_node
class ConstructorValue(Value): class ConstructorValue(Object):
def __init__(self, tag, fields, constructor): def __init__(self, tag, fields, constructor):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.ConstructorValue, tag, fields, constructor) _make.ConstructorValue, tag, fields, constructor)
@register_relay_node @register_relay_node
class TensorValue(Value): class RefValue(Object):
"""A Tensor value produced by the interpreter."""
def __init__(self, data):
"""Allocate a new TensorValue and copy the data from `array` into
the new array.
"""
if isinstance(data, np.ndarray):
data = nd.array(data)
self.__init_handle_by_constructor__(
_make.TensorValue, data)
def asnumpy(self):
"""Convert a Relay TensorValue into a numpy.ndarray."""
return self.data.asnumpy()
def __eq__(self, other):
return self.data == other.data
def __repr__(self):
return repr(self.data)
def __str__(self):
return str(self.data)
@register_relay_node
class RefValue(Value):
def __init__(self, value): def __init__(self, value):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.RefValue, value) _make.RefValue, value)
def _arg_to_ast(mod, arg): def _arg_to_ast(mod, arg):
if isinstance(arg, TensorValue): if isinstance(arg, nd.NDArray):
return Constant(arg.data.copyto(nd.cpu(0))) return Constant(arg.copyto(nd.cpu(0)))
elif isinstance(arg, TupleValue): elif isinstance(arg, TupleValue):
return Tuple([_arg_to_ast(mod, field) for field in arg.fields]) return Tuple([_arg_to_ast(mod, field) for field in arg.fields])
elif isinstance(arg, tuple): elif isinstance(arg, tuple):
...@@ -231,7 +189,7 @@ class Executor(object): ...@@ -231,7 +189,7 @@ class Executor(object):
Returns Returns
------- -------
val : Union[function, Value] val : Union[function, Object]
The evaluation result. The evaluation result.
""" """
if binds: if binds:
......
...@@ -31,16 +31,18 @@ from . import _vm ...@@ -31,16 +31,18 @@ from . import _vm
from . import vmobj as _obj from . import vmobj as _obj
from .interpreter import Executor from .interpreter import Executor
Tensor = _obj.Tensor
ADT = _obj.ADT ADT = _obj.ADT
def _convert(arg, cargs): def _convert(arg, cargs):
if isinstance(arg, _expr.Constant): if isinstance(arg, _expr.Constant):
cargs.append(_obj.Tensor(arg.data)) cargs.append(arg.data)
elif isinstance(arg, _obj.Object): elif isinstance(arg, _obj.Object):
cargs.append(arg) cargs.append(arg)
elif isinstance(arg, (np.ndarray, tvm.nd.NDArray)): elif isinstance(arg, np.ndarray):
cargs.append(_obj.Tensor(arg)) nd_arr = tvm.nd.array(arg, ctx=tvm.cpu(0))
cargs.append(nd_arr)
elif isinstance(arg, tvm.nd.NDArray):
cargs.append(arg)
elif isinstance(arg, (tuple, list)): elif isinstance(arg, (tuple, list)):
field_args = [] field_args = []
for field in arg: for field in arg:
...@@ -48,7 +50,7 @@ def _convert(arg, cargs): ...@@ -48,7 +50,7 @@ def _convert(arg, cargs):
cargs.append(_obj.tuple_object(field_args)) cargs.append(_obj.tuple_object(field_args))
elif isinstance(arg, (_base.numeric_types, bool)): elif isinstance(arg, (_base.numeric_types, bool)):
dtype = "int32" if isinstance(arg, (int, bool)) else "float32" dtype = "int32" if isinstance(arg, (int, bool)) else "float32"
value = _obj.Tensor(np.array(arg, dtype=dtype)) value = tvm.nd.array(np.array(arg, dtype=dtype), ctx=tvm.cpu(0))
cargs.append(value) cargs.append(value)
else: else:
raise TypeError("Unsupported type: %s" % (type(arg))) raise TypeError("Unsupported type: %s" % (type(arg)))
......
...@@ -16,51 +16,12 @@ ...@@ -16,51 +16,12 @@
# under the License. # under the License.
"""TVM Runtime Object API.""" """TVM Runtime Object API."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import numpy as _np
from tvm._ffi.object import Object, register_object, getitem_helper from tvm._ffi.object import Object, register_object, getitem_helper
from tvm import ndarray as _nd from tvm import ndarray as _nd
from . import _vmobj from . import _vmobj
@register_object("vm.Tensor")
class Tensor(Object):
"""Tensor object.
Parameters
----------
arr : numpy.ndarray or tvm.nd.NDArray
The source array.
ctx : TVMContext, optional
The device context to create the array
"""
def __init__(self, arr, ctx=None):
if isinstance(arr, _np.ndarray):
ctx = ctx if ctx else _nd.cpu(0)
self.__init_handle_by_constructor__(
_vmobj.Tensor, _nd.array(arr, ctx=ctx))
elif isinstance(arr, _nd.NDArray):
self.__init_handle_by_constructor__(
_vmobj.Tensor, arr)
else:
raise RuntimeError("Unsupported type for tensor object.")
@property
def data(self):
return _vmobj.GetTensorData(self)
def asnumpy(self):
"""Convert data to numpy array
Returns
-------
np_arr : numpy.ndarray
The corresponding numpy array.
"""
return self.data.asnumpy()
@register_object("vm.ADT") @register_object("vm.ADT")
class ADT(Object): class ADT(Object):
"""Algebatic data type(ADT) object. """Algebatic data type(ADT) object.
...@@ -75,7 +36,8 @@ class ADT(Object): ...@@ -75,7 +36,8 @@ class ADT(Object):
""" """
def __init__(self, tag, fields): def __init__(self, tag, fields):
for f in fields: for f in fields:
assert isinstance(f, Object) assert isinstance(f, (Object, _nd.NDArray)), "Expect object or "
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_vmobj.ADT, tag, *fields) _vmobj.ADT, tag, *fields)
...@@ -105,5 +67,6 @@ def tuple_object(fields): ...@@ -105,5 +67,6 @@ def tuple_object(fields):
The created object. The created object.
""" """
for f in fields: for f in fields:
assert isinstance(f, Object) assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm "
"NDArray type, but received : {0}".format(type(f))
return _vmobj.Tuple(*fields) return _vmobj.Tuple(*fields)
...@@ -17,12 +17,13 @@ ...@@ -17,12 +17,13 @@
# pylint: disable=no-else-return, unidiomatic-typecheck # pylint: disable=no-else-return, unidiomatic-typecheck
"""The base node types for the Relay language.""" """The base node types for the Relay language."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .._ffi.node import NodeBase, register_node as _register_tvm_node from .._ffi.object import register_object as _register_tvm_node
from .._ffi.object import Object
from . import _make from . import _make
from . import _expr from . import _expr
from . import _base from . import _base
NodeBase = NodeBase Object = Object
def register_relay_node(type_key=None): def register_relay_node(type_key=None):
"""Register a Relay node type. """Register a Relay node type.
...@@ -52,7 +53,7 @@ def register_relay_attr_node(type_key=None): ...@@ -52,7 +53,7 @@ def register_relay_attr_node(type_key=None):
return _register_tvm_node(type_key) return _register_tvm_node(type_key)
class RelayNode(NodeBase): class RelayNode(Object):
"""Base class of all Relay nodes.""" """Base class of all Relay nodes."""
def astext(self, show_meta_data=True, annotate=None): def astext(self, show_meta_data=True, annotate=None):
"""Get the text format of the expression. """Get the text format of the expression.
...@@ -102,7 +103,7 @@ class SourceName(RelayNode): ...@@ -102,7 +103,7 @@ class SourceName(RelayNode):
self.__init_handle_by_constructor__(_make.SourceName, name) self.__init_handle_by_constructor__(_make.SourceName, name)
@register_relay_node @register_relay_node
class Id(NodeBase): class Id(Object):
"""Unique identifier(name) used in Var. """Unique identifier(name) used in Var.
Guaranteed to be stable across all passes. Guaranteed to be stable across all passes.
""" """
......
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
from typing import List from typing import List
import tvm import tvm
from .base import Span, NodeBase from .base import Span, Object
from .ty import Type, TypeParam from .ty import Type, TypeParam
from ._analysis import _get_checked_type from ._analysis import _get_checked_type
class Expr(NodeBase): class Expr(Object):
def checked_type(self): def checked_type(self):
... ...
......
...@@ -22,7 +22,7 @@ from ._calibrate import calibrate ...@@ -22,7 +22,7 @@ from ._calibrate import calibrate
from .. import expr as _expr from .. import expr as _expr
from .. import transform as _transform from .. import transform as _transform
from ... import make as _make from ... import make as _make
from ..base import NodeBase, register_relay_node from ..base import Object, register_relay_node
class QAnnotateKind(object): class QAnnotateKind(object):
...@@ -53,7 +53,7 @@ def _forward_op(ref_call, args): ...@@ -53,7 +53,7 @@ def _forward_op(ref_call, args):
@register_relay_node("relay.quantize.QConfig") @register_relay_node("relay.quantize.QConfig")
class QConfig(NodeBase): class QConfig(Object):
"""Configure the quantization behavior by setting config variables. """Configure the quantization behavior by setting config variables.
Note Note
......
...@@ -32,15 +32,16 @@ OUTPUT_VAR_NAME = '_py_out' ...@@ -32,15 +32,16 @@ OUTPUT_VAR_NAME = '_py_out'
# import numpy # import numpy
# import tvm # import tvm
# from tvm import relay # from tvm import relay
# from tvm.relay.backend.interpreter import RefValue, TupleValue, TensorValue, ConstructorValue # from tvm import nd
# from tvm.relay.backend.interpreter import RefValue, TupleValue, ConstructorValue
PROLOGUE = [ PROLOGUE = [
ast.Import([alias('numpy', None)]), ast.Import([alias('numpy', None)]),
ast.Import([alias('tvm', None)]), ast.Import([alias('tvm', None)]),
ast.ImportFrom('tvm', [alias('relay', None)], 0), ast.ImportFrom('tvm', [alias('relay', None)], 0),
ast.ImportFrom('tvm', [alias('nd', None)], 0),
ast.ImportFrom('tvm.relay.backend.interpreter', ast.ImportFrom('tvm.relay.backend.interpreter',
[alias('RefValue', None), [alias('RefValue', None),
alias('TupleValue', None), alias('TupleValue', None),
alias('TensorValue', None),
alias('ConstructorValue', None)], alias('ConstructorValue', None)],
0) 0)
] ]
...@@ -245,7 +246,7 @@ class PythonConverter(ExprFunctor): ...@@ -245,7 +246,7 @@ class PythonConverter(ExprFunctor):
a tensor or tuple (returns list of inputs to the lowered op call)""" a tensor or tuple (returns list of inputs to the lowered op call)"""
# equivalent: input.data # equivalent: input.data
if isinstance(arg_type, relay.TensorType): if isinstance(arg_type, relay.TensorType):
return [ast.Attribute(py_input, 'data', Load())] return [py_input]
assert isinstance(arg_type, relay.TupleType) assert isinstance(arg_type, relay.TupleType)
# convert each input.fields[i] # convert each input.fields[i]
ret = [] ret = []
...@@ -265,15 +266,13 @@ class PythonConverter(ExprFunctor): ...@@ -265,15 +266,13 @@ class PythonConverter(ExprFunctor):
output_var_name = self.generate_var_name('_out') output_var_name = self.generate_var_name('_out')
output_var = Name(output_var_name, Load()) output_var = Name(output_var_name, Load())
shape = ast.Tuple([Num(dim) for dim in ret_type.concrete_shape], Load()) shape = ast.Tuple([Num(dim) for dim in ret_type.concrete_shape], Load())
# create a new TensorValue of the right shape and dtype # create a new NDArray of the right shape and dtype
assign_output = Assign( assign_output = Assign(
[Name(output_var_name, Store())], [Name(output_var_name, Store())],
self.create_call('TensorValue', [ self.create_call('nd.array', [
self.create_call('numpy.empty', [shape, Str(ret_type.dtype)]) self.create_call('numpy.empty', [shape, Str(ret_type.dtype)])
])) ]))
# we pass the data field as an argument return ([assign_output], [output_var], output_var)
extra_arg = ast.Attribute(output_var, 'data', Load())
return ([assign_output], [extra_arg], output_var)
assert isinstance(ret_type, relay.TupleType) assert isinstance(ret_type, relay.TupleType)
assignments = [] assignments = []
extra_args = [] extra_args = []
...@@ -459,7 +458,7 @@ class PythonConverter(ExprFunctor): ...@@ -459,7 +458,7 @@ class PythonConverter(ExprFunctor):
true_body, true_defs = self.visit(if_block.true_branch) true_body, true_defs = self.visit(if_block.true_branch)
false_body, false_defs = self.visit(if_block.false_branch) false_body, false_defs = self.visit(if_block.false_branch)
# need to get the value out of a TensorValue to check the condition # need to get the value out of a NDArray to check the condition
# equvialent to: val.asnumpy() # equvialent to: val.asnumpy()
cond_check = ast.Call(ast.Attribute(cond_body, 'asnumpy', Load()), [], []) cond_check = ast.Call(ast.Attribute(cond_body, 'asnumpy', Load()), [], [])
ret = ast.IfExp(cond_check, true_body, false_body) ret = ast.IfExp(cond_check, true_body, false_body)
...@@ -474,7 +473,7 @@ class PythonConverter(ExprFunctor): ...@@ -474,7 +473,7 @@ class PythonConverter(ExprFunctor):
const_expr = ast.Call(ast.Attribute(Name('numpy', Load()), 'array', Load()), const_expr = ast.Call(ast.Attribute(Name('numpy', Load()), 'array', Load()),
[self.parse_numpy_array(value)], [self.parse_numpy_array(value)],
[ast.keyword('dtype', Str(constant.checked_type.dtype))]) [ast.keyword('dtype', Str(constant.checked_type.dtype))])
return (self.create_call('TensorValue', [const_expr]), []) return (self.create_call('nd.array', [const_expr]), [])
def visit_function(self, func: Expr): def visit_function(self, func: Expr):
......
...@@ -16,14 +16,14 @@ ...@@ -16,14 +16,14 @@
# under the License. # under the License.
import tvm import tvm
from .base import NodeBase from .base import Object
class PassContext(NodeBase): class PassContext(Object):
def __init__(self): def __init__(self):
... ...
class PassInfo(NodeBase): class PassInfo(Object):
name = ... # type: str name = ... # type: str
opt_level = ... # type: int opt_level = ... # type: int
required = ... # type: list required = ... # type: list
...@@ -32,7 +32,7 @@ class PassInfo(NodeBase): ...@@ -32,7 +32,7 @@ class PassInfo(NodeBase):
# type: (str, int, list) -> None # type: (str, int, list) -> None
class Pass(NodeBase): class Pass(Object):
def __init__(self): def __init__(self):
... ...
......
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language.""" """The type nodes of the Relay language."""
from enum import IntEnum from enum import IntEnum
from .base import NodeBase, register_relay_node from .base import Object, register_relay_node
from . import _make from . import _make
class Type(NodeBase): class Type(Object):
"""The base type for all Relay types.""" """The base type for all Relay types."""
def __eq__(self, other): def __eq__(self, other):
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
"""The computation schedule api of TVM.""" """The computation schedule api of TVM."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.base import string_types from ._ffi.base import string_types
from ._ffi.node import NodeBase, register_node from ._ffi.object import Object, register_object
from ._ffi.node import convert_to_node as _convert_to_node from ._ffi.object import convert_to_object as _convert_to_object
from ._ffi.function import _init_api, Function from ._ffi.function import _init_api, Function
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
from . import _api_internal from . import _api_internal
...@@ -27,7 +27,7 @@ from . import expr as _expr ...@@ -27,7 +27,7 @@ from . import expr as _expr
from . import container as _container from . import container as _container
def convert(value): def convert(value):
"""Convert value to TVM node or function. """Convert value to TVM object or function.
Parameters Parameters
---------- ----------
...@@ -35,19 +35,19 @@ def convert(value): ...@@ -35,19 +35,19 @@ def convert(value):
Returns Returns
------- -------
tvm_val : Node or Function tvm_val : Object or Function
Converted value in TVM Converted value in TVM
""" """
if isinstance(value, (Function, NodeBase)): if isinstance(value, (Function, Object)):
return value return value
if callable(value): if callable(value):
return _convert_tvm_func(value) return _convert_tvm_func(value)
return _convert_to_node(value) return _convert_to_object(value)
@register_node @register_object
class Buffer(NodeBase): class Buffer(Object):
"""Symbolic data buffer in TVM. """Symbolic data buffer in TVM.
Buffer provide a way to represent data layout Buffer provide a way to represent data layout
...@@ -156,23 +156,23 @@ class Buffer(NodeBase): ...@@ -156,23 +156,23 @@ class Buffer(NodeBase):
return _api_internal._BufferVStore(self, begin, value) return _api_internal._BufferVStore(self, begin, value)
@register_node @register_object
class Split(NodeBase): class Split(Object):
"""Split operation on axis.""" """Split operation on axis."""
@register_node @register_object
class Fuse(NodeBase): class Fuse(Object):
"""Fuse operation on axis.""" """Fuse operation on axis."""
@register_node @register_object
class Singleton(NodeBase): class Singleton(Object):
"""Singleton axis.""" """Singleton axis."""
@register_node @register_object
class IterVar(NodeBase, _expr.ExprOp): class IterVar(Object, _expr.ExprOp):
"""Represent iteration variable. """Represent iteration variable.
IterVar is normally created by Operation, to represent IterVar is normally created by Operation, to represent
...@@ -214,8 +214,8 @@ def create_schedule(ops): ...@@ -214,8 +214,8 @@ def create_schedule(ops):
return _api_internal._CreateSchedule(ops) return _api_internal._CreateSchedule(ops)
@register_node @register_object
class Schedule(NodeBase): class Schedule(Object):
"""Schedule for all the stages.""" """Schedule for all the stages."""
def __getitem__(self, k): def __getitem__(self, k):
if isinstance(k, _tensor.Tensor): if isinstance(k, _tensor.Tensor):
...@@ -348,8 +348,8 @@ class Schedule(NodeBase): ...@@ -348,8 +348,8 @@ class Schedule(NodeBase):
return factored[0] if len(factored) == 1 else factored return factored[0] if len(factored) == 1 else factored
@register_node @register_object
class Stage(NodeBase): class Stage(Object):
"""A Stage represents schedule for one operation.""" """A Stage represents schedule for one operation."""
def split(self, parent, factor=None, nparts=None): def split(self, parent, factor=None, nparts=None):
"""Split the stage either by factor providing outer scope, or both """Split the stage either by factor providing outer scope, or both
......
...@@ -30,14 +30,14 @@ Each statement node have subfields that can be visited from python side. ...@@ -30,14 +30,14 @@ Each statement node have subfields that can be visited from python side.
assert(st.buffer_var == a) assert(st.buffer_var == a)
""" """
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node from ._ffi.object import Object, register_object
from . import make as _make from . import make as _make
class Stmt(NodeBase): class Stmt(Object):
pass pass
@register_node @register_object
class LetStmt(Stmt): class LetStmt(Stmt):
"""LetStmt node. """LetStmt node.
...@@ -57,7 +57,7 @@ class LetStmt(Stmt): ...@@ -57,7 +57,7 @@ class LetStmt(Stmt):
_make.LetStmt, var, value, body) _make.LetStmt, var, value, body)
@register_node @register_object
class AssertStmt(Stmt): class AssertStmt(Stmt):
"""AssertStmt node. """AssertStmt node.
...@@ -77,7 +77,7 @@ class AssertStmt(Stmt): ...@@ -77,7 +77,7 @@ class AssertStmt(Stmt):
_make.AssertStmt, condition, message, body) _make.AssertStmt, condition, message, body)
@register_node @register_object
class ProducerConsumer(Stmt): class ProducerConsumer(Stmt):
"""ProducerConsumer node. """ProducerConsumer node.
...@@ -97,7 +97,7 @@ class ProducerConsumer(Stmt): ...@@ -97,7 +97,7 @@ class ProducerConsumer(Stmt):
_make.ProducerConsumer, func, is_producer, body) _make.ProducerConsumer, func, is_producer, body)
@register_node @register_object
class For(Stmt): class For(Stmt):
"""For node. """For node.
...@@ -137,7 +137,7 @@ class For(Stmt): ...@@ -137,7 +137,7 @@ class For(Stmt):
for_type, device_api, body) for_type, device_api, body)
@register_node @register_object
class Store(Stmt): class Store(Stmt):
"""Store node. """Store node.
...@@ -160,7 +160,7 @@ class Store(Stmt): ...@@ -160,7 +160,7 @@ class Store(Stmt):
_make.Store, buffer_var, value, index, predicate) _make.Store, buffer_var, value, index, predicate)
@register_node @register_object
class Provide(Stmt): class Provide(Stmt):
"""Provide node. """Provide node.
...@@ -183,7 +183,7 @@ class Provide(Stmt): ...@@ -183,7 +183,7 @@ class Provide(Stmt):
_make.Provide, func, value_index, value, args) _make.Provide, func, value_index, value, args)
@register_node @register_object
class Allocate(Stmt): class Allocate(Stmt):
"""Allocate node. """Allocate node.
...@@ -215,7 +215,7 @@ class Allocate(Stmt): ...@@ -215,7 +215,7 @@ class Allocate(Stmt):
extents, condition, body) extents, condition, body)
@register_node @register_object
class AttrStmt(Stmt): class AttrStmt(Stmt):
"""AttrStmt node. """AttrStmt node.
...@@ -238,7 +238,7 @@ class AttrStmt(Stmt): ...@@ -238,7 +238,7 @@ class AttrStmt(Stmt):
_make.AttrStmt, node, attr_key, value, body) _make.AttrStmt, node, attr_key, value, body)
@register_node @register_object
class Free(Stmt): class Free(Stmt):
"""Free node. """Free node.
...@@ -252,7 +252,7 @@ class Free(Stmt): ...@@ -252,7 +252,7 @@ class Free(Stmt):
_make.Free, buffer_var) _make.Free, buffer_var)
@register_node @register_object
class Realize(Stmt): class Realize(Stmt):
"""Realize node. """Realize node.
...@@ -288,7 +288,7 @@ class Realize(Stmt): ...@@ -288,7 +288,7 @@ class Realize(Stmt):
bounds, condition, body) bounds, condition, body)
@register_node @register_object
class SeqStmt(Stmt): class SeqStmt(Stmt):
"""Sequence of statements. """Sequence of statements.
...@@ -308,7 +308,7 @@ class SeqStmt(Stmt): ...@@ -308,7 +308,7 @@ class SeqStmt(Stmt):
return len(self.seq) return len(self.seq)
@register_node @register_object
class IfThenElse(Stmt): class IfThenElse(Stmt):
"""IfThenElse node. """IfThenElse node.
...@@ -328,7 +328,7 @@ class IfThenElse(Stmt): ...@@ -328,7 +328,7 @@ class IfThenElse(Stmt):
_make.IfThenElse, condition, then_case, else_case) _make.IfThenElse, condition, then_case, else_case)
@register_node @register_object
class Evaluate(Stmt): class Evaluate(Stmt):
"""Evaluate node. """Evaluate node.
...@@ -342,7 +342,7 @@ class Evaluate(Stmt): ...@@ -342,7 +342,7 @@ class Evaluate(Stmt):
_make.Evaluate, value) _make.Evaluate, value)
@register_node @register_object
class Prefetch(Stmt): class Prefetch(Stmt):
"""Prefetch node. """Prefetch node.
......
...@@ -59,7 +59,7 @@ from __future__ import absolute_import ...@@ -59,7 +59,7 @@ from __future__ import absolute_import
import warnings import warnings
from ._ffi.base import _LIB_NAME from ._ffi.base import _LIB_NAME
from ._ffi.node import NodeBase, register_node from ._ffi.object import Object, register_object
from . import _api_internal from . import _api_internal
try: try:
...@@ -80,8 +80,8 @@ def _merge_opts(opts, new_opts): ...@@ -80,8 +80,8 @@ def _merge_opts(opts, new_opts):
return opts return opts
@register_node @register_object
class Target(NodeBase): class Target(Object):
"""Target device information, use through TVM API. """Target device information, use through TVM API.
Note Note
...@@ -97,7 +97,7 @@ class Target(NodeBase): ...@@ -97,7 +97,7 @@ class Target(NodeBase):
""" """
def __new__(cls): def __new__(cls):
# Always override new to enable class # Always override new to enable class
obj = NodeBase.__new__(cls) obj = Object.__new__(cls)
obj._keys = None obj._keys = None
obj._options = None obj._options = None
obj._libs = None obj._libs = None
...@@ -146,8 +146,8 @@ class Target(NodeBase): ...@@ -146,8 +146,8 @@ class Target(NodeBase):
_api_internal._ExitTargetScope(self) _api_internal._ExitTargetScope(self)
@register_node @register_object
class GenericFunc(NodeBase): class GenericFunc(Object):
"""GenericFunc node reference. This represents a generic function """GenericFunc node reference. This represents a generic function
that may be specialized for different targets. When this object is that may be specialized for different targets. When this object is
called, a specialization is chosen based on the current target. called, a specialization is chosen based on the current target.
......
...@@ -17,13 +17,14 @@ ...@@ -17,13 +17,14 @@
"""Tensor and Operation class for computation declaration.""" """Tensor and Operation class for computation declaration."""
# pylint: disable=invalid-name # pylint: disable=invalid-name
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, NodeGeneric, register_node, convert_to_node from ._ffi.object import Object, register_object, ObjectGeneric, \
convert_to_object
from . import _api_internal from . import _api_internal
from . import make as _make from . import make as _make
from . import expr as _expr from . import expr as _expr
class TensorSlice(NodeGeneric, _expr.ExprOp): class TensorSlice(ObjectGeneric, _expr.ExprOp):
"""Auxiliary data structure for enable slicing syntax from tensor.""" """Auxiliary data structure for enable slicing syntax from tensor."""
def __init__(self, tensor, indices): def __init__(self, tensor, indices):
...@@ -37,8 +38,8 @@ class TensorSlice(NodeGeneric, _expr.ExprOp): ...@@ -37,8 +38,8 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
indices = (indices,) indices = (indices,)
return TensorSlice(self.tensor, self.indices + indices) return TensorSlice(self.tensor, self.indices + indices)
def asnode(self): def asobject(self):
"""Convert slice to node.""" """Convert slice to object."""
return self.tensor(*self.indices) return self.tensor(*self.indices)
@property @property
...@@ -46,23 +47,23 @@ class TensorSlice(NodeGeneric, _expr.ExprOp): ...@@ -46,23 +47,23 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
"""Data content of the tensor.""" """Data content of the tensor."""
return self.tensor.dtype return self.tensor.dtype
@register_node @register_object
class TensorIntrinCall(NodeBase): class TensorIntrinCall(Object):
"""Intermediate structure for calling a tensor intrinsic.""" """Intermediate structure for calling a tensor intrinsic."""
itervar_cls = None itervar_cls = None
@register_node @register_object
class Tensor(NodeBase, _expr.ExprOp): class Tensor(Object, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor""" """Tensor object, to construct, see function.Tensor"""
def __call__(self, *indices): def __call__(self, *indices):
ndim = self.ndim ndim = self.ndim
if len(indices) != ndim: if len(indices) != ndim:
raise ValueError("Need to provide %d index in tensor slice" % ndim) raise ValueError("Need to provide %d index in tensor slice" % ndim)
indices = convert_to_node(indices) indices = convert_to_object(indices)
args = [] args = []
for x in indices: for x in indices:
if isinstance(x, _expr.PrimExpr): if isinstance(x, _expr.PrimExpr):
...@@ -127,7 +128,7 @@ class Tensor(NodeBase, _expr.ExprOp): ...@@ -127,7 +128,7 @@ class Tensor(NodeBase, _expr.ExprOp):
class Operation(NodeBase): class Operation(Object):
"""Represent an operation that generates a tensor""" """Represent an operation that generates a tensor"""
def output(self, index): def output(self, index):
...@@ -156,12 +157,12 @@ class Operation(NodeBase): ...@@ -156,12 +157,12 @@ class Operation(NodeBase):
return _api_internal._OpInputTensors(self) return _api_internal._OpInputTensors(self)
@register_node @register_object
class PlaceholderOp(Operation): class PlaceholderOp(Operation):
"""Placeholder operation.""" """Placeholder operation."""
@register_node @register_object
class BaseComputeOp(Operation): class BaseComputeOp(Operation):
"""Compute operation.""" """Compute operation."""
@property @property
...@@ -175,18 +176,18 @@ class BaseComputeOp(Operation): ...@@ -175,18 +176,18 @@ class BaseComputeOp(Operation):
return self.__getattr__("reduce_axis") return self.__getattr__("reduce_axis")
@register_node @register_object
class ComputeOp(BaseComputeOp): class ComputeOp(BaseComputeOp):
"""Scalar operation.""" """Scalar operation."""
pass pass
@register_node @register_object
class TensorComputeOp(BaseComputeOp): class TensorComputeOp(BaseComputeOp):
"""Tensor operation.""" """Tensor operation."""
@register_node @register_object
class ScanOp(Operation): class ScanOp(Operation):
"""Scan operation.""" """Scan operation."""
@property @property
...@@ -195,12 +196,12 @@ class ScanOp(Operation): ...@@ -195,12 +196,12 @@ class ScanOp(Operation):
return self.__getattr__("scan_axis") return self.__getattr__("scan_axis")
@register_node @register_object
class ExternOp(Operation): class ExternOp(Operation):
"""External operation.""" """External operation."""
@register_node @register_object
class HybridOp(Operation): class HybridOp(Operation):
"""Hybrid operation.""" """Hybrid operation."""
@property @property
...@@ -209,8 +210,8 @@ class HybridOp(Operation): ...@@ -209,8 +210,8 @@ class HybridOp(Operation):
return self.__getattr__("axis") return self.__getattr__("axis")
@register_node @register_object
class Layout(NodeBase): class Layout(Object):
"""Layout is composed of upper cases, lower cases and numbers, """Layout is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and where upper case indicates a primal axis and
the corresponding lower case with factor size indicates the subordinate axis. the corresponding lower case with factor size indicates the subordinate axis.
...@@ -269,8 +270,8 @@ class Layout(NodeBase): ...@@ -269,8 +270,8 @@ class Layout(NodeBase):
return _api_internal._LayoutFactorOf(self, axis) return _api_internal._LayoutFactorOf(self, axis)
@register_node @register_object
class BijectiveLayout(NodeBase): class BijectiveLayout(Object):
"""Bijective mapping for two layouts (src-layout and dst-layout). """Bijective mapping for two layouts (src-layout and dst-layout).
It provides shape and index conversion between each other. It provides shape and index conversion between each other.
......
...@@ -24,7 +24,7 @@ from . import make as _make ...@@ -24,7 +24,7 @@ from . import make as _make
from . import tensor as _tensor from . import tensor as _tensor
from . import schedule as _schedule from . import schedule as _schedule
from .build_module import current_build_config from .build_module import current_build_config
from ._ffi.node import NodeBase, register_node from ._ffi.object import Object, register_object
def _get_region(tslice): def _get_region(tslice):
...@@ -41,8 +41,8 @@ def _get_region(tslice): ...@@ -41,8 +41,8 @@ def _get_region(tslice):
region.append(_make.range_by_min_extent(begin, 1)) region.append(_make.range_by_min_extent(begin, 1))
return region return region
@register_node @register_object
class TensorIntrin(NodeBase): class TensorIntrin(Object):
"""Tensor intrinsic functions for certain computation. """Tensor intrinsic functions for certain computation.
See Also See Also
......
...@@ -854,7 +854,7 @@ void VMCompiler::Lower(Module mod, ...@@ -854,7 +854,7 @@ void VMCompiler::Lower(Module mod,
// populate constants // populate constants
for (auto data : context_.constants) { for (auto data : context_.constants) {
exec_->constants.push_back(vm::Tensor(data)); exec_->constants.push_back(data);
} }
// update global function map // update global function map
......
...@@ -27,12 +27,14 @@ ...@@ -27,12 +27,14 @@
#include <tvm/relay/interpreter.h> #include <tvm/relay/interpreter.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include "./pattern_util.h" #include <tvm/runtime/object.h>
#include <tvm/runtime/ndarray.h>
#include "pattern_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>; using FInterpreter = runtime::TypedPackedFunc<ObjectRef(Expr)>;
class ConstantChecker : private ExprVisitor { class ConstantChecker : private ExprVisitor {
public: public:
...@@ -177,17 +179,18 @@ class ConstantFolder : public ExprMutator { ...@@ -177,17 +179,18 @@ class ConstantFolder : public ExprMutator {
const Op& cast_op_; const Op& cast_op_;
// Convert value to expression. // Convert value to expression.
Expr ValueToExpr(Value value) { Expr ObjectToExpr(const ObjectRef& value) {
if (const auto* val = value.as<TensorValueNode>()) { if (value->IsInstance<runtime::NDArray::ContainerType>()) {
for (auto dim : val->data.Shape()) { auto nd_array = Downcast<runtime::NDArray>(value);
for (auto dim : nd_array.Shape()) {
CHECK_GT(dim, 0) CHECK_GT(dim, 0)
<< "invalid dimension after constant eval"; << "invalid dimension after constant eval";
} }
return ConstantNode::make(val->data); return ConstantNode::make(nd_array);
} else if (const auto* val = value.as<TupleValueNode>()) { } else if (const auto* val = value.as<TupleValueNode>()) {
Array<Expr> fields; Array<Expr> fields;
for (Value field : val->fields) { for (ObjectRef field : val->fields) {
fields.push_back(ValueToExpr(field)); fields.push_back(ObjectToExpr(field));
} }
return TupleNode::make(fields); return TupleNode::make(fields);
} else { } else {
...@@ -216,7 +219,7 @@ class ConstantFolder : public ExprMutator { ...@@ -216,7 +219,7 @@ class ConstantFolder : public ExprMutator {
mod = seq(mod); mod = seq(mod);
auto entry_func = mod->Lookup("main"); auto entry_func = mod->Lookup("main");
expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func; expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
return ValueToExpr(executor_(expr)); return ObjectToExpr(executor_(expr));
} }
// Evaluate a call to the shape_of operator for tensors with constant // Evaluate a call to the shape_of operator for tensors with constant
...@@ -258,7 +261,7 @@ class ConstantFolder : public ExprMutator { ...@@ -258,7 +261,7 @@ class ConstantFolder : public ExprMutator {
} }
} }
Constant shape = Downcast<Constant>(ValueToExpr(TensorValueNode::make(value))); Constant shape = Downcast<Constant>(ObjectToExpr(value));
if (shape->data.Shape().size() == 0 && GetScalarFromConstant<int32_t>(shape) == 0) { if (shape->data.Shape().size() == 0 && GetScalarFromConstant<int32_t>(shape) == 0) {
auto ndarray = runtime::NDArray::Empty({}, cdtype, ctx); auto ndarray = runtime::NDArray::Empty({}, cdtype, ctx);
...@@ -283,8 +286,7 @@ Expr FoldConstant(const Expr& expr, const Module& mod) { ...@@ -283,8 +286,7 @@ Expr FoldConstant(const Expr& expr, const Module& mod) {
// in case we are already in a build context. // in case we are already in a build context.
With<BuildConfig> fresh_build_ctx(BuildConfig::Create()); With<BuildConfig> fresh_build_ctx(BuildConfig::Create());
return ConstantFolder(CreateInterpreter( return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr);
mod, ctx, target), mod).Mutate(expr);
} }
namespace transform { namespace transform {
......
...@@ -403,7 +403,7 @@ Fuel MkFTop() { ...@@ -403,7 +403,7 @@ Fuel MkFTop() {
/*! /*!
* \brief A stack frame in the Relay interpreter. * \brief A stack frame in the Relay interpreter.
* *
* Contains a mapping from relay::Var to relay::Value. * Contains a mapping from relay::Var to relay::Object.
*/ */
struct Frame { struct Frame {
/*! \brief The set of local variables and arguments for the frame. */ /*! \brief The set of local variables and arguments for the frame. */
...@@ -554,7 +554,7 @@ bool StatefulOp(const Expr& e) { ...@@ -554,7 +554,7 @@ bool StatefulOp(const Expr& e) {
return sov.stateful; return sov.stateful;
} }
using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>; using FInterpreter = runtime::TypedPackedFunc<ObjectRef(Expr)>;
DLContext CPUContext() { DLContext CPUContext() {
DLContext ctx; DLContext ctx;
...@@ -925,13 +925,14 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -925,13 +925,14 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
} }
} }
PStatic Reify(const Value& v, LetList* ll) const { PStatic Reify(const ObjectRef& v, LetList* ll) const {
if (const TensorValueNode* op = v.as<TensorValueNode>()) { if (v->IsInstance<runtime::NDArray::ContainerType>()) {
return HasStatic(MkSTensor(op->data), ll->Push(ConstantNode::make(op->data))); auto nd_array = Downcast<runtime::NDArray>(v);
return HasStatic(MkSTensor(nd_array), ll->Push(ConstantNode::make(nd_array)));
} else if (const TupleValueNode* op = v.as<TupleValueNode>()) { } else if (const TupleValueNode* op = v.as<TupleValueNode>()) {
std::vector<PStatic> fields; std::vector<PStatic> fields;
tvm::Array<Expr> fields_dyn; tvm::Array<Expr> fields_dyn;
for (const Value& field : op->fields) { for (const ObjectRef& field : op->fields) {
PStatic ps = Reify(field, ll); PStatic ps = Reify(field, ll);
fields.push_back(ps); fields.push_back(ps);
fields_dyn.push_back(ps->dynamic); fields_dyn.push_back(ps->dynamic);
......
...@@ -150,10 +150,8 @@ std::string Executable::Stats() const { ...@@ -150,10 +150,8 @@ std::string Executable::Stats() const {
// Get the number of constants and the shape of each of them. // Get the number of constants and the shape of each of them.
oss << " Constant shapes (# " << constants.size() << "): ["; oss << " Constant shapes (# " << constants.size() << "): [";
for (const auto& it : constants) { for (const auto& it : constants) {
const auto* cell = it.as<TensorObj>(); const auto constant = Downcast<NDArray>(it);
CHECK(cell); const auto& shape = constant.Shape();
runtime::NDArray data = cell->data;
const auto& shape = data.Shape();
// Scalar // Scalar
if (shape.empty()) { if (shape.empty()) {
...@@ -250,10 +248,8 @@ void Executable::SaveGlobalSection(dmlc::Stream* strm) { ...@@ -250,10 +248,8 @@ void Executable::SaveGlobalSection(dmlc::Stream* strm) {
void Executable::SaveConstantSection(dmlc::Stream* strm) { void Executable::SaveConstantSection(dmlc::Stream* strm) {
std::vector<DLTensor*> arrays; std::vector<DLTensor*> arrays;
for (const auto& obj : this->constants) { for (const auto& obj : this->constants) {
const auto* cell = obj.as<runtime::vm::TensorObj>(); const auto cell = Downcast<runtime::NDArray>(obj);
CHECK(cell != nullptr); arrays.push_back(const_cast<DLTensor*>(cell.operator->()));
runtime::NDArray data = cell->data;
arrays.push_back(const_cast<DLTensor*>(data.operator->()));
} }
strm->Write(static_cast<uint64_t>(this->constants.size())); strm->Write(static_cast<uint64_t>(this->constants.size()));
for (const auto& it : arrays) { for (const auto& it : arrays) {
...@@ -513,8 +509,7 @@ void Executable::LoadConstantSection(dmlc::Stream* strm) { ...@@ -513,8 +509,7 @@ void Executable::LoadConstantSection(dmlc::Stream* strm) {
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
runtime::NDArray constant; runtime::NDArray constant;
STREAM_CHECK(constant.Load(strm), "constant"); STREAM_CHECK(constant.Load(strm), "constant");
runtime::ObjectRef obj = runtime::vm::Tensor(constant); this->constants.push_back(constant);
this->constants.push_back(obj);
} }
} }
......
...@@ -34,12 +34,6 @@ namespace tvm { ...@@ -34,12 +34,6 @@ namespace tvm {
namespace runtime { namespace runtime {
namespace vm { namespace vm {
Tensor::Tensor(NDArray data) {
auto ptr = make_object<TensorObj>();
ptr->data = std::move(data);
data_ = std::move(ptr);
}
Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) { Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
auto ptr = make_object<ClosureObj>(); auto ptr = make_object<ClosureObj>();
ptr->func_index = func_index; ptr->func_index = func_index;
...@@ -48,14 +42,6 @@ Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) { ...@@ -48,14 +42,6 @@ Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
} }
TVM_REGISTER_GLOBAL("_vmobj.GetTensorData")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto* cell = obj.as<TensorObj>();
CHECK(cell != nullptr);
*rv = cell->data;
});
TVM_REGISTER_GLOBAL("_vmobj.GetADTTag") TVM_REGISTER_GLOBAL("_vmobj.GetADTTag")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0]; ObjectRef obj = args[0];
...@@ -80,11 +66,6 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTFields") ...@@ -80,11 +66,6 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTFields")
*rv = adt[idx]; *rv = adt[idx];
}); });
TVM_REGISTER_GLOBAL("_vmobj.Tensor")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = Tensor(args[0].operator NDArray());
});
TVM_REGISTER_GLOBAL("_vmobj.Tuple") TVM_REGISTER_GLOBAL("_vmobj.Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<ObjectRef> fields; std::vector<ObjectRef> fields;
...@@ -105,7 +86,6 @@ TVM_REGISTER_GLOBAL("_vmobj.ADT") ...@@ -105,7 +86,6 @@ TVM_REGISTER_GLOBAL("_vmobj.ADT")
*rv = ADT(tag, fields); *rv = ADT(tag, fields);
}); });
TVM_REGISTER_OBJECT_TYPE(TensorObj);
TVM_REGISTER_OBJECT_TYPE(ADTObj); TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj); TVM_REGISTER_OBJECT_TYPE(ClosureObj);
} // namespace vm } // namespace vm
......
...@@ -613,18 +613,14 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) { ...@@ -613,18 +613,14 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) {
return os; return os;
} }
ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) {
if (const TensorObj* obj = src.as<TensorObj>()) { if (src->IsInstance<NDArray::ContainerType>()) {
auto tensor = obj->data; auto nd_array = Downcast<NDArray>(src);
if (tensor->ctx.device_type != ctx.device_type) { if (nd_array->ctx.device_type != ctx.device_type) {
auto copy = tensor.CopyTo(ctx); return nd_array.CopyTo(ctx);
return Tensor(copy);
} else {
return src;
} }
} else {
return src;
} }
return src;
} }
PackedFunc VirtualMachine::GetFunction(const std::string& name, PackedFunc VirtualMachine::GetFunction(const std::string& name,
...@@ -770,16 +766,12 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, ...@@ -770,16 +766,12 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
if (const auto* dt_cell = args[i].as<ADTObj>()) { if (const auto* dt_cell = args[i].as<ADTObj>()) {
for (size_t fi = 0; fi < dt_cell->size; ++fi) { for (size_t fi = 0; fi < dt_cell->size; ++fi) {
auto obj = (*dt_cell)[fi]; auto obj = (*dt_cell)[fi];
const auto* tensor = obj.as<TensorObj>(); auto nd_array = Downcast<NDArray>(obj);
CHECK(tensor != nullptr) << "Expect tensor object, but received: " setter(idx++, nd_array);
<< obj->GetTypeKey();
setter(idx++, tensor->data);
} }
} else { } else {
const auto* tensor = args[i].as<TensorObj>(); auto nd_array = Downcast<NDArray>(args[i]);
CHECK(tensor != nullptr) << "Expect tensor object, but received: " setter(idx++, nd_array);
<< args[i]->GetTypeKey();
setter(idx++, tensor->data);
} }
} }
...@@ -824,10 +816,8 @@ inline ObjectRef VirtualMachine::ReadRegister(Index r) const { ...@@ -824,10 +816,8 @@ inline ObjectRef VirtualMachine::ReadRegister(Index r) const {
inline int32_t VirtualMachine::LoadScalarInt(Index r) const { inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
int32_t result; int32_t result;
const auto& obj = ReadRegister(r); const auto& obj = ReadRegister(r);
const auto* tensor = obj.as<TensorObj>(); auto nd_array = Downcast<NDArray>(obj);
CHECK(tensor != nullptr) << "Expect tensor object, but received: " NDArray array = nd_array.CopyTo({kDLCPU, 0});
<< obj->GetTypeKey();
NDArray array = tensor->data.CopyTo({kDLCPU, 0});
if (array->dtype.bits <= 8) { if (array->dtype.bits <= 8) {
result = reinterpret_cast<int8_t*>(array->data)[0]; result = reinterpret_cast<int8_t*>(array->data)[0];
...@@ -883,7 +873,7 @@ void VirtualMachine::RunLoop() { ...@@ -883,7 +873,7 @@ void VirtualMachine::RunLoop() {
case Opcode::LoadConsti: { case Opcode::LoadConsti: {
auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0}); auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0});
reinterpret_cast<int64_t*>(tensor->data)[0] = instr.load_consti.val; reinterpret_cast<int64_t*>(tensor->data)[0] = instr.load_consti.val;
WriteRegister(instr.dst, Tensor(tensor)); WriteRegister(instr.dst, tensor);
pc_++; pc_++;
goto main_loop; goto main_loop;
} }
...@@ -943,7 +933,7 @@ void VirtualMachine::RunLoop() { ...@@ -943,7 +933,7 @@ void VirtualMachine::RunLoop() {
auto tag = adt.tag(); auto tag = adt.tag();
auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0}); auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0});
reinterpret_cast<int32_t*>(tag_tensor->data)[0] = tag; reinterpret_cast<int32_t*>(tag_tensor->data)[0] = tag;
WriteRegister(instr.dst, Tensor(tag_tensor)); WriteRegister(instr.dst, tag_tensor);
pc_++; pc_++;
goto main_loop; goto main_loop;
} }
...@@ -974,9 +964,8 @@ void VirtualMachine::RunLoop() { ...@@ -974,9 +964,8 @@ void VirtualMachine::RunLoop() {
auto storage_obj = ReadRegister(instr.alloc_tensor.storage); auto storage_obj = ReadRegister(instr.alloc_tensor.storage);
auto storage = Downcast<Storage>(storage_obj); auto storage = Downcast<Storage>(storage_obj);
auto data = storage->AllocNDArray(0, shape, instr.alloc_tensor.dtype); auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor.dtype);
auto obj = Tensor(data);
WriteRegister(instr.dst, obj); WriteRegister(instr.dst, obj);
pc_++; pc_++;
goto main_loop; goto main_loop;
...@@ -986,10 +975,8 @@ void VirtualMachine::RunLoop() { ...@@ -986,10 +975,8 @@ void VirtualMachine::RunLoop() {
cpu_ctx.device_type = kDLCPU; cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0; cpu_ctx.device_id = 0;
auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register);
const auto* tensor = shape_tensor_obj.as<TensorObj>(); const auto shape_arr = Downcast<NDArray>(shape_tensor_obj);
CHECK(tensor != nullptr) << "Expect tensor object, but received: " NDArray shape_tensor = shape_arr.CopyTo(cpu_ctx);
<< shape_tensor_obj->GetTypeKey();
NDArray shape_tensor = tensor->data.CopyTo(cpu_ctx);
const DLTensor* dl_tensor = shape_tensor.operator->(); const DLTensor* dl_tensor = shape_tensor.operator->();
CHECK_EQ(dl_tensor->dtype.code, 0u); CHECK_EQ(dl_tensor->dtype.code, 0u);
CHECK_LE(dl_tensor->dtype.bits, 64); CHECK_LE(dl_tensor->dtype.bits, 64);
...@@ -1000,9 +987,8 @@ void VirtualMachine::RunLoop() { ...@@ -1000,9 +987,8 @@ void VirtualMachine::RunLoop() {
auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage); auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage);
auto storage = Downcast<Storage>(storage_obj); auto storage = Downcast<Storage>(storage_obj);
auto data = storage->AllocNDArray(0, shape, instr.alloc_tensor_reg.dtype); auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor_reg.dtype);
auto obj = Tensor(data);
WriteRegister(instr.dst, obj); WriteRegister(instr.dst, obj);
pc_++; pc_++;
goto main_loop; goto main_loop;
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import pytest import pytest
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from tvm import nd
from tvm import relay from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow from tvm.relay.frontend.tensorflow import from_tensorflow
...@@ -26,7 +27,7 @@ def check_equal(graph, tf_out): ...@@ -26,7 +27,7 @@ def check_equal(graph, tf_out):
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
ex = relay.create_executor('vm', mod=mod) ex = relay.create_executor('vm', mod=mod)
relay_out = ex.evaluate()(**params) relay_out = ex.evaluate()(**params)
if isinstance(relay_out, relay.vmobj.Tensor): if isinstance(relay_out, nd.NDArray):
np.testing.assert_allclose(tf_out, relay_out.asnumpy()) np.testing.assert_allclose(tf_out, relay_out.asnumpy())
else: else:
if not isinstance(tf_out, list): if not isinstance(tf_out, list):
......
...@@ -60,7 +60,7 @@ tf_dtypes = { ...@@ -60,7 +60,7 @@ tf_dtypes = {
} }
def vmobj_to_list(o): def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vmobj.Tensor): if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()] return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vmobj.ADT): elif isinstance(o, tvm.relay.backend.vmobj.ADT):
result = [] result = []
...@@ -87,8 +87,6 @@ def vmobj_to_list(o): ...@@ -87,8 +87,6 @@ def vmobj_to_list(o):
else: else:
raise RuntimeError("Unknown object type: %s" % raise RuntimeError("Unknown object type: %s" %
o.constructor.name_hint) o.constructor.name_hint)
elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy()]
else: else:
raise RuntimeError("Unknown object type: %s" % type(o)) raise RuntimeError("Unknown object type: %s" % type(o))
......
...@@ -115,10 +115,8 @@ def tree_to_dict(t): ...@@ -115,10 +115,8 @@ def tree_to_dict(t):
def vmobj_to_list(o, dtype="float32"): def vmobj_to_list(o, dtype="float32"):
if isinstance(o, tvm.relay.backend.vmobj.Tensor): if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()] return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.asnumpy()]
elif isinstance(o, tvm.relay.backend.vmobj.ADT): elif isinstance(o, tvm.relay.backend.vmobj.ADT):
if len(o) == 0: if len(o) == 0:
tensor_nil = p.get_var("tensor_nil", dtype=dtype) tensor_nil = p.get_var("tensor_nil", dtype=dtype)
......
...@@ -17,8 +17,9 @@ ...@@ -17,8 +17,9 @@
import numpy as np import numpy as np
import tvm import tvm
import tvm.testing import tvm.testing
from tvm import nd
from tvm import relay from tvm import relay
from tvm.relay.backend.interpreter import Value, TupleValue, TensorValue from tvm.relay.backend.interpreter import TupleValue
from tvm.relay.backend.interpreter import RefValue, ConstructorValue from tvm.relay.backend.interpreter import RefValue, ConstructorValue
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor from tvm.relay import testing, create_executor
...@@ -37,18 +38,11 @@ def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): ...@@ -37,18 +38,11 @@ def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
result.asnumpy(), expected_result, rtol=rtol) result.asnumpy(), expected_result, rtol=rtol)
def test_from_scalar():
np.testing.assert_allclose(Value.from_scalar(1, 'int32').asnumpy(), 1)
np.testing.assert_allclose(Value.from_scalar(10.0, 'float32').asnumpy(), 10.0)
np.testing.assert_allclose(Value.from_scalar(True).asnumpy(), True)
def test_tuple_value(): def test_tuple_value():
tv = TupleValue(Value.from_scalar( tv = TupleValue(relay.const(1), relay.const(2), relay.const(3))
1), Value.from_scalar(2), Value.from_scalar(3)) np.testing.assert_allclose(tv[0].data.asnumpy(), 1)
np.testing.assert_allclose(tv[0].asnumpy(), 1) np.testing.assert_allclose(tv[1].data.asnumpy(), 2)
np.testing.assert_allclose(tv[1].asnumpy(), 2) np.testing.assert_allclose(tv[2].data.asnumpy(), 3)
np.testing.assert_allclose(tv[2].asnumpy(), 3)
def test_tuple_getitem(): def test_tuple_getitem():
...@@ -158,12 +152,6 @@ def test_binds(): ...@@ -158,12 +152,6 @@ def test_binds():
tvm.testing.assert_allclose(xx + xx, res) tvm.testing.assert_allclose(xx + xx, res)
def test_tensor_value():
x = relay.var("x", shape=(1, 10))
xx = np.ones((1, 10)).astype("float32")
check_eval(relay.Function([x], x), [TensorValue(xx)], xx)
def test_kwargs_params(): def test_kwargs_params():
x = relay.var("x", shape=(1, 10)) x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(1, 10)) y = relay.var("y", shape=(1, 10))
...@@ -174,7 +162,7 @@ def test_kwargs_params(): ...@@ -174,7 +162,7 @@ def test_kwargs_params():
z_data = np.random.rand(1, 10).astype('float32') z_data = np.random.rand(1, 10).astype('float32')
params = { 'y': y_data, 'z': z_data } params = { 'y': y_data, 'z': z_data }
intrp = create_executor("debug") intrp = create_executor("debug")
res = intrp.evaluate(f)(x_data, **params).data res = intrp.evaluate(f)(x_data, **params)
tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data)
...@@ -185,13 +173,13 @@ def test_function_taking_adt_ref_tuple(): ...@@ -185,13 +173,13 @@ def test_function_taking_adt_ref_tuple():
nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil) nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil)
cons_value = ConstructorValue(prelude.cons.tag, [ cons_value = ConstructorValue(prelude.cons.tag, [
TensorValue(np.random.rand(1, 10).astype('float32')), nd.array(np.random.rand(1, 10).astype('float32')),
nil_value nil_value
], prelude.cons) ], prelude.cons)
ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32'))) ref_value = RefValue(nd.array(np.random.rand(1, 10).astype('float32')))
tuple_value = TupleValue(*[ tuple_value = TupleValue(*[
TensorValue(np.random.rand(1, 10).astype('float32')) for _ in range(10) nd.array(np.random.rand(1, 10).astype('float32')) for _ in range(10)
]) ])
id_func = intrp.evaluate(prelude.id) id_func = intrp.evaluate(prelude.id)
...@@ -236,9 +224,7 @@ def test_tuple_passing(): ...@@ -236,9 +224,7 @@ def test_tuple_passing():
out = f((10, 8)) out = f((10, 8))
tvm.testing.assert_allclose(out.asnumpy(), np.array(10)) tvm.testing.assert_allclose(out.asnumpy(), np.array(10))
# Second use a tuple value. # Second use a tuple value.
value_tuple = TupleValue( value_tuple = TupleValue(nd.array(np.array(11)), nd.array(np.array(12)))
TensorValue(np.array(11)),
TensorValue(np.array(12)))
out = f(value_tuple) out = f(value_tuple)
tvm.testing.assert_allclose(out.asnumpy(), np.array(11)) tvm.testing.assert_allclose(out.asnumpy(), np.array(11))
...@@ -252,7 +238,6 @@ if __name__ == "__main__": ...@@ -252,7 +238,6 @@ if __name__ == "__main__":
test_binds() test_binds()
test_kwargs_params() test_kwargs_params()
test_ref() test_ref()
test_tensor_value()
test_tuple_value() test_tuple_value()
test_tuple_getitem() test_tuple_getitem()
test_function_taking_adt_ref_tuple() test_function_taking_adt_ref_tuple()
......
...@@ -19,7 +19,7 @@ import tvm ...@@ -19,7 +19,7 @@ import tvm
from tvm import relay from tvm import relay
from tvm.relay.testing import to_python, run_as_python from tvm.relay.testing import to_python, run_as_python
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay.backend.interpreter import TensorValue, TupleValue, RefValue, ConstructorValue from tvm.relay.backend.interpreter import TupleValue, RefValue, ConstructorValue
# helper: uses a dummy let binding to sequence a list # helper: uses a dummy let binding to sequence a list
# of expressions: expr1; expr2; expr3, etc. # of expressions: expr1; expr2; expr3, etc.
...@@ -39,9 +39,9 @@ def init_box_adt(mod): ...@@ -39,9 +39,9 @@ def init_box_adt(mod):
return (box, box_ctor) return (box, box_ctor)
# assert that the candidate is a TensorValue with value val # assert that the candidate is a NDArray with value val
def assert_tensor_value(candidate, val): def assert_tensor_value(candidate, val):
assert isinstance(candidate, TensorValue) assert isinstance(candidate, tvm.nd.NDArray)
assert np.array_equal(candidate.asnumpy(), np.array(val)) assert np.array_equal(candidate.asnumpy(), np.array(val))
...@@ -68,6 +68,7 @@ def test_create_empty_tuple(): ...@@ -68,6 +68,7 @@ def test_create_empty_tuple():
def test_create_scalar(): def test_create_scalar():
scalar = relay.const(1) scalar = relay.const(1)
tensor_val = run_as_python(scalar) tensor_val = run_as_python(scalar)
print(type(tensor_val))
assert_tensor_value(tensor_val, 1) assert_tensor_value(tensor_val, 1)
...@@ -544,7 +545,7 @@ def test_batch_norm(): ...@@ -544,7 +545,7 @@ def test_batch_norm():
# there will be a change in accuracy so we need to check # there will be a change in accuracy so we need to check
# approximate equality # approximate equality
assert isinstance(call_val, TensorValue) assert isinstance(call_val, tvm.nd.NDArray)
tvm.testing.assert_allclose(call_val.asnumpy(), ref_res, atol=eps, rtol=eps) tvm.testing.assert_allclose(call_val.asnumpy(), ref_res, atol=eps, rtol=eps)
verify_batch_norm([(10, 20), (20,), (20,), (20,), (20,)]) verify_batch_norm([(10, 20), (20,), (20,), (20,), (20,)])
......
...@@ -56,7 +56,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): ...@@ -56,7 +56,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
return vm.invoke("main", *args) return vm.invoke("main", *args)
def vmobj_to_list(o): def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vm.Tensor): if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()] return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vm.ADT): elif isinstance(o, tvm.relay.backend.vm.ADT):
result = [] result = []
......
...@@ -19,28 +19,16 @@ import numpy as np ...@@ -19,28 +19,16 @@ import numpy as np
import tvm import tvm
from tvm.relay import vm from tvm.relay import vm
def test_tensor():
arr = tvm.nd.array([1,2,3])
x = vm.Tensor(arr)
assert isinstance(x, vm.Tensor)
assert x.asnumpy()[0] == 1
assert x.asnumpy()[-1] == 3
assert isinstance(x.data, tvm.nd.NDArray)
def test_adt(): def test_adt():
arr = tvm.nd.array([1,2,3]) arr = tvm.nd.array([1,2,3])
x = vm.Tensor(arr) y = vm.ADT(0, [arr, arr])
y = vm.ADT(0, [x, x])
assert len(y) == 2 assert len(y) == 2
assert isinstance(y, vm.ADT) assert isinstance(y, vm.ADT)
y[0:1][-1].data == x.data y[0:1][-1] == arr
assert y.tag == 0 assert y.tag == 0
assert isinstance(x.data, tvm.nd.NDArray) assert isinstance(arr, tvm.nd.NDArray)
if __name__ == "__main__": if __name__ == "__main__":
test_tensor()
test_adt() test_adt()
...@@ -28,7 +28,7 @@ def test_double_buffer(): ...@@ -28,7 +28,7 @@ def test_double_buffer():
with ib.for_range(0, n) as i: with ib.for_range(0, n) as i:
B = ib.allocate("float32", m, name="B", scope="shared") B = ib.allocate("float32", m, name="B", scope="shared")
with ib.new_scope(): with ib.new_scope():
ib.scope_attr(B.asnode(), "double_buffer_scope", 1) ib.scope_attr(B.asobject(), "double_buffer_scope", 1)
with ib.for_range(0, m) as j: with ib.for_range(0, m) as j:
B[j] = A[i * 4 + j] B[j] = A[i * 4 + j]
with ib.for_range(0, m) as j: with ib.for_range(0, m) as j:
...@@ -39,7 +39,7 @@ def test_double_buffer(): ...@@ -39,7 +39,7 @@ def test_double_buffer():
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.stmt.Allocate) assert isinstance(stmt.body.body, tvm.stmt.Allocate)
assert stmt.body.body.extents[0].value == 2 assert stmt.body.body.extents[0].value == 2
f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asnode(), C.asnode()], 2, True) f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
f = tvm.ir_pass.ThreadSync(f, "shared") f = tvm.ir_pass.ThreadSync(f, "shared")
count = [0] count = [0]
def count_sync(op): def count_sync(op):
......
...@@ -32,7 +32,7 @@ def test_vthread(): ...@@ -32,7 +32,7 @@ def test_vthread():
ib.scope_attr(ty, "virtual_thread", nthread) ib.scope_attr(ty, "virtual_thread", nthread)
B = ib.allocate("float32", m, name="B", scope="shared") B = ib.allocate("float32", m, name="B", scope="shared")
B[i] = A[i * nthread + tx] B[i] = A[i * nthread + tx]
bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode()) bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asobject())
ib.emit(tvm.call_extern("int32", "Run", ib.emit(tvm.call_extern("int32", "Run",
bbuffer.access_ptr("r"), bbuffer.access_ptr("r"),
tvm.call_pure_intrin("int32", "tvm_context_id"))) tvm.call_pure_intrin("int32", "tvm_context_id")))
...@@ -60,9 +60,9 @@ def test_vthread_extern(): ...@@ -60,9 +60,9 @@ def test_vthread_extern():
A = ib.allocate("float32", m, name="A", scope="shared") A = ib.allocate("float32", m, name="A", scope="shared")
B = ib.allocate("float32", m, name="B", scope="shared") B = ib.allocate("float32", m, name="B", scope="shared")
C = ib.allocate("float32", m, name="C", scope="shared") C = ib.allocate("float32", m, name="C", scope="shared")
cbuffer = tvm.decl_buffer((m,), dtype=C.dtype, data=C.asnode()) cbuffer = tvm.decl_buffer((m,), dtype=C.dtype, data=C.asobject())
abuffer = tvm.decl_buffer((m,), dtype=A.dtype, data=A.asnode()) abuffer = tvm.decl_buffer((m,), dtype=A.dtype, data=A.asobject())
bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode()) bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asobject())
A[tx] = tx + 1.0 A[tx] = tx + 1.0
B[ty] = ty + 1.0 B[ty] = ty + 1.0
ib.emit(tvm.call_extern("int32", "Run", ib.emit(tvm.call_extern("int32", "Run",
......
...@@ -79,7 +79,7 @@ def test_flatten_double_buffer(): ...@@ -79,7 +79,7 @@ def test_flatten_double_buffer():
with ib.for_range(0, n) as i: with ib.for_range(0, n) as i:
B = ib.allocate("float32", m, name="B", scope="shared") B = ib.allocate("float32", m, name="B", scope="shared")
with ib.new_scope(): with ib.new_scope():
ib.scope_attr(B.asnode(), "double_buffer_scope", 1) ib.scope_attr(B.asobject(), "double_buffer_scope", 1)
with ib.for_range(0, m) as j: with ib.for_range(0, m) as j:
B[j] = A[i * 4 + j] B[j] = A[i * 4 + j]
with ib.for_range(0, m) as j: with ib.for_range(0, m) as j:
...@@ -91,7 +91,7 @@ def test_flatten_double_buffer(): ...@@ -91,7 +91,7 @@ def test_flatten_double_buffer():
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.stmt.Allocate) assert isinstance(stmt.body.body, tvm.stmt.Allocate)
assert stmt.body.body.extents[0].value == 2 assert stmt.body.body.extents[0].value == 2
f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asnode(), C.asnode()], 2, True) f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
f = tvm.ir_pass.ThreadSync(f, "shared") f = tvm.ir_pass.ThreadSync(f, "shared")
count = [0] count = [0]
def count_sync(op): def count_sync(op):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment