Unverified Commit 55bd786f by Tianqi Chen Committed by GitHub

[REFACTOR][RUNTIME] Update NDArray use the Unified Object System (#4581)

* [REFACTOR][RUNTIME] Move NDArray to Object System.

Previously NDArray has its own object reference counting mechanism.
This PR migrates NDArray to the unified object protocol.

The calling convention of NDArray remained intact.
That means NDArray still has its own type_code and
its handle is still DLTensor compatible.

In order to do so, this PR added a few minimum runtime type
detection in TVMArgValue and RetValue only when the corresponding
type is a base type(ObjectRef) that could also refer to NDArray.

This means that even if we return a base reference object ObjectRef
which refers to the NDArray. The type_code will still be translated
correctly as kNDArrayContainer.
If we assign a non-base type(say Expr) that we know is not compatible
with NDArray during compile time, no runtime type detection will be performed.

This PR also adopts the object protocol for NDArray sub-classing and
removed the legacy NDArray subclass protocol.
Examples in apps/extension are now updated to reflect that.

Making NDArray as an Object brings all the benefits of the object system.
For example, we can now use the Array container to store NDArrays.

* Address review comments
parent 4072396e
...@@ -51,26 +51,23 @@ class IntVec(tvm.Object): ...@@ -51,26 +51,23 @@ class IntVec(tvm.Object):
nd_create = tvm.get_global_func("tvm_ext.nd_create") nd_create = tvm.get_global_func("tvm_ext.nd_create")
nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two") nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
nd_get_addtional_info = tvm.get_global_func("tvm_ext.nd_get_addtional_info") nd_get_additional_info = tvm.get_global_func("tvm_ext.nd_get_additional_info")
@tvm.register_object("tvm_ext.NDSubClass")
class NDSubClass(tvm.nd.NDArrayBase): class NDSubClass(tvm.nd.NDArrayBase):
"""Example for subclassing TVM's NDArray infrastructure. """Example for subclassing TVM's NDArray infrastructure.
By inheriting TMV's NDArray, external libraries could By inheriting TMV's NDArray, external libraries could
leverage TVM's FFI without any modification. leverage TVM's FFI without any modification.
""" """
# Should be consistent with the type-trait set in the backend
_array_type_code = 1
@staticmethod @staticmethod
def create(addtional_info): def create(additional_info):
return nd_create(addtional_info) return nd_create(additional_info)
@property @property
def addtional_info(self): def additional_info(self):
return nd_get_addtional_info(self) return nd_get_additional_info(self)
def __add__(self, other): def __add__(self, other):
return nd_add_two(self, other) return nd_add_two(self, other)
tvm.register_extension(NDSubClass, NDSubClass)
...@@ -29,19 +29,6 @@ ...@@ -29,19 +29,6 @@
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
namespace tvm_ext {
class NDSubClass;
} // namespace tvm_ext
namespace tvm {
namespace runtime {
template<>
struct array_type_info<tvm_ext::NDSubClass> {
static const int code = 1;
};
} // namespace tvm
} // namespace runtime
using namespace tvm; using namespace tvm;
using namespace tvm::runtime; using namespace tvm::runtime;
...@@ -52,54 +39,55 @@ namespace tvm_ext { ...@@ -52,54 +39,55 @@ namespace tvm_ext {
* To use this extension, an external library should * To use this extension, an external library should
* *
* 1) Inherit TVM's NDArray and NDArray container, * 1) Inherit TVM's NDArray and NDArray container,
* and define the trait `array_type_info` for this class.
* *
* 2) Define a constructor in the inherited class that accepts * 2) Follow the new object protocol to define new NDArray as a reference class.
* a pointer to TVM's Container, which is nullable.
* *
* 3) On Python frontend, inherit `tvm.nd.NDArrayBase`, * 3) On Python frontend, inherit `tvm.nd.NDArray`,
* define the class attribute `_array_type_code` consistent to * register the type using tvm.register_object
* the C++ type trait, and register the subclass using `tvm.register_extension`.
*/ */
class NDSubClass : public tvm::runtime::NDArray { class NDSubClass : public tvm::runtime::NDArray {
public: public:
class SubContainer : public NDArray::Container { class SubContainer : public NDArray::Container {
public: public:
SubContainer(int addtional_info) : SubContainer(int additional_info) :
addtional_info_(addtional_info) { additional_info_(additional_info) {
array_type_code_ = array_type_info<NDSubClass>::code; type_index_ = SubContainer::RuntimeTypeIndex();
}
static bool Is(NDArray::Container *container) {
SubContainer *c = static_cast<SubContainer*>(container);
return c->array_type_code_ == array_type_info<NDSubClass>::code;
} }
int addtional_info_{0}; int additional_info_{0};
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "tvm_ext.NDSubClass";
TVM_DECLARE_FINAL_OBJECT_INFO(SubContainer, NDArray::Container);
}; };
NDSubClass(NDArray::Container *container) {
if (container == nullptr) { static void SubContainerDeleter(Object* obj) {
data_ = nullptr; auto* ptr = static_cast<SubContainer*>(obj);
return; delete ptr;
}
CHECK(SubContainer::Is(container));
container->IncRef();
data_ = container;
} }
~NDSubClass() {
this->reset(); NDSubClass() {}
explicit NDSubClass(ObjectPtr<Object> n) : NDArray(n) {}
explicit NDSubClass(int additional_info) {
SubContainer* ptr = new SubContainer(additional_info);
ptr->SetDeleter(SubContainerDeleter);
data_ = GetObjectPtr<Object>(ptr);
} }
NDSubClass AddWith(const NDSubClass &other) const { NDSubClass AddWith(const NDSubClass &other) const {
SubContainer *a = static_cast<SubContainer*>(data_); SubContainer *a = static_cast<SubContainer*>(get_mutable());
SubContainer *b = static_cast<SubContainer*>(other.data_); SubContainer *b = static_cast<SubContainer*>(other.get_mutable());
CHECK(a != nullptr && b != nullptr); CHECK(a != nullptr && b != nullptr);
return NDSubClass(new SubContainer(a->addtional_info_ + b->addtional_info_)); return NDSubClass(a->additional_info_ + b->additional_info_);
} }
int get_additional_info() const { int get_additional_info() const {
SubContainer *self = static_cast<SubContainer*>(data_); SubContainer *self = static_cast<SubContainer*>(get_mutable());
CHECK(self != nullptr); CHECK(self != nullptr);
return self->addtional_info_; return self->additional_info_;
} }
using ContainerType = SubContainer;
}; };
TVM_REGISTER_OBJECT_TYPE(NDSubClass::SubContainer);
/*! /*!
* \brief Introduce additional extension data structures * \brief Introduce additional extension data structures
...@@ -166,8 +154,10 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev") ...@@ -166,8 +154,10 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev")
TVM_REGISTER_GLOBAL("tvm_ext.nd_create") TVM_REGISTER_GLOBAL("tvm_ext.nd_create")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
int addtional_info = args[0]; int additional_info = args[0];
*rv = NDSubClass(new NDSubClass::SubContainer(addtional_info)); *rv = NDSubClass(additional_info);
CHECK_EQ(rv->type_code(), kNDArrayContainer);
}); });
TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two") TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
...@@ -177,7 +167,7 @@ TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two") ...@@ -177,7 +167,7 @@ TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
*rv = a.AddWith(b); *rv = a.AddWith(b);
}); });
TVM_REGISTER_GLOBAL("tvm_ext.nd_get_addtional_info") TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
NDSubClass a = args[0]; NDSubClass a = args[0];
*rv = a.get_additional_info(); *rv = a.get_additional_info();
......
...@@ -87,16 +87,17 @@ def test_extern_call(): ...@@ -87,16 +87,17 @@ def test_extern_call():
def test_nd_subclass(): def test_nd_subclass():
a = tvm_ext.NDSubClass.create(addtional_info=3) a = tvm_ext.NDSubClass.create(additional_info=3)
b = tvm_ext.NDSubClass.create(addtional_info=5) b = tvm_ext.NDSubClass.create(additional_info=5)
assert isinstance(a, tvm_ext.NDSubClass)
c = a + b c = a + b
d = a + a d = a + a
e = b + b e = b + b
assert(a.addtional_info == 3) assert(a.additional_info == 3)
assert(b.addtional_info == 5) assert(b.additional_info == 5)
assert(c.addtional_info == 8) assert(c.additional_info == 8)
assert(d.addtional_info == 6) assert(d.additional_info == 6)
assert(e.addtional_info == 10) assert(e.additional_info == 10)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -23,14 +23,14 @@ ...@@ -23,14 +23,14 @@
#ifndef TVM_NODE_CONTAINER_H_ #ifndef TVM_NODE_CONTAINER_H_
#define TVM_NODE_CONTAINER_H_ #define TVM_NODE_CONTAINER_H_
#include <tvm/node/node.h>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include <initializer_list> #include <initializer_list>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <string> #include <string>
#include "node.h"
#include "memory.h"
namespace tvm { namespace tvm {
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#ifndef TVM_PACKED_FUNC_EXT_H_ #ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_ #define TVM_PACKED_FUNC_EXT_H_
#include <sstream>
#include <string> #include <string>
#include <memory> #include <memory>
#include <limits> #include <limits>
...@@ -43,22 +42,7 @@ using runtime::TVMRetValue; ...@@ -43,22 +42,7 @@ using runtime::TVMRetValue;
using runtime::PackedFunc; using runtime::PackedFunc;
namespace runtime { namespace runtime {
/*!
* \brief Runtime type checker for node type.
* \tparam T the type to be checked.
*/
template<typename T>
struct ObjectTypeChecker {
static bool Check(const Object* ptr) {
using ContainerType = typename T::ContainerType;
if (ptr == nullptr) return true;
return ptr->IsInstance<ContainerType>();
}
static void PrintName(std::ostream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
}
};
template<typename T> template<typename T>
struct ObjectTypeChecker<Array<T> > { struct ObjectTypeChecker<Array<T> > {
...@@ -73,10 +57,8 @@ struct ObjectTypeChecker<Array<T> > { ...@@ -73,10 +57,8 @@ struct ObjectTypeChecker<Array<T> > {
} }
return true; return true;
} }
static void PrintName(std::ostream& os) { // NOLINT(*) static std::string TypeName() {
os << "List["; return "List[" + ObjectTypeChecker<T>::TypeName() + "]";
ObjectTypeChecker<T>::PrintName(os);
os << "]";
} }
}; };
...@@ -91,11 +73,9 @@ struct ObjectTypeChecker<Map<std::string, V> > { ...@@ -91,11 +73,9 @@ struct ObjectTypeChecker<Map<std::string, V> > {
} }
return true; return true;
} }
static void PrintName(std::ostream& os) { // NOLINT(*) static std::string TypeName() {
os << "Map[str"; return "Map[str, " +
os << ','; ObjectTypeChecker<V>::TypeName()+ ']';
ObjectTypeChecker<V>::PrintName(os);
os << ']';
} }
}; };
...@@ -111,39 +91,16 @@ struct ObjectTypeChecker<Map<K, V> > { ...@@ -111,39 +91,16 @@ struct ObjectTypeChecker<Map<K, V> > {
} }
return true; return true;
} }
static void PrintName(std::ostringstream& os) { // NOLINT(*) static std::string TypeName() {
os << "Map["; return "Map[" +
ObjectTypeChecker<K>::PrintName(os); ObjectTypeChecker<K>::TypeName() +
os << ','; ", " +
ObjectTypeChecker<V>::PrintName(os); ObjectTypeChecker<V>::TypeName()+ ']';
os << ']';
} }
}; };
template<typename T>
inline std::string ObjectTypeName() {
std::ostringstream os;
ObjectTypeChecker<T>::PrintName(os);
return os.str();
}
// extensions for tvm arg value // extensions for tvm arg value
inline TVMPODValue_::operator tvm::Expr() const {
template<typename TObjectRef>
inline TObjectRef TVMArgValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
if (type_code_ == kNull) return TObjectRef(NodePtr<Node>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TObjectRef>()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(ObjectPtr<Node>(ptr));
}
inline TVMArgValue::operator tvm::Expr() const {
if (type_code_ == kNull) return Expr(); if (type_code_ == kNull) return Expr();
if (type_code_ == kDLInt) { if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max()); CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
...@@ -164,12 +121,12 @@ inline TVMArgValue::operator tvm::Expr() const { ...@@ -164,12 +121,12 @@ inline TVMArgValue::operator tvm::Expr() const {
return Tensor(ObjectPtr<Node>(ptr))(); return Tensor(ObjectPtr<Node>(ptr))();
} }
CHECK(ObjectTypeChecker<Expr>::Check(ptr)) CHECK(ObjectTypeChecker<Expr>::Check(ptr))
<< "Expected type " << ObjectTypeName<Expr>() << "Expect type " << ObjectTypeChecker<Expr>::TypeName()
<< " but get " << ptr->GetTypeKey(); << " but get " << ptr->GetTypeKey();
return Expr(ObjectPtr<Node>(ptr)); return Expr(ObjectPtr<Node>(ptr));
} }
inline TVMArgValue::operator tvm::Integer() const { inline TVMPODValue_::operator tvm::Integer() const {
if (type_code_ == kNull) return Integer(); if (type_code_ == kNull) return Integer();
if (type_code_ == kDLInt) { if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max()); CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
...@@ -179,35 +136,10 @@ inline TVMArgValue::operator tvm::Integer() const { ...@@ -179,35 +136,10 @@ inline TVMArgValue::operator tvm::Integer() const {
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle); Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<Integer>::Check(ptr)) CHECK(ObjectTypeChecker<Integer>::Check(ptr))
<< "Expected type " << ObjectTypeName<Expr>() << "Expect type " << ObjectTypeChecker<Expr>::TypeName()
<< " but get " << ptr->GetTypeKey(); << " but get " << ptr->GetTypeKey();
return Integer(ObjectPtr<Node>(ptr)); return Integer(ObjectPtr<Node>(ptr));
} }
template<typename TObjectRef, typename>
inline bool TVMPODValue_::IsObjectRef() const {
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
return ObjectTypeChecker<TObjectRef>::Check(ptr);
}
// extensions for TVMRetValue
template<typename TObjectRef>
inline TObjectRef TVMRetValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
if (type_code_ == kNull) return TObjectRef();
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TObjectRef>()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(ObjectPtr<Object>(ptr));
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_PACKED_FUNC_EXT_H_ #endif // TVM_PACKED_FUNC_EXT_H_
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
*/ */
#ifndef TVM_RUNTIME_CONTAINER_H_ #ifndef TVM_RUNTIME_CONTAINER_H_
#define TVM_RUNTIME_CONTAINER_H_ #define TVM_RUNTIME_CONTAINER_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <tvm/runtime/memory.h> #include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
......
...@@ -24,11 +24,13 @@ ...@@ -24,11 +24,13 @@
#ifndef TVM_RUNTIME_NDARRAY_H_ #ifndef TVM_RUNTIME_NDARRAY_H_
#define TVM_RUNTIME_NDARRAY_H_ #define TVM_RUNTIME_NDARRAY_H_
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/serializer.h>
#include <atomic> #include <atomic>
#include <vector> #include <vector>
#include <utility> #include <utility>
#include "c_runtime_api.h"
#include "serializer.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -37,72 +39,23 @@ namespace runtime { ...@@ -37,72 +39,23 @@ namespace runtime {
* \brief Managed NDArray. * \brief Managed NDArray.
* The array is backed by reference counted blocks. * The array is backed by reference counted blocks.
*/ */
class NDArray { class NDArray : public ObjectRef {
public: public:
// internal container type /*! \brief ContainerBase used to back the TVMArrayHandle */
class ContainerBase;
/*! \brief NDArray internal container type */
class Container; class Container;
/*! \brief Container type for Object system. */
using ContainerType = Container;
/*! \brief default constructor */ /*! \brief default constructor */
NDArray() {} NDArray() {}
/*! /*!
* \brief cosntruct a NDArray that refers to data * \brief constructor.
* \param data The data this NDArray refers to * \param data ObjectPtr to the data container.
*/
explicit inline NDArray(Container* data);
/*!
* \brief copy constructor.
*
* It does not make a copy, but the reference count of the input NDArray is incremented
*
* \param other NDArray that shares internal data with the input NDArray.
*/
inline NDArray(const NDArray& other); // NOLINT(*)
/*!
* \brief move constructor
* \param other The value to be moved
*/
NDArray(NDArray&& other) // NOLINT(*)
: data_(other.data_) {
other.data_ = nullptr;
}
/*! \brief destructor */
~NDArray() {
this->reset();
}
/*!
* \brief Swap this array with another NDArray
* \param other The other NDArray
*/ */
void swap(NDArray& other) { // NOLINT(*) explicit NDArray(ObjectPtr<Object> data)
std::swap(data_, other.data_); : ObjectRef(data) {}
}
/*!
* \brief copy assignmemt
* \param other The value to be assigned.
* \return reference to self.
*/
NDArray& operator=(const NDArray& other) { // NOLINT(*)
// copy-and-swap idiom
NDArray(other).swap(*this); // NOLINT(*)
return *this;
}
/*!
* \brief move assignmemt
* \param other The value to be assigned.
* \return reference to self.
*/
NDArray& operator=(NDArray&& other) { // NOLINT(*)
// copy-and-swap idiom
NDArray(std::move(other)).swap(*this); // NOLINT(*)
return *this;
}
/*! \return If NDArray is defined */
bool defined() const {
return data_ != nullptr;
}
/*! \return If both NDArray reference the same container */
bool same_as(const NDArray& other) const {
return data_ == other.data_;
}
/*! \brief reset the content of NDArray to be nullptr */ /*! \brief reset the content of NDArray to be nullptr */
inline void reset(); inline void reset();
/*! /*!
...@@ -191,36 +144,40 @@ class NDArray { ...@@ -191,36 +144,40 @@ class NDArray {
const DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr); const DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr);
TVM_DLL std::vector<int64_t> Shape() const; TVM_DLL std::vector<int64_t> Shape() const;
// internal namespace // internal namespace
struct Internal; struct Internal;
protected: protected:
/*! \brief Internal Data content */
Container* data_{nullptr};
// enable internal functions
friend struct Internal;
friend class TVMPODValue_; friend class TVMPODValue_;
friend class TVMArgValue;
friend class TVMRetValue; friend class TVMRetValue;
friend class TVMArgsSetter; friend class TVMArgsSetter;
}; /*!
* \brief Get mutable internal container pointer.
/*! * \return a mutable container pointer.
* \brief The type trait indicates subclass of TVM's NDArray. */
* For irrelavant classes, code = -1. inline Container* get_mutable() const;
* For TVM NDArray itself, code = 0. // Helper functions for FFI handling.
* All subclasses of NDArray should override code > 0. /*!
*/ * \brief Construct NDArray's Data field from array handle in FFI.
template<typename T> * \param handle The array handle.
struct array_type_info { * \return The corresponding ObjectPtr to the constructed container object.
/*! \brief the value of the traits */ *
static const int code = -1; * \note We keep a special calling convention for NDArray by passing
}; * ContainerBase pointer in FFI.
* As a result, the argument is compatible to DLTensor*.
// Overrides the type trait for tvm's NDArray. */
template<> inline static ObjectPtr<Object> FFIDataFromHandle(TVMArrayHandle handle);
struct array_type_info<NDArray> { /*!
static const int code = 0; * \brief DecRef resource managed by an FFI array handle.
* \param handle The array handle.
*/
inline static void FFIDecRef(TVMArrayHandle handle);
/*!
* \brief Get FFI Array handle from ndarray.
* \param nd The object with ndarray type.
* \return The result array handle.
*/
inline static TVMArrayHandle FFIGetHandle(const ObjectRef& nd);
}; };
/*! /*!
...@@ -231,19 +188,14 @@ struct array_type_info<NDArray> { ...@@ -231,19 +188,14 @@ struct array_type_info<NDArray> {
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor); inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
/*! /*!
* \brief Reference counted Container object used to back NDArray. * \brief The container base structure
* contains all the fields except for the Object header.
* *
* This object is DLTensor compatible: * \note We explicitly declare this structure in order to pass
* the pointer to the NDArrayContainer can be directly * PackedFunc argument using ContainerBase*.
* interpreted as a DLTensor*
*
* \note do not use this function directly, use NDArray.
*/ */
class NDArray::Container { class NDArray::ContainerBase {
public: public:
// NOTE: the first part of this structure is the same as
// DLManagedTensor, note that, however, the deleter
// is only called when the reference counter goes to 0
/*! /*!
* \brief The corresponding dl_tensor field. * \brief The corresponding dl_tensor field.
* \note it is important that the first field is DLTensor * \note it is important that the first field is DLTensor
...@@ -259,42 +211,27 @@ class NDArray::Container { ...@@ -259,42 +211,27 @@ class NDArray::Container {
* (e.g. reference to original memory when creating views). * (e.g. reference to original memory when creating views).
*/ */
void* manager_ctx{nullptr}; void* manager_ctx{nullptr};
/*!
* \brief Customized deleter
*
* \note The customized deleter is helpful to enable
* different ways of memory allocator that are not
* currently defined by the system.
*/
void (*deleter)(Container* self) = nullptr;
protected: protected:
friend class NDArray;
friend class TVMPODValue_;
friend class TVMArgValue;
friend class TVMRetValue;
friend class RPCWrappedFunc;
/*!
* \brief Type flag used to indicate subclass.
* Default value 0 means normal NDArray::Conatainer.
*
* We can extend a more specialized NDArray::Container
* and use the array_type_code_ to indicate
* the specific array subclass.
*/
int32_t array_type_code_{0};
/*! \brief The internal reference counter */
std::atomic<int> ref_counter_{0};
/*! /*!
* \brief The shape container, * \brief The shape container,
* can be used used for shape data. * can be used used for shape data.
*/ */
std::vector<int64_t> shape_; std::vector<int64_t> shape_;
};
/*!
* \brief Object container class that backs NDArray.
* \note do not use this function directly, use NDArray.
*/
class NDArray::Container :
public Object,
public NDArray::ContainerBase {
public: public:
/*! \brief default constructor */ /*! \brief default constructor */
Container() { Container() {
// Initialize the type index.
type_index_ = Container::RuntimeTypeIndex();
dl_tensor.data = nullptr; dl_tensor.data = nullptr;
dl_tensor.ndim = 0; dl_tensor.ndim = 0;
dl_tensor.shape = nullptr; dl_tensor.shape = nullptr;
...@@ -306,6 +243,8 @@ class NDArray::Container { ...@@ -306,6 +243,8 @@ class NDArray::Container {
std::vector<int64_t> shape, std::vector<int64_t> shape,
DLDataType dtype, DLDataType dtype,
DLContext ctx) { DLContext ctx) {
// Initialize the type index.
type_index_ = Container::RuntimeTypeIndex();
dl_tensor.data = data; dl_tensor.data = data;
shape_ = std::move(shape); shape_ = std::move(shape);
dl_tensor.ndim = static_cast<int>(shape_.size()); dl_tensor.ndim = static_cast<int>(shape_.size());
...@@ -315,49 +254,36 @@ class NDArray::Container { ...@@ -315,49 +254,36 @@ class NDArray::Container {
dl_tensor.byte_offset = 0; dl_tensor.byte_offset = 0;
dl_tensor.ctx = ctx; dl_tensor.ctx = ctx;
} }
/*!
/*! \brief developer function, increases reference counter */ * \brief Set the deleter field.
void IncRef() { * \param deleter The deleter.
ref_counter_.fetch_add(1, std::memory_order_relaxed); */
} void SetDeleter(FDeleter deleter) {
/*! \brief developer function, decrease reference counter */ deleter_ = deleter;
void DecRef() {
if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
std::atomic_thread_fence(std::memory_order_acquire);
if (this->deleter != nullptr) {
(*this->deleter)(this);
}
}
} }
};
// implementations of inline functions // Expose DecRef and IncRef as public function
// the usages of functions are documented in place. // NOTE: they are only for developer purposes only.
inline NDArray::NDArray(Container* data) using Object::DecRef;
: data_(data) { using Object::IncRef;
if (data != nullptr) {
data_->IncRef();
}
}
inline NDArray::NDArray(const NDArray& other) // Information for object protocol.
: data_(other.data_) { static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
if (data_ != nullptr) { static constexpr const uint32_t _type_child_slots = 0;
data_->IncRef(); static constexpr const uint32_t _type_child_slots_can_overflow = true;
} static constexpr const char* _type_key = "NDArray";
} TVM_DECLARE_BASE_OBJECT_INFO(NDArray::Container, Object);
inline void NDArray::reset() { protected:
if (data_ != nullptr) { friend class RPCWrappedFunc;
data_->DecRef(); friend class NDArray;
data_ = nullptr; };
}
}
/*! \brief return the size of data the DLTensor hold, in term of number of bytes // implementations of inline functions
/*!
* \brief return the size of data the DLTensor hold, in term of number of bytes
* *
* \param arr the input DLTensor * \param arr the input DLTensor
*
* \return number of bytes of data in the DLTensor. * \return number of bytes of data in the DLTensor.
*/ */
inline size_t GetDataSize(const DLTensor& arr) { inline size_t GetDataSize(const DLTensor& arr) {
...@@ -371,24 +297,24 @@ inline size_t GetDataSize(const DLTensor& arr) { ...@@ -371,24 +297,24 @@ inline size_t GetDataSize(const DLTensor& arr) {
inline void NDArray::CopyFrom(const DLTensor* other) { inline void NDArray::CopyFrom(const DLTensor* other) {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
CopyFromTo(other, &(data_->dl_tensor)); CopyFromTo(other, &(get_mutable()->dl_tensor));
} }
inline void NDArray::CopyFrom(const NDArray& other) { inline void NDArray::CopyFrom(const NDArray& other) {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
CHECK(other.data_ != nullptr); CHECK(other.data_ != nullptr);
CopyFromTo(&(other.data_->dl_tensor), &(data_->dl_tensor)); CopyFromTo(&(other.get_mutable()->dl_tensor), &(get_mutable()->dl_tensor));
} }
inline void NDArray::CopyTo(DLTensor* other) const { inline void NDArray::CopyTo(DLTensor* other) const {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), other); CopyFromTo(&(get_mutable()->dl_tensor), other);
} }
inline void NDArray::CopyTo(const NDArray& other) const { inline void NDArray::CopyTo(const NDArray& other) const {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
CHECK(other.data_ != nullptr); CHECK(other.data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor)); CopyFromTo(&(get_mutable()->dl_tensor), &(other.get_mutable()->dl_tensor));
} }
inline NDArray NDArray::CopyTo(const DLContext& ctx) const { inline NDArray NDArray::CopyTo(const DLContext& ctx) const {
...@@ -401,12 +327,39 @@ inline NDArray NDArray::CopyTo(const DLContext& ctx) const { ...@@ -401,12 +327,39 @@ inline NDArray NDArray::CopyTo(const DLContext& ctx) const {
} }
inline int NDArray::use_count() const { inline int NDArray::use_count() const {
if (data_ == nullptr) return 0; return data_.use_count();
return data_->ref_counter_.load(std::memory_order_relaxed);
} }
inline const DLTensor* NDArray::operator->() const { inline const DLTensor* NDArray::operator->() const {
return &(data_->dl_tensor); return &(get_mutable()->dl_tensor);
}
inline NDArray::Container* NDArray::get_mutable() const {
return static_cast<NDArray::Container*>(data_.get());
}
inline ObjectPtr<Object> NDArray::FFIDataFromHandle(TVMArrayHandle handle) {
return GetObjectPtr<Object>(static_cast<NDArray::Container*>(
reinterpret_cast<NDArray::ContainerBase*>(handle)));
}
inline TVMArrayHandle NDArray::FFIGetHandle(const ObjectRef& nd) {
// NOTE: it is necessary to cast to container then to base
// so that the FFI handle uses the ContainerBase address.
return reinterpret_cast<TVMArrayHandle>(
static_cast<NDArray::ContainerBase*>(
static_cast<NDArray::Container*>(
const_cast<Object*>(nd.get()))));
}
inline void NDArray::FFIDecRef(TVMArrayHandle handle) {
static_cast<NDArray::Container*>(
reinterpret_cast<NDArray::ContainerBase*>(handle))->DecRef();
}
inline Object* TVMArrayHandleToObjectHandle(TVMArrayHandle handle) {
return static_cast<NDArray::Container*>(
reinterpret_cast<NDArray::ContainerBase*>(handle));
} }
/*! \brief Magic number for NDArray file */ /*! \brief Magic number for NDArray file */
......
...@@ -24,10 +24,11 @@ ...@@ -24,10 +24,11 @@
#define TVM_RUNTIME_OBJECT_H_ #define TVM_RUNTIME_OBJECT_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
#include <type_traits> #include <type_traits>
#include <string> #include <string>
#include <utility> #include <utility>
#include "c_runtime_api.h"
/*! /*!
* \brief Whether or not use atomic reference counter. * \brief Whether or not use atomic reference counter.
...@@ -581,6 +582,14 @@ class ObjectRef { ...@@ -581,6 +582,14 @@ class ObjectRef {
return T(std::move(ref.data_)); return T(std::move(ref.data_));
} }
/*! /*!
* \brief Clear the object ref data field without DecRef
* after we successfully moved the field.
* \param ref The reference data.
*/
static void FFIClearAfterMove(ObjectRef* ref) {
ref->data_.data_ = nullptr;
}
/*!
* \brief Internal helper function get data_ as ObjectPtr of ObjectType. * \brief Internal helper function get data_ as ObjectPtr of ObjectType.
* \note only used for internal dev purpose. * \note only used for internal dev purpose.
* \tparam ObjectType The corresponding object type. * \tparam ObjectType The corresponding object type.
...@@ -648,7 +657,7 @@ struct ObjectEqual { ...@@ -648,7 +657,7 @@ struct ObjectEqual {
return _GetOrAllocRuntimeTypeIndex(); \ return _GetOrAllocRuntimeTypeIndex(); \
} \ } \
static const uint32_t _GetOrAllocRuntimeTypeIndex() { \ static const uint32_t _GetOrAllocRuntimeTypeIndex() { \
static uint32_t tidx = GetOrAllocRuntimeTypeIndex( \ static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex( \
TypeName::_type_key, \ TypeName::_type_key, \
TypeName::_type_index, \ TypeName::_type_index, \
ParentType::_GetOrAllocRuntimeTypeIndex(), \ ParentType::_GetOrAllocRuntimeTypeIndex(), \
...@@ -668,6 +677,19 @@ struct ObjectEqual { ...@@ -668,6 +677,19 @@ struct ObjectEqual {
TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define TVM_ATTRIBUTE_UNUSED
#endif
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_OBJECT_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid
/*! /*!
* \brief Helper macro to register the object type to runtime. * \brief Helper macro to register the object type to runtime.
* Makes sure that the runtime type table is correctly populated. * Makes sure that the runtime type table is correctly populated.
...@@ -675,7 +697,7 @@ struct ObjectEqual { ...@@ -675,7 +697,7 @@ struct ObjectEqual {
* Use this macro in the cc file for each terminal class. * Use this macro in the cc file for each terminal class.
*/ */
#define TVM_REGISTER_OBJECT_TYPE(TypeName) \ #define TVM_REGISTER_OBJECT_TYPE(TypeName) \
static DMLC_ATTRIBUTE_UNUSED uint32_t __make_Object_tidx ## _ ## TypeName ## __ = \ TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \
TypeName::_GetOrAllocRuntimeTypeIndex() TypeName::_GetOrAllocRuntimeTypeIndex()
...@@ -691,14 +713,14 @@ struct ObjectEqual { ...@@ -691,14 +713,14 @@ struct ObjectEqual {
using ContainerType = ObjectName; using ContainerType = ObjectName;
#define TVM_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \ #define TVM_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \
TypeName() {} \ TypeName() {} \
explicit TypeName( \ explicit TypeName( \
::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
: ParentType(n) {} \ : ParentType(n) {} \
ObjectName* operator->() { \ ObjectName* operator->() { \
return static_cast<ObjectName*>(data_.get()); \ return static_cast<ObjectName*>(data_.get()); \
} \ } \
operator bool() const { return data_ != nullptr; } \ operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName; using ContainerType = ObjectName;
// Implementations details below // Implementations details below
......
...@@ -387,20 +387,22 @@ inline std::string TVMType2String(TVMType t); ...@@ -387,20 +387,22 @@ inline std::string TVMType2String(TVMType t);
#define TVM_CHECK_TYPE_CODE(CODE, T) \ #define TVM_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " \ CHECK_EQ(CODE, T) << " expected " \
<< TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \ << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \
/*! /*!
* \brief Type traits to mark if a class is tvm extension type. * \brief Type traits for runtime type check during FFI conversion.
* * \tparam T the type to be checked.
* To enable extension type in C++ must be registered via marco.
* TVM_REGISTER_EXT_TYPE(TypeName) after defining this with this traits.
*
* Extension class can be passed and returned via PackedFunc in all tvm runtime.
* Internally extension class is stored as T*.
*
* \tparam T the typename
*/ */
template<typename T> template<typename T>
struct extension_type_info { struct ObjectTypeChecker {
static const int code = 0; static bool Check(const Object* ptr) {
using ContainerType = typename T::ContainerType;
if (ptr == nullptr) return true;
return ptr->IsInstance<ContainerType>();
}
static std::string TypeName() {
using ContainerType = typename T::ContainerType;
return ContainerType::_type_key;
}
}; };
/*! /*!
...@@ -449,24 +451,17 @@ class TVMPODValue_ { ...@@ -449,24 +451,17 @@ class TVMPODValue_ {
return static_cast<DLTensor*>(value_.v_handle); return static_cast<DLTensor*>(value_.v_handle);
} else { } else {
if (type_code_ == kNull) return nullptr; if (type_code_ == kNull) return nullptr;
LOG(FATAL) << "Expected " LOG(FATAL) << "Expect "
<< "DLTensor* or NDArray but get " << "DLTensor* or NDArray but get "
<< TypeCode2Str(type_code_); << TypeCode2Str(type_code_);
return nullptr; return nullptr;
} }
} }
operator NDArray() const { operator NDArray() const {
if (type_code_ == kNull) return NDArray(); if (type_code_ == kNull) return NDArray(ObjectPtr<Object>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer); TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer);
return NDArray(static_cast<NDArray::Container*>(value_.v_handle)); return NDArray(NDArray::FFIDataFromHandle(
} static_cast<TVMArrayHandle>(value_.v_handle)));
operator ObjectRef() const {
if (type_code_ == kNull) {
return ObjectRef(ObjectPtr<Object>(nullptr));
}
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
return ObjectRef(
ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
} }
operator Module() const { operator Module() const {
if (type_code_ == kNull) { if (type_code_ == kNull) {
...@@ -480,23 +475,9 @@ class TVMPODValue_ { ...@@ -480,23 +475,9 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx; return value_.v_ctx;
} }
template<typename TNDArray,
typename = typename std::enable_if<
std::is_base_of<NDArray, TNDArray>::value>::type>
TNDArray AsNDArray() const {
if (type_code_ == kNull) return TNDArray(nullptr);
auto *container = static_cast<NDArray::Container*>(value_.v_handle);
CHECK_EQ(container->array_type_code_, array_type_info<TNDArray>::code);
return TNDArray(container);
}
template<typename TObjectRef,
typename = typename std::enable_if<
std::is_class<TObjectRef>::value>::type>
inline bool IsObjectRef() const;
int type_code() const { int type_code() const {
return type_code_; return type_code_;
} }
/*! /*!
* \brief return handle as specific pointer type. * \brief return handle as specific pointer type.
* \tparam T the data type. * \tparam T the data type.
...@@ -506,6 +487,16 @@ class TVMPODValue_ { ...@@ -506,6 +487,16 @@ class TVMPODValue_ {
T* ptr() const { T* ptr() const {
return static_cast<T*>(value_.v_handle); return static_cast<T*>(value_.v_handle);
} }
// ObjectRef handling
template<typename TObjectRef,
typename = typename std::enable_if<
std::is_base_of<ObjectRef, TObjectRef>::value>::type>
inline bool IsObjectRef() const;
template<typename TObjectRef>
inline TObjectRef AsObjectRef() const;
// ObjectRef Specializations
inline operator tvm::Expr() const;
inline operator tvm::Integer() const;
protected: protected:
friend class TVMArgsSetter; friend class TVMArgsSetter;
...@@ -548,9 +539,11 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -548,9 +539,11 @@ class TVMArgValue : public TVMPODValue_ {
using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator NDArray; using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator ObjectRef;
using TVMPODValue_::operator Module; using TVMPODValue_::operator Module;
using TVMPODValue_::IsObjectRef; using TVMPODValue_::IsObjectRef;
using TVMPODValue_::AsObjectRef;
using TVMPODValue_::operator tvm::Expr;
using TVMPODValue_::operator tvm::Integer;
// conversion operator. // conversion operator.
operator std::string() const { operator std::string() const {
...@@ -577,6 +570,9 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -577,6 +570,9 @@ class TVMArgValue : public TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMType); TVM_CHECK_TYPE_CODE(type_code_, kTVMType);
return value_.v_type; return value_.v_type;
} }
operator DataType() const {
return DataType(operator DLDataType());
}
operator PackedFunc() const { operator PackedFunc() const {
if (type_code_ == kNull) return PackedFunc(); if (type_code_ == kNull) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle); TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
...@@ -589,16 +585,10 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -589,16 +585,10 @@ class TVMArgValue : public TVMPODValue_ {
const TVMValue& value() const { const TVMValue& value() const {
return value_; return value_;
} }
// Deferred extension handler.
template<typename TObjectRef>
inline TObjectRef AsObjectRef() const;
template<typename T, template<typename T,
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_class<T>::value>::type> std::is_class<T>::value>::type>
inline operator T() const; inline operator T() const;
inline operator DataType() const;
inline operator tvm::Expr() const;
inline operator tvm::Integer() const;
}; };
/*! /*!
...@@ -636,9 +626,11 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -636,9 +626,11 @@ class TVMRetValue : public TVMPODValue_ {
using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator NDArray; using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator ObjectRef;
using TVMPODValue_::operator Module; using TVMPODValue_::operator Module;
using TVMPODValue_::IsObjectRef; using TVMPODValue_::IsObjectRef;
using TVMPODValue_::AsObjectRef;
using TVMPODValue_::operator tvm::Expr;
using TVMPODValue_::operator tvm::Integer;
TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { TVMRetValue(const TVMRetValue& other) : TVMPODValue_() {
this->Assign(other); this->Assign(other);
...@@ -660,6 +652,9 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -660,6 +652,9 @@ class TVMRetValue : public TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMType); TVM_CHECK_TYPE_CODE(type_code_, kTVMType);
return value_.v_type; return value_.v_type;
} }
operator DataType() const {
return DataType(operator DLDataType());
}
operator PackedFunc() const { operator PackedFunc() const {
if (type_code_ == kNull) return PackedFunc(); if (type_code_ == kNull) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle); TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
...@@ -712,6 +707,9 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -712,6 +707,9 @@ class TVMRetValue : public TVMPODValue_ {
value_.v_type = t; value_.v_type = t;
return *this; return *this;
} }
TVMRetValue& operator=(const DataType& other) {
return operator=(other.operator DLDataType());
}
TVMRetValue& operator=(bool value) { TVMRetValue& operator=(bool value) {
this->SwitchToPOD(kDLInt); this->SwitchToPOD(kDLInt);
value_.v_int64 = value; value_.v_int64 = value;
...@@ -726,24 +724,20 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -726,24 +724,20 @@ class TVMRetValue : public TVMPODValue_ {
return *this; return *this;
} }
TVMRetValue& operator=(NDArray other) { TVMRetValue& operator=(NDArray other) {
this->Clear(); if (other.data_ != nullptr) {
type_code_ = kNDArrayContainer; this->Clear();
value_.v_handle = other.data_; type_code_ = kNDArrayContainer;
other.data_ = nullptr; value_.v_handle = NDArray::FFIGetHandle(other);
ObjectRef::FFIClearAfterMove(&other);
} else {
SwitchToPOD(kNull);
}
return *this; return *this;
} }
TVMRetValue& operator=(ObjectRef other) {
return operator=(std::move(other.data_));
}
TVMRetValue& operator=(Module m) { TVMRetValue& operator=(Module m) {
SwitchToObject(kModuleHandle, std::move(m.data_)); SwitchToObject(kModuleHandle, std::move(m.data_));
return *this; return *this;
} }
template<typename T>
TVMRetValue& operator=(ObjectPtr<T> other) {
SwitchToObject(kObjectHandle, std::move(other));
return *this;
}
TVMRetValue& operator=(PackedFunc f) { TVMRetValue& operator=(PackedFunc f) {
this->SwitchToClass(kFuncHandle, f); this->SwitchToClass(kFuncHandle, f);
return *this; return *this;
...@@ -760,14 +754,6 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -760,14 +754,6 @@ class TVMRetValue : public TVMPODValue_ {
this->Assign(other); this->Assign(other);
return *this; return *this;
} }
template<typename T,
typename = typename std::enable_if<
extension_type_info<T>::code != 0>::type>
TVMRetValue& operator=(const T& other) {
this->SwitchToClass<T>(
extension_type_info<T>::code, other);
return *this;
}
/*! /*!
* \brief Move the value back to front-end via C API. * \brief Move the value back to front-end via C API.
* This marks the current container as null. * This marks the current container as null.
...@@ -793,16 +779,15 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -793,16 +779,15 @@ class TVMRetValue : public TVMPODValue_ {
type_code_ != kStr) << "TVMRetValue.value can only be used for POD data"; type_code_ != kStr) << "TVMRetValue.value can only be used for POD data";
return value_; return value_;
} }
// ObjectRef related extenstions: in tvm/packed_func_ext.h // ObjectRef handling
template<typename TObjectRef,
typename = typename std::enable_if<
std::is_base_of<ObjectRef, TObjectRef>::value>::type>
inline TVMRetValue& operator=(TObjectRef other);
template<typename T, template<typename T,
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_class<T>::value>::type> std::is_class<T>::value>::type>
inline operator T() const; inline operator T() const;
template<typename TObjectRef>
inline TObjectRef AsObjectRef() const;
// type related
inline operator DataType() const;
inline TVMRetValue& operator=(const DataType& other);
private: private:
template<typename T> template<typename T>
...@@ -829,7 +814,10 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -829,7 +814,10 @@ class TVMRetValue : public TVMPODValue_ {
break; break;
} }
case kObjectHandle: { case kObjectHandle: {
*this = other.operator ObjectRef(); // Avoid operator ObjectRef as we already know it is not NDArray/Module
SwitchToObject(
kObjectHandle, GetObjectPtr<Object>(
static_cast<Object*>(other.value_.v_handle)));
break; break;
} }
default: { default: {
...@@ -873,7 +861,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -873,7 +861,7 @@ class TVMRetValue : public TVMPODValue_ {
case kStr: delete ptr<std::string>(); break; case kStr: delete ptr<std::string>(); break;
case kFuncHandle: delete ptr<PackedFunc>(); break; case kFuncHandle: delete ptr<PackedFunc>(); break;
case kNDArrayContainer: { case kNDArrayContainer: {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef(); NDArray::FFIDecRef(static_cast<TVMArrayHandle>(value_.v_handle));
break; break;
} }
case kModuleHandle: { case kModuleHandle: {
...@@ -905,7 +893,7 @@ inline const char* TypeCode2Str(int type_code) { ...@@ -905,7 +893,7 @@ inline const char* TypeCode2Str(int type_code) {
case kFuncHandle: return "FunctionHandle"; case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle"; case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer"; case kNDArrayContainer: return "NDArrayContainer";
case kObjectHandle: return "ObjectCell"; case kObjectHandle: return "Object";
default: LOG(FATAL) << "unknown type_code=" default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return ""; << static_cast<int>(type_code); return "";
} }
...@@ -929,6 +917,10 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*) ...@@ -929,6 +917,10 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*)
return os; return os;
} }
inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
return os << dtype.operator DLDataType();
}
#endif #endif
inline std::string TVMType2String(TVMType t) { inline std::string TVMType2String(TVMType t) {
...@@ -996,10 +988,6 @@ inline TVMType String2TVMType(std::string s) { ...@@ -996,10 +988,6 @@ inline TVMType String2TVMType(std::string s) {
return t; return t;
} }
inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
return os << dtype.operator DLDataType();
}
inline TVMArgValue TVMArgs::operator[](int i) const { inline TVMArgValue TVMArgs::operator[](int i) const {
CHECK_LT(i, num_args) CHECK_LT(i, num_args)
<< "not enough argument passed, " << "not enough argument passed, "
...@@ -1092,50 +1080,31 @@ class TVMArgsSetter { ...@@ -1092,50 +1080,31 @@ class TVMArgsSetter {
values_[i].v_type = value; values_[i].v_type = value;
type_codes_[i] = kTVMType; type_codes_[i] = kTVMType;
} }
void operator()(size_t i, DataType dtype) const {
operator()(i, dtype.operator DLDataType());
}
void operator()(size_t i, const char* value) const { void operator()(size_t i, const char* value) const {
values_[i].v_str = value; values_[i].v_str = value;
type_codes_[i] = kStr; type_codes_[i] = kStr;
} }
// setters for container type // setters for container types
// They must be reference(instead of const ref) void operator()(size_t i, const std::string& value) const {
// to make sure they are alive in the tuple(instead of getting converted)
void operator()(size_t i, const std::string& value) const { // NOLINT(*)
values_[i].v_str = value.c_str(); values_[i].v_str = value.c_str();
type_codes_[i] = kStr; type_codes_[i] = kStr;
} }
void operator()(size_t i, const TVMByteArray& value) const { // NOLINT(*) void operator()(size_t i, const TVMByteArray& value) const {
values_[i].v_handle = const_cast<TVMByteArray*>(&value); values_[i].v_handle = const_cast<TVMByteArray*>(&value);
type_codes_[i] = kBytes; type_codes_[i] = kBytes;
} }
void operator()(size_t i, const PackedFunc& value) const { // NOLINT(*) void operator()(size_t i, const PackedFunc& value) const {
values_[i].v_handle = const_cast<PackedFunc*>(&value); values_[i].v_handle = const_cast<PackedFunc*>(&value);
type_codes_[i] = kFuncHandle; type_codes_[i] = kFuncHandle;
} }
template<typename FType> template<typename FType>
void operator()(size_t i, const TypedPackedFunc<FType>& value) const { // NOLINT(*) void operator()(size_t i, const TypedPackedFunc<FType>& value) const {
operator()(i, value.packed()); operator()(i, value.packed());
} }
void operator()(size_t i, const Module& value) const { // NOLINT(*) void operator()(size_t i, const TVMRetValue& value) const {
if (value.defined()) {
values_[i].v_handle = value.data_.data_;
type_codes_[i] = kModuleHandle;
} else {
type_codes_[i] = kNull;
}
}
void operator()(size_t i, const NDArray& value) const { // NOLINT(*)
values_[i].v_handle = value.data_;
type_codes_[i] = kNDArrayContainer;
}
void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*)
if (value.defined()) {
values_[i].v_handle = value.data_.data_;
type_codes_[i] = kObjectHandle;
} else {
type_codes_[i] = kNull;
}
}
void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*)
if (value.type_code() == kStr) { if (value.type_code() == kStr) {
values_[i].v_str = value.ptr<std::string>()->c_str(); values_[i].v_str = value.ptr<std::string>()->c_str();
type_codes_[i] = kStr; type_codes_[i] = kStr;
...@@ -1145,12 +1114,11 @@ class TVMArgsSetter { ...@@ -1145,12 +1114,11 @@ class TVMArgsSetter {
type_codes_[i] = value.type_code(); type_codes_[i] = value.type_code();
} }
} }
// extension // ObjectRef handling
template<typename T, template<typename TObjectRef,
typename = typename std::enable_if< typename = typename std::enable_if<
extension_type_info<T>::code != 0>::type> std::is_base_of<ObjectRef, TObjectRef>::value>::type>
inline void operator()(size_t i, const T& value) const; inline void operator()(size_t i, const TObjectRef& value) const;
inline void operator()(size_t i, const DataType& t) const;
private: private:
/*! \brief The values fields */ /*! \brief The values fields */
...@@ -1262,57 +1230,131 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const { ...@@ -1262,57 +1230,131 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
::run(packed_, std::forward<Args>(args)...); ::run(packed_, std::forward<Args>(args)...);
} }
// extension and node type handling // ObjectRef related conversion handling
namespace detail { // Object can have three possible type codes:
template<typename T, typename TSrc, bool is_nd> // kNDArrayContainer, kModuleHandle, kObjectHandle
struct TVMValueCast { //
static T Apply(const TSrc* self) { // We use type traits to eliminate un-necessary checks.
static_assert(!is_nd, "The default case accepts only non-extensions"); template<typename TObjectRef, typename>
return self->template AsObjectRef<T>(); inline void TVMArgsSetter::operator()(size_t i, const TObjectRef& value) const {
} if (value.defined()) {
}; Object* ptr = value.data_.data_;
if (std::is_base_of<NDArray, TObjectRef>::value ||
template<typename T, typename TSrc> (std::is_base_of<TObjectRef, NDArray>::value &&
struct TVMValueCast<T, TSrc, true> { ptr->IsInstance<NDArray::ContainerType>())) {
static T Apply(const TSrc* self) { values_[i].v_handle = NDArray::FFIGetHandle(value);
return self->template AsNDArray<T>(); type_codes_[i] = kNDArrayContainer;
} else if (std::is_base_of<Module, TObjectRef>::value ||
(std::is_base_of<TObjectRef, Module>::value &&
ptr->IsInstance<Module::ContainerType>())) {
values_[i].v_handle = ptr;
type_codes_[i] = kModuleHandle;
} else {
values_[i].v_handle = ptr;
type_codes_[i] = kObjectHandle;
}
} else {
type_codes_[i] = kNull;
} }
};
} // namespace detail
template<typename T, typename>
inline TVMArgValue::operator T() const {
return detail::
TVMValueCast<T, TVMArgValue,
(array_type_info<T>::code > 0)>
::Apply(this);
} }
template<typename T, typename> template<typename TObjectRef, typename>
inline TVMRetValue::operator T() const { inline bool TVMPODValue_::IsObjectRef() const {
return detail:: using ContainerType = typename TObjectRef::ContainerType;
TVMValueCast<T, TVMRetValue, // NOTE: the following code can be optimized by constant folding.
(array_type_info<T>::code > 0)> if (std::is_base_of<NDArray, TObjectRef>::value) {
::Apply(this); return type_code_ == kNDArrayContainer &&
TVMArrayHandleToObjectHandle(
static_cast<TVMArrayHandle>(value_.v_handle))->IsInstance<ContainerType>();
}
if (std::is_base_of<Module, TObjectRef>::value) {
return type_code_ == kModuleHandle &&
static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
}
return
(std::is_base_of<TObjectRef, NDArray>::value && type_code_ == kNDArrayContainer) ||
(std::is_base_of<TObjectRef, Module>::value && type_code_ == kModuleHandle) ||
(type_code_ == kObjectHandle &&
ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle)));
} }
// PackedFunc support template<typename TObjectRef>
inline TVMRetValue& TVMRetValue::operator=(const DataType& t) { inline TObjectRef TVMPODValue_::AsObjectRef() const {
return this->operator=(t.operator DLDataType()); static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
using ContainerType = typename TObjectRef::ContainerType;
if (type_code_ == kNull) return TObjectRef(ObjectPtr<Object>(nullptr));
// NOTE: the following code can be optimized by constant folding.
if (std::is_base_of<NDArray, TObjectRef>::value) {
// Casting to a sub-class of NDArray
TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer);
ObjectPtr<Object> data = NDArray::FFIDataFromHandle(
static_cast<TVMArrayHandle>(value_.v_handle));
CHECK(data->IsInstance<ContainerType>())
<< "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey();
return TObjectRef(data);
}
if (std::is_base_of<Module, TObjectRef>::value) {
// Casting to a sub-class of Module
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle));
CHECK(data->IsInstance<ContainerType>())
<< "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey();
return TObjectRef(data);
}
if (type_code_ == kObjectHandle) {
// normal object type check.
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expect " << ObjectTypeChecker<TObjectRef>::TypeName()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(GetObjectPtr<Object>(ptr));
} else if (std::is_base_of<TObjectRef, NDArray>::value &&
type_code_ == kNDArrayContainer) {
// Casting to a base class that NDArray can sub-class
ObjectPtr<Object> data = NDArray::FFIDataFromHandle(
static_cast<TVMArrayHandle>(value_.v_handle));
return TObjectRef(data);
} else if (std::is_base_of<TObjectRef, Module>::value &&
type_code_ == kModuleHandle) {
// Casting to a base class that Module can sub-class
return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
} else {
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
return TObjectRef(ObjectPtr<Object>(nullptr));
}
} }
inline TVMRetValue::operator DataType() const { template<typename TObjectRef, typename>
return DataType(operator DLDataType()); inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) {
const Object* ptr = other.get();
if (ptr != nullptr) {
if (std::is_base_of<NDArray, TObjectRef>::value ||
(std::is_base_of<TObjectRef, NDArray>::value &&
ptr->IsInstance<NDArray::ContainerType>())) {
return operator=(NDArray(std::move(other.data_)));
}
if (std::is_base_of<Module, TObjectRef>::value ||
(std::is_base_of<TObjectRef, Module>::value &&
ptr->IsInstance<Module::ContainerType>())) {
return operator=(Module(std::move(other.data_)));
}
SwitchToObject(kObjectHandle, std::move(other.data_));
} else {
SwitchToPOD(kNull);
}
return *this;
} }
inline TVMArgValue::operator DataType() const { template<typename T, typename>
return DataType(operator DLDataType()); inline TVMArgValue::operator T() const {
return AsObjectRef<T>();
} }
inline void TVMArgsSetter::operator()( template<typename T, typename>
size_t i, const DataType& t) const { inline TVMRetValue::operator T() const {
this->operator()(i, t.operator DLDataType()); return AsObjectRef<T>();
} }
inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) { inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
......
...@@ -43,9 +43,9 @@ ...@@ -43,9 +43,9 @@
#ifndef TVM_RUNTIME_REGISTRY_H_ #ifndef TVM_RUNTIME_REGISTRY_H_
#define TVM_RUNTIME_REGISTRY_H_ #define TVM_RUNTIME_REGISTRY_H_
#include <tvm/runtime/packed_func.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include "packed_func.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -283,22 +283,9 @@ class Registry { ...@@ -283,22 +283,9 @@ class Registry {
friend struct Manager; friend struct Manager;
}; };
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define TVM_ATTRIBUTE_UNUSED
#endif
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_FUNC_REG_VAR_DEF \ #define TVM_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM
#define TVM_TYPE_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::ExtTypeVTable* __mk_ ## TVMT
/*! /*!
* \brief Register a function globally. * \brief Register a function globally.
* \code * \code
......
...@@ -96,6 +96,7 @@ def config_cython(): ...@@ -96,6 +96,7 @@ def config_cython():
"../3rdparty/dmlc-core/include", "../3rdparty/dmlc-core/include",
"../3rdparty/dlpack/include", "../3rdparty/dlpack/include",
], ],
extra_compile_args=["-std=c++11"],
library_dirs=library_dirs, library_dirs=library_dirs,
libraries=libraries, libraries=libraries,
language="c++")) language="c++"))
......
...@@ -20,7 +20,7 @@ from __future__ import absolute_import ...@@ -20,7 +20,7 @@ from __future__ import absolute_import
import ctypes import ctypes
from ..base import _LIB, check_call, c_str from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import TVMArrayHandle, TVMNDArrayContainerHandle from ..runtime_ctypes import TVMArrayHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
...@@ -110,12 +110,17 @@ class NDArrayBase(object): ...@@ -110,12 +110,17 @@ class NDArrayBase(object):
def _make_array(handle, is_view, is_container): def _make_array(handle, is_view, is_container):
global _TVM_ND_CLS global _TVM_ND_CLS
handle = ctypes.cast(handle, TVMArrayHandle) handle = ctypes.cast(handle, TVMArrayHandle)
fcreate = _CLASS_NDARRAY if is_container:
if is_container and _TVM_ND_CLS: tindex = ctypes.c_uint()
array_type_info = ctypes.cast(handle, TVMNDArrayContainerHandle).array_type_info.value check_call(_LIB.TVMArrayGetTypeIndex(handle, ctypes.byref(tindex)))
if array_type_info > 0: cls = _TVM_ND_CLS.get(tindex.value, _CLASS_NDARRAY)
fcreate = _TVM_ND_CLS[array_type_info] else:
return fcreate(handle, is_view) cls = _CLASS_NDARRAY
ret = cls.__new__(cls)
ret.handle = handle
ret.is_view = is_view
return ret
_TVM_COMPATS = () _TVM_COMPATS = ()
...@@ -129,9 +134,9 @@ def _reg_extension(cls, fcreate): ...@@ -129,9 +134,9 @@ def _reg_extension(cls, fcreate):
_TVM_ND_CLS = {} _TVM_ND_CLS = {}
def _reg_ndarray(cls, fcreate): def _register_ndarray(index, cls):
global _TVM_ND_CLS global _TVM_ND_CLS
_TVM_ND_CLS[cls._array_type_code] = fcreate _TVM_ND_CLS[index] = cls
_CLASS_NDARRAY = None _CLASS_NDARRAY = None
......
...@@ -21,7 +21,7 @@ from __future__ import absolute_import ...@@ -21,7 +21,7 @@ from __future__ import absolute_import
import ctypes import ctypes
from ..base import _LIB, check_call from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from ..node_generic import _set_class_node_base from .ndarray import _register_ndarray, NDArrayBase
ObjectHandle = ctypes.c_void_p ObjectHandle = ctypes.c_void_p
...@@ -39,6 +39,9 @@ def _set_class_node(node_class): ...@@ -39,6 +39,9 @@ def _set_class_node(node_class):
def _register_object(index, cls): def _register_object(index, cls):
"""register object class""" """register object class"""
if issubclass(cls, NDArrayBase):
_register_ndarray(index, cls)
return
OBJECT_TYPE[index] = cls OBJECT_TYPE[index] = cls
...@@ -91,6 +94,3 @@ class ObjectBase(object): ...@@ -91,6 +94,3 @@ class ObjectBase(object):
if not isinstance(handle, ObjectHandle): if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle) handle = ObjectHandle(handle)
self.handle = handle self.handle = handle
_set_class_node_base(ObjectBase)
...@@ -19,7 +19,7 @@ from ..base import get_last_ffi_error ...@@ -19,7 +19,7 @@ from ..base import get_last_ffi_error
from libcpp.vector cimport vector from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule from cpython cimport pycapsule
from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t from libc.stdint cimport int32_t, int64_t, uint64_t, uint32_t, uint8_t, uint16_t
import ctypes import ctypes
cdef enum TVMTypeCode: cdef enum TVMTypeCode:
...@@ -78,14 +78,11 @@ ctypedef void* TVMRetValueHandle ...@@ -78,14 +78,11 @@ ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle ctypedef void* TVMFunctionHandle
ctypedef void* ObjectHandle ctypedef void* ObjectHandle
ctypedef struct TVMObject:
uint32_t type_index_
int32_t ref_counter_
void (*deleter_)(TVMObject* self)
ctypedef struct TVMNDArrayContainer:
DLTensor dl_tensor
void* manager_ctx
void (*deleter)(DLManagedTensor* self)
int32_t array_type_info
ctypedef TVMNDArrayContainer* TVMNDArrayContainerHandle
ctypedef int (*TVMPackedCFunc)( ctypedef int (*TVMPackedCFunc)(
TVMValue* args, TVMValue* args,
......
...@@ -100,17 +100,34 @@ cdef class NDArrayBase: ...@@ -100,17 +100,34 @@ cdef class NDArrayBase:
return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter) return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter)
# Import limited object-related function from C++ side to improve the speed
# NOTE: can only use POD-C compatible object in FFI.
cdef extern from "tvm/runtime/ndarray.h" namespace "tvm::runtime":
cdef void* TVMArrayHandleToObjectHandle(DLTensorHandle handle)
cdef c_make_array(void* chandle, is_view, is_container): cdef c_make_array(void* chandle, is_view, is_container):
global _TVM_ND_CLS global _TVM_ND_CLS
cdef int32_t array_type_info
fcreate = _CLASS_NDARRAY if is_container:
if is_container and len(_TVM_ND_CLS) > 0: tindex = (
array_type_info = (<TVMNDArrayContainerHandle>chandle).array_type_info <TVMObject*>TVMArrayHandleToObjectHandle(<DLTensorHandle>chandle)).type_index_
if array_type_info > 0: if tindex < len(_TVM_ND_CLS):
fcreate = _TVM_ND_CLS[array_type_info] cls = _TVM_ND_CLS[tindex]
ret = fcreate(None, is_view) if cls is not None:
(<NDArrayBase>ret).chandle = <DLTensor*>chandle ret = cls.__new__(cls)
return ret else:
ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY)
else:
ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
(<NDArrayBase>ret).c_is_view = <int>is_view
return ret
else:
ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
(<NDArrayBase>ret).c_is_view = <int>is_view
return ret
cdef _TVM_COMPATS = () cdef _TVM_COMPATS = ()
...@@ -123,11 +140,16 @@ def _reg_extension(cls, fcreate): ...@@ -123,11 +140,16 @@ def _reg_extension(cls, fcreate):
if fcreate: if fcreate:
_TVM_EXT_RET[cls._tvm_tcode] = fcreate _TVM_EXT_RET[cls._tvm_tcode] = fcreate
cdef _TVM_ND_CLS = {} cdef list _TVM_ND_CLS = []
def _reg_ndarray(cls, fcreate): cdef _register_ndarray(int index, object cls):
"""register object class"""
global _TVM_ND_CLS global _TVM_ND_CLS
_TVM_ND_CLS[cls._array_type_code] = fcreate while len(_TVM_ND_CLS) <= index:
_TVM_ND_CLS.append(None)
_TVM_ND_CLS[index] = cls
def _make_array(handle, is_view, is_container): def _make_array(handle, is_view, is_container):
cdef unsigned long long ptr cdef unsigned long long ptr
......
...@@ -16,12 +16,15 @@ ...@@ -16,12 +16,15 @@
# under the License. # under the License.
"""Maps object type to its constructor""" """Maps object type to its constructor"""
from ..node_generic import _set_class_node_base cdef list OBJECT_TYPE = []
OBJECT_TYPE = []
def _register_object(int index, object cls): def _register_object(int index, object cls):
"""register object class""" """register object class"""
if issubclass(cls, NDArrayBase):
_register_ndarray(index, cls)
return
global OBJECT_TYPE
while len(OBJECT_TYPE) <= index: while len(OBJECT_TYPE) <= index:
OBJECT_TYPE.append(None) OBJECT_TYPE.append(None)
OBJECT_TYPE[index] = cls OBJECT_TYPE[index] = cls
...@@ -31,14 +34,13 @@ cdef inline object make_ret_object(void* chandle): ...@@ -31,14 +34,13 @@ cdef inline object make_ret_object(void* chandle):
global OBJECT_TYPE global OBJECT_TYPE
global _CLASS_NODE global _CLASS_NODE
cdef unsigned tindex cdef unsigned tindex
cdef list object_type
cdef object cls cdef object cls
cdef object handle cdef object handle
object_type = OBJECT_TYPE object_type = OBJECT_TYPE
handle = ctypes_handle(chandle) handle = ctypes_handle(chandle)
CALL(TVMObjectGetTypeIndex(chandle, &tindex)) CALL(TVMObjectGetTypeIndex(chandle, &tindex))
if tindex < len(object_type): if tindex < len(OBJECT_TYPE):
cls = object_type[tindex] cls = OBJECT_TYPE[tindex]
if cls is not None: if cls is not None:
obj = cls.__new__(cls) obj = cls.__new__(cls)
else: else:
...@@ -99,6 +101,3 @@ cdef class ObjectBase: ...@@ -99,6 +101,3 @@ cdef class ObjectBase:
(<FunctionBase>fconstructor).chandle, (<FunctionBase>fconstructor).chandle,
kObjectHandle, args, &chandle) kObjectHandle, args, &chandle)
self.chandle = chandle self.chandle = chandle
_set_class_node_base(ObjectBase)
...@@ -22,6 +22,7 @@ from __future__ import absolute_import ...@@ -22,6 +22,7 @@ from __future__ import absolute_import
import sys import sys
import ctypes import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
from .node_generic import _set_class_objects
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
...@@ -32,15 +33,21 @@ try: ...@@ -32,15 +33,21 @@ try:
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import convert_to_tvm_func from ._cy3.core import convert_to_tvm_func
else: else:
from ._cy2.core import _set_class_function, _set_class_module from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import convert_to_tvm_func from ._cy2.core import convert_to_tvm_func
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import FunctionBase as _FunctionBase from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.object import ObjectBase as _ObjectBase
from ._ctypes.function import convert_to_tvm_func from ._ctypes.function import convert_to_tvm_func
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
...@@ -325,3 +332,4 @@ def _init_api_prefix(module_name, prefix): ...@@ -325,3 +332,4 @@ def _init_api_prefix(module_name, prefix):
setattr(target_module, ff.__name__, ff) setattr(target_module, ff.__name__, ff)
_set_class_function(Function) _set_class_function(Function)
_set_class_objects((_ObjectBase, _NDArrayBase, ModuleBase))
...@@ -35,16 +35,16 @@ try: ...@@ -35,16 +35,16 @@ try:
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy3.core import NDArrayBase as _NDArrayBase from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import _reg_extension, _reg_ndarray from ._cy3.core import _reg_extension
else: else:
from ._cy2.core import _set_class_ndarray, _make_array, _from_dlpack from ._cy2.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy2.core import NDArrayBase as _NDArrayBase from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import _reg_extension, _reg_ndarray from ._cy2.core import _reg_extension
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.ndarray import _reg_extension, _reg_ndarray from ._ctypes.ndarray import _reg_extension
def context(dev_type, dev_id=0): def context(dev_type, dev_id=0):
...@@ -348,13 +348,8 @@ def register_extension(cls, fcreate=None): ...@@ -348,13 +348,8 @@ def register_extension(cls, fcreate=None):
def _tvm_handle(self): def _tvm_handle(self):
return self.handle.value return self.handle.value
""" """
if issubclass(cls, _NDArrayBase): assert hasattr(cls, "_tvm_tcode")
assert fcreate is not None if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
assert hasattr(cls, "_array_type_code") raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_ndarray(cls, fcreate) _reg_extension(cls, fcreate)
else:
assert hasattr(cls, "_tvm_tcode")
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
return cls return cls
...@@ -23,11 +23,11 @@ from .. import _api_internal ...@@ -23,11 +23,11 @@ from .. import _api_internal
from .base import string_types from .base import string_types
# Node base class # Node base class
_CLASS_NODE_BASE = None _CLASS_OBJECTS = None
def _set_class_node_base(cls): def _set_class_objects(cls):
global _CLASS_NODE_BASE global _CLASS_OBJECTS
_CLASS_NODE_BASE = cls _CLASS_OBJECTS = cls
def _scalar_type_inference(value): def _scalar_type_inference(value):
...@@ -67,7 +67,7 @@ def convert_to_node(value): ...@@ -67,7 +67,7 @@ def convert_to_node(value):
node : Node node : Node
The corresponding node value. The corresponding node value.
""" """
if isinstance(value, _CLASS_NODE_BASE): if isinstance(value, _CLASS_OBJECTS):
return value return value
if isinstance(value, bool): if isinstance(value, bool):
return const(value, 'uint1x1') return const(value, 'uint1x1')
...@@ -81,7 +81,7 @@ def convert_to_node(value): ...@@ -81,7 +81,7 @@ def convert_to_node(value):
if isinstance(value, dict): if isinstance(value, dict):
vlist = [] vlist = []
for item in value.items(): for item in value.items():
if (not isinstance(item[0], _CLASS_NODE_BASE) and if (not isinstance(item[0], _CLASS_OBJECTS) and
not isinstance(item[0], string_types)): not isinstance(item[0], string_types)):
raise ValueError("key of map must already been a container type") raise ValueError("key of map must already been a container type")
vlist.append(item[0]) vlist.append(item[0])
......
...@@ -271,12 +271,3 @@ class TVMArray(ctypes.Structure): ...@@ -271,12 +271,3 @@ class TVMArray(ctypes.Structure):
("byte_offset", ctypes.c_uint64)] ("byte_offset", ctypes.c_uint64)]
TVMArrayHandle = ctypes.POINTER(TVMArray) TVMArrayHandle = ctypes.POINTER(TVMArray)
class TVMNDArrayContainer(ctypes.Structure):
"""TVM NDArray::Container"""
_fields_ = [("dl_tensor", TVMArray),
("manager_ctx", ctypes.c_void_p),
("deleter", ctypes.c_void_p),
("array_type_info", ctypes.c_int32)]
TVMNDArrayContainerHandle = ctypes.POINTER(TVMNDArrayContainer)
...@@ -27,7 +27,10 @@ from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase ...@@ -27,7 +27,10 @@ from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack from ._ffi.ndarray import context, empty, from_dlpack
from ._ffi.ndarray import _set_class_ndarray from ._ffi.ndarray import _set_class_ndarray
from ._ffi.ndarray import register_extension from ._ffi.ndarray import register_extension
from ._ffi.object import register_object
@register_object
class NDArray(NDArrayBase): class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime. """Lightweight NDArray class of TVM runtime.
......
...@@ -67,7 +67,7 @@ TVM_REGISTER_API("_Array") ...@@ -67,7 +67,7 @@ TVM_REGISTER_API("_Array")
} }
auto node = make_node<ArrayNode>(); auto node = make_node<ArrayNode>();
node->data = std::move(data); node->data = std::move(data);
*ret = runtime::ObjectRef(node); *ret = Array<ObjectRef>(node);
}); });
TVM_REGISTER_API("_ArrayGetItem") TVM_REGISTER_API("_ArrayGetItem")
...@@ -100,28 +100,28 @@ TVM_REGISTER_API("_Map") ...@@ -100,28 +100,28 @@ TVM_REGISTER_API("_Map")
for (int i = 0; i < args.num_args; i += 2) { for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kStr) CHECK(args[i].type_code() == kStr)
<< "key of str map need to be str"; << "key of str map need to be str";
CHECK(args[i + 1].type_code() == kObjectHandle) CHECK(args[i + 1].IsObjectRef<ObjectRef>())
<< "value of the map to be NodeRef"; << "value of the map to be NodeRef";
data.emplace(std::make_pair(args[i].operator std::string(), data.emplace(std::make_pair(args[i].operator std::string(),
args[i + 1].operator ObjectRef())); args[i + 1].operator ObjectRef()));
} }
auto node = make_node<StrMapNode>(); auto node = make_node<StrMapNode>();
node->data = std::move(data); node->data = std::move(data);
*ret = node; *ret = Map<ObjectRef, ObjectRef>(node);
} else { } else {
// Container node. // Container node.
MapNode::ContainerType data; MapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) { for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kObjectHandle) CHECK(args[i].IsObjectRef<ObjectRef>())
<< "key of str map need to be str"; << "key of str map need to be object";
CHECK(args[i + 1].type_code() == kObjectHandle) CHECK(args[i + 1].IsObjectRef<ObjectRef>())
<< "value of map to be NodeRef"; << "value of map to be NodeRef";
data.emplace(std::make_pair(args[i].operator ObjectRef(), data.emplace(std::make_pair(args[i].operator ObjectRef(),
args[i + 1].operator ObjectRef())); args[i + 1].operator ObjectRef()));
} }
auto node = make_node<MapNode>(); auto node = make_node<MapNode>();
node->data = std::move(data); node->data = std::move(data);
*ret = node; *ret = Map<ObjectRef, ObjectRef>(node);
} }
}); });
...@@ -191,7 +191,7 @@ TVM_REGISTER_API("_MapItems") ...@@ -191,7 +191,7 @@ TVM_REGISTER_API("_MapItems")
rkvs->data.push_back(kv.first); rkvs->data.push_back(kv.first);
rkvs->data.push_back(kv.second); rkvs->data.push_back(kv.second);
} }
*ret = rkvs; *ret = Array<ObjectRef>(rkvs);
} else { } else {
auto* n = static_cast<const StrMapNode*>(ptr); auto* n = static_cast<const StrMapNode*>(ptr);
auto rkvs = make_node<ArrayNode>(); auto rkvs = make_node<ArrayNode>();
...@@ -199,7 +199,7 @@ TVM_REGISTER_API("_MapItems") ...@@ -199,7 +199,7 @@ TVM_REGISTER_API("_MapItems")
rkvs->data.push_back(ir::StringImm::make(kv.first)); rkvs->data.push_back(ir::StringImm::make(kv.first));
rkvs->data.push_back(kv.second); rkvs->data.push_back(kv.second);
} }
*ret = rkvs; *ret = Array<ObjectRef>(rkvs);
} }
}); });
......
...@@ -27,8 +27,12 @@ ...@@ -27,8 +27,12 @@
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include "runtime_base.h" #include "runtime_base.h"
// deleter for arrays used by DLPack exporter extern "C" {
extern "C" void NDArrayDLPackDeleter(DLManagedTensor* tensor); // C-mangled dlpack deleter.
static void TVMNDArrayDLPackDeleter(DLManagedTensor* tensor);
// helper function to get NDArray's type index, only used by ctypes.
TVM_DLL int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex);
}
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -53,8 +57,8 @@ inline size_t GetDataAlignment(const DLTensor& arr) { ...@@ -53,8 +57,8 @@ inline size_t GetDataAlignment(const DLTensor& arr) {
struct NDArray::Internal { struct NDArray::Internal {
// Default deleter for the container // Default deleter for the container
static void DefaultDeleter(NDArray::Container* ptr) { static void DefaultDeleter(Object* ptr_obj) {
using tvm::runtime::NDArray; auto* ptr = static_cast<NDArray::Container*>(ptr_obj);
if (ptr->manager_ctx != nullptr) { if (ptr->manager_ctx != nullptr) {
static_cast<NDArray::Container*>(ptr->manager_ctx)->DecRef(); static_cast<NDArray::Container*>(ptr->manager_ctx)->DecRef();
} else if (ptr->dl_tensor.data != nullptr) { } else if (ptr->dl_tensor.data != nullptr) {
...@@ -68,7 +72,8 @@ struct NDArray::Internal { ...@@ -68,7 +72,8 @@ struct NDArray::Internal {
// that are not allocated inside of TVM. // that are not allocated inside of TVM.
// This enables us to create NDArray from memory allocated by other // This enables us to create NDArray from memory allocated by other
// frameworks that are DLPack compatible // frameworks that are DLPack compatible
static void DLPackDeleter(NDArray::Container* ptr) { static void DLPackDeleter(Object* ptr_obj) {
auto* ptr = static_cast<NDArray::Container*>(ptr_obj);
DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx); DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx);
if (tensor->deleter != nullptr) { if (tensor->deleter != nullptr) {
(*tensor->deleter)(tensor); (*tensor->deleter)(tensor);
...@@ -81,12 +86,13 @@ struct NDArray::Internal { ...@@ -81,12 +86,13 @@ struct NDArray::Internal {
DLDataType dtype, DLDataType dtype,
DLContext ctx) { DLContext ctx) {
VerifyDataType(dtype); VerifyDataType(dtype);
// critical zone
// critical zone: construct header
NDArray::Container* data = new NDArray::Container(); NDArray::Container* data = new NDArray::Container();
data->deleter = DefaultDeleter; data->SetDeleter(DefaultDeleter);
NDArray ret(data);
ret.data_ = data;
// RAII now in effect // RAII now in effect
NDArray ret(GetObjectPtr<Object>(data));
// setup shape // setup shape
data->shape_ = std::move(shape); data->shape_ = std::move(shape);
data->dl_tensor.shape = dmlc::BeginPtr(data->shape_); data->dl_tensor.shape = dmlc::BeginPtr(data->shape_);
...@@ -98,45 +104,57 @@ struct NDArray::Internal { ...@@ -98,45 +104,57 @@ struct NDArray::Internal {
return ret; return ret;
} }
// Implementation of API function // Implementation of API function
static DLTensor* MoveAsDLTensor(NDArray arr) { static DLTensor* MoveToFFIHandle(NDArray arr) {
DLTensor* tensor = const_cast<DLTensor*>(arr.operator->()); DLTensor* handle = NDArray::FFIGetHandle(arr);
CHECK(reinterpret_cast<DLTensor*>(arr.data_) == tensor); ObjectRef::FFIClearAfterMove(&arr);
arr.data_ = nullptr; return handle;
return tensor; }
static void FFIDecRef(TVMArrayHandle tensor) {
NDArray::FFIDecRef(tensor);
} }
// Container to DLManagedTensor // Container to DLManagedTensor
static DLManagedTensor* ToDLPack(TVMArrayHandle handle) {
auto* from = static_cast<NDArray::Container*>(
reinterpret_cast<NDArray::ContainerBase*>(handle));
return ToDLPack(from);
}
static DLManagedTensor* ToDLPack(NDArray::Container* from) { static DLManagedTensor* ToDLPack(NDArray::Container* from) {
CHECK(from != nullptr); CHECK(from != nullptr);
DLManagedTensor* ret = new DLManagedTensor(); DLManagedTensor* ret = new DLManagedTensor();
ret->dl_tensor = from->dl_tensor; ret->dl_tensor = from->dl_tensor;
ret->manager_ctx = from; ret->manager_ctx = from;
from->IncRef(); from->IncRef();
ret->deleter = NDArrayDLPackDeleter; ret->deleter = TVMNDArrayDLPackDeleter;
return ret; return ret;
} }
// Delete dlpack object.
static void NDArrayDLPackDeleter(DLManagedTensor* tensor) {
static_cast<NDArray::Container*>(tensor->manager_ctx)->DecRef();
delete tensor;
}
}; };
NDArray NDArray::CreateView(std::vector<int64_t> shape, NDArray NDArray::CreateView(std::vector<int64_t> shape, DLDataType dtype) {
DLDataType dtype) {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
CHECK(data_->dl_tensor.strides == nullptr) CHECK(get_mutable()->dl_tensor.strides == nullptr)
<< "Can only create view for compact tensor"; << "Can only create view for compact tensor";
NDArray ret = Internal::Create(shape, dtype, data_->dl_tensor.ctx); NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.ctx);
ret.data_->dl_tensor.byte_offset = ret.get_mutable()->dl_tensor.byte_offset =
this->data_->dl_tensor.byte_offset; this->get_mutable()->dl_tensor.byte_offset;
size_t curr_size = GetDataSize(this->data_->dl_tensor); size_t curr_size = GetDataSize(this->get_mutable()->dl_tensor);
size_t view_size = GetDataSize(ret.data_->dl_tensor); size_t view_size = GetDataSize(ret.get_mutable()->dl_tensor);
CHECK_LE(view_size, curr_size) CHECK_LE(view_size, curr_size)
<< "Tries to create a view that has bigger memory than current one"; << "Tries to create a view that has bigger memory than current one";
// increase ref count // increase ref count
this->data_->IncRef(); get_mutable()->IncRef();
ret.data_->manager_ctx = this->data_; ret.get_mutable()->manager_ctx = get_mutable();
ret.data_->dl_tensor.data = this->data_->dl_tensor.data; ret.get_mutable()->dl_tensor.data = get_mutable()->dl_tensor.data;
return ret; return ret;
} }
DLManagedTensor* NDArray::ToDLPack() const { DLManagedTensor* NDArray::ToDLPack() const {
return Internal::ToDLPack(data_); return Internal::ToDLPack(get_mutable());
} }
NDArray NDArray::Empty(std::vector<int64_t> shape, NDArray NDArray::Empty(std::vector<int64_t> shape,
...@@ -144,9 +162,9 @@ NDArray NDArray::Empty(std::vector<int64_t> shape, ...@@ -144,9 +162,9 @@ NDArray NDArray::Empty(std::vector<int64_t> shape,
DLContext ctx) { DLContext ctx) {
NDArray ret = Internal::Create(shape, dtype, ctx); NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content // setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor); size_t size = GetDataSize(ret.get_mutable()->dl_tensor);
size_t alignment = GetDataAlignment(ret.data_->dl_tensor); size_t alignment = GetDataAlignment(ret.get_mutable()->dl_tensor);
ret.data_->dl_tensor.data = ret.get_mutable()->dl_tensor.data =
DeviceAPI::Get(ret->ctx)->AllocDataSpace( DeviceAPI::Get(ret->ctx)->AllocDataSpace(
ret->ctx, size, alignment, ret->dtype); ret->ctx, size, alignment, ret->dtype);
return ret; return ret;
...@@ -154,10 +172,12 @@ NDArray NDArray::Empty(std::vector<int64_t> shape, ...@@ -154,10 +172,12 @@ NDArray NDArray::Empty(std::vector<int64_t> shape,
NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { NDArray NDArray::FromDLPack(DLManagedTensor* tensor) {
NDArray::Container* data = new NDArray::Container(); NDArray::Container* data = new NDArray::Container();
data->deleter = Internal::DLPackDeleter; // construct header
data->SetDeleter(Internal::DLPackDeleter);
// fill up content.
data->manager_ctx = tensor; data->manager_ctx = tensor;
data->dl_tensor = tensor->dl_tensor; data->dl_tensor = tensor->dl_tensor;
return NDArray(data); return NDArray(GetObjectPtr<Object>(data));
} }
void NDArray::CopyFromTo(const DLTensor* from, void NDArray::CopyFromTo(const DLTensor* from,
...@@ -184,17 +204,24 @@ void NDArray::CopyFromTo(const DLTensor* from, ...@@ -184,17 +204,24 @@ void NDArray::CopyFromTo(const DLTensor* from,
} }
std::vector<int64_t> NDArray::Shape() const { std::vector<int64_t> NDArray::Shape() const {
return data_->shape_; return get_mutable()->shape_;
} }
TVM_REGISTER_OBJECT_TYPE(NDArray::Container);
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
using namespace tvm::runtime; using namespace tvm::runtime;
void NDArrayDLPackDeleter(DLManagedTensor* tensor) { void TVMNDArrayDLPackDeleter(DLManagedTensor* tensor) {
static_cast<NDArray::Container*>(tensor->manager_ctx)->DecRef(); NDArray::Internal::NDArrayDLPackDeleter(tensor);
delete tensor; }
int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex) {
API_BEGIN();
*out_tindex = TVMArrayHandleToObjectHandle(handle)->type_index();
API_END();
} }
int TVMArrayAlloc(const tvm_index_t* shape, int TVMArrayAlloc(const tvm_index_t* shape,
...@@ -213,14 +240,14 @@ int TVMArrayAlloc(const tvm_index_t* shape, ...@@ -213,14 +240,14 @@ int TVMArrayAlloc(const tvm_index_t* shape,
DLContext ctx; DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
*out = NDArray::Internal::MoveAsDLTensor( *out = NDArray::Internal::MoveToFFIHandle(
NDArray::Empty(std::vector<int64_t>(shape, shape + ndim), dtype, ctx)); NDArray::Empty(std::vector<int64_t>(shape, shape + ndim), dtype, ctx));
API_END(); API_END();
} }
int TVMArrayFree(TVMArrayHandle handle) { int TVMArrayFree(TVMArrayHandle handle) {
API_BEGIN(); API_BEGIN();
reinterpret_cast<NDArray::Container*>(handle)->DecRef(); NDArray::Internal::FFIDecRef(handle);
API_END(); API_END();
} }
...@@ -235,14 +262,14 @@ int TVMArrayCopyFromTo(TVMArrayHandle from, ...@@ -235,14 +262,14 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
int TVMArrayFromDLPack(DLManagedTensor* from, int TVMArrayFromDLPack(DLManagedTensor* from,
TVMArrayHandle* out) { TVMArrayHandle* out) {
API_BEGIN(); API_BEGIN();
*out = NDArray::Internal::MoveAsDLTensor(NDArray::FromDLPack(from)); *out = NDArray::Internal::MoveToFFIHandle(NDArray::FromDLPack(from));
API_END(); API_END();
} }
int TVMArrayToDLPack(TVMArrayHandle from, int TVMArrayToDLPack(TVMArrayHandle from,
DLManagedTensor** out) { DLManagedTensor** out) {
API_BEGIN(); API_BEGIN();
*out = NDArray::Internal::ToDLPack(reinterpret_cast<NDArray::Container*>(from)); *out = NDArray::Internal::ToDLPack(from);
API_END(); API_END();
} }
......
...@@ -59,7 +59,8 @@ class RPCWrappedFunc { ...@@ -59,7 +59,8 @@ class RPCWrappedFunc {
const TVMArgValue& arg); const TVMArgValue& arg);
// deleter of RPC remote array // deleter of RPC remote array
static void RemoteNDArrayDeleter(NDArray::Container* ptr) { static void RemoteNDArrayDeleter(Object* obj) {
auto* ptr = static_cast<NDArray::Container*>(obj);
RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data); RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data);
space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx); space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx);
delete space; delete space;
...@@ -71,12 +72,12 @@ class RPCWrappedFunc { ...@@ -71,12 +72,12 @@ class RPCWrappedFunc {
void* nd_handle) { void* nd_handle) {
NDArray::Container* data = new NDArray::Container(); NDArray::Container* data = new NDArray::Container();
data->manager_ctx = nd_handle; data->manager_ctx = nd_handle;
data->deleter = RemoteNDArrayDeleter; data->SetDeleter(RemoteNDArrayDeleter);
RemoteSpace* space = new RemoteSpace(); RemoteSpace* space = new RemoteSpace();
space->sess = sess; space->sess = sess;
space->data = tensor->data; space->data = tensor->data;
data->dl_tensor.data = space; data->dl_tensor.data = space;
NDArray ret(data); NDArray ret(GetObjectPtr<Object>(data));
// RAII now in effect // RAII now in effect
data->shape_ = std::vector<int64_t>( data->shape_ = std::vector<int64_t>(
tensor->shape, tensor->shape + tensor->ndim); tensor->shape, tensor->shape + tensor->ndim);
......
...@@ -787,9 +787,7 @@ class RPCSession::EventHandler : public dmlc::Stream { ...@@ -787,9 +787,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
TVMValue ret_value_pack[2]; TVMValue ret_value_pack[2];
int ret_tcode_pack[2]; int ret_tcode_pack[2];
rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]); rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]);
ret_value_pack[1].v_handle = ret_value_pack[0].v_handle;
NDArray::Container* nd = static_cast<NDArray::Container*>(ret_value_pack[0].v_handle);
ret_value_pack[1].v_handle = nd;
ret_tcode_pack[1] = kHandle; ret_tcode_pack[1] = kHandle;
SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true); SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true);
} else { } else {
...@@ -1190,7 +1188,8 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) { ...@@ -1190,7 +1188,8 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) { void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) {
void* handle = args[0]; void* handle = args[0];
static_cast<NDArray::Container*>(handle)->DecRef(); static_cast<NDArray::Container*>(
reinterpret_cast<NDArray::ContainerBase*>(handle))->DecRef();
} }
void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) { void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) {
......
...@@ -31,7 +31,8 @@ namespace tvm { ...@@ -31,7 +31,8 @@ namespace tvm {
namespace runtime { namespace runtime {
namespace vm { namespace vm {
static void BufferDeleter(NDArray::Container* ptr) { static void BufferDeleter(Object* obj) {
auto* ptr = static_cast<NDArray::Container*>(obj);
CHECK(ptr->manager_ctx != nullptr); CHECK(ptr->manager_ctx != nullptr);
Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx); Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
MemoryManager::Global()->GetAllocator(buffer->ctx)-> MemoryManager::Global()->GetAllocator(buffer->ctx)->
...@@ -40,7 +41,8 @@ static void BufferDeleter(NDArray::Container* ptr) { ...@@ -40,7 +41,8 @@ static void BufferDeleter(NDArray::Container* ptr) {
delete ptr; delete ptr;
} }
void StorageObj::Deleter(NDArray::Container* ptr) { void StorageObj::Deleter(Object* obj) {
auto* ptr = static_cast<NDArray::Container*>(obj);
// When invoking AllocNDArray we don't own the underlying allocation // When invoking AllocNDArray we don't own the underlying allocation
// and should not delete the buffer, but instead let it be reclaimed // and should not delete the buffer, but instead let it be reclaimed
// by the storage object's destructor. // by the storage object's destructor.
...@@ -77,16 +79,23 @@ NDArray StorageObj::AllocNDArray(size_t offset, std::vector<int64_t> shape, DLDa ...@@ -77,16 +79,23 @@ NDArray StorageObj::AllocNDArray(size_t offset, std::vector<int64_t> shape, DLDa
// TODO(@jroesch): generalize later to non-overlapping allocations. // TODO(@jroesch): generalize later to non-overlapping allocations.
CHECK_EQ(offset, 0u); CHECK_EQ(offset, 0u);
VerifyDataType(dtype); VerifyDataType(dtype);
// crtical zone: allocate header, cannot throw
NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, this->buffer.ctx); NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, this->buffer.ctx);
container->deleter = StorageObj::Deleter;
container->SetDeleter(StorageObj::Deleter);
size_t needed_size = GetDataSize(container->dl_tensor); size_t needed_size = GetDataSize(container->dl_tensor);
// TODO(@jroesch): generalize later to non-overlapping allocations.
CHECK(needed_size == this->buffer.size)
<< "size mistmatch required " << needed_size << " found " << this->buffer.size;
this->IncRef(); this->IncRef();
container->manager_ctx = reinterpret_cast<void*>(this); container->manager_ctx = reinterpret_cast<void*>(this);
container->dl_tensor.data = this->buffer.data; container->dl_tensor.data = this->buffer.data;
return NDArray(container); NDArray ret(GetObjectPtr<Object>(container));
// RAII in effect, now run the check.
// TODO(@jroesch): generalize later to non-overlapping allocations.
CHECK(needed_size == this->buffer.size)
<< "size mistmatch required " << needed_size << " found " << this->buffer.size;
return ret;
} }
MemoryManager* MemoryManager::Global() { MemoryManager* MemoryManager::Global() {
...@@ -108,14 +117,14 @@ Allocator* MemoryManager::GetAllocator(TVMContext ctx) { ...@@ -108,14 +117,14 @@ Allocator* MemoryManager::GetAllocator(TVMContext ctx) {
NDArray Allocator::Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) { NDArray Allocator::Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) {
VerifyDataType(dtype); VerifyDataType(dtype);
NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, ctx); NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, ctx);
container->deleter = BufferDeleter; container->SetDeleter(BufferDeleter);
size_t size = GetDataSize(container->dl_tensor); size_t size = GetDataSize(container->dl_tensor);
size_t alignment = GetDataAlignment(container->dl_tensor); size_t alignment = GetDataAlignment(container->dl_tensor);
Buffer *buffer = new Buffer; Buffer *buffer = new Buffer;
*buffer = this->Alloc(size, alignment, dtype); *buffer = this->Alloc(size, alignment, dtype);
container->manager_ctx = reinterpret_cast<void*>(buffer); container->manager_ctx = reinterpret_cast<void*>(buffer);
container->dl_tensor.data = buffer->data; container->dl_tensor.data = buffer->data;
return NDArray(container); return NDArray(GetObjectPtr<Object>(container));
} }
} // namespace vm } // namespace vm
......
...@@ -120,7 +120,7 @@ class StorageObj : public Object { ...@@ -120,7 +120,7 @@ class StorageObj : public Object {
DLDataType dtype); DLDataType dtype);
/*! \brief The deleter for an NDArray when allocated from underlying storage. */ /*! \brief The deleter for an NDArray when allocated from underlying storage. */
static void Deleter(NDArray::Container* ptr); static void Deleter(Object* ptr);
~StorageObj() { ~StorageObj() {
auto alloc = MemoryManager::Global()->GetAllocator(buffer.ctx); auto alloc = MemoryManager::Global()->GetAllocator(buffer.ctx);
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir.h> #include <tvm/ir.h>
TEST(PackedFunc, Basic) { TEST(PackedFunc, Basic) {
...@@ -178,6 +179,69 @@ TEST(TypedPackedFunc, HighOrder) { ...@@ -178,6 +179,69 @@ TEST(TypedPackedFunc, HighOrder) {
CHECK_EQ(f1(3), 4); CHECK_EQ(f1(3), 4);
} }
TEST(PackedFunc, ObjectConversion) {
using namespace tvm;
using namespace tvm::runtime;
TVMRetValue rv;
auto x = NDArray::Empty(
{}, String2TVMType("float32"),
TVMContext{kDLCPU, 0});
// assign null
rv = ObjectRef();
CHECK_EQ(rv.type_code(), kNull);
// Can assign NDArray to ret type
rv = x;
CHECK_EQ(rv.type_code(), kNDArrayContainer);
// Even if we assign base type it still shows as NDArray
rv = ObjectRef(x);
CHECK_EQ(rv.type_code(), kNDArrayContainer);
// Check convert back
CHECK(rv.operator NDArray().same_as(x));
CHECK(rv.operator ObjectRef().same_as(x));
CHECK(!rv.IsObjectRef<Expr>());
auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args[0].type_code(), kNDArrayContainer);
CHECK(args[0].operator NDArray().same_as(x));
CHECK(args[0].operator ObjectRef().same_as(x));
CHECK(args[1].operator ObjectRef().get() == nullptr);
CHECK(args[1].operator NDArray().get() == nullptr);
CHECK(args[1].operator Module().get() == nullptr);
CHECK(args[1].operator Array<NDArray>().get() == nullptr);
CHECK(!args[0].IsObjectRef<Expr>());
});
pf1(x, ObjectRef());
pf1(ObjectRef(x), NDArray());
// testcases for modules
auto* pf = tvm::runtime::Registry::Get("module.source_module_create");
CHECK(pf != nullptr);
Module m = (*pf)("", "xyz");
rv = m;
CHECK_EQ(rv.type_code(), kModuleHandle);
// Even if we assign base type it still shows as NDArray
rv = ObjectRef(m);
CHECK_EQ(rv.type_code(), kModuleHandle);
// Check convert back
CHECK(rv.operator Module().same_as(m));
CHECK(rv.operator ObjectRef().same_as(m));
CHECK(!rv.IsObjectRef<NDArray>());
auto pf2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args[0].type_code(), kModuleHandle);
CHECK(args[0].operator Module().same_as(m));
CHECK(args[0].operator ObjectRef().same_as(m));
CHECK(args[1].operator ObjectRef().get() == nullptr);
CHECK(args[1].operator NDArray().get() == nullptr);
CHECK(args[1].operator Module().get() == nullptr);
CHECK(!args[0].IsObjectRef<Expr>());
});
pf2(m, ObjectRef());
pf2(ObjectRef(m), Module());
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import tvm import tvm
import numpy as np
def test_array(): def test_array():
a = tvm.convert([1,2,3]) a = tvm.convert([1,2,3])
...@@ -71,6 +72,14 @@ def test_in_container(): ...@@ -71,6 +72,14 @@ def test_in_container():
assert tvm.make.StringImm('a') in arr assert tvm.make.StringImm('a') in arr
assert 'd' not in arr assert 'd' not in arr
def test_ndarray_container():
x = tvm.nd.array([1,2,3])
arr = tvm.convert([x, x])
assert arr[0].same_as(x)
assert arr[1].same_as(x)
assert isinstance(arr[0], tvm.nd.NDArray)
if __name__ == "__main__": if __name__ == "__main__":
test_str_map() test_str_map()
test_array() test_array()
...@@ -78,3 +87,4 @@ if __name__ == "__main__": ...@@ -78,3 +87,4 @@ if __name__ == "__main__":
test_array_save_load_json() test_array_save_load_json()
test_map_save_load_json() test_map_save_load_json()
test_in_container() test_in_container()
test_ndarray_container()
...@@ -32,7 +32,7 @@ def gen_engine_header(): ...@@ -32,7 +32,7 @@ def gen_engine_header():
#include <vector> #include <vector>
class Engine { class Engine {
}; };
#endif #endif
''' '''
header_file = header_file_dir_path.relpath("gcc_engine.h") header_file = header_file_dir_path.relpath("gcc_engine.h")
...@@ -45,7 +45,7 @@ def generate_engine_module(): ...@@ -45,7 +45,7 @@ def generate_engine_module():
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <dlpack/dlpack.h> #include <dlpack/dlpack.h>
#include "gcc_engine.h" #include "gcc_engine.h"
extern "C" void gcc_1_(float* gcc_input4, float* gcc_input5, extern "C" void gcc_1_(float* gcc_input4, float* gcc_input5,
float* gcc_input6, float* gcc_input7, float* out) { float* gcc_input6, float* gcc_input7, float* out) {
Engine engine; Engine engine;
......
...@@ -35,7 +35,8 @@ rm -rf lib ...@@ -35,7 +35,8 @@ rm -rf lib
make make
cd ../.. cd ../..
python3 -m pytest -v apps/extension/tests TVM_FFI=cython python3 -m pytest -v apps/extension/tests
TVM_FFI=ctypes python3 -m pytest -v apps/extension/tests
TVM_FFI=ctypes python3 -m pytest -v tests/python/integration TVM_FFI=ctypes python3 -m pytest -v tests/python/integration
TVM_FFI=ctypes python3 -m pytest -v tests/python/contrib TVM_FFI=ctypes python3 -m pytest -v tests/python/contrib
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment