Unverified Commit 02c1e117 by Tianqi Chen Committed by GitHub

[RUNTIME] Refactor object python FFI to new protocol. (#4128)

* [RUNTIME] Refactor object python FFI to new protocol.

This is a pre-req to bring the Node system under object protocol.
Most of the code reflects the current code in the Node system.

- Use new instead of init so subclass can define their own constructors
- Allow register via name, besides type idnex
- Introduce necessary runtime C API functions
- Refactored Tensor and Datatype to directly use constructor.

* address review comments
parent e3fbdc8b
......@@ -104,7 +104,7 @@ typedef enum {
kStr = 11U,
kBytes = 12U,
kNDArrayContainer = 13U,
kObjectCell = 14U,
kObjectHandle = 14U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
......@@ -549,13 +549,31 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type,
TVMStreamHandle dst);
/*!
* \brief Get the tag from an object.
* \brief Get the type_index from an object.
*
* \param obj The object handle.
* \param tag The tag of object.
* \param out_tindex the output type index.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMGetObjectTag(TVMObjectHandle obj, int* tag);
TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex);
/*!
* \brief Convert type key to type index.
* \param type_key The key of the type.
* \param out_tindex the corresponding type index.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);
/*!
* \brief Free the object.
*
* \param obj The object handle.
* \note Internally we decrease the reference counter of the object.
* The object will be freed when every reference to the object are removed.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMObjectFree(TVMObjectHandle obj);
#ifdef __cplusplus
} // TVM_EXTERN_C
......
......@@ -253,6 +253,7 @@ class Object {
template<typename>
friend class ObjectPtr;
friend class TVMRetValue;
friend class TVMObjectCAPI;
};
/*!
......
......@@ -491,7 +491,7 @@ class TVMPODValue_ {
}
operator ObjectRef() const {
if (type_code_ == kNull) return ObjectRef(ObjectPtr<Object>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kObjectCell);
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
return ObjectRef(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
}
operator TVMContext() const {
......@@ -761,7 +761,7 @@ class TVMRetValue : public TVMPODValue_ {
}
TVMRetValue& operator=(ObjectRef other) {
this->Clear();
type_code_ = kObjectCell;
type_code_ = kObjectHandle;
// move the handle out
value_.v_handle = other.data_.data_;
other.data_.data_ = nullptr;
......@@ -862,7 +862,7 @@ class TVMRetValue : public TVMPODValue_ {
kNodeHandle, *other.template ptr<NodePtr<Node> >());
break;
}
case kObjectCell: {
case kObjectHandle: {
*this = other.operator ObjectRef();
break;
}
......@@ -913,7 +913,7 @@ class TVMRetValue : public TVMPODValue_ {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break;
}
case kObjectCell: {
case kObjectHandle: {
static_cast<Object*>(value_.v_handle)->DecRef();
break;
}
......@@ -946,7 +946,7 @@ inline const char* TypeCode2Str(int type_code) {
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer";
case kObjectCell: return "ObjectCell";
case kObjectHandle: return "ObjectCell";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
......@@ -1164,7 +1164,7 @@ class TVMArgsSetter {
}
void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*)
values_[i].v_handle = value.data_.data_;
type_codes_[i] = kObjectCell;
type_codes_[i] = kObjectHandle;
}
void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*)
if (value.type_code() == kStr) {
......
......@@ -33,6 +33,7 @@ from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .node import NodeBase
from . import object as _object
from . import node as _node
FunctionHandle = ctypes.c_void_p
......@@ -165,7 +166,7 @@ def _make_tvm_args(args, temp_args):
temp_args.append(arg)
elif isinstance(arg, _CLASS_OBJECT):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT_CELL
type_codes[i] = TypeCode.OBJECT_HANDLE
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args
......@@ -225,7 +226,7 @@ def __init_handle_by_constructor__(fconstructor, args):
raise get_last_ffi_error()
_ = temp_args
_ = args
assert ret_tcode.value == TypeCode.NODE_HANDLE
assert ret_tcode.value in (TypeCode.NODE_HANDLE, TypeCode.OBJECT_HANDLE)
handle = ret_val.v_handle
return handle
......@@ -247,6 +248,7 @@ def _handle_return_func(x):
# setup return handle for function type
_node.__init_by_constructor__ = __init_handle_by_constructor__
_object.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
......
......@@ -20,9 +20,11 @@ from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH
from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
ObjectHandle = ctypes.c_void_p
__init_by_constructor__ = None
"""Maps object type to its constructor"""
OBJECT_TYPE = {}
......@@ -36,17 +38,48 @@ def _return_object(x):
handle = x.v_handle
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
tag = ctypes.c_int()
check_call(_LIB.TVMGetObjectTag(handle, ctypes.byref(tag)))
cls = OBJECT_TYPE.get(tag.value, ObjectBase)
obj = cls(handle)
tindex = ctypes.c_uint()
check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
cls = OBJECT_TYPE.get(tindex.value, ObjectBase)
# Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
obj = cls.__new__(cls)
obj.handle = handle
return obj
RETURN_SWITCH[TypeCode.OBJECT_CELL] = _return_object
RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object
C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func(
_return_object, TypeCode.OBJECT_HANDLE)
class ObjectBase(object):
"""Base object for all object types"""
__slots__ = ["handle"]
def __init__(self, handle):
def __del__(self):
if _LIB is not None:
check_call(_LIB.TVMObjectFree(self.handle))
def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
# assign handle first to avoid error raising
self.handle = None
handle = __init_by_constructor__(fconstructor, args)
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
self.handle = handle
......@@ -37,7 +37,7 @@ cdef enum TVMTypeCode:
kStr = 11
kBytes = 12
kNDArrayContainer = 13
kObjectCell = 14
kObjectHandle = 14
kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h":
......@@ -130,7 +130,9 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
int TVMArrayToDLPack(DLTensorHandle arr_from,
DLManagedTensor** out)
void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
int TVMGetObjectTag(ObjectHandle obj, int* tag)
int TVMObjectFree(ObjectHandle obj)
int TVMObjectGetTypeIndex(ObjectHandle obj, unsigned* out_index)
cdef extern from "tvm/c_dsl_api.h":
int TVMNodeFree(NodeHandle handle)
......
......@@ -16,7 +16,8 @@
# under the License.
include "./base.pxi"
include "./object.pxi"
include "./node.pxi"
include "./function.pxi"
include "./ndarray.pxi"
include "./vmobj.pxi"
......@@ -44,7 +44,7 @@ cdef int tvm_callback(TVMValue* args,
if (tcode == kNodeHandle or
tcode == kFuncHandle or
tcode == kModuleHandle or
tcode == kObjectCell or
tcode == kObjectHandle or
tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode))
......@@ -155,12 +155,12 @@ cdef inline int make_arg(object arg,
value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle
temp_args.append(arg)
elif isinstance(arg, _CLASS_OBJECT):
value[0].v_handle = (<ObjectBase>arg).chandle
tcode[0] = kObjectHandle
elif isinstance(arg, _CLASS_MODULE):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kModuleHandle
elif isinstance(arg, _CLASS_OBJECT):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kObjectCell
elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
......@@ -190,6 +190,8 @@ cdef inline object make_ret(TVMValue value, int tcode):
"""convert result to return value."""
if tcode == kNodeHandle:
return make_ret_node(value.v_handle)
elif tcode == kObjectHandle:
return make_ret_object(value.v_handle)
elif tcode == kNull:
return None
elif tcode == kInt:
......@@ -212,8 +214,6 @@ cdef inline object make_ret(TVMValue value, int tcode):
fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle
return fobj
elif tcode == kObjectCell:
return make_ret_object(value.v_handle)
elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
......
......@@ -19,7 +19,7 @@
OBJECT_TYPE = []
def _register_object(int index, object cls):
"""register node class"""
"""register object class"""
while len(OBJECT_TYPE) <= index:
OBJECT_TYPE.append(None)
OBJECT_TYPE[index] = cls
......@@ -27,41 +27,70 @@ def _register_object(int index, object cls):
cdef inline object make_ret_object(void* chandle):
global OBJECT_TYPE
cdef int tag
cdef unsigned tindex
cdef list object_type
cdef object cls
cdef object handle
object_type = OBJECT_TYPE
handle = ctypes_handle(chandle)
CALL(TVMGetObjectTag(chandle, &tag))
if tag < len(object_type):
cls = object_type[tag]
CALL(TVMObjectGetTypeIndex(chandle, &tindex))
if tindex < len(object_type):
cls = object_type[tindex]
if cls is not None:
obj = cls(handle)
obj = cls.__new__(cls)
else:
obj = ObjectBase(handle)
obj = ObjectBase.__new__(ObjectBase)
else:
obj = ObjectBase(handle)
obj = ObjectBase.__new__(ObjectBase)
(<ObjectBase>obj).chandle = chandle
return obj
cdef class ObjectBase:
cdef ObjectHandle chandle
cdef void* chandle
cdef inline _set_handle(self, handle):
cdef unsigned long long ptr
if handle is None:
self.chandle = NULL
else:
self.chandle = c_handle(handle)
ptr = handle.value
self.chandle = <void*>(ptr)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
return ctypes_handle(self.chandle)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
def __dealloc__(self):
CALL(TVMObjectFree(self.chandle))
def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
# avoid error raised during construction.
self.chandle = NULL
cdef void* chandle
ConstructorCall(
(<FunctionBase>fconstructor).chandle,
kObjectHandle, args, &chandle)
self.chandle = chandle
......@@ -22,7 +22,6 @@ from __future__ import absolute_import
import sys
import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
from . import vmobj as _vmobj
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Runtime Object API"""
from __future__ import absolute_import
import sys
import ctypes
from .base import _FFI_MODE, check_call, _LIB, c_str
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_object
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object
else:
from ._cy2.core import _set_class_object
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_object
from ._ctypes.object import ObjectBase as _ObjectBase
from ._ctypes.object import _register_object
class Object(_ObjectBase):
"""Base class for all tvm's runtime objects."""
pass
def register_object(type_key=None):
"""register object type.
Parameters
----------
type_key : str or cls
The type key of the node
Examples
--------
The following code registers MyObject
using type key "test.MyObject"
.. code-block:: python
@tvm.register_object("test.MyObject")
class MyObject(Object):
pass
"""
object_name = type_key if isinstance(type_key, str) else type_key.__name__
def register(cls):
"""internal register function"""
if hasattr(cls, "_type_index"):
tindex = cls._type_index
else:
tidx = ctypes.c_uint()
check_call(_LIB.TVMObjectTypeKey2Index(
c_str(object_name), ctypes.byref(tidx)))
tindex = tidx.value
_register_object(tindex, cls)
return cls
if isinstance(type_key, str):
return register
return register(type_key)
def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
Parameters
----------
obj: object
The original object
elem_getter : function
A simple function that takes index and return a single element.
length : int
The size of the array
idx : int or slice
The argument passed to getitem
Returns
-------
result : object
The result of getitem
"""
if isinstance(idx, slice):
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else length
step = idx.step if idx.step is not None else 1
if start < 0:
start += length
if stop < 0:
stop += length
return [elem_getter(obj, i) for i in range(start, stop, step)]
if idx < -length or idx >= length:
raise IndexError("Index out of range. size: {}, got index {}"
.format(length, idx))
if idx < 0:
idx += length
return elem_getter(obj, idx)
_set_class_object(Object)
......@@ -42,7 +42,7 @@ class TypeCode(object):
STR = 11
BYTES = 12
NDARRAY_CONTAINER = 13
OBJECT_CELL = 14
OBJECT_HANDLE = 14
EXT_BEGIN = 15
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Runtime Object api"""
from __future__ import absolute_import
import sys
from .base import _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_object
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object
else:
from ._cy2.core import _set_class_object
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_object
from ._ctypes.vmobj import ObjectBase as _ObjectBase
from ._ctypes.vmobj import _register_object
class ObjectTag(object):
"""Type code used in API calls"""
TENSOR = 1
CLOSURE = 2
DATATYPE = 3
class Object(_ObjectBase):
"""The VM Object used in Relay virtual machine."""
def register_object(cls):
_register_object(cls.tag, cls)
return cls
_set_class_object(Object)
......@@ -21,6 +21,7 @@ from __future__ import absolute_import as _abs
from numbers import Integral as _Integral
from ._ffi.base import string_types
from ._ffi.object import register_object, Object
from ._ffi.node import register_node, NodeBase
from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.node_generic import _scalar_type_inference
......
......@@ -30,9 +30,12 @@ from . import _vm
from . import vmobj as _obj
from .interpreter import Executor
Tensor = _obj.Tensor
Datatype = _obj.Datatype
def _convert(arg, cargs):
if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.tensor_object(arg))
cargs.append(_obj.Tensor(arg))
elif isinstance(arg, (tuple, list)):
field_args = []
for field in arg:
......
......@@ -18,32 +18,37 @@
from __future__ import absolute_import as _abs
import numpy as _np
from tvm._ffi.vmobj import Object, ObjectTag, register_object
from tvm._ffi.object import Object, register_object, getitem_helper
from tvm import ndarray as _nd
from . import _vmobj
# TODO(@icemelon9): Add ClosureObject
@register_object
class TensorObject(Object):
"""Tensor object."""
tag = ObjectTag.TENSOR
@register_object("vm.Tensor")
class Tensor(Object):
"""Tensor object.
def __init__(self, handle):
"""Constructs a Tensor object
Parameters
----------
handle : object
Object handle
Parameters
----------
arr : numpy.ndarray or tvm.nd.NDArray
The source array.
Returns
-------
obj : TensorObject
A tensor object.
"""
super(TensorObject, self).__init__(handle)
self.data = _vmobj.GetTensorData(self)
ctx : TVMContext, optional
The device context to create the array
"""
def __init__(self, arr, ctx=None):
if isinstance(arr, _np.ndarray):
ctx = ctx if ctx else _nd.cpu(0)
self.__init_handle_by_constructor__(
_vmobj.Tensor, _nd.array(arr, ctx=ctx))
elif isinstance(arr, _nd.NDArray):
self.__init_handle_by_constructor__(
_vmobj.Tensor, arr)
else:
raise RuntimeError("Unsupported type for tensor object.")
@property
def data(self):
return _vmobj.GetTensorData(self)
def asnumpy(self):
"""Convert data to numpy array
......@@ -56,65 +61,34 @@ class TensorObject(Object):
return self.data.asnumpy()
@register_object
class DatatypeObject(Object):
"""Datatype object."""
tag = ObjectTag.DATATYPE
@register_object("vm.Datatype")
class Datatype(Object):
"""Datatype object.
def __init__(self, handle):
"""Constructs a Datatype object
Parameters
----------
tag : int
The tag of datatype.
Parameters
----------
handle : object
Object handle
fields : list[Object] or tuple[Object]
The source tuple.
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, Object)
self.__init_handle_by_constructor__(
_vmobj.Datatype, tag, *fields)
Returns
-------
obj : DatatypeObject
A Datatype object.
"""
super(DatatypeObject, self).__init__(handle)
self.tag = _vmobj.GetDatatypeTag(self)
num_fields = _vmobj.GetDatatypeNumberOfFields(self)
self.fields = []
for i in range(num_fields):
self.fields.append(_vmobj.GetDatatypeFields(self, i))
@property
def tag(self):
return _vmobj.GetDatatypeTag(self)
def __getitem__(self, idx):
return self.fields[idx]
return getitem_helper(
self, _vmobj.GetDatatypeFields, len(self), idx)
def __len__(self):
return len(self.fields)
def __iter__(self):
return iter(self.fields)
# TODO(icemelon9): Add closure object
def tensor_object(arr, ctx=_nd.cpu(0)):
"""Create a tensor object from source arr.
Parameters
----------
arr : numpy.ndarray or tvm.nd.NDArray
The source array.
ctx : TVMContext, optional
The device context to create the array
Returns
-------
ret : TensorObject
The created object.
"""
if isinstance(arr, _np.ndarray):
tensor = _vmobj.Tensor(_nd.array(arr, ctx))
elif isinstance(arr, _nd.NDArray):
tensor = _vmobj.Tensor(arr)
else:
raise RuntimeError("Unsupported type for tensor object.")
return tensor
return _vmobj.GetDatatypeNumberOfFields(self)
def tuple_object(fields):
......@@ -127,30 +101,9 @@ def tuple_object(fields):
Returns
-------
ret : DatatypeObject
ret : Datatype
The created object.
"""
for f in fields:
assert isinstance(f, Object)
return _vmobj.Tuple(*fields)
def datatype_object(tag, fields):
"""Create a datatype object from tag and source fields.
Parameters
----------
tag : int
The tag of datatype.
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : DatatypeObject
The created object.
"""
for f in fields:
assert isinstance(f, Object)
return _vmobj.Datatype(tag, *fields)
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
......@@ -26,6 +26,7 @@
#include <string>
#include <vector>
#include <unordered_map>
#include "runtime_base.h"
namespace tvm {
namespace runtime {
......@@ -184,5 +185,35 @@ std::string Object::TypeIndex2Key(uint32_t tindex) {
uint32_t Object::TypeKey2Index(const char* key) {
return TypeContext::Global()->TypeKey2Index(key);
}
class TVMObjectCAPI {
public:
static void Free(TVMObjectHandle obj) {
static_cast<Object*>(obj)->DecRef();
}
static uint32_t TypeKey2Index(const char* type_key) {
return Object::TypeKey2Index(type_key);
}
};
} // namespace runtime
} // namespace tvm
int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) {
API_BEGIN();
out_tindex[0] = static_cast<tvm::runtime::Object*>(obj)->type_index();
API_END();
}
int TVMObjectFree(TVMObjectHandle obj) {
API_BEGIN();
tvm::runtime::TVMObjectCAPI::Free(obj);
API_END();
}
int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) {
API_BEGIN();
out_tindex[0] = tvm::runtime::TVMObjectCAPI::TypeKey2Index(
type_key);
API_END();
}
......@@ -47,9 +47,9 @@ def convert_to_list(x):
return x
def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vmobj.TensorObject):
if isinstance(o, tvm.relay.backend.vmobj.Tensor):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject):
elif isinstance(o, tvm.relay.backend.vmobj.Datatype):
result = []
for f in o:
result.extend(vmobj_to_list(f))
......
......@@ -59,9 +59,9 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
return ret
def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vmobj.TensorObject):
if isinstance(o, tvm.relay.backend.vm.Tensor):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject):
elif isinstance(o, tvm.relay.backend.vm.Datatype):
result = []
for f in o:
result.extend(vmobj_to_list(f))
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
from tvm.relay import vm
def test_tensor():
arr = tvm.nd.array([1,2,3])
x = vm.Tensor(arr)
assert isinstance(x, vm.Tensor)
assert x.asnumpy()[0] == 1
assert x.asnumpy()[-1] == 3
assert isinstance(x.data, tvm.nd.NDArray)
def test_datatype():
arr = tvm.nd.array([1,2,3])
x = vm.Tensor(arr)
y = vm.Datatype(0, [x, x])
assert len(y) == 2
assert isinstance(y, vm.Datatype)
y[0:1][-1].data == x.data
assert y.tag == 0
assert isinstance(x.data, tvm.nd.NDArray)
if __name__ == "__main__":
test_tensor()
test_datatype()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment