Commit 81334be3 by Junru Shao Committed by Tianqi Chen

[RUNTIME][NDArray] Allowing External Libraries to Subclass NDArrays (#2613)

parent 79abd2c3
...@@ -6,7 +6,7 @@ PKG_CFLAGS = -std=c++11 -O2 -fPIC\ ...@@ -6,7 +6,7 @@ PKG_CFLAGS = -std=c++11 -O2 -fPIC\
-I${TVM_ROOT}/3rdparty/dlpack/include\ -I${TVM_ROOT}/3rdparty/dlpack/include\
-I${TVM_ROOT}/3rdparty/HalideIR/src -I${TVM_ROOT}/3rdparty/HalideIR/src
PKG_LDFLAGS =-L${TVM_ROOT}/lib PKG_LDFLAGS =-L${TVM_ROOT}/build
UNAME_S := $(shell uname -s) UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S), Darwin) ifeq ($(UNAME_S), Darwin)
......
...@@ -31,7 +31,7 @@ class IntVec(object): ...@@ -31,7 +31,7 @@ class IntVec(object):
def __del__(self): def __del__(self):
# You can also call your own customized # You can also call your own customized
# deleter if you can free it via your own FFI. # deleter if you can free it via your own FFI.
tvm.nd.free_extension_handle(self.handle, 17) tvm.nd.free_extension_handle(self.handle, self.__class__._tvm_tcode)
@property @property
def _tvm_handle(self): def _tvm_handle(self):
...@@ -42,3 +42,30 @@ class IntVec(object): ...@@ -42,3 +42,30 @@ class IntVec(object):
# Register IntVec extension on python side. # Register IntVec extension on python side.
tvm.register_extension(IntVec, IntVec) tvm.register_extension(IntVec, IntVec)
nd_create = tvm.get_global_func("tvm_ext.nd_create")
nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
nd_get_addtional_info = tvm.get_global_func("tvm_ext.nd_get_addtional_info")
class NDSubClass(tvm.nd.NDArrayBase):
"""Example for subclassing TVM's NDArray infrastructure.
By inheriting TMV's NDArray, external libraries could
leverage TVM's FFI without any modification.
"""
# Should be consistent with the type-trait set in the backend
_array_type_code = 1
@staticmethod
def create(addtional_info):
return nd_create(addtional_info)
@property
def addtional_info(self):
return nd_get_addtional_info(self)
def __add__(self, other):
return nd_add_two(self, other)
tvm.register_extension(NDSubClass, NDSubClass)
...@@ -7,18 +7,25 @@ ...@@ -7,18 +7,25 @@
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h>
namespace tvm_ext { namespace tvm_ext {
using IntVector = std::vector<int>; using IntVector = std::vector<int>;
class NDSubClass;
} // namespace tvm_ext } // namespace tvm_ext
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
template<> template<>
struct extension_class_info<tvm_ext::IntVector> { struct extension_type_info<tvm_ext::IntVector> {
static const int code = 17; static const int code = 17;
}; };
template<>
struct array_type_info<tvm_ext::NDSubClass> {
static const int code = 1;
};
} // namespace tvm } // namespace tvm
} // namespace runtime } // namespace runtime
...@@ -26,6 +33,62 @@ using namespace tvm; ...@@ -26,6 +33,62 @@ using namespace tvm;
using namespace tvm::runtime; using namespace tvm::runtime;
namespace tvm_ext { namespace tvm_ext {
/*!
* \brief A subclass of TVM's NDArray.
*
* To use this extension, an external library should
*
* 1) Inherit TVM's NDArray and NDArray container,
* and define the trait `array_type_info` for this class.
*
* 2) Define a constructor in the inherited class that accepts
* a pointer to TVM's Container, which is nullable.
*
* 3) On Python frontend, inherit `tvm.nd.NDArrayBase`,
* define the class attribute `_array_type_code` consistent to
* the C++ type trait, and register the subclass using `tvm.register_extension`.
*/
class NDSubClass : public tvm::runtime::NDArray {
public:
class SubContainer : public NDArray::Container {
public:
SubContainer(int addtional_info) :
addtional_info_(addtional_info) {
array_type_code_ = array_type_info<NDSubClass>::code;
}
static bool Is(NDArray::Container *container) {
SubContainer *c = static_cast<SubContainer*>(container);
return c->array_type_code_ == array_type_info<NDSubClass>::code;
}
int addtional_info_{0};
};
NDSubClass(NDArray::Container *container) {
if (container == nullptr) {
data_ = nullptr;
return;
}
CHECK(SubContainer::Is(container));
container->IncRef();
data_ = container;
}
~NDSubClass() {
this->reset();
}
NDSubClass AddWith(const NDSubClass &other) const {
SubContainer *a = static_cast<SubContainer*>(data_);
SubContainer *b = static_cast<SubContainer*>(other.data_);
CHECK(a != nullptr && b != nullptr);
return NDSubClass(new SubContainer(a->addtional_info_ + b->addtional_info_));
}
int get_additional_info() const {
SubContainer *self = static_cast<SubContainer*>(data_);
CHECK(self != nullptr);
return self->addtional_info_;
}
};
} // namespace tvm_ext
namespace tvm_ext {
TVM_REGISTER_EXT_TYPE(IntVector); TVM_REGISTER_EXT_TYPE(IntVector);
...@@ -64,6 +127,26 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev") ...@@ -64,6 +127,26 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = (*tvm::runtime::Registry::Get("device_api.cpu"))(); *rv = (*tvm::runtime::Registry::Get("device_api.cpu"))();
}); });
TVM_REGISTER_GLOBAL("tvm_ext.nd_create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
int addtional_info = args[0];
*rv = NDSubClass(new NDSubClass::SubContainer(addtional_info));
});
TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
.set_body([](TVMArgs args, TVMRetValue *rv) {
NDSubClass a = args[0];
NDSubClass b = args[1];
*rv = a.AddWith(b);
});
TVM_REGISTER_GLOBAL("tvm_ext.nd_get_addtional_info")
.set_body([](TVMArgs args, TVMRetValue *rv) {
NDSubClass a = args[0];
*rv = a.get_additional_info();
});
} // namespace tvm_ext } // namespace tvm_ext
// External function exposed to runtime. // External function exposed to runtime.
......
...@@ -32,6 +32,7 @@ def test_sym_add(): ...@@ -32,6 +32,7 @@ def test_sym_add():
c = tvm_ext.sym_add(a, b) c = tvm_ext.sym_add(a, b)
assert c.a == a and c.b == b assert c.a == a and c.b == b
def test_ext_vec(): def test_ext_vec():
ivec = tvm_ext.ivec_create(1, 2, 3) ivec = tvm_ext.ivec_create(1, 2, 3)
assert(isinstance(ivec, tvm_ext.IntVec)) assert(isinstance(ivec, tvm_ext.IntVec))
...@@ -44,6 +45,7 @@ def test_ext_vec(): ...@@ -44,6 +45,7 @@ def test_ext_vec():
tvm.convert(ivec_cb)(ivec) tvm.convert(ivec_cb)(ivec)
def test_extract_ext(): def test_extract_ext():
fdict = tvm.extract_ext_funcs(tvm_ext._LIB.TVMExtDeclare) fdict = tvm.extract_ext_funcs(tvm_ext._LIB.TVMExtDeclare)
assert fdict["mul"](3, 4) == 12 assert fdict["mul"](3, 4) == 12
...@@ -68,7 +70,21 @@ def test_extern_call(): ...@@ -68,7 +70,21 @@ def test_extern_call():
check_llvm() check_llvm()
def test_nd_subclass():
a = tvm_ext.NDSubClass.create(addtional_info=3)
b = tvm_ext.NDSubClass.create(addtional_info=5)
c = a + b
d = a + a
e = b + b
assert(a.addtional_info == 3)
assert(b.addtional_info == 5)
assert(c.addtional_info == 8)
assert(d.addtional_info == 6)
assert(e.addtional_info == 10)
if __name__ == "__main__": if __name__ == "__main__":
test_nd_subclass()
test_extern_call() test_extern_call()
test_ext_dev() test_ext_dev()
test_ext_vec() test_ext_vec()
......
...@@ -178,11 +178,31 @@ class NDArray { ...@@ -178,11 +178,31 @@ class NDArray {
Container* data_{nullptr}; Container* data_{nullptr};
// enable internal functions // enable internal functions
friend struct Internal; friend struct Internal;
friend class TVMPODValue_;
friend class TVMArgValue;
friend class TVMRetValue; friend class TVMRetValue;
friend class TVMArgsSetter; friend class TVMArgsSetter;
}; };
/*! /*!
* \brief The type trait indicates subclass of TVM's NDArray.
* For irrelavant classes, code = -1.
* For TVM NDArray itself, code = 0.
* All subclasses of NDArray should override code > 0.
*/
template<typename T>
struct array_type_info {
/*! \brief the value of the traits */
static const int code = -1;
};
// Overrides the type trait for tvm's NDArray.
template<>
struct array_type_info<NDArray> {
static const int code = 0;
};
/*!
* \brief Save a DLTensor to stream * \brief Save a DLTensor to stream
* \param strm The outpu stream * \param strm The outpu stream
* \param tensor The tensor to be saved. * \param tensor The tensor to be saved.
...@@ -196,7 +216,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor); ...@@ -196,7 +216,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
* the pointer to the NDArrayContainer can be directly * the pointer to the NDArrayContainer can be directly
* interpreted as a DLTensor* * interpreted as a DLTensor*
* *
* \note: do not use this function directly, use NDArray. * \note do not use this function directly, use NDArray.
*/ */
class NDArray::Container { class NDArray::Container {
public: public:
...@@ -228,16 +248,19 @@ class NDArray::Container { ...@@ -228,16 +248,19 @@ class NDArray::Container {
protected: protected:
friend class NDArray; friend class NDArray;
friend class TVMPODValue_;
friend class TVMArgValue;
friend class TVMRetValue;
friend class RPCWrappedFunc; friend class RPCWrappedFunc;
/*! /*!
* \brief Type flag used to indicate subclass. * \brief Type flag used to indicate subclass.
* Default value 0 means normal NDArray::Conatainer. * Default value 0 means normal NDArray::Conatainer.
* *
* We can extend a more specialized NDArray::Container * We can extend a more specialized NDArray::Container
* and use the array_type_index_ to indicate * and use the array_type_code_ to indicate
* the specific array subclass. * the specific array subclass.
*/ */
uint32_t array_type_index_{0}; int32_t array_type_code_{0};
/*! \brief The internal reference counter */ /*! \brief The internal reference counter */
std::atomic<int> ref_counter_{0}; std::atomic<int> ref_counter_{0};
/*! /*!
......
...@@ -362,7 +362,7 @@ inline std::string TVMType2String(TVMType t); ...@@ -362,7 +362,7 @@ inline std::string TVMType2String(TVMType t);
* \tparam T the typename * \tparam T the typename
*/ */
template<typename T> template<typename T>
struct extension_class_info { struct extension_type_info {
static const int code = 0; static const int code = 0;
}; };
...@@ -455,6 +455,15 @@ class TVMPODValue_ { ...@@ -455,6 +455,15 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx; return value_.v_ctx;
} }
template<typename TNDArray,
typename = typename std::enable_if<
std::is_base_of<NDArray, TNDArray>::value>::type>
TNDArray AsNDArray() const {
if (type_code_ == kNull) return TNDArray(nullptr);
auto *container = static_cast<NDArray::Container*>(value_.v_handle);
CHECK_EQ(container->array_type_code_, array_type_info<TNDArray>::code);
return TNDArray(container);
}
template<typename TExtension> template<typename TExtension>
const TExtension& AsExtension() const { const TExtension& AsExtension() const {
CHECK_LT(type_code_, kExtEnd); CHECK_LT(type_code_, kExtEnd);
...@@ -561,7 +570,7 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -561,7 +570,7 @@ class TVMArgValue : public TVMPODValue_ {
inline TNodeRef AsNodeRef() const; inline TNodeRef AsNodeRef() const;
template<typename T, template<typename T,
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_class<T>::value>::type> std::is_class<T>::value>::type>
inline operator T() const; inline operator T() const;
template<typename TNodeRef, template<typename TNodeRef,
typename = typename std::enable_if< typename = typename std::enable_if<
...@@ -727,10 +736,10 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -727,10 +736,10 @@ class TVMRetValue : public TVMPODValue_ {
} }
template<typename T, template<typename T,
typename = typename std::enable_if< typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type> extension_type_info<T>::code != 0>::type>
TVMRetValue& operator=(const T& other) { TVMRetValue& operator=(const T& other) {
this->SwitchToClass<T>( this->SwitchToClass<T>(
extension_class_info<T>::code, other); extension_type_info<T>::code, other);
return *this; return *this;
} }
/*! /*!
...@@ -1094,7 +1103,7 @@ class TVMArgsSetter { ...@@ -1094,7 +1103,7 @@ class TVMArgsSetter {
// extension // extension
template<typename T, template<typename T,
typename = typename std::enable_if< typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type> extension_type_info<T>::code != 0>::type>
inline void operator()(size_t i, const T& value) const; inline void operator()(size_t i, const T& value) const;
// NodeRef related extenstions: in tvm/packed_func_ext.h // NodeRef related extenstions: in tvm/packed_func_ext.h
inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*) inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*)
...@@ -1212,40 +1221,53 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const { ...@@ -1212,40 +1221,53 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
// extension and node type handling // extension and node type handling
namespace detail { namespace detail {
template<typename T, typename TSrc, bool is_ext> template<typename T, typename TSrc, bool is_ext, bool is_nd>
struct TVMValueCast { struct TVMValueCast {
static T Apply(const TSrc* self) { static T Apply(const TSrc* self) {
static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions");
return self->template AsNodeRef<T>(); return self->template AsNodeRef<T>();
} }
}; };
template<typename T, typename TSrc> template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, true> { struct TVMValueCast<T, TSrc, true, false> {
static T Apply(const TSrc* self) { static T Apply(const TSrc* self) {
return self->template AsExtension<T>(); return self->template AsExtension<T>();
} }
}; };
template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, false, true> {
static T Apply(const TSrc* self) {
return self->template AsNDArray<T>();
}
};
} // namespace detail } // namespace detail
template<typename T, typename> template<typename T, typename>
inline TVMArgValue::operator T() const { inline TVMArgValue::operator T() const {
return detail:: return detail::
TVMValueCast<T, TVMArgValue, extension_class_info<T>::code != 0> TVMValueCast<T, TVMArgValue,
(extension_type_info<T>::code != 0),
(array_type_info<T>::code > 0)>
::Apply(this); ::Apply(this);
} }
template<typename T, typename> template<typename T, typename>
inline TVMRetValue::operator T() const { inline TVMRetValue::operator T() const {
return detail:: return detail::
TVMValueCast<T, TVMRetValue, extension_class_info<T>::code != 0> TVMValueCast<T, TVMRetValue,
(extension_type_info<T>::code != 0),
(array_type_info<T>::code > 0)>
::Apply(this); ::Apply(this);
} }
template<typename T, typename> template<typename T, typename>
inline void TVMArgsSetter::operator()(size_t i, const T& value) const { inline void TVMArgsSetter::operator()(size_t i, const T& value) const {
static_assert(extension_class_info<T>::code != 0, static_assert(extension_type_info<T>::code != 0,
"Need to have extesion code"); "Need to have extesion code");
type_codes_[i] = extension_class_info<T>::code; type_codes_[i] = extension_type_info<T>::code;
values_[i].v_handle = const_cast<T*>(&value); values_[i].v_handle = const_cast<T*>(&value);
} }
...@@ -1262,9 +1284,9 @@ struct ExtTypeInfo { ...@@ -1262,9 +1284,9 @@ struct ExtTypeInfo {
template<typename T> template<typename T>
inline ExtTypeVTable* ExtTypeVTable::Register_() { inline ExtTypeVTable* ExtTypeVTable::Register_() {
const int code = extension_class_info<T>::code; const int code = extension_type_info<T>::code;
static_assert(code != 0, static_assert(code != 0,
"require extension_class_info traits to be declared with non-zero code"); "require extension_type_info traits to be declared with non-zero code");
ExtTypeVTable vt; ExtTypeVTable vt;
vt.clone = ExtTypeInfo<T>::clone; vt.clone = ExtTypeInfo<T>::clone;
vt.destroy = ExtTypeInfo<T>::destroy; vt.destroy = ExtTypeInfo<T>::destroy;
......
...@@ -133,7 +133,7 @@ class Registry { ...@@ -133,7 +133,7 @@ class Registry {
/*! /*!
* \brief Macro to register extension type. * \brief Macro to register extension type.
* This must be registered in a cc file * This must be registered in a cc file
* after the trait extension_class_info is defined. * after the trait extension_type_info is defined.
*/ */
#define TVM_REGISTER_EXT_TYPE(T) \ #define TVM_REGISTER_EXT_TYPE(T) \
TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \ TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \
......
...@@ -40,17 +40,17 @@ namespace tvm { ...@@ -40,17 +40,17 @@ namespace tvm {
namespace runtime { namespace runtime {
template<> template<>
struct extension_class_info<nnvm::Symbol> { struct extension_type_info<nnvm::Symbol> {
static const int code = 16; static const int code = 16;
}; };
template<> template<>
struct extension_class_info<nnvm::Graph> { struct extension_type_info<nnvm::Graph> {
static const int code = 17; static const int code = 17;
}; };
template<> template<>
struct extension_class_info<nnvm::compiler::AttrDict> { struct extension_type_info<nnvm::compiler::AttrDict> {
static const int code = 18; static const int code = 18;
}; };
......
...@@ -76,8 +76,8 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._register_alter_op_layout") ...@@ -76,8 +76,8 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._register_alter_op_layout")
if (ret.type_code() == TVMTypeCode::kNull) { if (ret.type_code() == TVMTypeCode::kNull) {
return false; return false;
} }
CHECK_EQ(ret.type_code(), tvm::runtime::extension_class_info<Symbol>::code) CHECK_EQ(ret.type_code(), tvm::runtime::extension_type_info<Symbol>::code)
<< " expected " << "Symbol (code = " << tvm::runtime::extension_class_info<Symbol>::code << " expected " << "Symbol (code = " << tvm::runtime::extension_type_info<Symbol>::code
<< ") but get code = " << ret.type_code(); << ") but get code = " << ret.type_code();
*ret_symbol = *(static_cast<Symbol*>(ret.value().v_handle)); *ret_symbol = *(static_cast<Symbol*>(ret.value().v_handle));
return true; return true;
......
...@@ -223,13 +223,13 @@ def _handle_return_func(x): ...@@ -223,13 +223,13 @@ def _handle_return_func(x):
_node.__init_by_constructor__ = __init_handle_by_constructor__ _node.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False) RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE) _handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, TypeCode.MODULE_HANDLE) _return_module, TypeCode.MODULE_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(x.v_handle, True) C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(x.v_handle, True, False)
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False) C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
_CLASS_MODULE = None _CLASS_MODULE = None
_CLASS_FUNCTION = None _CLASS_FUNCTION = None
......
...@@ -4,7 +4,7 @@ from __future__ import absolute_import ...@@ -4,7 +4,7 @@ from __future__ import absolute_import
import ctypes import ctypes
from ..base import _LIB, check_call, c_str from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import TVMArrayHandle from ..runtime_ctypes import TVMArrayHandle, TVMNDArrayContainerHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
...@@ -28,7 +28,7 @@ def _from_dlpack(dltensor): ...@@ -28,7 +28,7 @@ def _from_dlpack(dltensor):
check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle))) check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle)))
ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor) ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0)) ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))
return _make_array(handle, False) return _make_array(handle, False, False)
raise ValueError("Expect a dltensor field, PyCapsule can only be consumed once") raise ValueError("Expect a dltensor field, PyCapsule can only be consumed once")
...@@ -77,9 +77,15 @@ class NDArrayBase(object): ...@@ -77,9 +77,15 @@ class NDArrayBase(object):
return ctypes.pythonapi.PyCapsule_New(handle, _c_str_dltensor, _c_dlpack_deleter) return ctypes.pythonapi.PyCapsule_New(handle, _c_str_dltensor, _c_dlpack_deleter)
def _make_array(handle, is_view): def _make_array(handle, is_view, is_container):
global _TVM_ND_CLS
handle = ctypes.cast(handle, TVMArrayHandle) handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view) fcreate = _CLASS_NDARRAY
if is_container and _TVM_ND_CLS:
array_type_info = ctypes.cast(handle, TVMNDArrayContainerHandle).array_type_info.value
if array_type_info > 0:
fcreate = _TVM_ND_CLS[array_type_info]
return fcreate(handle, is_view)
_TVM_COMPATS = () _TVM_COMPATS = ()
...@@ -91,6 +97,11 @@ def _reg_extension(cls, fcreate): ...@@ -91,6 +97,11 @@ def _reg_extension(cls, fcreate):
RETURN_SWITCH[cls._tvm_tcode] = fret RETURN_SWITCH[cls._tvm_tcode] = fret
C_TO_PY_ARG_SWITCH[cls._tvm_tcode] = _wrap_arg_func(fret, cls._tvm_tcode) C_TO_PY_ARG_SWITCH[cls._tvm_tcode] = _wrap_arg_func(fret, cls._tvm_tcode)
_TVM_ND_CLS = {}
def _reg_ndarray(cls, fcreate):
global _TVM_ND_CLS
_TVM_ND_CLS[cls._array_type_code] = fcreate
_CLASS_NDARRAY = None _CLASS_NDARRAY = None
......
...@@ -2,7 +2,7 @@ from ..base import TVMError ...@@ -2,7 +2,7 @@ from ..base import TVMError
from libcpp.vector cimport vector from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule from cpython cimport pycapsule
from libc.stdint cimport int64_t, uint64_t, uint8_t, uint16_t from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t
import ctypes import ctypes
cdef enum TVMTypeCode: cdef enum TVMTypeCode:
...@@ -61,6 +61,14 @@ ctypedef void* TVMRetValueHandle ...@@ -61,6 +61,14 @@ ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle ctypedef void* TVMFunctionHandle
ctypedef void* NodeHandle ctypedef void* NodeHandle
ctypedef struct TVMNDArrayContainer:
DLTensor dl_tensor
void* manager_ctx
void (*deleter)(DLManagedTensor* self)
int32_t array_type_info
ctypedef TVMNDArrayContainer* TVMNDArrayContainerHandle
ctypedef int (*TVMPackedCFunc)( ctypedef int (*TVMPackedCFunc)(
TVMValue* args, TVMValue* args,
int* type_codes, int* type_codes,
......
...@@ -33,7 +33,7 @@ cdef int tvm_callback(TVMValue* args, ...@@ -33,7 +33,7 @@ cdef int tvm_callback(TVMValue* args,
if tcode != kArrayHandle: if tcode != kArrayHandle:
pyargs.append(make_ret(value, tcode)) pyargs.append(make_ret(value, tcode))
else: else:
pyargs.append(c_make_array(value.v_handle, True)) pyargs.append(c_make_array(value.v_handle, True, False))
try: try:
rv = local_pyfunc(*pyargs) rv = local_pyfunc(*pyargs)
except Exception: except Exception:
...@@ -175,7 +175,7 @@ cdef inline object make_ret(TVMValue value, int tcode): ...@@ -175,7 +175,7 @@ cdef inline object make_ret(TVMValue value, int tcode):
elif tcode == kFloat: elif tcode == kFloat:
return value.v_float64 return value.v_float64
elif tcode == kNDArrayContainer: elif tcode == kNDArrayContainer:
return c_make_array(value.v_handle, False) return c_make_array(value.v_handle, False, True)
elif tcode == kStr: elif tcode == kStr:
return py_str(value.v_str) return py_str(value.v_str)
elif tcode == kBytes: elif tcode == kBytes:
......
...@@ -20,7 +20,7 @@ def _from_dlpack(object dltensor): ...@@ -20,7 +20,7 @@ def _from_dlpack(object dltensor):
# set name and destructor to be empty # set name and destructor to be empty
pycapsule.PyCapsule_SetDestructor(dltensor, NULL) pycapsule.PyCapsule_SetDestructor(dltensor, NULL)
pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor) pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
return c_make_array(chandle, 0) return c_make_array(chandle, False, False)
raise ValueError("Expect a dltensor field, pycapsule.PyCapsule can only be consumed once") raise ValueError("Expect a dltensor field, pycapsule.PyCapsule can only be consumed once")
...@@ -73,8 +73,15 @@ cdef class NDArrayBase: ...@@ -73,8 +73,15 @@ cdef class NDArrayBase:
return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter) return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter)
cdef c_make_array(void* chandle, is_view): cdef c_make_array(void* chandle, is_view, is_container):
ret = _CLASS_NDARRAY(None, is_view) global _TVM_ND_CLS
cdef int32_t array_type_info
fcreate = _CLASS_NDARRAY
if is_container and len(_TVM_ND_CLS) > 0:
array_type_info = (<TVMNDArrayContainerHandle>chandle).array_type_info
if array_type_info > 0:
fcreate = _TVM_ND_CLS[array_type_info]
ret = fcreate(None, is_view)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle (<NDArrayBase>ret).chandle = <DLTensor*>chandle
return ret return ret
...@@ -89,11 +96,16 @@ def _reg_extension(cls, fcreate): ...@@ -89,11 +96,16 @@ def _reg_extension(cls, fcreate):
if fcreate: if fcreate:
_TVM_EXT_RET[cls._tvm_tcode] = fcreate _TVM_EXT_RET[cls._tvm_tcode] = fcreate
cdef _TVM_ND_CLS = {}
def _make_array(handle, is_view): def _reg_ndarray(cls, fcreate):
global _TVM_ND_CLS
_TVM_ND_CLS[cls._array_type_code] = fcreate
def _make_array(handle, is_view, is_container):
cdef unsigned long long ptr cdef unsigned long long ptr
ptr = ctypes.cast(handle, ctypes.c_void_p).value ptr = ctypes.cast(handle, ctypes.c_void_p).value
return c_make_array(<void*>ptr, is_view) return c_make_array(<void*>ptr, is_view, is_container)
cdef object _CLASS_NDARRAY = None cdef object _CLASS_NDARRAY = None
......
...@@ -17,15 +17,18 @@ try: ...@@ -17,15 +17,18 @@ try:
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy3.core import NDArrayBase as _NDArrayBase from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import _reg_extension, _reg_ndarray
else: else:
from ._cy2.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack from ._cy2.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy2.core import NDArrayBase as _NDArrayBase from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import _reg_extension, _reg_ndarray
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.ndarray import _reg_extension, _reg_ndarray
def context(dev_type, dev_id=0): def context(dev_type, dev_id=0):
...@@ -111,7 +114,7 @@ def empty(shape, dtype="float32", ctx=context(1, 0)): ...@@ -111,7 +114,7 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
ctx.device_type, ctx.device_type,
ctx.device_id, ctx.device_id,
ctypes.byref(handle))) ctypes.byref(handle)))
return _make_array(handle, False) return _make_array(handle, False, False)
def from_dlpack(dltensor): def from_dlpack(dltensor):
...@@ -295,6 +298,7 @@ def free_extension_handle(handle, type_code): ...@@ -295,6 +298,7 @@ def free_extension_handle(handle, type_code):
""" """
check_call(_LIB.TVMExtTypeFree(handle, ctypes.c_int(type_code))) check_call(_LIB.TVMExtTypeFree(handle, ctypes.c_int(type_code)))
def register_extension(cls, fcreate=None): def register_extension(cls, fcreate=None):
"""Register a extension class to TVM. """Register a extension class to TVM.
...@@ -306,21 +310,26 @@ def register_extension(cls, fcreate=None): ...@@ -306,21 +310,26 @@ def register_extension(cls, fcreate=None):
cls : class cls : class
The class object to be registered as extension. The class object to be registered as extension.
fcreate : function, optional
The creation function to create a class object given handle value.
Note Note
---- ----
The registered class is requires one property: _tvm_handle and a class attribute _tvm_tcode. The registered class is requires one property: _tvm_handle.
If the registered class is a subclass of NDArray,
it is required to have a class attribute _array_type_code.
Otherwise, it is required to have a class attribute _tvm_tcode.
- ```_tvm_handle``` returns integer represents the address of the handle. - ```_tvm_handle``` returns integer represents the address of the handle.
- ```_tvm_tcode``` gives integer represents type code of the class. - ```_tvm_tcode``` or ```_array_type_code``` gives integer represents type
code of the class.
Returns Returns
------- -------
cls : class cls : class
The class being registered. The class being registered.
fcreate : function, optional
The creation function to create a class object given handle value.
Example Example
------- -------
The following code registers user defined class The following code registers user defined class
...@@ -339,7 +348,13 @@ def register_extension(cls, fcreate=None): ...@@ -339,7 +348,13 @@ def register_extension(cls, fcreate=None):
def _tvm_handle(self): def _tvm_handle(self):
return self.handle.value return self.handle.value
""" """
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN: if issubclass(cls, _NDArrayBase):
raise ValueError("Cannot register create when extension tcode is same as buildin") assert fcreate is not None
_reg_extension(cls, fcreate) assert hasattr(cls, "_array_type_code")
_reg_ndarray(cls, fcreate)
else:
assert hasattr(cls, "_tvm_tcode")
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
return cls return cls
...@@ -240,3 +240,12 @@ class TVMArray(ctypes.Structure): ...@@ -240,3 +240,12 @@ class TVMArray(ctypes.Structure):
("byte_offset", ctypes.c_uint64)] ("byte_offset", ctypes.c_uint64)]
TVMArrayHandle = ctypes.POINTER(TVMArray) TVMArrayHandle = ctypes.POINTER(TVMArray)
class TVMNDArrayContainer(ctypes.Structure):
"""TVM NDArray::Container"""
_fields_ = [("dl_tensor", TVMArray),
("manager_ctx", ctypes.c_void_p),
("deleter", ctypes.c_void_p),
("array_type_info", ctypes.c_int32)]
TVMNDArrayContainerHandle = ctypes.POINTER(TVMNDArrayContainer)
...@@ -15,7 +15,7 @@ from ._ffi.ndarray import register_extension, free_extension_handle ...@@ -15,7 +15,7 @@ from ._ffi.ndarray import register_extension, free_extension_handle
class NDArray(NDArrayBase): class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime. """Lightweight NDArray class of TVM runtime.
Strictly this is only an Array Container(a buffer object) Strictly this is only an Array Container (a buffer object)
No arthimetic operations are defined. No arthimetic operations are defined.
All operations are performed by TVM functions. All operations are performed by TVM functions.
......
...@@ -168,7 +168,7 @@ namespace tvm { ...@@ -168,7 +168,7 @@ namespace tvm {
namespace runtime { namespace runtime {
template<> template<>
struct extension_class_info<test::IntVector> { struct extension_type_info<test::IntVector> {
static const int code = kExtBegin + 1; static const int code = kExtBegin + 1;
}; };
} // runtime } // runtime
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment