Unverified Commit 55bd786f by Tianqi Chen Committed by GitHub

[REFACTOR][RUNTIME] Update NDArray use the Unified Object System (#4581)

* [REFACTOR][RUNTIME] Move NDArray to Object System.

Previously NDArray has its own object reference counting mechanism.
This PR migrates NDArray to the unified object protocol.

The calling convention of NDArray remained intact.
That means NDArray still has its own type_code and
its handle is still DLTensor compatible.

In order to do so, this PR added a few minimum runtime type
detection in TVMArgValue and RetValue only when the corresponding
type is a base type(ObjectRef) that could also refer to NDArray.

This means that even if we return a base reference object ObjectRef
which refers to the NDArray. The type_code will still be translated
correctly as kNDArrayContainer.
If we assign a non-base type(say Expr) that we know is not compatible
with NDArray during compile time, no runtime type detection will be performed.

This PR also adopts the object protocol for NDArray sub-classing and
removed the legacy NDArray subclass protocol.
Examples in apps/extension are now updated to reflect that.

Making NDArray as an Object brings all the benefits of the object system.
For example, we can now use the Array container to store NDArrays.

* Address review comments
parent 4072396e
...@@ -51,26 +51,23 @@ class IntVec(tvm.Object): ...@@ -51,26 +51,23 @@ class IntVec(tvm.Object):
nd_create = tvm.get_global_func("tvm_ext.nd_create") nd_create = tvm.get_global_func("tvm_ext.nd_create")
nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two") nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
nd_get_addtional_info = tvm.get_global_func("tvm_ext.nd_get_addtional_info") nd_get_additional_info = tvm.get_global_func("tvm_ext.nd_get_additional_info")
@tvm.register_object("tvm_ext.NDSubClass")
class NDSubClass(tvm.nd.NDArrayBase): class NDSubClass(tvm.nd.NDArrayBase):
"""Example for subclassing TVM's NDArray infrastructure. """Example for subclassing TVM's NDArray infrastructure.
By inheriting TMV's NDArray, external libraries could By inheriting TMV's NDArray, external libraries could
leverage TVM's FFI without any modification. leverage TVM's FFI without any modification.
""" """
# Should be consistent with the type-trait set in the backend
_array_type_code = 1
@staticmethod @staticmethod
def create(addtional_info): def create(additional_info):
return nd_create(addtional_info) return nd_create(additional_info)
@property @property
def addtional_info(self): def additional_info(self):
return nd_get_addtional_info(self) return nd_get_additional_info(self)
def __add__(self, other): def __add__(self, other):
return nd_add_two(self, other) return nd_add_two(self, other)
tvm.register_extension(NDSubClass, NDSubClass)
...@@ -29,19 +29,6 @@ ...@@ -29,19 +29,6 @@
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
namespace tvm_ext {
class NDSubClass;
} // namespace tvm_ext
namespace tvm {
namespace runtime {
template<>
struct array_type_info<tvm_ext::NDSubClass> {
static const int code = 1;
};
} // namespace tvm
} // namespace runtime
using namespace tvm; using namespace tvm;
using namespace tvm::runtime; using namespace tvm::runtime;
...@@ -52,54 +39,55 @@ namespace tvm_ext { ...@@ -52,54 +39,55 @@ namespace tvm_ext {
* To use this extension, an external library should * To use this extension, an external library should
* *
* 1) Inherit TVM's NDArray and NDArray container, * 1) Inherit TVM's NDArray and NDArray container,
* and define the trait `array_type_info` for this class.
* *
* 2) Define a constructor in the inherited class that accepts * 2) Follow the new object protocol to define new NDArray as a reference class.
* a pointer to TVM's Container, which is nullable.
* *
* 3) On Python frontend, inherit `tvm.nd.NDArrayBase`, * 3) On Python frontend, inherit `tvm.nd.NDArray`,
* define the class attribute `_array_type_code` consistent to * register the type using tvm.register_object
* the C++ type trait, and register the subclass using `tvm.register_extension`.
*/ */
class NDSubClass : public tvm::runtime::NDArray { class NDSubClass : public tvm::runtime::NDArray {
public: public:
class SubContainer : public NDArray::Container { class SubContainer : public NDArray::Container {
public: public:
SubContainer(int addtional_info) : SubContainer(int additional_info) :
addtional_info_(addtional_info) { additional_info_(additional_info) {
array_type_code_ = array_type_info<NDSubClass>::code; type_index_ = SubContainer::RuntimeTypeIndex();
}
static bool Is(NDArray::Container *container) {
SubContainer *c = static_cast<SubContainer*>(container);
return c->array_type_code_ == array_type_info<NDSubClass>::code;
} }
int addtional_info_{0}; int additional_info_{0};
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "tvm_ext.NDSubClass";
TVM_DECLARE_FINAL_OBJECT_INFO(SubContainer, NDArray::Container);
}; };
NDSubClass(NDArray::Container *container) {
if (container == nullptr) { static void SubContainerDeleter(Object* obj) {
data_ = nullptr; auto* ptr = static_cast<SubContainer*>(obj);
return; delete ptr;
}
CHECK(SubContainer::Is(container));
container->IncRef();
data_ = container;
} }
~NDSubClass() {
this->reset(); NDSubClass() {}
explicit NDSubClass(ObjectPtr<Object> n) : NDArray(n) {}
explicit NDSubClass(int additional_info) {
SubContainer* ptr = new SubContainer(additional_info);
ptr->SetDeleter(SubContainerDeleter);
data_ = GetObjectPtr<Object>(ptr);
} }
NDSubClass AddWith(const NDSubClass &other) const { NDSubClass AddWith(const NDSubClass &other) const {
SubContainer *a = static_cast<SubContainer*>(data_); SubContainer *a = static_cast<SubContainer*>(get_mutable());
SubContainer *b = static_cast<SubContainer*>(other.data_); SubContainer *b = static_cast<SubContainer*>(other.get_mutable());
CHECK(a != nullptr && b != nullptr); CHECK(a != nullptr && b != nullptr);
return NDSubClass(new SubContainer(a->addtional_info_ + b->addtional_info_)); return NDSubClass(a->additional_info_ + b->additional_info_);
} }
int get_additional_info() const { int get_additional_info() const {
SubContainer *self = static_cast<SubContainer*>(data_); SubContainer *self = static_cast<SubContainer*>(get_mutable());
CHECK(self != nullptr); CHECK(self != nullptr);
return self->addtional_info_; return self->additional_info_;
} }
using ContainerType = SubContainer;
}; };
TVM_REGISTER_OBJECT_TYPE(NDSubClass::SubContainer);
/*! /*!
* \brief Introduce additional extension data structures * \brief Introduce additional extension data structures
...@@ -166,8 +154,10 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev") ...@@ -166,8 +154,10 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev")
TVM_REGISTER_GLOBAL("tvm_ext.nd_create") TVM_REGISTER_GLOBAL("tvm_ext.nd_create")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
int addtional_info = args[0]; int additional_info = args[0];
*rv = NDSubClass(new NDSubClass::SubContainer(addtional_info)); *rv = NDSubClass(additional_info);
CHECK_EQ(rv->type_code(), kNDArrayContainer);
}); });
TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two") TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
...@@ -177,7 +167,7 @@ TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two") ...@@ -177,7 +167,7 @@ TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
*rv = a.AddWith(b); *rv = a.AddWith(b);
}); });
TVM_REGISTER_GLOBAL("tvm_ext.nd_get_addtional_info") TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
NDSubClass a = args[0]; NDSubClass a = args[0];
*rv = a.get_additional_info(); *rv = a.get_additional_info();
......
...@@ -87,16 +87,17 @@ def test_extern_call(): ...@@ -87,16 +87,17 @@ def test_extern_call():
def test_nd_subclass(): def test_nd_subclass():
a = tvm_ext.NDSubClass.create(addtional_info=3) a = tvm_ext.NDSubClass.create(additional_info=3)
b = tvm_ext.NDSubClass.create(addtional_info=5) b = tvm_ext.NDSubClass.create(additional_info=5)
assert isinstance(a, tvm_ext.NDSubClass)
c = a + b c = a + b
d = a + a d = a + a
e = b + b e = b + b
assert(a.addtional_info == 3) assert(a.additional_info == 3)
assert(b.addtional_info == 5) assert(b.additional_info == 5)
assert(c.addtional_info == 8) assert(c.additional_info == 8)
assert(d.addtional_info == 6) assert(d.additional_info == 6)
assert(e.addtional_info == 10) assert(e.additional_info == 10)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -23,14 +23,14 @@ ...@@ -23,14 +23,14 @@
#ifndef TVM_NODE_CONTAINER_H_ #ifndef TVM_NODE_CONTAINER_H_
#define TVM_NODE_CONTAINER_H_ #define TVM_NODE_CONTAINER_H_
#include <tvm/node/node.h>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include <initializer_list> #include <initializer_list>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <string> #include <string>
#include "node.h"
#include "memory.h"
namespace tvm { namespace tvm {
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#ifndef TVM_PACKED_FUNC_EXT_H_ #ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_ #define TVM_PACKED_FUNC_EXT_H_
#include <sstream>
#include <string> #include <string>
#include <memory> #include <memory>
#include <limits> #include <limits>
...@@ -43,22 +42,7 @@ using runtime::TVMRetValue; ...@@ -43,22 +42,7 @@ using runtime::TVMRetValue;
using runtime::PackedFunc; using runtime::PackedFunc;
namespace runtime { namespace runtime {
/*!
* \brief Runtime type checker for node type.
* \tparam T the type to be checked.
*/
template<typename T>
struct ObjectTypeChecker {
static bool Check(const Object* ptr) {
using ContainerType = typename T::ContainerType;
if (ptr == nullptr) return true;
return ptr->IsInstance<ContainerType>();
}
static void PrintName(std::ostream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
}
};
template<typename T> template<typename T>
struct ObjectTypeChecker<Array<T> > { struct ObjectTypeChecker<Array<T> > {
...@@ -73,10 +57,8 @@ struct ObjectTypeChecker<Array<T> > { ...@@ -73,10 +57,8 @@ struct ObjectTypeChecker<Array<T> > {
} }
return true; return true;
} }
static void PrintName(std::ostream& os) { // NOLINT(*) static std::string TypeName() {
os << "List["; return "List[" + ObjectTypeChecker<T>::TypeName() + "]";
ObjectTypeChecker<T>::PrintName(os);
os << "]";
} }
}; };
...@@ -91,11 +73,9 @@ struct ObjectTypeChecker<Map<std::string, V> > { ...@@ -91,11 +73,9 @@ struct ObjectTypeChecker<Map<std::string, V> > {
} }
return true; return true;
} }
static void PrintName(std::ostream& os) { // NOLINT(*) static std::string TypeName() {
os << "Map[str"; return "Map[str, " +
os << ','; ObjectTypeChecker<V>::TypeName()+ ']';
ObjectTypeChecker<V>::PrintName(os);
os << ']';
} }
}; };
...@@ -111,39 +91,16 @@ struct ObjectTypeChecker<Map<K, V> > { ...@@ -111,39 +91,16 @@ struct ObjectTypeChecker<Map<K, V> > {
} }
return true; return true;
} }
static void PrintName(std::ostringstream& os) { // NOLINT(*) static std::string TypeName() {
os << "Map["; return "Map[" +
ObjectTypeChecker<K>::PrintName(os); ObjectTypeChecker<K>::TypeName() +
os << ','; ", " +
ObjectTypeChecker<V>::PrintName(os); ObjectTypeChecker<V>::TypeName()+ ']';
os << ']';
} }
}; };
template<typename T>
inline std::string ObjectTypeName() {
std::ostringstream os;
ObjectTypeChecker<T>::PrintName(os);
return os.str();
}
// extensions for tvm arg value // extensions for tvm arg value
inline TVMPODValue_::operator tvm::Expr() const {
template<typename TObjectRef>
inline TObjectRef TVMArgValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
if (type_code_ == kNull) return TObjectRef(NodePtr<Node>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TObjectRef>()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(ObjectPtr<Node>(ptr));
}
inline TVMArgValue::operator tvm::Expr() const {
if (type_code_ == kNull) return Expr(); if (type_code_ == kNull) return Expr();
if (type_code_ == kDLInt) { if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max()); CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
...@@ -164,12 +121,12 @@ inline TVMArgValue::operator tvm::Expr() const { ...@@ -164,12 +121,12 @@ inline TVMArgValue::operator tvm::Expr() const {
return Tensor(ObjectPtr<Node>(ptr))(); return Tensor(ObjectPtr<Node>(ptr))();
} }
CHECK(ObjectTypeChecker<Expr>::Check(ptr)) CHECK(ObjectTypeChecker<Expr>::Check(ptr))
<< "Expected type " << ObjectTypeName<Expr>() << "Expect type " << ObjectTypeChecker<Expr>::TypeName()
<< " but get " << ptr->GetTypeKey(); << " but get " << ptr->GetTypeKey();
return Expr(ObjectPtr<Node>(ptr)); return Expr(ObjectPtr<Node>(ptr));
} }
inline TVMArgValue::operator tvm::Integer() const { inline TVMPODValue_::operator tvm::Integer() const {
if (type_code_ == kNull) return Integer(); if (type_code_ == kNull) return Integer();
if (type_code_ == kDLInt) { if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max()); CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
...@@ -179,35 +136,10 @@ inline TVMArgValue::operator tvm::Integer() const { ...@@ -179,35 +136,10 @@ inline TVMArgValue::operator tvm::Integer() const {
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle); Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<Integer>::Check(ptr)) CHECK(ObjectTypeChecker<Integer>::Check(ptr))
<< "Expected type " << ObjectTypeName<Expr>() << "Expect type " << ObjectTypeChecker<Expr>::TypeName()
<< " but get " << ptr->GetTypeKey(); << " but get " << ptr->GetTypeKey();
return Integer(ObjectPtr<Node>(ptr)); return Integer(ObjectPtr<Node>(ptr));
} }
template<typename TObjectRef, typename>
inline bool TVMPODValue_::IsObjectRef() const {
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
return ObjectTypeChecker<TObjectRef>::Check(ptr);
}
// extensions for TVMRetValue
template<typename TObjectRef>
inline TObjectRef TVMRetValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
if (type_code_ == kNull) return TObjectRef();
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TObjectRef>()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(ObjectPtr<Object>(ptr));
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_PACKED_FUNC_EXT_H_ #endif // TVM_PACKED_FUNC_EXT_H_
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
*/ */
#ifndef TVM_RUNTIME_CONTAINER_H_ #ifndef TVM_RUNTIME_CONTAINER_H_
#define TVM_RUNTIME_CONTAINER_H_ #define TVM_RUNTIME_CONTAINER_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <tvm/runtime/memory.h> #include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
......
...@@ -24,10 +24,11 @@ ...@@ -24,10 +24,11 @@
#define TVM_RUNTIME_OBJECT_H_ #define TVM_RUNTIME_OBJECT_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
#include <type_traits> #include <type_traits>
#include <string> #include <string>
#include <utility> #include <utility>
#include "c_runtime_api.h"
/*! /*!
* \brief Whether or not use atomic reference counter. * \brief Whether or not use atomic reference counter.
...@@ -581,6 +582,14 @@ class ObjectRef { ...@@ -581,6 +582,14 @@ class ObjectRef {
return T(std::move(ref.data_)); return T(std::move(ref.data_));
} }
/*! /*!
* \brief Clear the object ref data field without DecRef
* after we successfully moved the field.
* \param ref The reference data.
*/
static void FFIClearAfterMove(ObjectRef* ref) {
ref->data_.data_ = nullptr;
}
/*!
* \brief Internal helper function get data_ as ObjectPtr of ObjectType. * \brief Internal helper function get data_ as ObjectPtr of ObjectType.
* \note only used for internal dev purpose. * \note only used for internal dev purpose.
* \tparam ObjectType The corresponding object type. * \tparam ObjectType The corresponding object type.
...@@ -648,7 +657,7 @@ struct ObjectEqual { ...@@ -648,7 +657,7 @@ struct ObjectEqual {
return _GetOrAllocRuntimeTypeIndex(); \ return _GetOrAllocRuntimeTypeIndex(); \
} \ } \
static const uint32_t _GetOrAllocRuntimeTypeIndex() { \ static const uint32_t _GetOrAllocRuntimeTypeIndex() { \
static uint32_t tidx = GetOrAllocRuntimeTypeIndex( \ static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex( \
TypeName::_type_key, \ TypeName::_type_key, \
TypeName::_type_index, \ TypeName::_type_index, \
ParentType::_GetOrAllocRuntimeTypeIndex(), \ ParentType::_GetOrAllocRuntimeTypeIndex(), \
...@@ -668,6 +677,19 @@ struct ObjectEqual { ...@@ -668,6 +677,19 @@ struct ObjectEqual {
TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define TVM_ATTRIBUTE_UNUSED
#endif
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_OBJECT_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid
/*! /*!
* \brief Helper macro to register the object type to runtime. * \brief Helper macro to register the object type to runtime.
* Makes sure that the runtime type table is correctly populated. * Makes sure that the runtime type table is correctly populated.
...@@ -675,7 +697,7 @@ struct ObjectEqual { ...@@ -675,7 +697,7 @@ struct ObjectEqual {
* Use this macro in the cc file for each terminal class. * Use this macro in the cc file for each terminal class.
*/ */
#define TVM_REGISTER_OBJECT_TYPE(TypeName) \ #define TVM_REGISTER_OBJECT_TYPE(TypeName) \
static DMLC_ATTRIBUTE_UNUSED uint32_t __make_Object_tidx ## _ ## TypeName ## __ = \ TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \
TypeName::_GetOrAllocRuntimeTypeIndex() TypeName::_GetOrAllocRuntimeTypeIndex()
...@@ -691,14 +713,14 @@ struct ObjectEqual { ...@@ -691,14 +713,14 @@ struct ObjectEqual {
using ContainerType = ObjectName; using ContainerType = ObjectName;
#define TVM_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \ #define TVM_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \
TypeName() {} \ TypeName() {} \
explicit TypeName( \ explicit TypeName( \
::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
: ParentType(n) {} \ : ParentType(n) {} \
ObjectName* operator->() { \ ObjectName* operator->() { \
return static_cast<ObjectName*>(data_.get()); \ return static_cast<ObjectName*>(data_.get()); \
} \ } \
operator bool() const { return data_ != nullptr; } \ operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName; using ContainerType = ObjectName;
// Implementations details below // Implementations details below
......
...@@ -43,9 +43,9 @@ ...@@ -43,9 +43,9 @@
#ifndef TVM_RUNTIME_REGISTRY_H_ #ifndef TVM_RUNTIME_REGISTRY_H_
#define TVM_RUNTIME_REGISTRY_H_ #define TVM_RUNTIME_REGISTRY_H_
#include <tvm/runtime/packed_func.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include "packed_func.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -283,22 +283,9 @@ class Registry { ...@@ -283,22 +283,9 @@ class Registry {
friend struct Manager; friend struct Manager;
}; };
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define TVM_ATTRIBUTE_UNUSED
#endif
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_FUNC_REG_VAR_DEF \ #define TVM_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM
#define TVM_TYPE_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::ExtTypeVTable* __mk_ ## TVMT
/*! /*!
* \brief Register a function globally. * \brief Register a function globally.
* \code * \code
......
...@@ -96,6 +96,7 @@ def config_cython(): ...@@ -96,6 +96,7 @@ def config_cython():
"../3rdparty/dmlc-core/include", "../3rdparty/dmlc-core/include",
"../3rdparty/dlpack/include", "../3rdparty/dlpack/include",
], ],
extra_compile_args=["-std=c++11"],
library_dirs=library_dirs, library_dirs=library_dirs,
libraries=libraries, libraries=libraries,
language="c++")) language="c++"))
......
...@@ -20,7 +20,7 @@ from __future__ import absolute_import ...@@ -20,7 +20,7 @@ from __future__ import absolute_import
import ctypes import ctypes
from ..base import _LIB, check_call, c_str from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import TVMArrayHandle, TVMNDArrayContainerHandle from ..runtime_ctypes import TVMArrayHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
...@@ -110,12 +110,17 @@ class NDArrayBase(object): ...@@ -110,12 +110,17 @@ class NDArrayBase(object):
def _make_array(handle, is_view, is_container): def _make_array(handle, is_view, is_container):
global _TVM_ND_CLS global _TVM_ND_CLS
handle = ctypes.cast(handle, TVMArrayHandle) handle = ctypes.cast(handle, TVMArrayHandle)
fcreate = _CLASS_NDARRAY if is_container:
if is_container and _TVM_ND_CLS: tindex = ctypes.c_uint()
array_type_info = ctypes.cast(handle, TVMNDArrayContainerHandle).array_type_info.value check_call(_LIB.TVMArrayGetTypeIndex(handle, ctypes.byref(tindex)))
if array_type_info > 0: cls = _TVM_ND_CLS.get(tindex.value, _CLASS_NDARRAY)
fcreate = _TVM_ND_CLS[array_type_info] else:
return fcreate(handle, is_view) cls = _CLASS_NDARRAY
ret = cls.__new__(cls)
ret.handle = handle
ret.is_view = is_view
return ret
_TVM_COMPATS = () _TVM_COMPATS = ()
...@@ -129,9 +134,9 @@ def _reg_extension(cls, fcreate): ...@@ -129,9 +134,9 @@ def _reg_extension(cls, fcreate):
_TVM_ND_CLS = {} _TVM_ND_CLS = {}
def _reg_ndarray(cls, fcreate): def _register_ndarray(index, cls):
global _TVM_ND_CLS global _TVM_ND_CLS
_TVM_ND_CLS[cls._array_type_code] = fcreate _TVM_ND_CLS[index] = cls
_CLASS_NDARRAY = None _CLASS_NDARRAY = None
......
...@@ -21,7 +21,7 @@ from __future__ import absolute_import ...@@ -21,7 +21,7 @@ from __future__ import absolute_import
import ctypes import ctypes
from ..base import _LIB, check_call from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from ..node_generic import _set_class_node_base from .ndarray import _register_ndarray, NDArrayBase
ObjectHandle = ctypes.c_void_p ObjectHandle = ctypes.c_void_p
...@@ -39,6 +39,9 @@ def _set_class_node(node_class): ...@@ -39,6 +39,9 @@ def _set_class_node(node_class):
def _register_object(index, cls): def _register_object(index, cls):
"""register object class""" """register object class"""
if issubclass(cls, NDArrayBase):
_register_ndarray(index, cls)
return
OBJECT_TYPE[index] = cls OBJECT_TYPE[index] = cls
...@@ -91,6 +94,3 @@ class ObjectBase(object): ...@@ -91,6 +94,3 @@ class ObjectBase(object):
if not isinstance(handle, ObjectHandle): if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle) handle = ObjectHandle(handle)
self.handle = handle self.handle = handle
_set_class_node_base(ObjectBase)
...@@ -19,7 +19,7 @@ from ..base import get_last_ffi_error ...@@ -19,7 +19,7 @@ from ..base import get_last_ffi_error
from libcpp.vector cimport vector from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule from cpython cimport pycapsule
from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t from libc.stdint cimport int32_t, int64_t, uint64_t, uint32_t, uint8_t, uint16_t
import ctypes import ctypes
cdef enum TVMTypeCode: cdef enum TVMTypeCode:
...@@ -78,14 +78,11 @@ ctypedef void* TVMRetValueHandle ...@@ -78,14 +78,11 @@ ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle ctypedef void* TVMFunctionHandle
ctypedef void* ObjectHandle ctypedef void* ObjectHandle
ctypedef struct TVMObject:
uint32_t type_index_
int32_t ref_counter_
void (*deleter_)(TVMObject* self)
ctypedef struct TVMNDArrayContainer:
DLTensor dl_tensor
void* manager_ctx
void (*deleter)(DLManagedTensor* self)
int32_t array_type_info
ctypedef TVMNDArrayContainer* TVMNDArrayContainerHandle
ctypedef int (*TVMPackedCFunc)( ctypedef int (*TVMPackedCFunc)(
TVMValue* args, TVMValue* args,
......
...@@ -100,17 +100,34 @@ cdef class NDArrayBase: ...@@ -100,17 +100,34 @@ cdef class NDArrayBase:
return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter) return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter)
# Import limited object-related function from C++ side to improve the speed
# NOTE: can only use POD-C compatible object in FFI.
cdef extern from "tvm/runtime/ndarray.h" namespace "tvm::runtime":
cdef void* TVMArrayHandleToObjectHandle(DLTensorHandle handle)
cdef c_make_array(void* chandle, is_view, is_container): cdef c_make_array(void* chandle, is_view, is_container):
global _TVM_ND_CLS global _TVM_ND_CLS
cdef int32_t array_type_info
fcreate = _CLASS_NDARRAY if is_container:
if is_container and len(_TVM_ND_CLS) > 0: tindex = (
array_type_info = (<TVMNDArrayContainerHandle>chandle).array_type_info <TVMObject*>TVMArrayHandleToObjectHandle(<DLTensorHandle>chandle)).type_index_
if array_type_info > 0: if tindex < len(_TVM_ND_CLS):
fcreate = _TVM_ND_CLS[array_type_info] cls = _TVM_ND_CLS[tindex]
ret = fcreate(None, is_view) if cls is not None:
(<NDArrayBase>ret).chandle = <DLTensor*>chandle ret = cls.__new__(cls)
return ret else:
ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY)
else:
ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
(<NDArrayBase>ret).c_is_view = <int>is_view
return ret
else:
ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
(<NDArrayBase>ret).c_is_view = <int>is_view
return ret
cdef _TVM_COMPATS = () cdef _TVM_COMPATS = ()
...@@ -123,11 +140,16 @@ def _reg_extension(cls, fcreate): ...@@ -123,11 +140,16 @@ def _reg_extension(cls, fcreate):
if fcreate: if fcreate:
_TVM_EXT_RET[cls._tvm_tcode] = fcreate _TVM_EXT_RET[cls._tvm_tcode] = fcreate
cdef _TVM_ND_CLS = {} cdef list _TVM_ND_CLS = []
def _reg_ndarray(cls, fcreate): cdef _register_ndarray(int index, object cls):
"""register object class"""
global _TVM_ND_CLS global _TVM_ND_CLS
_TVM_ND_CLS[cls._array_type_code] = fcreate while len(_TVM_ND_CLS) <= index:
_TVM_ND_CLS.append(None)
_TVM_ND_CLS[index] = cls
def _make_array(handle, is_view, is_container): def _make_array(handle, is_view, is_container):
cdef unsigned long long ptr cdef unsigned long long ptr
......
...@@ -16,12 +16,15 @@ ...@@ -16,12 +16,15 @@
# under the License. # under the License.
"""Maps object type to its constructor""" """Maps object type to its constructor"""
from ..node_generic import _set_class_node_base cdef list OBJECT_TYPE = []
OBJECT_TYPE = []
def _register_object(int index, object cls): def _register_object(int index, object cls):
"""register object class""" """register object class"""
if issubclass(cls, NDArrayBase):
_register_ndarray(index, cls)
return
global OBJECT_TYPE
while len(OBJECT_TYPE) <= index: while len(OBJECT_TYPE) <= index:
OBJECT_TYPE.append(None) OBJECT_TYPE.append(None)
OBJECT_TYPE[index] = cls OBJECT_TYPE[index] = cls
...@@ -31,14 +34,13 @@ cdef inline object make_ret_object(void* chandle): ...@@ -31,14 +34,13 @@ cdef inline object make_ret_object(void* chandle):
global OBJECT_TYPE global OBJECT_TYPE
global _CLASS_NODE global _CLASS_NODE
cdef unsigned tindex cdef unsigned tindex
cdef list object_type
cdef object cls cdef object cls
cdef object handle cdef object handle
object_type = OBJECT_TYPE object_type = OBJECT_TYPE
handle = ctypes_handle(chandle) handle = ctypes_handle(chandle)
CALL(TVMObjectGetTypeIndex(chandle, &tindex)) CALL(TVMObjectGetTypeIndex(chandle, &tindex))
if tindex < len(object_type): if tindex < len(OBJECT_TYPE):
cls = object_type[tindex] cls = OBJECT_TYPE[tindex]
if cls is not None: if cls is not None:
obj = cls.__new__(cls) obj = cls.__new__(cls)
else: else:
...@@ -99,6 +101,3 @@ cdef class ObjectBase: ...@@ -99,6 +101,3 @@ cdef class ObjectBase:
(<FunctionBase>fconstructor).chandle, (<FunctionBase>fconstructor).chandle,
kObjectHandle, args, &chandle) kObjectHandle, args, &chandle)
self.chandle = chandle self.chandle = chandle
_set_class_node_base(ObjectBase)
...@@ -22,6 +22,7 @@ from __future__ import absolute_import ...@@ -22,6 +22,7 @@ from __future__ import absolute_import
import sys import sys
import ctypes import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
from .node_generic import _set_class_objects
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
...@@ -32,15 +33,21 @@ try: ...@@ -32,15 +33,21 @@ try:
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import convert_to_tvm_func from ._cy3.core import convert_to_tvm_func
else: else:
from ._cy2.core import _set_class_function, _set_class_module from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import convert_to_tvm_func from ._cy2.core import convert_to_tvm_func
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import FunctionBase as _FunctionBase from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.object import ObjectBase as _ObjectBase
from ._ctypes.function import convert_to_tvm_func from ._ctypes.function import convert_to_tvm_func
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
...@@ -325,3 +332,4 @@ def _init_api_prefix(module_name, prefix): ...@@ -325,3 +332,4 @@ def _init_api_prefix(module_name, prefix):
setattr(target_module, ff.__name__, ff) setattr(target_module, ff.__name__, ff)
_set_class_function(Function) _set_class_function(Function)
_set_class_objects((_ObjectBase, _NDArrayBase, ModuleBase))
...@@ -35,16 +35,16 @@ try: ...@@ -35,16 +35,16 @@ try:
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy3.core import NDArrayBase as _NDArrayBase from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import _reg_extension, _reg_ndarray from ._cy3.core import _reg_extension
else: else:
from ._cy2.core import _set_class_ndarray, _make_array, _from_dlpack from ._cy2.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy2.core import NDArrayBase as _NDArrayBase from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import _reg_extension, _reg_ndarray from ._cy2.core import _reg_extension
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.ndarray import _reg_extension, _reg_ndarray from ._ctypes.ndarray import _reg_extension
def context(dev_type, dev_id=0): def context(dev_type, dev_id=0):
...@@ -348,13 +348,8 @@ def register_extension(cls, fcreate=None): ...@@ -348,13 +348,8 @@ def register_extension(cls, fcreate=None):
def _tvm_handle(self): def _tvm_handle(self):
return self.handle.value return self.handle.value
""" """
if issubclass(cls, _NDArrayBase): assert hasattr(cls, "_tvm_tcode")
assert fcreate is not None if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
assert hasattr(cls, "_array_type_code") raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_ndarray(cls, fcreate) _reg_extension(cls, fcreate)
else:
assert hasattr(cls, "_tvm_tcode")
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
return cls return cls
...@@ -23,11 +23,11 @@ from .. import _api_internal ...@@ -23,11 +23,11 @@ from .. import _api_internal
from .base import string_types from .base import string_types
# Node base class # Node base class
_CLASS_NODE_BASE = None _CLASS_OBJECTS = None
def _set_class_node_base(cls): def _set_class_objects(cls):
global _CLASS_NODE_BASE global _CLASS_OBJECTS
_CLASS_NODE_BASE = cls _CLASS_OBJECTS = cls
def _scalar_type_inference(value): def _scalar_type_inference(value):
...@@ -67,7 +67,7 @@ def convert_to_node(value): ...@@ -67,7 +67,7 @@ def convert_to_node(value):
node : Node node : Node
The corresponding node value. The corresponding node value.
""" """
if isinstance(value, _CLASS_NODE_BASE): if isinstance(value, _CLASS_OBJECTS):
return value return value
if isinstance(value, bool): if isinstance(value, bool):
return const(value, 'uint1x1') return const(value, 'uint1x1')
...@@ -81,7 +81,7 @@ def convert_to_node(value): ...@@ -81,7 +81,7 @@ def convert_to_node(value):
if isinstance(value, dict): if isinstance(value, dict):
vlist = [] vlist = []
for item in value.items(): for item in value.items():
if (not isinstance(item[0], _CLASS_NODE_BASE) and if (not isinstance(item[0], _CLASS_OBJECTS) and
not isinstance(item[0], string_types)): not isinstance(item[0], string_types)):
raise ValueError("key of map must already been a container type") raise ValueError("key of map must already been a container type")
vlist.append(item[0]) vlist.append(item[0])
......
...@@ -271,12 +271,3 @@ class TVMArray(ctypes.Structure): ...@@ -271,12 +271,3 @@ class TVMArray(ctypes.Structure):
("byte_offset", ctypes.c_uint64)] ("byte_offset", ctypes.c_uint64)]
TVMArrayHandle = ctypes.POINTER(TVMArray) TVMArrayHandle = ctypes.POINTER(TVMArray)
class TVMNDArrayContainer(ctypes.Structure):
"""TVM NDArray::Container"""
_fields_ = [("dl_tensor", TVMArray),
("manager_ctx", ctypes.c_void_p),
("deleter", ctypes.c_void_p),
("array_type_info", ctypes.c_int32)]
TVMNDArrayContainerHandle = ctypes.POINTER(TVMNDArrayContainer)
...@@ -27,7 +27,10 @@ from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase ...@@ -27,7 +27,10 @@ from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack from ._ffi.ndarray import context, empty, from_dlpack
from ._ffi.ndarray import _set_class_ndarray from ._ffi.ndarray import _set_class_ndarray
from ._ffi.ndarray import register_extension from ._ffi.ndarray import register_extension
from ._ffi.object import register_object
@register_object
class NDArray(NDArrayBase): class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime. """Lightweight NDArray class of TVM runtime.
......
...@@ -67,7 +67,7 @@ TVM_REGISTER_API("_Array") ...@@ -67,7 +67,7 @@ TVM_REGISTER_API("_Array")
} }
auto node = make_node<ArrayNode>(); auto node = make_node<ArrayNode>();
node->data = std::move(data); node->data = std::move(data);
*ret = runtime::ObjectRef(node); *ret = Array<ObjectRef>(node);
}); });
TVM_REGISTER_API("_ArrayGetItem") TVM_REGISTER_API("_ArrayGetItem")
...@@ -100,28 +100,28 @@ TVM_REGISTER_API("_Map") ...@@ -100,28 +100,28 @@ TVM_REGISTER_API("_Map")
for (int i = 0; i < args.num_args; i += 2) { for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kStr) CHECK(args[i].type_code() == kStr)
<< "key of str map need to be str"; << "key of str map need to be str";
CHECK(args[i + 1].type_code() == kObjectHandle) CHECK(args[i + 1].IsObjectRef<ObjectRef>())
<< "value of the map to be NodeRef"; << "value of the map to be NodeRef";
data.emplace(std::make_pair(args[i].operator std::string(), data.emplace(std::make_pair(args[i].operator std::string(),
args[i + 1].operator ObjectRef())); args[i + 1].operator ObjectRef()));
} }
auto node = make_node<StrMapNode>(); auto node = make_node<StrMapNode>();
node->data = std::move(data); node->data = std::move(data);
*ret = node; *ret = Map<ObjectRef, ObjectRef>(node);
} else { } else {
// Container node. // Container node.
MapNode::ContainerType data; MapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) { for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kObjectHandle) CHECK(args[i].IsObjectRef<ObjectRef>())
<< "key of str map need to be str"; << "key of str map need to be object";
CHECK(args[i + 1].type_code() == kObjectHandle) CHECK(args[i + 1].IsObjectRef<ObjectRef>())
<< "value of map to be NodeRef"; << "value of map to be NodeRef";
data.emplace(std::make_pair(args[i].operator ObjectRef(), data.emplace(std::make_pair(args[i].operator ObjectRef(),
args[i + 1].operator ObjectRef())); args[i + 1].operator ObjectRef()));
} }
auto node = make_node<MapNode>(); auto node = make_node<MapNode>();
node->data = std::move(data); node->data = std::move(data);
*ret = node; *ret = Map<ObjectRef, ObjectRef>(node);
} }
}); });
...@@ -191,7 +191,7 @@ TVM_REGISTER_API("_MapItems") ...@@ -191,7 +191,7 @@ TVM_REGISTER_API("_MapItems")
rkvs->data.push_back(kv.first); rkvs->data.push_back(kv.first);
rkvs->data.push_back(kv.second); rkvs->data.push_back(kv.second);
} }
*ret = rkvs; *ret = Array<ObjectRef>(rkvs);
} else { } else {
auto* n = static_cast<const StrMapNode*>(ptr); auto* n = static_cast<const StrMapNode*>(ptr);
auto rkvs = make_node<ArrayNode>(); auto rkvs = make_node<ArrayNode>();
...@@ -199,7 +199,7 @@ TVM_REGISTER_API("_MapItems") ...@@ -199,7 +199,7 @@ TVM_REGISTER_API("_MapItems")
rkvs->data.push_back(ir::StringImm::make(kv.first)); rkvs->data.push_back(ir::StringImm::make(kv.first));
rkvs->data.push_back(kv.second); rkvs->data.push_back(kv.second);
} }
*ret = rkvs; *ret = Array<ObjectRef>(rkvs);
} }
}); });
......
...@@ -27,8 +27,12 @@ ...@@ -27,8 +27,12 @@
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include "runtime_base.h" #include "runtime_base.h"
// deleter for arrays used by DLPack exporter extern "C" {
extern "C" void NDArrayDLPackDeleter(DLManagedTensor* tensor); // C-mangled dlpack deleter.
static void TVMNDArrayDLPackDeleter(DLManagedTensor* tensor);
// helper function to get NDArray's type index, only used by ctypes.
TVM_DLL int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex);
}
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -53,8 +57,8 @@ inline size_t GetDataAlignment(const DLTensor& arr) { ...@@ -53,8 +57,8 @@ inline size_t GetDataAlignment(const DLTensor& arr) {
struct NDArray::Internal { struct NDArray::Internal {
// Default deleter for the container // Default deleter for the container
static void DefaultDeleter(NDArray::Container* ptr) { static void DefaultDeleter(Object* ptr_obj) {
using tvm::runtime::NDArray; auto* ptr = static_cast<NDArray::Container*>(ptr_obj);
if (ptr->manager_ctx != nullptr) { if (ptr->manager_ctx != nullptr) {
static_cast<NDArray::Container*>(ptr->manager_ctx)->DecRef(); static_cast<NDArray::Container*>(ptr->manager_ctx)->DecRef();
} else if (ptr->dl_tensor.data != nullptr) { } else if (ptr->dl_tensor.data != nullptr) {
...@@ -68,7 +72,8 @@ struct NDArray::Internal { ...@@ -68,7 +72,8 @@ struct NDArray::Internal {
// that are not allocated inside of TVM. // that are not allocated inside of TVM.
// This enables us to create NDArray from memory allocated by other // This enables us to create NDArray from memory allocated by other
// frameworks that are DLPack compatible // frameworks that are DLPack compatible
static void DLPackDeleter(NDArray::Container* ptr) { static void DLPackDeleter(Object* ptr_obj) {
auto* ptr = static_cast<NDArray::Container*>(ptr_obj);
DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx); DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx);
if (tensor->deleter != nullptr) { if (tensor->deleter != nullptr) {
(*tensor->deleter)(tensor); (*tensor->deleter)(tensor);
...@@ -81,12 +86,13 @@ struct NDArray::Internal { ...@@ -81,12 +86,13 @@ struct NDArray::Internal {
DLDataType dtype, DLDataType dtype,
DLContext ctx) { DLContext ctx) {
VerifyDataType(dtype); VerifyDataType(dtype);
// critical zone
// critical zone: construct header
NDArray::Container* data = new NDArray::Container(); NDArray::Container* data = new NDArray::Container();
data->deleter = DefaultDeleter; data->SetDeleter(DefaultDeleter);
NDArray ret(data);
ret.data_ = data;
// RAII now in effect // RAII now in effect
NDArray ret(GetObjectPtr<Object>(data));
// setup shape // setup shape
data->shape_ = std::move(shape); data->shape_ = std::move(shape);
data->dl_tensor.shape = dmlc::BeginPtr(data->shape_); data->dl_tensor.shape = dmlc::BeginPtr(data->shape_);
...@@ -98,45 +104,57 @@ struct NDArray::Internal { ...@@ -98,45 +104,57 @@ struct NDArray::Internal {
return ret; return ret;
} }
// Implementation of API function // Implementation of API function
static DLTensor* MoveAsDLTensor(NDArray arr) { static DLTensor* MoveToFFIHandle(NDArray arr) {
DLTensor* tensor = const_cast<DLTensor*>(arr.operator->()); DLTensor* handle = NDArray::FFIGetHandle(arr);
CHECK(reinterpret_cast<DLTensor*>(arr.data_) == tensor); ObjectRef::FFIClearAfterMove(&arr);
arr.data_ = nullptr; return handle;
return tensor; }
static void FFIDecRef(TVMArrayHandle tensor) {
NDArray::FFIDecRef(tensor);
} }
// Container to DLManagedTensor // Container to DLManagedTensor
static DLManagedTensor* ToDLPack(TVMArrayHandle handle) {
auto* from = static_cast<NDArray::Container*>(
reinterpret_cast<NDArray::ContainerBase*>(handle));
return ToDLPack(from);
}
static DLManagedTensor* ToDLPack(NDArray::Container* from) { static DLManagedTensor* ToDLPack(NDArray::Container* from) {
CHECK(from != nullptr); CHECK(from != nullptr);
DLManagedTensor* ret = new DLManagedTensor(); DLManagedTensor* ret = new DLManagedTensor();
ret->dl_tensor = from->dl_tensor; ret->dl_tensor = from->dl_tensor;
ret->manager_ctx = from; ret->manager_ctx = from;
from->IncRef(); from->IncRef();
ret->deleter = NDArrayDLPackDeleter; ret->deleter = TVMNDArrayDLPackDeleter;
return ret; return ret;
} }
// Delete dlpack object.
static void NDArrayDLPackDeleter(DLManagedTensor* tensor) {
static_cast<NDArray::Container*>(tensor->manager_ctx)->DecRef();
delete tensor;
}
}; };
NDArray NDArray::CreateView(std::vector<int64_t> shape, NDArray NDArray::CreateView(std::vector<int64_t> shape, DLDataType dtype) {
DLDataType dtype) {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
CHECK(data_->dl_tensor.strides == nullptr) CHECK(get_mutable()->dl_tensor.strides == nullptr)
<< "Can only create view for compact tensor"; << "Can only create view for compact tensor";
NDArray ret = Internal::Create(shape, dtype, data_->dl_tensor.ctx); NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.ctx);
ret.data_->dl_tensor.byte_offset = ret.get_mutable()->dl_tensor.byte_offset =
this->data_->dl_tensor.byte_offset; this->get_mutable()->dl_tensor.byte_offset;
size_t curr_size = GetDataSize(this->data_->dl_tensor); size_t curr_size = GetDataSize(this->get_mutable()->dl_tensor);
size_t view_size = GetDataSize(ret.data_->dl_tensor); size_t view_size = GetDataSize(ret.get_mutable()->dl_tensor);
CHECK_LE(view_size, curr_size) CHECK_LE(view_size, curr_size)
<< "Tries to create a view that has bigger memory than current one"; << "Tries to create a view that has bigger memory than current one";
// increase ref count // increase ref count
this->data_->IncRef(); get_mutable()->IncRef();
ret.data_->manager_ctx = this->data_; ret.get_mutable()->manager_ctx = get_mutable();
ret.data_->dl_tensor.data = this->data_->dl_tensor.data; ret.get_mutable()->dl_tensor.data = get_mutable()->dl_tensor.data;
return ret; return ret;
} }
DLManagedTensor* NDArray::ToDLPack() const { DLManagedTensor* NDArray::ToDLPack() const {
return Internal::ToDLPack(data_); return Internal::ToDLPack(get_mutable());
} }
NDArray NDArray::Empty(std::vector<int64_t> shape, NDArray NDArray::Empty(std::vector<int64_t> shape,
...@@ -144,9 +162,9 @@ NDArray NDArray::Empty(std::vector<int64_t> shape, ...@@ -144,9 +162,9 @@ NDArray NDArray::Empty(std::vector<int64_t> shape,
DLContext ctx) { DLContext ctx) {
NDArray ret = Internal::Create(shape, dtype, ctx); NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content // setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor); size_t size = GetDataSize(ret.get_mutable()->dl_tensor);
size_t alignment = GetDataAlignment(ret.data_->dl_tensor); size_t alignment = GetDataAlignment(ret.get_mutable()->dl_tensor);
ret.data_->dl_tensor.data = ret.get_mutable()->dl_tensor.data =
DeviceAPI::Get(ret->ctx)->AllocDataSpace( DeviceAPI::Get(ret->ctx)->AllocDataSpace(
ret->ctx, size, alignment, ret->dtype); ret->ctx, size, alignment, ret->dtype);
return ret; return ret;
...@@ -154,10 +172,12 @@ NDArray NDArray::Empty(std::vector<int64_t> shape, ...@@ -154,10 +172,12 @@ NDArray NDArray::Empty(std::vector<int64_t> shape,
NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { NDArray NDArray::FromDLPack(DLManagedTensor* tensor) {
NDArray::Container* data = new NDArray::Container(); NDArray::Container* data = new NDArray::Container();
data->deleter = Internal::DLPackDeleter; // construct header
data->SetDeleter(Internal::DLPackDeleter);
// fill up content.
data->manager_ctx = tensor; data->manager_ctx = tensor;
data->dl_tensor = tensor->dl_tensor; data->dl_tensor = tensor->dl_tensor;
return NDArray(data); return NDArray(GetObjectPtr<Object>(data));
} }
void NDArray::CopyFromTo(const DLTensor* from, void NDArray::CopyFromTo(const DLTensor* from,
...@@ -184,17 +204,24 @@ void NDArray::CopyFromTo(const DLTensor* from, ...@@ -184,17 +204,24 @@ void NDArray::CopyFromTo(const DLTensor* from,
} }
std::vector<int64_t> NDArray::Shape() const { std::vector<int64_t> NDArray::Shape() const {
return data_->shape_; return get_mutable()->shape_;
} }
TVM_REGISTER_OBJECT_TYPE(NDArray::Container);
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
using namespace tvm::runtime; using namespace tvm::runtime;
void NDArrayDLPackDeleter(DLManagedTensor* tensor) { void TVMNDArrayDLPackDeleter(DLManagedTensor* tensor) {
static_cast<NDArray::Container*>(tensor->manager_ctx)->DecRef(); NDArray::Internal::NDArrayDLPackDeleter(tensor);
delete tensor; }
int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex) {
API_BEGIN();
*out_tindex = TVMArrayHandleToObjectHandle(handle)->type_index();
API_END();
} }
int TVMArrayAlloc(const tvm_index_t* shape, int TVMArrayAlloc(const tvm_index_t* shape,
...@@ -213,14 +240,14 @@ int TVMArrayAlloc(const tvm_index_t* shape, ...@@ -213,14 +240,14 @@ int TVMArrayAlloc(const tvm_index_t* shape,
DLContext ctx; DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
*out = NDArray::Internal::MoveAsDLTensor( *out = NDArray::Internal::MoveToFFIHandle(
NDArray::Empty(std::vector<int64_t>(shape, shape + ndim), dtype, ctx)); NDArray::Empty(std::vector<int64_t>(shape, shape + ndim), dtype, ctx));
API_END(); API_END();
} }
int TVMArrayFree(TVMArrayHandle handle) { int TVMArrayFree(TVMArrayHandle handle) {
API_BEGIN(); API_BEGIN();
reinterpret_cast<NDArray::Container*>(handle)->DecRef(); NDArray::Internal::FFIDecRef(handle);
API_END(); API_END();
} }
...@@ -235,14 +262,14 @@ int TVMArrayCopyFromTo(TVMArrayHandle from, ...@@ -235,14 +262,14 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
int TVMArrayFromDLPack(DLManagedTensor* from, int TVMArrayFromDLPack(DLManagedTensor* from,
TVMArrayHandle* out) { TVMArrayHandle* out) {
API_BEGIN(); API_BEGIN();
*out = NDArray::Internal::MoveAsDLTensor(NDArray::FromDLPack(from)); *out = NDArray::Internal::MoveToFFIHandle(NDArray::FromDLPack(from));
API_END(); API_END();
} }
int TVMArrayToDLPack(TVMArrayHandle from, int TVMArrayToDLPack(TVMArrayHandle from,
DLManagedTensor** out) { DLManagedTensor** out) {
API_BEGIN(); API_BEGIN();
*out = NDArray::Internal::ToDLPack(reinterpret_cast<NDArray::Container*>(from)); *out = NDArray::Internal::ToDLPack(from);
API_END(); API_END();
} }
......
...@@ -59,7 +59,8 @@ class RPCWrappedFunc { ...@@ -59,7 +59,8 @@ class RPCWrappedFunc {
const TVMArgValue& arg); const TVMArgValue& arg);
// deleter of RPC remote array // deleter of RPC remote array
static void RemoteNDArrayDeleter(NDArray::Container* ptr) { static void RemoteNDArrayDeleter(Object* obj) {
auto* ptr = static_cast<NDArray::Container*>(obj);
RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data); RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data);
space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx); space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx);
delete space; delete space;
...@@ -71,12 +72,12 @@ class RPCWrappedFunc { ...@@ -71,12 +72,12 @@ class RPCWrappedFunc {
void* nd_handle) { void* nd_handle) {
NDArray::Container* data = new NDArray::Container(); NDArray::Container* data = new NDArray::Container();
data->manager_ctx = nd_handle; data->manager_ctx = nd_handle;
data->deleter = RemoteNDArrayDeleter; data->SetDeleter(RemoteNDArrayDeleter);
RemoteSpace* space = new RemoteSpace(); RemoteSpace* space = new RemoteSpace();
space->sess = sess; space->sess = sess;
space->data = tensor->data; space->data = tensor->data;
data->dl_tensor.data = space; data->dl_tensor.data = space;
NDArray ret(data); NDArray ret(GetObjectPtr<Object>(data));
// RAII now in effect // RAII now in effect
data->shape_ = std::vector<int64_t>( data->shape_ = std::vector<int64_t>(
tensor->shape, tensor->shape + tensor->ndim); tensor->shape, tensor->shape + tensor->ndim);
......
...@@ -787,9 +787,7 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -787,9 +787,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
TVMValue ret_value_pack[2]; TVMValue ret_value_pack[2];
int ret_tcode_pack[2]; int ret_tcode_pack[2];
rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]); rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]);
ret_value_pack[1].v_handle = ret_value_pack[0].v_handle;
NDArray::Container* nd = static_cast<NDArray::Container*>(ret_value_pack[0].v_handle);
ret_value_pack[1].v_handle = nd;
ret_tcode_pack[1] = kHandle; ret_tcode_pack[1] = kHandle;
SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true); SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true);
} else { } else {
...@@ -1190,7 +1188,8 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) { ...@@ -1190,7 +1188,8 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) { void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) {
void* handle = args[0]; void* handle = args[0];
static_cast<NDArray::Container*>(handle)->DecRef(); static_cast<NDArray::Container*>(
reinterpret_cast<NDArray::ContainerBase*>(handle))->DecRef();
} }
void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) { void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) {
......
...@@ -31,7 +31,8 @@ namespace tvm { ...@@ -31,7 +31,8 @@ namespace tvm {
namespace runtime { namespace runtime {
namespace vm { namespace vm {
static void BufferDeleter(NDArray::Container* ptr) { static void BufferDeleter(Object* obj) {
auto* ptr = static_cast<NDArray::Container*>(obj);
CHECK(ptr->manager_ctx != nullptr); CHECK(ptr->manager_ctx != nullptr);
Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx); Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
MemoryManager::Global()->GetAllocator(buffer->ctx)-> MemoryManager::Global()->GetAllocator(buffer->ctx)->
...@@ -40,7 +41,8 @@ static void BufferDeleter(NDArray::Container* ptr) { ...@@ -40,7 +41,8 @@ static void BufferDeleter(NDArray::Container* ptr) {
delete ptr; delete ptr;
} }
void StorageObj::Deleter(NDArray::Container* ptr) { void StorageObj::Deleter(Object* obj) {
auto* ptr = static_cast<NDArray::Container*>(obj);
// When invoking AllocNDArray we don't own the underlying allocation // When invoking AllocNDArray we don't own the underlying allocation
// and should not delete the buffer, but instead let it be reclaimed // and should not delete the buffer, but instead let it be reclaimed
// by the storage object's destructor. // by the storage object's destructor.
...@@ -77,16 +79,23 @@ NDArray StorageObj::AllocNDArray(size_t offset, std::vector<int64_t> shape, DLDa ...@@ -77,16 +79,23 @@ NDArray StorageObj::AllocNDArray(size_t offset, std::vector<int64_t> shape, DLDa
// TODO(@jroesch): generalize later to non-overlapping allocations. // TODO(@jroesch): generalize later to non-overlapping allocations.
CHECK_EQ(offset, 0u); CHECK_EQ(offset, 0u);
VerifyDataType(dtype); VerifyDataType(dtype);
// crtical zone: allocate header, cannot throw
NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, this->buffer.ctx); NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, this->buffer.ctx);
container->deleter = StorageObj::Deleter;
container->SetDeleter(StorageObj::Deleter);
size_t needed_size = GetDataSize(container->dl_tensor); size_t needed_size = GetDataSize(container->dl_tensor);
// TODO(@jroesch): generalize later to non-overlapping allocations.
CHECK(needed_size == this->buffer.size)
<< "size mistmatch required " << needed_size << " found " << this->buffer.size;
this->IncRef(); this->IncRef();
container->manager_ctx = reinterpret_cast<void*>(this); container->manager_ctx = reinterpret_cast<void*>(this);
container->dl_tensor.data = this->buffer.data; container->dl_tensor.data = this->buffer.data;
return NDArray(container); NDArray ret(GetObjectPtr<Object>(container));
// RAII in effect, now run the check.
// TODO(@jroesch): generalize later to non-overlapping allocations.
CHECK(needed_size == this->buffer.size)
<< "size mistmatch required " << needed_size << " found " << this->buffer.size;
return ret;
} }
MemoryManager* MemoryManager::Global() { MemoryManager* MemoryManager::Global() {
...@@ -108,14 +117,14 @@ Allocator* MemoryManager::GetAllocator(TVMContext ctx) { ...@@ -108,14 +117,14 @@ Allocator* MemoryManager::GetAllocator(TVMContext ctx) {
NDArray Allocator::Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) { NDArray Allocator::Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) {
VerifyDataType(dtype); VerifyDataType(dtype);
NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, ctx); NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, ctx);
container->deleter = BufferDeleter; container->SetDeleter(BufferDeleter);
size_t size = GetDataSize(container->dl_tensor); size_t size = GetDataSize(container->dl_tensor);
size_t alignment = GetDataAlignment(container->dl_tensor); size_t alignment = GetDataAlignment(container->dl_tensor);
Buffer *buffer = new Buffer; Buffer *buffer = new Buffer;
*buffer = this->Alloc(size, alignment, dtype); *buffer = this->Alloc(size, alignment, dtype);
container->manager_ctx = reinterpret_cast<void*>(buffer); container->manager_ctx = reinterpret_cast<void*>(buffer);
container->dl_tensor.data = buffer->data; container->dl_tensor.data = buffer->data;
return NDArray(container); return NDArray(GetObjectPtr<Object>(container));
} }
} // namespace vm } // namespace vm
......
...@@ -120,7 +120,7 @@ class StorageObj : public Object { ...@@ -120,7 +120,7 @@ class StorageObj : public Object {
DLDataType dtype); DLDataType dtype);
/*! \brief The deleter for an NDArray when allocated from underlying storage. */ /*! \brief The deleter for an NDArray when allocated from underlying storage. */
static void Deleter(NDArray::Container* ptr); static void Deleter(Object* ptr);
~StorageObj() { ~StorageObj() {
auto alloc = MemoryManager::Global()->GetAllocator(buffer.ctx); auto alloc = MemoryManager::Global()->GetAllocator(buffer.ctx);
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir.h> #include <tvm/ir.h>
TEST(PackedFunc, Basic) { TEST(PackedFunc, Basic) {
...@@ -178,6 +179,69 @@ TEST(TypedPackedFunc, HighOrder) { ...@@ -178,6 +179,69 @@ TEST(TypedPackedFunc, HighOrder) {
CHECK_EQ(f1(3), 4); CHECK_EQ(f1(3), 4);
} }
TEST(PackedFunc, ObjectConversion) {
using namespace tvm;
using namespace tvm::runtime;
TVMRetValue rv;
auto x = NDArray::Empty(
{}, String2TVMType("float32"),
TVMContext{kDLCPU, 0});
// assign null
rv = ObjectRef();
CHECK_EQ(rv.type_code(), kNull);
// Can assign NDArray to ret type
rv = x;
CHECK_EQ(rv.type_code(), kNDArrayContainer);
// Even if we assign base type it still shows as NDArray
rv = ObjectRef(x);
CHECK_EQ(rv.type_code(), kNDArrayContainer);
// Check convert back
CHECK(rv.operator NDArray().same_as(x));
CHECK(rv.operator ObjectRef().same_as(x));
CHECK(!rv.IsObjectRef<Expr>());
auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args[0].type_code(), kNDArrayContainer);
CHECK(args[0].operator NDArray().same_as(x));
CHECK(args[0].operator ObjectRef().same_as(x));
CHECK(args[1].operator ObjectRef().get() == nullptr);
CHECK(args[1].operator NDArray().get() == nullptr);
CHECK(args[1].operator Module().get() == nullptr);
CHECK(args[1].operator Array<NDArray>().get() == nullptr);
CHECK(!args[0].IsObjectRef<Expr>());
});
pf1(x, ObjectRef());
pf1(ObjectRef(x), NDArray());
// testcases for modules
auto* pf = tvm::runtime::Registry::Get("module.source_module_create");
CHECK(pf != nullptr);
Module m = (*pf)("", "xyz");
rv = m;
CHECK_EQ(rv.type_code(), kModuleHandle);
// Even if we assign base type it still shows as NDArray
rv = ObjectRef(m);
CHECK_EQ(rv.type_code(), kModuleHandle);
// Check convert back
CHECK(rv.operator Module().same_as(m));
CHECK(rv.operator ObjectRef().same_as(m));
CHECK(!rv.IsObjectRef<NDArray>());
auto pf2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args[0].type_code(), kModuleHandle);
CHECK(args[0].operator Module().same_as(m));
CHECK(args[0].operator ObjectRef().same_as(m));
CHECK(args[1].operator ObjectRef().get() == nullptr);
CHECK(args[1].operator NDArray().get() == nullptr);
CHECK(args[1].operator Module().get() == nullptr);
CHECK(!args[0].IsObjectRef<Expr>());
});
pf2(m, ObjectRef());
pf2(ObjectRef(m), Module());
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import tvm import tvm
import numpy as np
def test_array(): def test_array():
a = tvm.convert([1,2,3]) a = tvm.convert([1,2,3])
...@@ -71,6 +72,14 @@ def test_in_container(): ...@@ -71,6 +72,14 @@ def test_in_container():
assert tvm.make.StringImm('a') in arr assert tvm.make.StringImm('a') in arr
assert 'd' not in arr assert 'd' not in arr
def test_ndarray_container():
x = tvm.nd.array([1,2,3])
arr = tvm.convert([x, x])
assert arr[0].same_as(x)
assert arr[1].same_as(x)
assert isinstance(arr[0], tvm.nd.NDArray)
if __name__ == "__main__": if __name__ == "__main__":
test_str_map() test_str_map()
test_array() test_array()
...@@ -78,3 +87,4 @@ if __name__ == "__main__": ...@@ -78,3 +87,4 @@ if __name__ == "__main__":
test_array_save_load_json() test_array_save_load_json()
test_map_save_load_json() test_map_save_load_json()
test_in_container() test_in_container()
test_ndarray_container()
...@@ -32,7 +32,7 @@ def gen_engine_header(): ...@@ -32,7 +32,7 @@ def gen_engine_header():
#include <vector> #include <vector>
class Engine { class Engine {
}; };
#endif #endif
''' '''
header_file = header_file_dir_path.relpath("gcc_engine.h") header_file = header_file_dir_path.relpath("gcc_engine.h")
...@@ -45,7 +45,7 @@ def generate_engine_module(): ...@@ -45,7 +45,7 @@ def generate_engine_module():
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <dlpack/dlpack.h> #include <dlpack/dlpack.h>
#include "gcc_engine.h" #include "gcc_engine.h"
extern "C" void gcc_1_(float* gcc_input4, float* gcc_input5, extern "C" void gcc_1_(float* gcc_input4, float* gcc_input5,
float* gcc_input6, float* gcc_input7, float* out) { float* gcc_input6, float* gcc_input7, float* out) {
Engine engine; Engine engine;
......
...@@ -35,7 +35,8 @@ rm -rf lib ...@@ -35,7 +35,8 @@ rm -rf lib
make make
cd ../.. cd ../..
python3 -m pytest -v apps/extension/tests TVM_FFI=cython python3 -m pytest -v apps/extension/tests
TVM_FFI=ctypes python3 -m pytest -v apps/extension/tests
TVM_FFI=ctypes python3 -m pytest -v tests/python/integration TVM_FFI=ctypes python3 -m pytest -v tests/python/integration
TVM_FFI=ctypes python3 -m pytest -v tests/python/contrib TVM_FFI=ctypes python3 -m pytest -v tests/python/contrib
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment