Commit 81334be3 by Junru Shao Committed by Tianqi Chen

[RUNTIME][NDArray] Allowing External Libraries to Subclass NDArrays (#2613)

parent 79abd2c3
......@@ -6,7 +6,7 @@ PKG_CFLAGS = -std=c++11 -O2 -fPIC\
-I${TVM_ROOT}/3rdparty/dlpack/include\
-I${TVM_ROOT}/3rdparty/HalideIR/src
PKG_LDFLAGS =-L${TVM_ROOT}/lib
PKG_LDFLAGS =-L${TVM_ROOT}/build
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S), Darwin)
......
......@@ -31,7 +31,7 @@ class IntVec(object):
def __del__(self):
# You can also call your own customized
# deleter if you can free it via your own FFI.
tvm.nd.free_extension_handle(self.handle, 17)
tvm.nd.free_extension_handle(self.handle, self.__class__._tvm_tcode)
@property
def _tvm_handle(self):
......@@ -42,3 +42,30 @@ class IntVec(object):
# Register IntVec extension on python side.
tvm.register_extension(IntVec, IntVec)
nd_create = tvm.get_global_func("tvm_ext.nd_create")
nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
nd_get_addtional_info = tvm.get_global_func("tvm_ext.nd_get_addtional_info")
class NDSubClass(tvm.nd.NDArrayBase):
"""Example for subclassing TVM's NDArray infrastructure.
By inheriting TMV's NDArray, external libraries could
leverage TVM's FFI without any modification.
"""
# Should be consistent with the type-trait set in the backend
_array_type_code = 1
@staticmethod
def create(addtional_info):
return nd_create(addtional_info)
@property
def addtional_info(self):
return nd_get_addtional_info(self)
def __add__(self, other):
return nd_add_two(self, other)
tvm.register_extension(NDSubClass, NDSubClass)
......@@ -7,18 +7,25 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h>
namespace tvm_ext {
using IntVector = std::vector<int>;
class NDSubClass;
} // namespace tvm_ext
namespace tvm {
namespace runtime {
template<>
struct extension_class_info<tvm_ext::IntVector> {
struct extension_type_info<tvm_ext::IntVector> {
static const int code = 17;
};
template<>
struct array_type_info<tvm_ext::NDSubClass> {
static const int code = 1;
};
} // namespace tvm
} // namespace runtime
......@@ -26,6 +33,62 @@ using namespace tvm;
using namespace tvm::runtime;
namespace tvm_ext {
/*!
* \brief A subclass of TVM's NDArray.
*
* To use this extension, an external library should
*
* 1) Inherit TVM's NDArray and NDArray container,
* and define the trait `array_type_info` for this class.
*
* 2) Define a constructor in the inherited class that accepts
* a pointer to TVM's Container, which is nullable.
*
* 3) On Python frontend, inherit `tvm.nd.NDArrayBase`,
* define the class attribute `_array_type_code` consistent to
* the C++ type trait, and register the subclass using `tvm.register_extension`.
*/
class NDSubClass : public tvm::runtime::NDArray {
public:
class SubContainer : public NDArray::Container {
public:
SubContainer(int addtional_info) :
addtional_info_(addtional_info) {
array_type_code_ = array_type_info<NDSubClass>::code;
}
static bool Is(NDArray::Container *container) {
SubContainer *c = static_cast<SubContainer*>(container);
return c->array_type_code_ == array_type_info<NDSubClass>::code;
}
int addtional_info_{0};
};
NDSubClass(NDArray::Container *container) {
if (container == nullptr) {
data_ = nullptr;
return;
}
CHECK(SubContainer::Is(container));
container->IncRef();
data_ = container;
}
~NDSubClass() {
this->reset();
}
NDSubClass AddWith(const NDSubClass &other) const {
SubContainer *a = static_cast<SubContainer*>(data_);
SubContainer *b = static_cast<SubContainer*>(other.data_);
CHECK(a != nullptr && b != nullptr);
return NDSubClass(new SubContainer(a->addtional_info_ + b->addtional_info_));
}
int get_additional_info() const {
SubContainer *self = static_cast<SubContainer*>(data_);
CHECK(self != nullptr);
return self->addtional_info_;
}
};
} // namespace tvm_ext
namespace tvm_ext {
TVM_REGISTER_EXT_TYPE(IntVector);
......@@ -64,6 +127,26 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = (*tvm::runtime::Registry::Get("device_api.cpu"))();
});
TVM_REGISTER_GLOBAL("tvm_ext.nd_create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
int addtional_info = args[0];
*rv = NDSubClass(new NDSubClass::SubContainer(addtional_info));
});
TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
.set_body([](TVMArgs args, TVMRetValue *rv) {
NDSubClass a = args[0];
NDSubClass b = args[1];
*rv = a.AddWith(b);
});
TVM_REGISTER_GLOBAL("tvm_ext.nd_get_addtional_info")
.set_body([](TVMArgs args, TVMRetValue *rv) {
NDSubClass a = args[0];
*rv = a.get_additional_info();
});
} // namespace tvm_ext
// External function exposed to runtime.
......
......@@ -32,6 +32,7 @@ def test_sym_add():
c = tvm_ext.sym_add(a, b)
assert c.a == a and c.b == b
def test_ext_vec():
ivec = tvm_ext.ivec_create(1, 2, 3)
assert(isinstance(ivec, tvm_ext.IntVec))
......@@ -44,6 +45,7 @@ def test_ext_vec():
tvm.convert(ivec_cb)(ivec)
def test_extract_ext():
fdict = tvm.extract_ext_funcs(tvm_ext._LIB.TVMExtDeclare)
assert fdict["mul"](3, 4) == 12
......@@ -68,7 +70,21 @@ def test_extern_call():
check_llvm()
def test_nd_subclass():
a = tvm_ext.NDSubClass.create(addtional_info=3)
b = tvm_ext.NDSubClass.create(addtional_info=5)
c = a + b
d = a + a
e = b + b
assert(a.addtional_info == 3)
assert(b.addtional_info == 5)
assert(c.addtional_info == 8)
assert(d.addtional_info == 6)
assert(e.addtional_info == 10)
if __name__ == "__main__":
test_nd_subclass()
test_extern_call()
test_ext_dev()
test_ext_vec()
......
......@@ -178,11 +178,31 @@ class NDArray {
Container* data_{nullptr};
// enable internal functions
friend struct Internal;
friend class TVMPODValue_;
friend class TVMArgValue;
friend class TVMRetValue;
friend class TVMArgsSetter;
};
/*!
* \brief The type trait indicates subclass of TVM's NDArray.
* For irrelavant classes, code = -1.
* For TVM NDArray itself, code = 0.
* All subclasses of NDArray should override code > 0.
*/
template<typename T>
struct array_type_info {
/*! \brief the value of the traits */
static const int code = -1;
};
// Overrides the type trait for tvm's NDArray.
template<>
struct array_type_info<NDArray> {
static const int code = 0;
};
/*!
* \brief Save a DLTensor to stream
* \param strm The outpu stream
* \param tensor The tensor to be saved.
......@@ -196,7 +216,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
* the pointer to the NDArrayContainer can be directly
* interpreted as a DLTensor*
*
* \note: do not use this function directly, use NDArray.
* \note do not use this function directly, use NDArray.
*/
class NDArray::Container {
public:
......@@ -228,16 +248,19 @@ class NDArray::Container {
protected:
friend class NDArray;
friend class TVMPODValue_;
friend class TVMArgValue;
friend class TVMRetValue;
friend class RPCWrappedFunc;
/*!
* \brief Type flag used to indicate subclass.
* Default value 0 means normal NDArray::Conatainer.
*
* We can extend a more specialized NDArray::Container
* and use the array_type_index_ to indicate
* and use the array_type_code_ to indicate
* the specific array subclass.
*/
uint32_t array_type_index_{0};
int32_t array_type_code_{0};
/*! \brief The internal reference counter */
std::atomic<int> ref_counter_{0};
/*!
......
......@@ -362,7 +362,7 @@ inline std::string TVMType2String(TVMType t);
* \tparam T the typename
*/
template<typename T>
struct extension_class_info {
struct extension_type_info {
static const int code = 0;
};
......@@ -455,6 +455,15 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx;
}
template<typename TNDArray,
typename = typename std::enable_if<
std::is_base_of<NDArray, TNDArray>::value>::type>
TNDArray AsNDArray() const {
if (type_code_ == kNull) return TNDArray(nullptr);
auto *container = static_cast<NDArray::Container*>(value_.v_handle);
CHECK_EQ(container->array_type_code_, array_type_info<TNDArray>::code);
return TNDArray(container);
}
template<typename TExtension>
const TExtension& AsExtension() const {
CHECK_LT(type_code_, kExtEnd);
......@@ -561,7 +570,7 @@ class TVMArgValue : public TVMPODValue_ {
inline TNodeRef AsNodeRef() const;
template<typename T,
typename = typename std::enable_if<
std::is_class<T>::value>::type>
std::is_class<T>::value>::type>
inline operator T() const;
template<typename TNodeRef,
typename = typename std::enable_if<
......@@ -727,10 +736,10 @@ class TVMRetValue : public TVMPODValue_ {
}
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
extension_type_info<T>::code != 0>::type>
TVMRetValue& operator=(const T& other) {
this->SwitchToClass<T>(
extension_class_info<T>::code, other);
extension_type_info<T>::code, other);
return *this;
}
/*!
......@@ -1094,7 +1103,7 @@ class TVMArgsSetter {
// extension
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
extension_type_info<T>::code != 0>::type>
inline void operator()(size_t i, const T& value) const;
// NodeRef related extenstions: in tvm/packed_func_ext.h
inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*)
......@@ -1212,40 +1221,53 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
// extension and node type handling
namespace detail {
template<typename T, typename TSrc, bool is_ext>
template<typename T, typename TSrc, bool is_ext, bool is_nd>
struct TVMValueCast {
static T Apply(const TSrc* self) {
static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions");
return self->template AsNodeRef<T>();
}
};
template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, true> {
struct TVMValueCast<T, TSrc, true, false> {
static T Apply(const TSrc* self) {
return self->template AsExtension<T>();
}
};
template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, false, true> {
static T Apply(const TSrc* self) {
return self->template AsNDArray<T>();
}
};
} // namespace detail
template<typename T, typename>
inline TVMArgValue::operator T() const {
return detail::
TVMValueCast<T, TVMArgValue, extension_class_info<T>::code != 0>
TVMValueCast<T, TVMArgValue,
(extension_type_info<T>::code != 0),
(array_type_info<T>::code > 0)>
::Apply(this);
}
template<typename T, typename>
inline TVMRetValue::operator T() const {
return detail::
TVMValueCast<T, TVMRetValue, extension_class_info<T>::code != 0>
TVMValueCast<T, TVMRetValue,
(extension_type_info<T>::code != 0),
(array_type_info<T>::code > 0)>
::Apply(this);
}
template<typename T, typename>
inline void TVMArgsSetter::operator()(size_t i, const T& value) const {
static_assert(extension_class_info<T>::code != 0,
static_assert(extension_type_info<T>::code != 0,
"Need to have extesion code");
type_codes_[i] = extension_class_info<T>::code;
type_codes_[i] = extension_type_info<T>::code;
values_[i].v_handle = const_cast<T*>(&value);
}
......@@ -1262,9 +1284,9 @@ struct ExtTypeInfo {
template<typename T>
inline ExtTypeVTable* ExtTypeVTable::Register_() {
const int code = extension_class_info<T>::code;
const int code = extension_type_info<T>::code;
static_assert(code != 0,
"require extension_class_info traits to be declared with non-zero code");
"require extension_type_info traits to be declared with non-zero code");
ExtTypeVTable vt;
vt.clone = ExtTypeInfo<T>::clone;
vt.destroy = ExtTypeInfo<T>::destroy;
......
......@@ -133,7 +133,7 @@ class Registry {
/*!
* \brief Macro to register extension type.
* This must be registered in a cc file
* after the trait extension_class_info is defined.
* after the trait extension_type_info is defined.
*/
#define TVM_REGISTER_EXT_TYPE(T) \
TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \
......
......@@ -40,17 +40,17 @@ namespace tvm {
namespace runtime {
template<>
struct extension_class_info<nnvm::Symbol> {
struct extension_type_info<nnvm::Symbol> {
static const int code = 16;
};
template<>
struct extension_class_info<nnvm::Graph> {
struct extension_type_info<nnvm::Graph> {
static const int code = 17;
};
template<>
struct extension_class_info<nnvm::compiler::AttrDict> {
struct extension_type_info<nnvm::compiler::AttrDict> {
static const int code = 18;
};
......
......@@ -76,8 +76,8 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._register_alter_op_layout")
if (ret.type_code() == TVMTypeCode::kNull) {
return false;
}
CHECK_EQ(ret.type_code(), tvm::runtime::extension_class_info<Symbol>::code)
<< " expected " << "Symbol (code = " << tvm::runtime::extension_class_info<Symbol>::code
CHECK_EQ(ret.type_code(), tvm::runtime::extension_type_info<Symbol>::code)
<< " expected " << "Symbol (code = " << tvm::runtime::extension_type_info<Symbol>::code
<< ") but get code = " << ret.type_code();
*ret_symbol = *(static_cast<Symbol*>(ret.value().v_handle));
return true;
......
......@@ -223,13 +223,13 @@ def _handle_return_func(x):
_node.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False)
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, TypeCode.MODULE_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(x.v_handle, True)
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False)
C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(x.v_handle, True, False)
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
_CLASS_MODULE = None
_CLASS_FUNCTION = None
......
......@@ -4,7 +4,7 @@ from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import TVMArrayHandle
from ..runtime_ctypes import TVMArrayHandle, TVMNDArrayContainerHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
......@@ -28,7 +28,7 @@ def _from_dlpack(dltensor):
check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle)))
ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))
return _make_array(handle, False)
return _make_array(handle, False, False)
raise ValueError("Expect a dltensor field, PyCapsule can only be consumed once")
......@@ -77,9 +77,15 @@ class NDArrayBase(object):
return ctypes.pythonapi.PyCapsule_New(handle, _c_str_dltensor, _c_dlpack_deleter)
def _make_array(handle, is_view):
def _make_array(handle, is_view, is_container):
global _TVM_ND_CLS
handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view)
fcreate = _CLASS_NDARRAY
if is_container and _TVM_ND_CLS:
array_type_info = ctypes.cast(handle, TVMNDArrayContainerHandle).array_type_info.value
if array_type_info > 0:
fcreate = _TVM_ND_CLS[array_type_info]
return fcreate(handle, is_view)
_TVM_COMPATS = ()
......@@ -91,6 +97,11 @@ def _reg_extension(cls, fcreate):
RETURN_SWITCH[cls._tvm_tcode] = fret
C_TO_PY_ARG_SWITCH[cls._tvm_tcode] = _wrap_arg_func(fret, cls._tvm_tcode)
_TVM_ND_CLS = {}
def _reg_ndarray(cls, fcreate):
global _TVM_ND_CLS
_TVM_ND_CLS[cls._array_type_code] = fcreate
_CLASS_NDARRAY = None
......
......@@ -2,7 +2,7 @@ from ..base import TVMError
from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule
from libc.stdint cimport int64_t, uint64_t, uint8_t, uint16_t
from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t
import ctypes
cdef enum TVMTypeCode:
......@@ -61,6 +61,14 @@ ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
ctypedef void* NodeHandle
ctypedef struct TVMNDArrayContainer:
DLTensor dl_tensor
void* manager_ctx
void (*deleter)(DLManagedTensor* self)
int32_t array_type_info
ctypedef TVMNDArrayContainer* TVMNDArrayContainerHandle
ctypedef int (*TVMPackedCFunc)(
TVMValue* args,
int* type_codes,
......
......@@ -33,7 +33,7 @@ cdef int tvm_callback(TVMValue* args,
if tcode != kArrayHandle:
pyargs.append(make_ret(value, tcode))
else:
pyargs.append(c_make_array(value.v_handle, True))
pyargs.append(c_make_array(value.v_handle, True, False))
try:
rv = local_pyfunc(*pyargs)
except Exception:
......@@ -175,7 +175,7 @@ cdef inline object make_ret(TVMValue value, int tcode):
elif tcode == kFloat:
return value.v_float64
elif tcode == kNDArrayContainer:
return c_make_array(value.v_handle, False)
return c_make_array(value.v_handle, False, True)
elif tcode == kStr:
return py_str(value.v_str)
elif tcode == kBytes:
......
......@@ -20,7 +20,7 @@ def _from_dlpack(object dltensor):
# set name and destructor to be empty
pycapsule.PyCapsule_SetDestructor(dltensor, NULL)
pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
return c_make_array(chandle, 0)
return c_make_array(chandle, False, False)
raise ValueError("Expect a dltensor field, pycapsule.PyCapsule can only be consumed once")
......@@ -73,8 +73,15 @@ cdef class NDArrayBase:
return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter)
cdef c_make_array(void* chandle, is_view):
ret = _CLASS_NDARRAY(None, is_view)
cdef c_make_array(void* chandle, is_view, is_container):
global _TVM_ND_CLS
cdef int32_t array_type_info
fcreate = _CLASS_NDARRAY
if is_container and len(_TVM_ND_CLS) > 0:
array_type_info = (<TVMNDArrayContainerHandle>chandle).array_type_info
if array_type_info > 0:
fcreate = _TVM_ND_CLS[array_type_info]
ret = fcreate(None, is_view)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
return ret
......@@ -89,11 +96,16 @@ def _reg_extension(cls, fcreate):
if fcreate:
_TVM_EXT_RET[cls._tvm_tcode] = fcreate
cdef _TVM_ND_CLS = {}
def _make_array(handle, is_view):
def _reg_ndarray(cls, fcreate):
global _TVM_ND_CLS
_TVM_ND_CLS[cls._array_type_code] = fcreate
def _make_array(handle, is_view, is_container):
cdef unsigned long long ptr
ptr = ctypes.cast(handle, ctypes.c_void_p).value
return c_make_array(<void*>ptr, is_view)
return c_make_array(<void*>ptr, is_view, is_container)
cdef object _CLASS_NDARRAY = None
......
......@@ -17,15 +17,18 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import _reg_extension, _reg_ndarray
else:
from ._cy2.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._cy2.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import _reg_extension, _reg_ndarray
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.ndarray import _reg_extension, _reg_ndarray
def context(dev_type, dev_id=0):
......@@ -111,7 +114,7 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
ctx.device_type,
ctx.device_id,
ctypes.byref(handle)))
return _make_array(handle, False)
return _make_array(handle, False, False)
def from_dlpack(dltensor):
......@@ -295,6 +298,7 @@ def free_extension_handle(handle, type_code):
"""
check_call(_LIB.TVMExtTypeFree(handle, ctypes.c_int(type_code)))
def register_extension(cls, fcreate=None):
"""Register a extension class to TVM.
......@@ -306,21 +310,26 @@ def register_extension(cls, fcreate=None):
cls : class
The class object to be registered as extension.
fcreate : function, optional
The creation function to create a class object given handle value.
Note
----
The registered class is requires one property: _tvm_handle and a class attribute _tvm_tcode.
The registered class is requires one property: _tvm_handle.
If the registered class is a subclass of NDArray,
it is required to have a class attribute _array_type_code.
Otherwise, it is required to have a class attribute _tvm_tcode.
- ```_tvm_handle``` returns integer represents the address of the handle.
- ```_tvm_tcode``` gives integer represents type code of the class.
- ```_tvm_tcode``` or ```_array_type_code``` gives integer represents type
code of the class.
Returns
-------
cls : class
The class being registered.
fcreate : function, optional
The creation function to create a class object given handle value.
Example
-------
The following code registers user defined class
......@@ -339,7 +348,13 @@ def register_extension(cls, fcreate=None):
def _tvm_handle(self):
return self.handle.value
"""
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
if issubclass(cls, _NDArrayBase):
assert fcreate is not None
assert hasattr(cls, "_array_type_code")
_reg_ndarray(cls, fcreate)
else:
assert hasattr(cls, "_tvm_tcode")
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
return cls
......@@ -240,3 +240,12 @@ class TVMArray(ctypes.Structure):
("byte_offset", ctypes.c_uint64)]
TVMArrayHandle = ctypes.POINTER(TVMArray)
class TVMNDArrayContainer(ctypes.Structure):
"""TVM NDArray::Container"""
_fields_ = [("dl_tensor", TVMArray),
("manager_ctx", ctypes.c_void_p),
("deleter", ctypes.c_void_p),
("array_type_info", ctypes.c_int32)]
TVMNDArrayContainerHandle = ctypes.POINTER(TVMNDArrayContainer)
......@@ -15,7 +15,7 @@ from ._ffi.ndarray import register_extension, free_extension_handle
class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
Strictly this is only an Array Container(a buffer object)
Strictly this is only an Array Container (a buffer object)
No arthimetic operations are defined.
All operations are performed by TVM functions.
......
......@@ -168,7 +168,7 @@ namespace tvm {
namespace runtime {
template<>
struct extension_class_info<test::IntVector> {
struct extension_type_info<test::IntVector> {
static const int code = kExtBegin + 1;
};
} // runtime
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment