Unverified Commit 02c1e117 by Tianqi Chen Committed by GitHub

[RUNTIME] Refactor object python FFI to new protocol. (#4128)

* [RUNTIME] Refactor object python FFI to new protocol.

This is a pre-req to bring the Node system under object protocol.
Most of the code reflects the current code in the Node system.

- Use new instead of init so subclass can define their own constructors
- Allow register via name, besides type idnex
- Introduce necessary runtime C API functions
- Refactored Tensor and Datatype to directly use constructor.

* address review comments
parent e3fbdc8b
...@@ -104,7 +104,7 @@ typedef enum { ...@@ -104,7 +104,7 @@ typedef enum {
kStr = 11U, kStr = 11U,
kBytes = 12U, kBytes = 12U,
kNDArrayContainer = 13U, kNDArrayContainer = 13U,
kObjectCell = 14U, kObjectHandle = 14U,
// Extension codes for other frameworks to integrate TVM PackedFunc. // Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and // To make sure each framework's id do not conflict, use first and
// last sections to mark ranges. // last sections to mark ranges.
...@@ -549,13 +549,31 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type, ...@@ -549,13 +549,31 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type,
TVMStreamHandle dst); TVMStreamHandle dst);
/*! /*!
* \brief Get the tag from an object. * \brief Get the type_index from an object.
* *
* \param obj The object handle. * \param obj The object handle.
* \param tag The tag of object. * \param out_tindex the output type index.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMGetObjectTag(TVMObjectHandle obj, int* tag); TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex);
/*!
* \brief Convert type key to type index.
* \param type_key The key of the type.
* \param out_tindex the corresponding type index.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);
/*!
* \brief Free the object.
*
* \param obj The object handle.
* \note Internally we decrease the reference counter of the object.
* The object will be freed when every reference to the object are removed.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMObjectFree(TVMObjectHandle obj);
#ifdef __cplusplus #ifdef __cplusplus
} // TVM_EXTERN_C } // TVM_EXTERN_C
......
...@@ -253,6 +253,7 @@ class Object { ...@@ -253,6 +253,7 @@ class Object {
template<typename> template<typename>
friend class ObjectPtr; friend class ObjectPtr;
friend class TVMRetValue; friend class TVMRetValue;
friend class TVMObjectCAPI;
}; };
/*! /*!
......
...@@ -491,7 +491,7 @@ class TVMPODValue_ { ...@@ -491,7 +491,7 @@ class TVMPODValue_ {
} }
operator ObjectRef() const { operator ObjectRef() const {
if (type_code_ == kNull) return ObjectRef(ObjectPtr<Object>(nullptr)); if (type_code_ == kNull) return ObjectRef(ObjectPtr<Object>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kObjectCell); TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
return ObjectRef(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle))); return ObjectRef(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
} }
operator TVMContext() const { operator TVMContext() const {
...@@ -761,7 +761,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -761,7 +761,7 @@ class TVMRetValue : public TVMPODValue_ {
} }
TVMRetValue& operator=(ObjectRef other) { TVMRetValue& operator=(ObjectRef other) {
this->Clear(); this->Clear();
type_code_ = kObjectCell; type_code_ = kObjectHandle;
// move the handle out // move the handle out
value_.v_handle = other.data_.data_; value_.v_handle = other.data_.data_;
other.data_.data_ = nullptr; other.data_.data_ = nullptr;
...@@ -862,7 +862,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -862,7 +862,7 @@ class TVMRetValue : public TVMPODValue_ {
kNodeHandle, *other.template ptr<NodePtr<Node> >()); kNodeHandle, *other.template ptr<NodePtr<Node> >());
break; break;
} }
case kObjectCell: { case kObjectHandle: {
*this = other.operator ObjectRef(); *this = other.operator ObjectRef();
break; break;
} }
...@@ -913,7 +913,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -913,7 +913,7 @@ class TVMRetValue : public TVMPODValue_ {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef(); static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break; break;
} }
case kObjectCell: { case kObjectHandle: {
static_cast<Object*>(value_.v_handle)->DecRef(); static_cast<Object*>(value_.v_handle)->DecRef();
break; break;
} }
...@@ -946,7 +946,7 @@ inline const char* TypeCode2Str(int type_code) { ...@@ -946,7 +946,7 @@ inline const char* TypeCode2Str(int type_code) {
case kFuncHandle: return "FunctionHandle"; case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle"; case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer"; case kNDArrayContainer: return "NDArrayContainer";
case kObjectCell: return "ObjectCell"; case kObjectHandle: return "ObjectCell";
default: LOG(FATAL) << "unknown type_code=" default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return ""; << static_cast<int>(type_code); return "";
} }
...@@ -1164,7 +1164,7 @@ class TVMArgsSetter { ...@@ -1164,7 +1164,7 @@ class TVMArgsSetter {
} }
void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*) void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*)
values_[i].v_handle = value.data_.data_; values_[i].v_handle = value.data_.data_;
type_codes_[i] = kObjectCell; type_codes_[i] = kObjectHandle;
} }
void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*) void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*)
if (value.type_code() == kStr) { if (value.type_code() == kStr) {
......
...@@ -33,6 +33,7 @@ from .types import TVMValue, TypeCode ...@@ -33,6 +33,7 @@ from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64 from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .node import NodeBase from .node import NodeBase
from . import object as _object
from . import node as _node from . import node as _node
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
...@@ -165,7 +166,7 @@ def _make_tvm_args(args, temp_args): ...@@ -165,7 +166,7 @@ def _make_tvm_args(args, temp_args):
temp_args.append(arg) temp_args.append(arg)
elif isinstance(arg, _CLASS_OBJECT): elif isinstance(arg, _CLASS_OBJECT):
values[i].v_handle = arg.handle values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT_CELL type_codes[i] = TypeCode.OBJECT_HANDLE
else: else:
raise TypeError("Don't know how to handle type %s" % type(arg)) raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args return values, type_codes, num_args
...@@ -225,7 +226,7 @@ def __init_handle_by_constructor__(fconstructor, args): ...@@ -225,7 +226,7 @@ def __init_handle_by_constructor__(fconstructor, args):
raise get_last_ffi_error() raise get_last_ffi_error()
_ = temp_args _ = temp_args
_ = args _ = args
assert ret_tcode.value == TypeCode.NODE_HANDLE assert ret_tcode.value in (TypeCode.NODE_HANDLE, TypeCode.OBJECT_HANDLE)
handle = ret_val.v_handle handle = ret_val.v_handle
return handle return handle
...@@ -247,6 +248,7 @@ def _handle_return_func(x): ...@@ -247,6 +248,7 @@ def _handle_return_func(x):
# setup return handle for function type # setup return handle for function type
_node.__init_by_constructor__ = __init_handle_by_constructor__ _node.__init_by_constructor__ = __init_handle_by_constructor__
_object.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True) RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
......
...@@ -20,9 +20,11 @@ from __future__ import absolute_import ...@@ -20,9 +20,11 @@ from __future__ import absolute_import
import ctypes import ctypes
from ..base import _LIB, check_call from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
ObjectHandle = ctypes.c_void_p ObjectHandle = ctypes.c_void_p
__init_by_constructor__ = None
"""Maps object type to its constructor""" """Maps object type to its constructor"""
OBJECT_TYPE = {} OBJECT_TYPE = {}
...@@ -36,17 +38,48 @@ def _return_object(x): ...@@ -36,17 +38,48 @@ def _return_object(x):
handle = x.v_handle handle = x.v_handle
if not isinstance(handle, ObjectHandle): if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle) handle = ObjectHandle(handle)
tag = ctypes.c_int() tindex = ctypes.c_uint()
check_call(_LIB.TVMGetObjectTag(handle, ctypes.byref(tag))) check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
cls = OBJECT_TYPE.get(tag.value, ObjectBase) cls = OBJECT_TYPE.get(tindex.value, ObjectBase)
obj = cls(handle) # Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
obj = cls.__new__(cls)
obj.handle = handle
return obj return obj
RETURN_SWITCH[TypeCode.OBJECT_CELL] = _return_object RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object
C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func(
_return_object, TypeCode.OBJECT_HANDLE)
class ObjectBase(object): class ObjectBase(object):
"""Base object for all object types"""
__slots__ = ["handle"] __slots__ = ["handle"]
def __init__(self, handle): def __del__(self):
if _LIB is not None:
check_call(_LIB.TVMObjectFree(self.handle))
def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
# assign handle first to avoid error raising
self.handle = None
handle = __init_by_constructor__(fconstructor, args)
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
self.handle = handle self.handle = handle
...@@ -37,7 +37,7 @@ cdef enum TVMTypeCode: ...@@ -37,7 +37,7 @@ cdef enum TVMTypeCode:
kStr = 11 kStr = 11
kBytes = 12 kBytes = 12
kNDArrayContainer = 13 kNDArrayContainer = 13
kObjectCell = 14 kObjectHandle = 14
kExtBegin = 15 kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h": cdef extern from "tvm/runtime/c_runtime_api.h":
...@@ -130,7 +130,9 @@ cdef extern from "tvm/runtime/c_runtime_api.h": ...@@ -130,7 +130,9 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
int TVMArrayToDLPack(DLTensorHandle arr_from, int TVMArrayToDLPack(DLTensorHandle arr_from,
DLManagedTensor** out) DLManagedTensor** out)
void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
int TVMGetObjectTag(ObjectHandle obj, int* tag) int TVMObjectFree(ObjectHandle obj)
int TVMObjectGetTypeIndex(ObjectHandle obj, unsigned* out_index)
cdef extern from "tvm/c_dsl_api.h": cdef extern from "tvm/c_dsl_api.h":
int TVMNodeFree(NodeHandle handle) int TVMNodeFree(NodeHandle handle)
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
# under the License. # under the License.
include "./base.pxi" include "./base.pxi"
include "./object.pxi"
include "./node.pxi" include "./node.pxi"
include "./function.pxi" include "./function.pxi"
include "./ndarray.pxi" include "./ndarray.pxi"
include "./vmobj.pxi"
...@@ -44,7 +44,7 @@ cdef int tvm_callback(TVMValue* args, ...@@ -44,7 +44,7 @@ cdef int tvm_callback(TVMValue* args,
if (tcode == kNodeHandle or if (tcode == kNodeHandle or
tcode == kFuncHandle or tcode == kFuncHandle or
tcode == kModuleHandle or tcode == kModuleHandle or
tcode == kObjectCell or tcode == kObjectHandle or
tcode > kExtBegin): tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode)) CALL(TVMCbArgToReturn(&value, tcode))
...@@ -155,12 +155,12 @@ cdef inline int make_arg(object arg, ...@@ -155,12 +155,12 @@ cdef inline int make_arg(object arg,
value[0].v_handle = (<NodeBase>arg).chandle value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle tcode[0] = kNodeHandle
temp_args.append(arg) temp_args.append(arg)
elif isinstance(arg, _CLASS_OBJECT):
value[0].v_handle = (<ObjectBase>arg).chandle
tcode[0] = kObjectHandle
elif isinstance(arg, _CLASS_MODULE): elif isinstance(arg, _CLASS_MODULE):
value[0].v_handle = c_handle(arg.handle) value[0].v_handle = c_handle(arg.handle)
tcode[0] = kModuleHandle tcode[0] = kModuleHandle
elif isinstance(arg, _CLASS_OBJECT):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kObjectCell
elif isinstance(arg, FunctionBase): elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle tcode[0] = kFuncHandle
...@@ -190,6 +190,8 @@ cdef inline object make_ret(TVMValue value, int tcode): ...@@ -190,6 +190,8 @@ cdef inline object make_ret(TVMValue value, int tcode):
"""convert result to return value.""" """convert result to return value."""
if tcode == kNodeHandle: if tcode == kNodeHandle:
return make_ret_node(value.v_handle) return make_ret_node(value.v_handle)
elif tcode == kObjectHandle:
return make_ret_object(value.v_handle)
elif tcode == kNull: elif tcode == kNull:
return None return None
elif tcode == kInt: elif tcode == kInt:
...@@ -212,8 +214,6 @@ cdef inline object make_ret(TVMValue value, int tcode): ...@@ -212,8 +214,6 @@ cdef inline object make_ret(TVMValue value, int tcode):
fobj = _CLASS_FUNCTION(None, False) fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle (<FunctionBase>fobj).chandle = value.v_handle
return fobj return fobj
elif tcode == kObjectCell:
return make_ret_object(value.v_handle)
elif tcode in _TVM_EXT_RET: elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle)) return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
OBJECT_TYPE = [] OBJECT_TYPE = []
def _register_object(int index, object cls): def _register_object(int index, object cls):
"""register node class""" """register object class"""
while len(OBJECT_TYPE) <= index: while len(OBJECT_TYPE) <= index:
OBJECT_TYPE.append(None) OBJECT_TYPE.append(None)
OBJECT_TYPE[index] = cls OBJECT_TYPE[index] = cls
...@@ -27,41 +27,70 @@ def _register_object(int index, object cls): ...@@ -27,41 +27,70 @@ def _register_object(int index, object cls):
cdef inline object make_ret_object(void* chandle): cdef inline object make_ret_object(void* chandle):
global OBJECT_TYPE global OBJECT_TYPE
cdef int tag cdef unsigned tindex
cdef list object_type cdef list object_type
cdef object cls cdef object cls
cdef object handle cdef object handle
object_type = OBJECT_TYPE object_type = OBJECT_TYPE
handle = ctypes_handle(chandle) handle = ctypes_handle(chandle)
CALL(TVMGetObjectTag(chandle, &tag)) CALL(TVMObjectGetTypeIndex(chandle, &tindex))
if tag < len(object_type): if tindex < len(object_type):
cls = object_type[tag] cls = object_type[tindex]
if cls is not None: if cls is not None:
obj = cls(handle) obj = cls.__new__(cls)
else: else:
obj = ObjectBase(handle) obj = ObjectBase.__new__(ObjectBase)
else: else:
obj = ObjectBase(handle) obj = ObjectBase.__new__(ObjectBase)
(<ObjectBase>obj).chandle = chandle
return obj return obj
cdef class ObjectBase: cdef class ObjectBase:
cdef ObjectHandle chandle cdef void* chandle
cdef inline _set_handle(self, handle): cdef inline _set_handle(self, handle):
cdef unsigned long long ptr
if handle is None: if handle is None:
self.chandle = NULL self.chandle = NULL
else: else:
self.chandle = c_handle(handle) ptr = handle.value
self.chandle = <void*>(ptr)
property handle: property handle:
def __get__(self): def __get__(self):
if self.chandle == NULL: if self.chandle == NULL:
return None return None
else: else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p) return ctypes_handle(self.chandle)
def __set__(self, value): def __set__(self, value):
self._set_handle(value) self._set_handle(value)
def __init__(self, handle): def __dealloc__(self):
self._set_handle(handle) CALL(TVMObjectFree(self.chandle))
def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
# avoid error raised during construction.
self.chandle = NULL
cdef void* chandle
ConstructorCall(
(<FunctionBase>fconstructor).chandle,
kObjectHandle, args, &chandle)
self.chandle = chandle
...@@ -22,7 +22,6 @@ from __future__ import absolute_import ...@@ -22,7 +22,6 @@ from __future__ import absolute_import
import sys import sys
import ctypes import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
from . import vmobj as _vmobj
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Runtime Object API"""
from __future__ import absolute_import
import sys
import ctypes
from .base import _FFI_MODE, check_call, _LIB, c_str
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_object
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object
else:
from ._cy2.core import _set_class_object
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_object
from ._ctypes.object import ObjectBase as _ObjectBase
from ._ctypes.object import _register_object
class Object(_ObjectBase):
"""Base class for all tvm's runtime objects."""
pass
def register_object(type_key=None):
"""register object type.
Parameters
----------
type_key : str or cls
The type key of the node
Examples
--------
The following code registers MyObject
using type key "test.MyObject"
.. code-block:: python
@tvm.register_object("test.MyObject")
class MyObject(Object):
pass
"""
object_name = type_key if isinstance(type_key, str) else type_key.__name__
def register(cls):
"""internal register function"""
if hasattr(cls, "_type_index"):
tindex = cls._type_index
else:
tidx = ctypes.c_uint()
check_call(_LIB.TVMObjectTypeKey2Index(
c_str(object_name), ctypes.byref(tidx)))
tindex = tidx.value
_register_object(tindex, cls)
return cls
if isinstance(type_key, str):
return register
return register(type_key)
def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
Parameters
----------
obj: object
The original object
elem_getter : function
A simple function that takes index and return a single element.
length : int
The size of the array
idx : int or slice
The argument passed to getitem
Returns
-------
result : object
The result of getitem
"""
if isinstance(idx, slice):
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else length
step = idx.step if idx.step is not None else 1
if start < 0:
start += length
if stop < 0:
stop += length
return [elem_getter(obj, i) for i in range(start, stop, step)]
if idx < -length or idx >= length:
raise IndexError("Index out of range. size: {}, got index {}"
.format(length, idx))
if idx < 0:
idx += length
return elem_getter(obj, idx)
_set_class_object(Object)
...@@ -42,7 +42,7 @@ class TypeCode(object): ...@@ -42,7 +42,7 @@ class TypeCode(object):
STR = 11 STR = 11
BYTES = 12 BYTES = 12
NDARRAY_CONTAINER = 13 NDARRAY_CONTAINER = 13
OBJECT_CELL = 14 OBJECT_HANDLE = 14
EXT_BEGIN = 15 EXT_BEGIN = 15
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Runtime Object api"""
from __future__ import absolute_import
import sys
from .base import _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_object
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object
else:
from ._cy2.core import _set_class_object
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_object
from ._ctypes.vmobj import ObjectBase as _ObjectBase
from ._ctypes.vmobj import _register_object
class ObjectTag(object):
"""Type code used in API calls"""
TENSOR = 1
CLOSURE = 2
DATATYPE = 3
class Object(_ObjectBase):
"""The VM Object used in Relay virtual machine."""
def register_object(cls):
_register_object(cls.tag, cls)
return cls
_set_class_object(Object)
...@@ -21,6 +21,7 @@ from __future__ import absolute_import as _abs ...@@ -21,6 +21,7 @@ from __future__ import absolute_import as _abs
from numbers import Integral as _Integral from numbers import Integral as _Integral
from ._ffi.base import string_types from ._ffi.base import string_types
from ._ffi.object import register_object, Object
from ._ffi.node import register_node, NodeBase from ._ffi.node import register_node, NodeBase
from ._ffi.node import convert_to_node as _convert_to_node from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.node_generic import _scalar_type_inference from ._ffi.node_generic import _scalar_type_inference
......
...@@ -30,9 +30,12 @@ from . import _vm ...@@ -30,9 +30,12 @@ from . import _vm
from . import vmobj as _obj from . import vmobj as _obj
from .interpreter import Executor from .interpreter import Executor
Tensor = _obj.Tensor
Datatype = _obj.Datatype
def _convert(arg, cargs): def _convert(arg, cargs):
if isinstance(arg, (np.ndarray, tvm.nd.NDArray)): if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.tensor_object(arg)) cargs.append(_obj.Tensor(arg))
elif isinstance(arg, (tuple, list)): elif isinstance(arg, (tuple, list)):
field_args = [] field_args = []
for field in arg: for field in arg:
......
...@@ -18,32 +18,37 @@ ...@@ -18,32 +18,37 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import numpy as _np import numpy as _np
from tvm._ffi.vmobj import Object, ObjectTag, register_object from tvm._ffi.object import Object, register_object, getitem_helper
from tvm import ndarray as _nd from tvm import ndarray as _nd
from . import _vmobj from . import _vmobj
# TODO(@icemelon9): Add ClosureObject
@register_object @register_object("vm.Tensor")
class TensorObject(Object): class Tensor(Object):
"""Tensor object.""" """Tensor object.
tag = ObjectTag.TENSOR
def __init__(self, handle): Parameters
"""Constructs a Tensor object ----------
arr : numpy.ndarray or tvm.nd.NDArray
Parameters The source array.
----------
handle : object
Object handle
Returns ctx : TVMContext, optional
------- The device context to create the array
obj : TensorObject """
A tensor object. def __init__(self, arr, ctx=None):
""" if isinstance(arr, _np.ndarray):
super(TensorObject, self).__init__(handle) ctx = ctx if ctx else _nd.cpu(0)
self.data = _vmobj.GetTensorData(self) self.__init_handle_by_constructor__(
_vmobj.Tensor, _nd.array(arr, ctx=ctx))
elif isinstance(arr, _nd.NDArray):
self.__init_handle_by_constructor__(
_vmobj.Tensor, arr)
else:
raise RuntimeError("Unsupported type for tensor object.")
@property
def data(self):
return _vmobj.GetTensorData(self)
def asnumpy(self): def asnumpy(self):
"""Convert data to numpy array """Convert data to numpy array
...@@ -56,65 +61,34 @@ class TensorObject(Object): ...@@ -56,65 +61,34 @@ class TensorObject(Object):
return self.data.asnumpy() return self.data.asnumpy()
@register_object @register_object("vm.Datatype")
class DatatypeObject(Object): class Datatype(Object):
"""Datatype object.""" """Datatype object.
tag = ObjectTag.DATATYPE
def __init__(self, handle): Parameters
"""Constructs a Datatype object ----------
tag : int
The tag of datatype.
Parameters fields : list[Object] or tuple[Object]
---------- The source tuple.
handle : object """
Object handle def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, Object)
self.__init_handle_by_constructor__(
_vmobj.Datatype, tag, *fields)
Returns @property
------- def tag(self):
obj : DatatypeObject return _vmobj.GetDatatypeTag(self)
A Datatype object.
"""
super(DatatypeObject, self).__init__(handle)
self.tag = _vmobj.GetDatatypeTag(self)
num_fields = _vmobj.GetDatatypeNumberOfFields(self)
self.fields = []
for i in range(num_fields):
self.fields.append(_vmobj.GetDatatypeFields(self, i))
def __getitem__(self, idx): def __getitem__(self, idx):
return self.fields[idx] return getitem_helper(
self, _vmobj.GetDatatypeFields, len(self), idx)
def __len__(self): def __len__(self):
return len(self.fields) return _vmobj.GetDatatypeNumberOfFields(self)
def __iter__(self):
return iter(self.fields)
# TODO(icemelon9): Add closure object
def tensor_object(arr, ctx=_nd.cpu(0)):
"""Create a tensor object from source arr.
Parameters
----------
arr : numpy.ndarray or tvm.nd.NDArray
The source array.
ctx : TVMContext, optional
The device context to create the array
Returns
-------
ret : TensorObject
The created object.
"""
if isinstance(arr, _np.ndarray):
tensor = _vmobj.Tensor(_nd.array(arr, ctx))
elif isinstance(arr, _nd.NDArray):
tensor = _vmobj.Tensor(arr)
else:
raise RuntimeError("Unsupported type for tensor object.")
return tensor
def tuple_object(fields): def tuple_object(fields):
...@@ -127,30 +101,9 @@ def tuple_object(fields): ...@@ -127,30 +101,9 @@ def tuple_object(fields):
Returns Returns
------- -------
ret : DatatypeObject ret : Datatype
The created object. The created object.
""" """
for f in fields: for f in fields:
assert isinstance(f, Object) assert isinstance(f, Object)
return _vmobj.Tuple(*fields) return _vmobj.Tuple(*fields)
def datatype_object(tag, fields):
"""Create a datatype object from tag and source fields.
Parameters
----------
tag : int
The tag of datatype.
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : DatatypeObject
The created object.
"""
for f in fields:
assert isinstance(f, Object)
return _vmobj.Datatype(tag, *fields)
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include "runtime_base.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -184,5 +185,35 @@ std::string Object::TypeIndex2Key(uint32_t tindex) { ...@@ -184,5 +185,35 @@ std::string Object::TypeIndex2Key(uint32_t tindex) {
uint32_t Object::TypeKey2Index(const char* key) { uint32_t Object::TypeKey2Index(const char* key) {
return TypeContext::Global()->TypeKey2Index(key); return TypeContext::Global()->TypeKey2Index(key);
} }
class TVMObjectCAPI {
public:
static void Free(TVMObjectHandle obj) {
static_cast<Object*>(obj)->DecRef();
}
static uint32_t TypeKey2Index(const char* type_key) {
return Object::TypeKey2Index(type_key);
}
};
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) {
API_BEGIN();
out_tindex[0] = static_cast<tvm::runtime::Object*>(obj)->type_index();
API_END();
}
int TVMObjectFree(TVMObjectHandle obj) {
API_BEGIN();
tvm::runtime::TVMObjectCAPI::Free(obj);
API_END();
}
int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) {
API_BEGIN();
out_tindex[0] = tvm::runtime::TVMObjectCAPI::TypeKey2Index(
type_key);
API_END();
}
...@@ -47,9 +47,9 @@ def convert_to_list(x): ...@@ -47,9 +47,9 @@ def convert_to_list(x):
return x return x
def vmobj_to_list(o): def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vmobj.TensorObject): if isinstance(o, tvm.relay.backend.vmobj.Tensor):
return [o.asnumpy().tolist()] return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject): elif isinstance(o, tvm.relay.backend.vmobj.Datatype):
result = [] result = []
for f in o: for f in o:
result.extend(vmobj_to_list(f)) result.extend(vmobj_to_list(f))
......
...@@ -59,9 +59,9 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): ...@@ -59,9 +59,9 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
return ret return ret
def vmobj_to_list(o): def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vmobj.TensorObject): if isinstance(o, tvm.relay.backend.vm.Tensor):
return [o.asnumpy().tolist()] return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject): elif isinstance(o, tvm.relay.backend.vm.Datatype):
result = [] result = []
for f in o: for f in o:
result.extend(vmobj_to_list(f)) result.extend(vmobj_to_list(f))
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
from tvm.relay import vm
def test_tensor():
arr = tvm.nd.array([1,2,3])
x = vm.Tensor(arr)
assert isinstance(x, vm.Tensor)
assert x.asnumpy()[0] == 1
assert x.asnumpy()[-1] == 3
assert isinstance(x.data, tvm.nd.NDArray)
def test_datatype():
arr = tvm.nd.array([1,2,3])
x = vm.Tensor(arr)
y = vm.Datatype(0, [x, x])
assert len(y) == 2
assert isinstance(y, vm.Datatype)
y[0:1][-1].data == x.data
assert y.tag == 0
assert isinstance(x.data, tvm.nd.NDArray)
if __name__ == "__main__":
test_tensor()
test_datatype()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment