Unverified Commit 55bd786f by Tianqi Chen Committed by GitHub

[REFACTOR][RUNTIME] Update NDArray use the Unified Object System (#4581)

* [REFACTOR][RUNTIME] Move NDArray to Object System.

Previously NDArray has its own object reference counting mechanism.
This PR migrates NDArray to the unified object protocol.

The calling convention of NDArray remained intact.
That means NDArray still has its own type_code and
its handle is still DLTensor compatible.

In order to do so, this PR added a few minimum runtime type
detection in TVMArgValue and RetValue only when the corresponding
type is a base type(ObjectRef) that could also refer to NDArray.

This means that even if we return a base reference object ObjectRef
which refers to the NDArray. The type_code will still be translated
correctly as kNDArrayContainer.
If we assign a non-base type(say Expr) that we know is not compatible
with NDArray during compile time, no runtime type detection will be performed.

This PR also adopts the object protocol for NDArray sub-classing and
removed the legacy NDArray subclass protocol.
Examples in apps/extension are now updated to reflect that.

Making NDArray as an Object brings all the benefits of the object system.
For example, we can now use the Array container to store NDArrays.

* Address review comments
parent 4072396e
......@@ -51,26 +51,23 @@ class IntVec(tvm.Object):
nd_create = tvm.get_global_func("tvm_ext.nd_create")
nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
nd_get_addtional_info = tvm.get_global_func("tvm_ext.nd_get_addtional_info")
nd_get_additional_info = tvm.get_global_func("tvm_ext.nd_get_additional_info")
@tvm.register_object("tvm_ext.NDSubClass")
class NDSubClass(tvm.nd.NDArrayBase):
"""Example for subclassing TVM's NDArray infrastructure.
By inheriting TMV's NDArray, external libraries could
leverage TVM's FFI without any modification.
"""
# Should be consistent with the type-trait set in the backend
_array_type_code = 1
@staticmethod
def create(addtional_info):
return nd_create(addtional_info)
def create(additional_info):
return nd_create(additional_info)
@property
def addtional_info(self):
return nd_get_addtional_info(self)
def additional_info(self):
return nd_get_additional_info(self)
def __add__(self, other):
return nd_add_two(self, other)
tvm.register_extension(NDSubClass, NDSubClass)
......@@ -29,19 +29,6 @@
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h>
namespace tvm_ext {
class NDSubClass;
} // namespace tvm_ext
namespace tvm {
namespace runtime {
template<>
struct array_type_info<tvm_ext::NDSubClass> {
static const int code = 1;
};
} // namespace tvm
} // namespace runtime
using namespace tvm;
using namespace tvm::runtime;
......@@ -52,54 +39,55 @@ namespace tvm_ext {
* To use this extension, an external library should
*
* 1) Inherit TVM's NDArray and NDArray container,
* and define the trait `array_type_info` for this class.
*
* 2) Define a constructor in the inherited class that accepts
* a pointer to TVM's Container, which is nullable.
* 2) Follow the new object protocol to define new NDArray as a reference class.
*
* 3) On Python frontend, inherit `tvm.nd.NDArrayBase`,
* define the class attribute `_array_type_code` consistent to
* the C++ type trait, and register the subclass using `tvm.register_extension`.
* 3) On Python frontend, inherit `tvm.nd.NDArray`,
* register the type using tvm.register_object
*/
class NDSubClass : public tvm::runtime::NDArray {
public:
class SubContainer : public NDArray::Container {
public:
SubContainer(int addtional_info) :
addtional_info_(addtional_info) {
array_type_code_ = array_type_info<NDSubClass>::code;
}
static bool Is(NDArray::Container *container) {
SubContainer *c = static_cast<SubContainer*>(container);
return c->array_type_code_ == array_type_info<NDSubClass>::code;
SubContainer(int additional_info) :
additional_info_(additional_info) {
type_index_ = SubContainer::RuntimeTypeIndex();
}
int addtional_info_{0};
int additional_info_{0};
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "tvm_ext.NDSubClass";
TVM_DECLARE_FINAL_OBJECT_INFO(SubContainer, NDArray::Container);
};
NDSubClass(NDArray::Container *container) {
if (container == nullptr) {
data_ = nullptr;
return;
}
CHECK(SubContainer::Is(container));
container->IncRef();
data_ = container;
static void SubContainerDeleter(Object* obj) {
auto* ptr = static_cast<SubContainer*>(obj);
delete ptr;
}
~NDSubClass() {
this->reset();
NDSubClass() {}
explicit NDSubClass(ObjectPtr<Object> n) : NDArray(n) {}
explicit NDSubClass(int additional_info) {
SubContainer* ptr = new SubContainer(additional_info);
ptr->SetDeleter(SubContainerDeleter);
data_ = GetObjectPtr<Object>(ptr);
}
NDSubClass AddWith(const NDSubClass &other) const {
SubContainer *a = static_cast<SubContainer*>(data_);
SubContainer *b = static_cast<SubContainer*>(other.data_);
SubContainer *a = static_cast<SubContainer*>(get_mutable());
SubContainer *b = static_cast<SubContainer*>(other.get_mutable());
CHECK(a != nullptr && b != nullptr);
return NDSubClass(new SubContainer(a->addtional_info_ + b->addtional_info_));
return NDSubClass(a->additional_info_ + b->additional_info_);
}
int get_additional_info() const {
SubContainer *self = static_cast<SubContainer*>(data_);
SubContainer *self = static_cast<SubContainer*>(get_mutable());
CHECK(self != nullptr);
return self->addtional_info_;
return self->additional_info_;
}
using ContainerType = SubContainer;
};
TVM_REGISTER_OBJECT_TYPE(NDSubClass::SubContainer);
/*!
* \brief Introduce additional extension data structures
......@@ -166,8 +154,10 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev")
TVM_REGISTER_GLOBAL("tvm_ext.nd_create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
int addtional_info = args[0];
*rv = NDSubClass(new NDSubClass::SubContainer(addtional_info));
int additional_info = args[0];
*rv = NDSubClass(additional_info);
CHECK_EQ(rv->type_code(), kNDArrayContainer);
});
TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
......@@ -177,7 +167,7 @@ TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
*rv = a.AddWith(b);
});
TVM_REGISTER_GLOBAL("tvm_ext.nd_get_addtional_info")
TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info")
.set_body([](TVMArgs args, TVMRetValue *rv) {
NDSubClass a = args[0];
*rv = a.get_additional_info();
......
......@@ -87,16 +87,17 @@ def test_extern_call():
def test_nd_subclass():
a = tvm_ext.NDSubClass.create(addtional_info=3)
b = tvm_ext.NDSubClass.create(addtional_info=5)
a = tvm_ext.NDSubClass.create(additional_info=3)
b = tvm_ext.NDSubClass.create(additional_info=5)
assert isinstance(a, tvm_ext.NDSubClass)
c = a + b
d = a + a
e = b + b
assert(a.addtional_info == 3)
assert(b.addtional_info == 5)
assert(c.addtional_info == 8)
assert(d.addtional_info == 6)
assert(e.addtional_info == 10)
assert(a.additional_info == 3)
assert(b.additional_info == 5)
assert(c.additional_info == 8)
assert(d.additional_info == 6)
assert(e.additional_info == 10)
if __name__ == "__main__":
......
......@@ -23,14 +23,14 @@
#ifndef TVM_NODE_CONTAINER_H_
#define TVM_NODE_CONTAINER_H_
#include <tvm/node/node.h>
#include <type_traits>
#include <vector>
#include <initializer_list>
#include <unordered_map>
#include <utility>
#include <string>
#include "node.h"
#include "memory.h"
namespace tvm {
......
......@@ -25,7 +25,6 @@
#ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_
#include <sstream>
#include <string>
#include <memory>
#include <limits>
......@@ -43,22 +42,7 @@ using runtime::TVMRetValue;
using runtime::PackedFunc;
namespace runtime {
/*!
* \brief Runtime type checker for node type.
* \tparam T the type to be checked.
*/
template<typename T>
struct ObjectTypeChecker {
static bool Check(const Object* ptr) {
using ContainerType = typename T::ContainerType;
if (ptr == nullptr) return true;
return ptr->IsInstance<ContainerType>();
}
static void PrintName(std::ostream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
}
};
template<typename T>
struct ObjectTypeChecker<Array<T> > {
......@@ -73,10 +57,8 @@ struct ObjectTypeChecker<Array<T> > {
}
return true;
}
static void PrintName(std::ostream& os) { // NOLINT(*)
os << "List[";
ObjectTypeChecker<T>::PrintName(os);
os << "]";
static std::string TypeName() {
return "List[" + ObjectTypeChecker<T>::TypeName() + "]";
}
};
......@@ -91,11 +73,9 @@ struct ObjectTypeChecker<Map<std::string, V> > {
}
return true;
}
static void PrintName(std::ostream& os) { // NOLINT(*)
os << "Map[str";
os << ',';
ObjectTypeChecker<V>::PrintName(os);
os << ']';
static std::string TypeName() {
return "Map[str, " +
ObjectTypeChecker<V>::TypeName()+ ']';
}
};
......@@ -111,39 +91,16 @@ struct ObjectTypeChecker<Map<K, V> > {
}
return true;
}
static void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "Map[";
ObjectTypeChecker<K>::PrintName(os);
os << ',';
ObjectTypeChecker<V>::PrintName(os);
os << ']';
static std::string TypeName() {
return "Map[" +
ObjectTypeChecker<K>::TypeName() +
", " +
ObjectTypeChecker<V>::TypeName()+ ']';
}
};
template<typename T>
inline std::string ObjectTypeName() {
std::ostringstream os;
ObjectTypeChecker<T>::PrintName(os);
return os.str();
}
// extensions for tvm arg value
template<typename TObjectRef>
inline TObjectRef TVMArgValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
if (type_code_ == kNull) return TObjectRef(NodePtr<Node>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TObjectRef>()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(ObjectPtr<Node>(ptr));
}
inline TVMArgValue::operator tvm::Expr() const {
inline TVMPODValue_::operator tvm::Expr() const {
if (type_code_ == kNull) return Expr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
......@@ -164,12 +121,12 @@ inline TVMArgValue::operator tvm::Expr() const {
return Tensor(ObjectPtr<Node>(ptr))();
}
CHECK(ObjectTypeChecker<Expr>::Check(ptr))
<< "Expected type " << ObjectTypeName<Expr>()
<< "Expect type " << ObjectTypeChecker<Expr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return Expr(ObjectPtr<Node>(ptr));
}
inline TVMArgValue::operator tvm::Integer() const {
inline TVMPODValue_::operator tvm::Integer() const {
if (type_code_ == kNull) return Integer();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
......@@ -179,35 +136,10 @@ inline TVMArgValue::operator tvm::Integer() const {
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<Integer>::Check(ptr))
<< "Expected type " << ObjectTypeName<Expr>()
<< "Expect type " << ObjectTypeChecker<Expr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return Integer(ObjectPtr<Node>(ptr));
}
template<typename TObjectRef, typename>
inline bool TVMPODValue_::IsObjectRef() const {
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
return ObjectTypeChecker<TObjectRef>::Check(ptr);
}
// extensions for TVMRetValue
template<typename TObjectRef>
inline TObjectRef TVMRetValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
if (type_code_ == kNull) return TObjectRef();
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TObjectRef>()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(ObjectPtr<Object>(ptr));
}
} // namespace runtime
} // namespace tvm
#endif // TVM_PACKED_FUNC_EXT_H_
......@@ -23,6 +23,7 @@
*/
#ifndef TVM_RUNTIME_CONTAINER_H_
#define TVM_RUNTIME_CONTAINER_H_
#include <dmlc/logging.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
......
......@@ -24,10 +24,11 @@
#define TVM_RUNTIME_OBJECT_H_
#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
#include <type_traits>
#include <string>
#include <utility>
#include "c_runtime_api.h"
/*!
* \brief Whether or not use atomic reference counter.
......@@ -581,6 +582,14 @@ class ObjectRef {
return T(std::move(ref.data_));
}
/*!
* \brief Clear the object ref data field without DecRef
* after we successfully moved the field.
* \param ref The reference data.
*/
static void FFIClearAfterMove(ObjectRef* ref) {
ref->data_.data_ = nullptr;
}
/*!
* \brief Internal helper function get data_ as ObjectPtr of ObjectType.
* \note only used for internal dev purpose.
* \tparam ObjectType The corresponding object type.
......@@ -648,7 +657,7 @@ struct ObjectEqual {
return _GetOrAllocRuntimeTypeIndex(); \
} \
static const uint32_t _GetOrAllocRuntimeTypeIndex() { \
static uint32_t tidx = GetOrAllocRuntimeTypeIndex( \
static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex( \
TypeName::_type_key, \
TypeName::_type_index, \
ParentType::_GetOrAllocRuntimeTypeIndex(), \
......@@ -668,6 +677,19 @@ struct ObjectEqual {
TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define TVM_ATTRIBUTE_UNUSED
#endif
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_OBJECT_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid
/*!
* \brief Helper macro to register the object type to runtime.
* Makes sure that the runtime type table is correctly populated.
......@@ -675,7 +697,7 @@ struct ObjectEqual {
* Use this macro in the cc file for each terminal class.
*/
#define TVM_REGISTER_OBJECT_TYPE(TypeName) \
static DMLC_ATTRIBUTE_UNUSED uint32_t __make_Object_tidx ## _ ## TypeName ## __ = \
TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \
TypeName::_GetOrAllocRuntimeTypeIndex()
......@@ -691,14 +713,14 @@ struct ObjectEqual {
using ContainerType = ObjectName;
#define TVM_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \
TypeName() {} \
explicit TypeName( \
::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
: ParentType(n) {} \
ObjectName* operator->() { \
return static_cast<ObjectName*>(data_.get()); \
} \
operator bool() const { return data_ != nullptr; } \
TypeName() {} \
explicit TypeName( \
::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
: ParentType(n) {} \
ObjectName* operator->() { \
return static_cast<ObjectName*>(data_.get()); \
} \
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName;
// Implementations details below
......
......@@ -43,9 +43,9 @@
#ifndef TVM_RUNTIME_REGISTRY_H_
#define TVM_RUNTIME_REGISTRY_H_
#include <tvm/runtime/packed_func.h>
#include <string>
#include <vector>
#include "packed_func.h"
namespace tvm {
namespace runtime {
......@@ -283,22 +283,9 @@ class Registry {
friend struct Manager;
};
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define TVM_ATTRIBUTE_UNUSED
#endif
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM
#define TVM_TYPE_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::ExtTypeVTable* __mk_ ## TVMT
/*!
* \brief Register a function globally.
* \code
......
......@@ -96,6 +96,7 @@ def config_cython():
"../3rdparty/dmlc-core/include",
"../3rdparty/dlpack/include",
],
extra_compile_args=["-std=c++11"],
library_dirs=library_dirs,
libraries=libraries,
language="c++"))
......
......@@ -20,7 +20,7 @@ from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import TVMArrayHandle, TVMNDArrayContainerHandle
from ..runtime_ctypes import TVMArrayHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
......@@ -110,12 +110,17 @@ class NDArrayBase(object):
def _make_array(handle, is_view, is_container):
global _TVM_ND_CLS
handle = ctypes.cast(handle, TVMArrayHandle)
fcreate = _CLASS_NDARRAY
if is_container and _TVM_ND_CLS:
array_type_info = ctypes.cast(handle, TVMNDArrayContainerHandle).array_type_info.value
if array_type_info > 0:
fcreate = _TVM_ND_CLS[array_type_info]
return fcreate(handle, is_view)
if is_container:
tindex = ctypes.c_uint()
check_call(_LIB.TVMArrayGetTypeIndex(handle, ctypes.byref(tindex)))
cls = _TVM_ND_CLS.get(tindex.value, _CLASS_NDARRAY)
else:
cls = _CLASS_NDARRAY
ret = cls.__new__(cls)
ret.handle = handle
ret.is_view = is_view
return ret
_TVM_COMPATS = ()
......@@ -129,9 +134,9 @@ def _reg_extension(cls, fcreate):
_TVM_ND_CLS = {}
def _reg_ndarray(cls, fcreate):
def _register_ndarray(index, cls):
global _TVM_ND_CLS
_TVM_ND_CLS[cls._array_type_code] = fcreate
_TVM_ND_CLS[index] = cls
_CLASS_NDARRAY = None
......
......@@ -21,7 +21,7 @@ from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from ..node_generic import _set_class_node_base
from .ndarray import _register_ndarray, NDArrayBase
ObjectHandle = ctypes.c_void_p
......@@ -39,6 +39,9 @@ def _set_class_node(node_class):
def _register_object(index, cls):
"""register object class"""
if issubclass(cls, NDArrayBase):
_register_ndarray(index, cls)
return
OBJECT_TYPE[index] = cls
......@@ -91,6 +94,3 @@ class ObjectBase(object):
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
self.handle = handle
_set_class_node_base(ObjectBase)
......@@ -19,7 +19,7 @@ from ..base import get_last_ffi_error
from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule
from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t
from libc.stdint cimport int32_t, int64_t, uint64_t, uint32_t, uint8_t, uint16_t
import ctypes
cdef enum TVMTypeCode:
......@@ -78,14 +78,11 @@ ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
ctypedef void* ObjectHandle
ctypedef struct TVMObject:
uint32_t type_index_
int32_t ref_counter_
void (*deleter_)(TVMObject* self)
ctypedef struct TVMNDArrayContainer:
DLTensor dl_tensor
void* manager_ctx
void (*deleter)(DLManagedTensor* self)
int32_t array_type_info
ctypedef TVMNDArrayContainer* TVMNDArrayContainerHandle
ctypedef int (*TVMPackedCFunc)(
TVMValue* args,
......
......@@ -100,17 +100,34 @@ cdef class NDArrayBase:
return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter)
# Import limited object-related function from C++ side to improve the speed
# NOTE: can only use POD-C compatible object in FFI.
cdef extern from "tvm/runtime/ndarray.h" namespace "tvm::runtime":
cdef void* TVMArrayHandleToObjectHandle(DLTensorHandle handle)
cdef c_make_array(void* chandle, is_view, is_container):
global _TVM_ND_CLS
cdef int32_t array_type_info
fcreate = _CLASS_NDARRAY
if is_container and len(_TVM_ND_CLS) > 0:
array_type_info = (<TVMNDArrayContainerHandle>chandle).array_type_info
if array_type_info > 0:
fcreate = _TVM_ND_CLS[array_type_info]
ret = fcreate(None, is_view)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
return ret
if is_container:
tindex = (
<TVMObject*>TVMArrayHandleToObjectHandle(<DLTensorHandle>chandle)).type_index_
if tindex < len(_TVM_ND_CLS):
cls = _TVM_ND_CLS[tindex]
if cls is not None:
ret = cls.__new__(cls)
else:
ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY)
else:
ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
(<NDArrayBase>ret).c_is_view = <int>is_view
return ret
else:
ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
(<NDArrayBase>ret).c_is_view = <int>is_view
return ret
cdef _TVM_COMPATS = ()
......@@ -123,11 +140,16 @@ def _reg_extension(cls, fcreate):
if fcreate:
_TVM_EXT_RET[cls._tvm_tcode] = fcreate
cdef _TVM_ND_CLS = {}
cdef list _TVM_ND_CLS = []
def _reg_ndarray(cls, fcreate):
cdef _register_ndarray(int index, object cls):
"""register object class"""
global _TVM_ND_CLS
_TVM_ND_CLS[cls._array_type_code] = fcreate
while len(_TVM_ND_CLS) <= index:
_TVM_ND_CLS.append(None)
_TVM_ND_CLS[index] = cls
def _make_array(handle, is_view, is_container):
cdef unsigned long long ptr
......
......@@ -16,12 +16,15 @@
# under the License.
"""Maps object type to its constructor"""
from ..node_generic import _set_class_node_base
OBJECT_TYPE = []
cdef list OBJECT_TYPE = []
def _register_object(int index, object cls):
"""register object class"""
if issubclass(cls, NDArrayBase):
_register_ndarray(index, cls)
return
global OBJECT_TYPE
while len(OBJECT_TYPE) <= index:
OBJECT_TYPE.append(None)
OBJECT_TYPE[index] = cls
......@@ -31,14 +34,13 @@ cdef inline object make_ret_object(void* chandle):
global OBJECT_TYPE
global _CLASS_NODE
cdef unsigned tindex
cdef list object_type
cdef object cls
cdef object handle
object_type = OBJECT_TYPE
handle = ctypes_handle(chandle)
CALL(TVMObjectGetTypeIndex(chandle, &tindex))
if tindex < len(object_type):
cls = object_type[tindex]
if tindex < len(OBJECT_TYPE):
cls = OBJECT_TYPE[tindex]
if cls is not None:
obj = cls.__new__(cls)
else:
......@@ -99,6 +101,3 @@ cdef class ObjectBase:
(<FunctionBase>fconstructor).chandle,
kObjectHandle, args, &chandle)
self.chandle = chandle
_set_class_node_base(ObjectBase)
......@@ -22,6 +22,7 @@ from __future__ import absolute_import
import sys
import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
from .node_generic import _set_class_objects
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......@@ -32,15 +33,21 @@ try:
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import convert_to_tvm_func
else:
from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import convert_to_tvm_func
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.object import ObjectBase as _ObjectBase
from ._ctypes.function import convert_to_tvm_func
FunctionHandle = ctypes.c_void_p
......@@ -325,3 +332,4 @@ def _init_api_prefix(module_name, prefix):
setattr(target_module, ff.__name__, ff)
_set_class_function(Function)
_set_class_objects((_ObjectBase, _NDArrayBase, ModuleBase))
......@@ -35,16 +35,16 @@ try:
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import _reg_extension, _reg_ndarray
from ._cy3.core import _reg_extension
else:
from ._cy2.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import _reg_extension, _reg_ndarray
from ._cy2.core import _reg_extension
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.ndarray import _reg_extension, _reg_ndarray
from ._ctypes.ndarray import _reg_extension
def context(dev_type, dev_id=0):
......@@ -348,13 +348,8 @@ def register_extension(cls, fcreate=None):
def _tvm_handle(self):
return self.handle.value
"""
if issubclass(cls, _NDArrayBase):
assert fcreate is not None
assert hasattr(cls, "_array_type_code")
_reg_ndarray(cls, fcreate)
else:
assert hasattr(cls, "_tvm_tcode")
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
assert hasattr(cls, "_tvm_tcode")
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
return cls
......@@ -23,11 +23,11 @@ from .. import _api_internal
from .base import string_types
# Node base class
_CLASS_NODE_BASE = None
_CLASS_OBJECTS = None
def _set_class_node_base(cls):
global _CLASS_NODE_BASE
_CLASS_NODE_BASE = cls
def _set_class_objects(cls):
global _CLASS_OBJECTS
_CLASS_OBJECTS = cls
def _scalar_type_inference(value):
......@@ -67,7 +67,7 @@ def convert_to_node(value):
node : Node
The corresponding node value.
"""
if isinstance(value, _CLASS_NODE_BASE):
if isinstance(value, _CLASS_OBJECTS):
return value
if isinstance(value, bool):
return const(value, 'uint1x1')
......@@ -81,7 +81,7 @@ def convert_to_node(value):
if isinstance(value, dict):
vlist = []
for item in value.items():
if (not isinstance(item[0], _CLASS_NODE_BASE) and
if (not isinstance(item[0], _CLASS_OBJECTS) and
not isinstance(item[0], string_types)):
raise ValueError("key of map must already been a container type")
vlist.append(item[0])
......
......@@ -271,12 +271,3 @@ class TVMArray(ctypes.Structure):
("byte_offset", ctypes.c_uint64)]
TVMArrayHandle = ctypes.POINTER(TVMArray)
class TVMNDArrayContainer(ctypes.Structure):
"""TVM NDArray::Container"""
_fields_ = [("dl_tensor", TVMArray),
("manager_ctx", ctypes.c_void_p),
("deleter", ctypes.c_void_p),
("array_type_info", ctypes.c_int32)]
TVMNDArrayContainerHandle = ctypes.POINTER(TVMNDArrayContainer)
......@@ -27,7 +27,10 @@ from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack
from ._ffi.ndarray import _set_class_ndarray
from ._ffi.ndarray import register_extension
from ._ffi.object import register_object
@register_object
class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
......
......@@ -67,7 +67,7 @@ TVM_REGISTER_API("_Array")
}
auto node = make_node<ArrayNode>();
node->data = std::move(data);
*ret = runtime::ObjectRef(node);
*ret = Array<ObjectRef>(node);
});
TVM_REGISTER_API("_ArrayGetItem")
......@@ -100,28 +100,28 @@ TVM_REGISTER_API("_Map")
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kStr)
<< "key of str map need to be str";
CHECK(args[i + 1].type_code() == kObjectHandle)
CHECK(args[i + 1].IsObjectRef<ObjectRef>())
<< "value of the map to be NodeRef";
data.emplace(std::make_pair(args[i].operator std::string(),
args[i + 1].operator ObjectRef()));
}
auto node = make_node<StrMapNode>();
node->data = std::move(data);
*ret = node;
*ret = Map<ObjectRef, ObjectRef>(node);
} else {
// Container node.
MapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kObjectHandle)
<< "key of str map need to be str";
CHECK(args[i + 1].type_code() == kObjectHandle)
CHECK(args[i].IsObjectRef<ObjectRef>())
<< "key of str map need to be object";
CHECK(args[i + 1].IsObjectRef<ObjectRef>())
<< "value of map to be NodeRef";
data.emplace(std::make_pair(args[i].operator ObjectRef(),
args[i + 1].operator ObjectRef()));
}
auto node = make_node<MapNode>();
node->data = std::move(data);
*ret = node;
*ret = Map<ObjectRef, ObjectRef>(node);
}
});
......@@ -191,7 +191,7 @@ TVM_REGISTER_API("_MapItems")
rkvs->data.push_back(kv.first);
rkvs->data.push_back(kv.second);
}
*ret = rkvs;
*ret = Array<ObjectRef>(rkvs);
} else {
auto* n = static_cast<const StrMapNode*>(ptr);
auto rkvs = make_node<ArrayNode>();
......@@ -199,7 +199,7 @@ TVM_REGISTER_API("_MapItems")
rkvs->data.push_back(ir::StringImm::make(kv.first));
rkvs->data.push_back(kv.second);
}
*ret = rkvs;
*ret = Array<ObjectRef>(rkvs);
}
});
......
......@@ -27,8 +27,12 @@
#include <tvm/runtime/device_api.h>
#include "runtime_base.h"
// deleter for arrays used by DLPack exporter
extern "C" void NDArrayDLPackDeleter(DLManagedTensor* tensor);
extern "C" {
// C-mangled dlpack deleter.
static void TVMNDArrayDLPackDeleter(DLManagedTensor* tensor);
// helper function to get NDArray's type index, only used by ctypes.
TVM_DLL int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex);
}
namespace tvm {
namespace runtime {
......@@ -53,8 +57,8 @@ inline size_t GetDataAlignment(const DLTensor& arr) {
struct NDArray::Internal {
// Default deleter for the container
static void DefaultDeleter(NDArray::Container* ptr) {
using tvm::runtime::NDArray;
static void DefaultDeleter(Object* ptr_obj) {
auto* ptr = static_cast<NDArray::Container*>(ptr_obj);
if (ptr->manager_ctx != nullptr) {
static_cast<NDArray::Container*>(ptr->manager_ctx)->DecRef();
} else if (ptr->dl_tensor.data != nullptr) {
......@@ -68,7 +72,8 @@ struct NDArray::Internal {
// that are not allocated inside of TVM.
// This enables us to create NDArray from memory allocated by other
// frameworks that are DLPack compatible
static void DLPackDeleter(NDArray::Container* ptr) {
static void DLPackDeleter(Object* ptr_obj) {
auto* ptr = static_cast<NDArray::Container*>(ptr_obj);
DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx);
if (tensor->deleter != nullptr) {
(*tensor->deleter)(tensor);
......@@ -81,12 +86,13 @@ struct NDArray::Internal {
DLDataType dtype,
DLContext ctx) {
VerifyDataType(dtype);
// critical zone
// critical zone: construct header
NDArray::Container* data = new NDArray::Container();
data->deleter = DefaultDeleter;
NDArray ret(data);
ret.data_ = data;
data->SetDeleter(DefaultDeleter);
// RAII now in effect
NDArray ret(GetObjectPtr<Object>(data));
// setup shape
data->shape_ = std::move(shape);
data->dl_tensor.shape = dmlc::BeginPtr(data->shape_);
......@@ -98,45 +104,57 @@ struct NDArray::Internal {
return ret;
}
// Implementation of API function
static DLTensor* MoveAsDLTensor(NDArray arr) {
DLTensor* tensor = const_cast<DLTensor*>(arr.operator->());
CHECK(reinterpret_cast<DLTensor*>(arr.data_) == tensor);
arr.data_ = nullptr;
return tensor;
static DLTensor* MoveToFFIHandle(NDArray arr) {
DLTensor* handle = NDArray::FFIGetHandle(arr);
ObjectRef::FFIClearAfterMove(&arr);
return handle;
}
static void FFIDecRef(TVMArrayHandle tensor) {
NDArray::FFIDecRef(tensor);
}
// Container to DLManagedTensor
static DLManagedTensor* ToDLPack(TVMArrayHandle handle) {
auto* from = static_cast<NDArray::Container*>(
reinterpret_cast<NDArray::ContainerBase*>(handle));
return ToDLPack(from);
}
static DLManagedTensor* ToDLPack(NDArray::Container* from) {
CHECK(from != nullptr);
DLManagedTensor* ret = new DLManagedTensor();
ret->dl_tensor = from->dl_tensor;
ret->manager_ctx = from;
from->IncRef();
ret->deleter = NDArrayDLPackDeleter;
ret->deleter = TVMNDArrayDLPackDeleter;
return ret;
}
// Delete dlpack object.
static void NDArrayDLPackDeleter(DLManagedTensor* tensor) {
static_cast<NDArray::Container*>(tensor->manager_ctx)->DecRef();
delete tensor;
}
};
NDArray NDArray::CreateView(std::vector<int64_t> shape,
DLDataType dtype) {
NDArray NDArray::CreateView(std::vector<int64_t> shape, DLDataType dtype) {
CHECK(data_ != nullptr);
CHECK(data_->dl_tensor.strides == nullptr)
CHECK(get_mutable()->dl_tensor.strides == nullptr)
<< "Can only create view for compact tensor";
NDArray ret = Internal::Create(shape, dtype, data_->dl_tensor.ctx);
ret.data_->dl_tensor.byte_offset =
this->data_->dl_tensor.byte_offset;
size_t curr_size = GetDataSize(this->data_->dl_tensor);
size_t view_size = GetDataSize(ret.data_->dl_tensor);
NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.ctx);
ret.get_mutable()->dl_tensor.byte_offset =
this->get_mutable()->dl_tensor.byte_offset;
size_t curr_size = GetDataSize(this->get_mutable()->dl_tensor);
size_t view_size = GetDataSize(ret.get_mutable()->dl_tensor);
CHECK_LE(view_size, curr_size)
<< "Tries to create a view that has bigger memory than current one";
// increase ref count
this->data_->IncRef();
ret.data_->manager_ctx = this->data_;
ret.data_->dl_tensor.data = this->data_->dl_tensor.data;
get_mutable()->IncRef();
ret.get_mutable()->manager_ctx = get_mutable();
ret.get_mutable()->dl_tensor.data = get_mutable()->dl_tensor.data;
return ret;
}
DLManagedTensor* NDArray::ToDLPack() const {
return Internal::ToDLPack(data_);
return Internal::ToDLPack(get_mutable());
}
NDArray NDArray::Empty(std::vector<int64_t> shape,
......@@ -144,9 +162,9 @@ NDArray NDArray::Empty(std::vector<int64_t> shape,
DLContext ctx) {
NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor);
size_t alignment = GetDataAlignment(ret.data_->dl_tensor);
ret.data_->dl_tensor.data =
size_t size = GetDataSize(ret.get_mutable()->dl_tensor);
size_t alignment = GetDataAlignment(ret.get_mutable()->dl_tensor);
ret.get_mutable()->dl_tensor.data =
DeviceAPI::Get(ret->ctx)->AllocDataSpace(
ret->ctx, size, alignment, ret->dtype);
return ret;
......@@ -154,10 +172,12 @@ NDArray NDArray::Empty(std::vector<int64_t> shape,
NDArray NDArray::FromDLPack(DLManagedTensor* tensor) {
NDArray::Container* data = new NDArray::Container();
data->deleter = Internal::DLPackDeleter;
// construct header
data->SetDeleter(Internal::DLPackDeleter);
// fill up content.
data->manager_ctx = tensor;
data->dl_tensor = tensor->dl_tensor;
return NDArray(data);
return NDArray(GetObjectPtr<Object>(data));
}
void NDArray::CopyFromTo(const DLTensor* from,
......@@ -184,17 +204,24 @@ void NDArray::CopyFromTo(const DLTensor* from,
}
std::vector<int64_t> NDArray::Shape() const {
return data_->shape_;
return get_mutable()->shape_;
}
TVM_REGISTER_OBJECT_TYPE(NDArray::Container);
} // namespace runtime
} // namespace tvm
using namespace tvm::runtime;
void NDArrayDLPackDeleter(DLManagedTensor* tensor) {
static_cast<NDArray::Container*>(tensor->manager_ctx)->DecRef();
delete tensor;
void TVMNDArrayDLPackDeleter(DLManagedTensor* tensor) {
NDArray::Internal::NDArrayDLPackDeleter(tensor);
}
int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex) {
API_BEGIN();
*out_tindex = TVMArrayHandleToObjectHandle(handle)->type_index();
API_END();
}
int TVMArrayAlloc(const tvm_index_t* shape,
......@@ -213,14 +240,14 @@ int TVMArrayAlloc(const tvm_index_t* shape,
DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
*out = NDArray::Internal::MoveAsDLTensor(
*out = NDArray::Internal::MoveToFFIHandle(
NDArray::Empty(std::vector<int64_t>(shape, shape + ndim), dtype, ctx));
API_END();
}
int TVMArrayFree(TVMArrayHandle handle) {
API_BEGIN();
reinterpret_cast<NDArray::Container*>(handle)->DecRef();
NDArray::Internal::FFIDecRef(handle);
API_END();
}
......@@ -235,14 +262,14 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
int TVMArrayFromDLPack(DLManagedTensor* from,
TVMArrayHandle* out) {
API_BEGIN();
*out = NDArray::Internal::MoveAsDLTensor(NDArray::FromDLPack(from));
*out = NDArray::Internal::MoveToFFIHandle(NDArray::FromDLPack(from));
API_END();
}
int TVMArrayToDLPack(TVMArrayHandle from,
DLManagedTensor** out) {
API_BEGIN();
*out = NDArray::Internal::ToDLPack(reinterpret_cast<NDArray::Container*>(from));
*out = NDArray::Internal::ToDLPack(from);
API_END();
}
......
......@@ -59,7 +59,8 @@ class RPCWrappedFunc {
const TVMArgValue& arg);
// deleter of RPC remote array
static void RemoteNDArrayDeleter(NDArray::Container* ptr) {
static void RemoteNDArrayDeleter(Object* obj) {
auto* ptr = static_cast<NDArray::Container*>(obj);
RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data);
space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx);
delete space;
......@@ -71,12 +72,12 @@ class RPCWrappedFunc {
void* nd_handle) {
NDArray::Container* data = new NDArray::Container();
data->manager_ctx = nd_handle;
data->deleter = RemoteNDArrayDeleter;
data->SetDeleter(RemoteNDArrayDeleter);
RemoteSpace* space = new RemoteSpace();
space->sess = sess;
space->data = tensor->data;
data->dl_tensor.data = space;
NDArray ret(data);
NDArray ret(GetObjectPtr<Object>(data));
// RAII now in effect
data->shape_ = std::vector<int64_t>(
tensor->shape, tensor->shape + tensor->ndim);
......
......@@ -787,9 +787,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
TVMValue ret_value_pack[2];
int ret_tcode_pack[2];
rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]);
NDArray::Container* nd = static_cast<NDArray::Container*>(ret_value_pack[0].v_handle);
ret_value_pack[1].v_handle = nd;
ret_value_pack[1].v_handle = ret_value_pack[0].v_handle;
ret_tcode_pack[1] = kHandle;
SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true);
} else {
......@@ -1190,7 +1188,8 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) {
void* handle = args[0];
static_cast<NDArray::Container*>(handle)->DecRef();
static_cast<NDArray::Container*>(
reinterpret_cast<NDArray::ContainerBase*>(handle))->DecRef();
}
void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) {
......
......@@ -31,7 +31,8 @@ namespace tvm {
namespace runtime {
namespace vm {
static void BufferDeleter(NDArray::Container* ptr) {
static void BufferDeleter(Object* obj) {
auto* ptr = static_cast<NDArray::Container*>(obj);
CHECK(ptr->manager_ctx != nullptr);
Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
MemoryManager::Global()->GetAllocator(buffer->ctx)->
......@@ -40,7 +41,8 @@ static void BufferDeleter(NDArray::Container* ptr) {
delete ptr;
}
void StorageObj::Deleter(NDArray::Container* ptr) {
void StorageObj::Deleter(Object* obj) {
auto* ptr = static_cast<NDArray::Container*>(obj);
// When invoking AllocNDArray we don't own the underlying allocation
// and should not delete the buffer, but instead let it be reclaimed
// by the storage object's destructor.
......@@ -77,16 +79,23 @@ NDArray StorageObj::AllocNDArray(size_t offset, std::vector<int64_t> shape, DLDa
// TODO(@jroesch): generalize later to non-overlapping allocations.
CHECK_EQ(offset, 0u);
VerifyDataType(dtype);
// crtical zone: allocate header, cannot throw
NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, this->buffer.ctx);
container->deleter = StorageObj::Deleter;
container->SetDeleter(StorageObj::Deleter);
size_t needed_size = GetDataSize(container->dl_tensor);
// TODO(@jroesch): generalize later to non-overlapping allocations.
CHECK(needed_size == this->buffer.size)
<< "size mistmatch required " << needed_size << " found " << this->buffer.size;
this->IncRef();
container->manager_ctx = reinterpret_cast<void*>(this);
container->dl_tensor.data = this->buffer.data;
return NDArray(container);
NDArray ret(GetObjectPtr<Object>(container));
// RAII in effect, now run the check.
// TODO(@jroesch): generalize later to non-overlapping allocations.
CHECK(needed_size == this->buffer.size)
<< "size mistmatch required " << needed_size << " found " << this->buffer.size;
return ret;
}
MemoryManager* MemoryManager::Global() {
......@@ -108,14 +117,14 @@ Allocator* MemoryManager::GetAllocator(TVMContext ctx) {
NDArray Allocator::Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) {
VerifyDataType(dtype);
NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, ctx);
container->deleter = BufferDeleter;
container->SetDeleter(BufferDeleter);
size_t size = GetDataSize(container->dl_tensor);
size_t alignment = GetDataAlignment(container->dl_tensor);
Buffer *buffer = new Buffer;
*buffer = this->Alloc(size, alignment, dtype);
container->manager_ctx = reinterpret_cast<void*>(buffer);
container->dl_tensor.data = buffer->data;
return NDArray(container);
return NDArray(GetObjectPtr<Object>(container));
}
} // namespace vm
......
......@@ -120,7 +120,7 @@ class StorageObj : public Object {
DLDataType dtype);
/*! \brief The deleter for an NDArray when allocated from underlying storage. */
static void Deleter(NDArray::Container* ptr);
static void Deleter(Object* ptr);
~StorageObj() {
auto alloc = MemoryManager::Global()->GetAllocator(buffer.ctx);
......
......@@ -22,6 +22,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir.h>
TEST(PackedFunc, Basic) {
......@@ -178,6 +179,69 @@ TEST(TypedPackedFunc, HighOrder) {
CHECK_EQ(f1(3), 4);
}
TEST(PackedFunc, ObjectConversion) {
using namespace tvm;
using namespace tvm::runtime;
TVMRetValue rv;
auto x = NDArray::Empty(
{}, String2TVMType("float32"),
TVMContext{kDLCPU, 0});
// assign null
rv = ObjectRef();
CHECK_EQ(rv.type_code(), kNull);
// Can assign NDArray to ret type
rv = x;
CHECK_EQ(rv.type_code(), kNDArrayContainer);
// Even if we assign base type it still shows as NDArray
rv = ObjectRef(x);
CHECK_EQ(rv.type_code(), kNDArrayContainer);
// Check convert back
CHECK(rv.operator NDArray().same_as(x));
CHECK(rv.operator ObjectRef().same_as(x));
CHECK(!rv.IsObjectRef<Expr>());
auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args[0].type_code(), kNDArrayContainer);
CHECK(args[0].operator NDArray().same_as(x));
CHECK(args[0].operator ObjectRef().same_as(x));
CHECK(args[1].operator ObjectRef().get() == nullptr);
CHECK(args[1].operator NDArray().get() == nullptr);
CHECK(args[1].operator Module().get() == nullptr);
CHECK(args[1].operator Array<NDArray>().get() == nullptr);
CHECK(!args[0].IsObjectRef<Expr>());
});
pf1(x, ObjectRef());
pf1(ObjectRef(x), NDArray());
// testcases for modules
auto* pf = tvm::runtime::Registry::Get("module.source_module_create");
CHECK(pf != nullptr);
Module m = (*pf)("", "xyz");
rv = m;
CHECK_EQ(rv.type_code(), kModuleHandle);
// Even if we assign base type it still shows as NDArray
rv = ObjectRef(m);
CHECK_EQ(rv.type_code(), kModuleHandle);
// Check convert back
CHECK(rv.operator Module().same_as(m));
CHECK(rv.operator ObjectRef().same_as(m));
CHECK(!rv.IsObjectRef<NDArray>());
auto pf2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args[0].type_code(), kModuleHandle);
CHECK(args[0].operator Module().same_as(m));
CHECK(args[0].operator ObjectRef().same_as(m));
CHECK(args[1].operator ObjectRef().get() == nullptr);
CHECK(args[1].operator NDArray().get() == nullptr);
CHECK(args[1].operator Module().get() == nullptr);
CHECK(!args[0].IsObjectRef<Expr>());
});
pf2(m, ObjectRef());
pf2(ObjectRef(m), Module());
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np
def test_array():
a = tvm.convert([1,2,3])
......@@ -71,6 +72,14 @@ def test_in_container():
assert tvm.make.StringImm('a') in arr
assert 'd' not in arr
def test_ndarray_container():
x = tvm.nd.array([1,2,3])
arr = tvm.convert([x, x])
assert arr[0].same_as(x)
assert arr[1].same_as(x)
assert isinstance(arr[0], tvm.nd.NDArray)
if __name__ == "__main__":
test_str_map()
test_array()
......@@ -78,3 +87,4 @@ if __name__ == "__main__":
test_array_save_load_json()
test_map_save_load_json()
test_in_container()
test_ndarray_container()
......@@ -32,7 +32,7 @@ def gen_engine_header():
#include <vector>
class Engine {
};
#endif
'''
header_file = header_file_dir_path.relpath("gcc_engine.h")
......@@ -45,7 +45,7 @@ def generate_engine_module():
#include <tvm/runtime/c_runtime_api.h>
#include <dlpack/dlpack.h>
#include "gcc_engine.h"
extern "C" void gcc_1_(float* gcc_input4, float* gcc_input5,
float* gcc_input6, float* gcc_input7, float* out) {
Engine engine;
......
......@@ -35,7 +35,8 @@ rm -rf lib
make
cd ../..
python3 -m pytest -v apps/extension/tests
TVM_FFI=cython python3 -m pytest -v apps/extension/tests
TVM_FFI=ctypes python3 -m pytest -v apps/extension/tests
TVM_FFI=ctypes python3 -m pytest -v tests/python/integration
TVM_FFI=ctypes python3 -m pytest -v tests/python/contrib
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment