Unverified Commit 55bd786f by Tianqi Chen Committed by GitHub

[REFACTOR][RUNTIME] Update NDArray use the Unified Object System (#4581)

* [REFACTOR][RUNTIME] Move NDArray to Object System.

Previously NDArray has its own object reference counting mechanism.
This PR migrates NDArray to the unified object protocol.

The calling convention of NDArray remained intact.
That means NDArray still has its own type_code and
its handle is still DLTensor compatible.

In order to do so, this PR added a few minimum runtime type
detection in TVMArgValue and RetValue only when the corresponding
type is a base type(ObjectRef) that could also refer to NDArray.

This means that even if we return a base reference object ObjectRef
which refers to the NDArray. The type_code will still be translated
correctly as kNDArrayContainer.
If we assign a non-base type(say Expr) that we know is not compatible
with NDArray during compile time, no runtime type detection will be performed.

This PR also adopts the object protocol for NDArray sub-classing and
removed the legacy NDArray subclass protocol.
Examples in apps/extension are now updated to reflect that.

Making NDArray as an Object brings all the benefits of the object system.
For example, we can now use the Array container to store NDArrays.

* Address review comments
parent 4072396e
......@@ -51,26 +51,23 @@ class IntVec(tvm.Object):
nd_create = tvm.get_global_func("tvm_ext.nd_create")
nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
nd_get_addtional_info = tvm.get_global_func("tvm_ext.nd_get_addtional_info")
nd_get_additional_info = tvm.get_global_func("tvm_ext.nd_get_additional_info")
@tvm.register_object("tvm_ext.NDSubClass")
class NDSubClass(tvm.nd.NDArrayBase):
"""Example for subclassing TVM's NDArray infrastructure.
By inheriting TMV's NDArray, external libraries could
leverage TVM's FFI without any modification.
"""
# Should be consistent with the type-trait set in the backend
_array_type_code = 1
@staticmethod
def create(addtional_info):
return nd_create(addtional_info)
def create(additional_info):
return nd_create(additional_info)
@property
def addtional_info(self):
return nd_get_addtional_info(self)
def additional_info(self):
return nd_get_additional_info(self)
def __add__(self, other):
return nd_add_two(self, other)
tvm.register_extension(NDSubClass, NDSubClass)
......@@ -29,19 +29,6 @@
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h>
namespace tvm_ext {
class NDSubClass;
} // namespace tvm_ext
namespace tvm {
namespace runtime {
template<>
struct array_type_info<tvm_ext::NDSubClass> {
static const int code = 1;
};
} // namespace tvm
} // namespace runtime
using namespace tvm;
using namespace tvm::runtime;
......@@ -52,54 +39,55 @@ namespace tvm_ext {
* To use this extension, an external library should
*
* 1) Inherit TVM's NDArray and NDArray container,
* and define the trait `array_type_info` for this class.
*
* 2) Define a constructor in the inherited class that accepts
* a pointer to TVM's Container, which is nullable.
* 2) Follow the new object protocol to define new NDArray as a reference class.
*
* 3) On Python frontend, inherit `tvm.nd.NDArrayBase`,
* define the class attribute `_array_type_code` consistent to
* the C++ type trait, and register the subclass using `tvm.register_extension`.
* 3) On Python frontend, inherit `tvm.nd.NDArray`,
* register the type using tvm.register_object
*/
class NDSubClass : public tvm::runtime::NDArray {
public:
class SubContainer : public NDArray::Container {
public:
SubContainer(int addtional_info) :
addtional_info_(addtional_info) {
array_type_code_ = array_type_info<NDSubClass>::code;
}
static bool Is(NDArray::Container *container) {
SubContainer *c = static_cast<SubContainer*>(container);
return c->array_type_code_ == array_type_info<NDSubClass>::code;
SubContainer(int additional_info) :
additional_info_(additional_info) {
type_index_ = SubContainer::RuntimeTypeIndex();
}
int addtional_info_{0};
int additional_info_{0};
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "tvm_ext.NDSubClass";
TVM_DECLARE_FINAL_OBJECT_INFO(SubContainer, NDArray::Container);
};
NDSubClass(NDArray::Container *container) {
if (container == nullptr) {
data_ = nullptr;
return;
}
CHECK(SubContainer::Is(container));
container->IncRef();
data_ = container;
static void SubContainerDeleter(Object* obj) {
auto* ptr = static_cast<SubContainer*>(obj);
delete ptr;
}
~NDSubClass() {
this->reset();
NDSubClass() {}
explicit NDSubClass(ObjectPtr<Object> n) : NDArray(n) {}
explicit NDSubClass(int additional_info) {
SubContainer* ptr = new SubContainer(additional_info);
ptr->SetDeleter(SubContainerDeleter);
data_ = GetObjectPtr<Object>(ptr);
}
NDSubClass AddWith(const NDSubClass &other) const {
SubContainer *a = static_cast<SubContainer*>(data_);
SubContainer *b = static_cast<SubContainer*>(other.data_);
SubContainer *a = static_cast<SubContainer*>(get_mutable());
SubContainer *b = static_cast<SubContainer*>(other.get_mutable());
CHECK(a != nullptr && b != nullptr);
return NDSubClass(new SubContainer(a->addtional_info_ + b->addtional_info_));
return NDSubClass(a->additional_info_ + b->additional_info_);
}
int get_additional_info() const {
SubContainer *self = static_cast<SubContainer*>(data_);
SubContainer *self = static_cast<SubContainer*>(get_mutable());
CHECK(self != nullptr);
return self->addtional_info_;
return self->additional_info_;
}
using ContainerType = SubContainer;
};
TVM_REGISTER_OBJECT_TYPE(NDSubClass::SubContainer);
/*!
* \brief Introduce additional extension data structures
......@@ -166,8 +154,10 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev")
TVM_REGISTER_GLOBAL("tvm_ext.nd_create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
int addtional_info = args[0];
*rv = NDSubClass(new NDSubClass::SubContainer(addtional_info));
int additional_info = args[0];
*rv = NDSubClass(additional_info);
CHECK_EQ(rv->type_code(), kNDArrayContainer);
});
TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
......@@ -177,7 +167,7 @@ TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
*rv = a.AddWith(b);
});
TVM_REGISTER_GLOBAL("tvm_ext.nd_get_addtional_info")
TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info")
.set_body([](TVMArgs args, TVMRetValue *rv) {
NDSubClass a = args[0];
*rv = a.get_additional_info();
......
......@@ -87,16 +87,17 @@ def test_extern_call():
def test_nd_subclass():
a = tvm_ext.NDSubClass.create(addtional_info=3)
b = tvm_ext.NDSubClass.create(addtional_info=5)
a = tvm_ext.NDSubClass.create(additional_info=3)
b = tvm_ext.NDSubClass.create(additional_info=5)
assert isinstance(a, tvm_ext.NDSubClass)
c = a + b
d = a + a
e = b + b
assert(a.addtional_info == 3)
assert(b.addtional_info == 5)
assert(c.addtional_info == 8)
assert(d.addtional_info == 6)
assert(e.addtional_info == 10)
assert(a.additional_info == 3)
assert(b.additional_info == 5)
assert(c.additional_info == 8)
assert(d.additional_info == 6)
assert(e.additional_info == 10)
if __name__ == "__main__":
......
......@@ -23,14 +23,14 @@
#ifndef TVM_NODE_CONTAINER_H_
#define TVM_NODE_CONTAINER_H_
#include <tvm/node/node.h>
#include <type_traits>
#include <vector>
#include <initializer_list>
#include <unordered_map>
#include <utility>
#include <string>
#include "node.h"
#include "memory.h"
namespace tvm {
......
......@@ -25,7 +25,6 @@
#ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_
#include <sstream>
#include <string>
#include <memory>
#include <limits>
......@@ -43,22 +42,7 @@ using runtime::TVMRetValue;
using runtime::PackedFunc;
namespace runtime {
/*!
* \brief Runtime type checker for node type.
* \tparam T the type to be checked.
*/
template<typename T>
struct ObjectTypeChecker {
static bool Check(const Object* ptr) {
using ContainerType = typename T::ContainerType;
if (ptr == nullptr) return true;
return ptr->IsInstance<ContainerType>();
}
static void PrintName(std::ostream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
}
};
template<typename T>
struct ObjectTypeChecker<Array<T> > {
......@@ -73,10 +57,8 @@ struct ObjectTypeChecker<Array<T> > {
}
return true;
}
static void PrintName(std::ostream& os) { // NOLINT(*)
os << "List[";
ObjectTypeChecker<T>::PrintName(os);
os << "]";
static std::string TypeName() {
return "List[" + ObjectTypeChecker<T>::TypeName() + "]";
}
};
......@@ -91,11 +73,9 @@ struct ObjectTypeChecker<Map<std::string, V> > {
}
return true;
}
static void PrintName(std::ostream& os) { // NOLINT(*)
os << "Map[str";
os << ',';
ObjectTypeChecker<V>::PrintName(os);
os << ']';
static std::string TypeName() {
return "Map[str, " +
ObjectTypeChecker<V>::TypeName()+ ']';
}
};
......@@ -111,39 +91,16 @@ struct ObjectTypeChecker<Map<K, V> > {
}
return true;
}
static void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "Map[";
ObjectTypeChecker<K>::PrintName(os);
os << ',';
ObjectTypeChecker<V>::PrintName(os);
os << ']';
static std::string TypeName() {
return "Map[" +
ObjectTypeChecker<K>::TypeName() +
", " +
ObjectTypeChecker<V>::TypeName()+ ']';
}
};
template<typename T>
inline std::string ObjectTypeName() {
std::ostringstream os;
ObjectTypeChecker<T>::PrintName(os);
return os.str();
}
// extensions for tvm arg value
template<typename TObjectRef>
inline TObjectRef TVMArgValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
if (type_code_ == kNull) return TObjectRef(NodePtr<Node>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TObjectRef>()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(ObjectPtr<Node>(ptr));
}
inline TVMArgValue::operator tvm::Expr() const {
inline TVMPODValue_::operator tvm::Expr() const {
if (type_code_ == kNull) return Expr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
......@@ -164,12 +121,12 @@ inline TVMArgValue::operator tvm::Expr() const {
return Tensor(ObjectPtr<Node>(ptr))();
}
CHECK(ObjectTypeChecker<Expr>::Check(ptr))
<< "Expected type " << ObjectTypeName<Expr>()
<< "Expect type " << ObjectTypeChecker<Expr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return Expr(ObjectPtr<Node>(ptr));
}
inline TVMArgValue::operator tvm::Integer() const {
inline TVMPODValue_::operator tvm::Integer() const {
if (type_code_ == kNull) return Integer();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
......@@ -179,35 +136,10 @@ inline TVMArgValue::operator tvm::Integer() const {
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<Integer>::Check(ptr))
<< "Expected type " << ObjectTypeName<Expr>()
<< "Expect type " << ObjectTypeChecker<Expr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return Integer(ObjectPtr<Node>(ptr));
}
template<typename TObjectRef, typename>
inline bool TVMPODValue_::IsObjectRef() const {
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
return ObjectTypeChecker<TObjectRef>::Check(ptr);
}
// extensions for TVMRetValue
template<typename TObjectRef>
inline TObjectRef TVMRetValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
if (type_code_ == kNull) return TObjectRef();
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TObjectRef>()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(ObjectPtr<Object>(ptr));
}
} // namespace runtime
} // namespace tvm
#endif // TVM_PACKED_FUNC_EXT_H_
......@@ -23,6 +23,7 @@
*/
#ifndef TVM_RUNTIME_CONTAINER_H_
#define TVM_RUNTIME_CONTAINER_H_
#include <dmlc/logging.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
......
......@@ -24,11 +24,13 @@
#ifndef TVM_RUNTIME_NDARRAY_H_
#define TVM_RUNTIME_NDARRAY_H_
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/serializer.h>
#include <atomic>
#include <vector>
#include <utility>
#include "c_runtime_api.h"
#include "serializer.h"
namespace tvm {
namespace runtime {
......@@ -37,72 +39,23 @@ namespace runtime {
* \brief Managed NDArray.
* The array is backed by reference counted blocks.
*/
class NDArray {
class NDArray : public ObjectRef {
public:
// internal container type
/*! \brief ContainerBase used to back the TVMArrayHandle */
class ContainerBase;
/*! \brief NDArray internal container type */
class Container;
/*! \brief Container type for Object system. */
using ContainerType = Container;
/*! \brief default constructor */
NDArray() {}
/*!
* \brief cosntruct a NDArray that refers to data
* \param data The data this NDArray refers to
*/
explicit inline NDArray(Container* data);
/*!
* \brief copy constructor.
*
* It does not make a copy, but the reference count of the input NDArray is incremented
*
* \param other NDArray that shares internal data with the input NDArray.
*/
inline NDArray(const NDArray& other); // NOLINT(*)
/*!
* \brief move constructor
* \param other The value to be moved
*/
NDArray(NDArray&& other) // NOLINT(*)
: data_(other.data_) {
other.data_ = nullptr;
}
/*! \brief destructor */
~NDArray() {
this->reset();
}
/*!
* \brief Swap this array with another NDArray
* \param other The other NDArray
* \brief constructor.
* \param data ObjectPtr to the data container.
*/
void swap(NDArray& other) { // NOLINT(*)
std::swap(data_, other.data_);
}
/*!
* \brief copy assignmemt
* \param other The value to be assigned.
* \return reference to self.
*/
NDArray& operator=(const NDArray& other) { // NOLINT(*)
// copy-and-swap idiom
NDArray(other).swap(*this); // NOLINT(*)
return *this;
}
/*!
* \brief move assignmemt
* \param other The value to be assigned.
* \return reference to self.
*/
NDArray& operator=(NDArray&& other) { // NOLINT(*)
// copy-and-swap idiom
NDArray(std::move(other)).swap(*this); // NOLINT(*)
return *this;
}
/*! \return If NDArray is defined */
bool defined() const {
return data_ != nullptr;
}
/*! \return If both NDArray reference the same container */
bool same_as(const NDArray& other) const {
return data_ == other.data_;
}
explicit NDArray(ObjectPtr<Object> data)
: ObjectRef(data) {}
/*! \brief reset the content of NDArray to be nullptr */
inline void reset();
/*!
......@@ -191,36 +144,40 @@ class NDArray {
const DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr);
TVM_DLL std::vector<int64_t> Shape() const;
// internal namespace
struct Internal;
protected:
/*! \brief Internal Data content */
Container* data_{nullptr};
// enable internal functions
friend struct Internal;
friend class TVMPODValue_;
friend class TVMArgValue;
friend class TVMRetValue;
friend class TVMArgsSetter;
};
/*!
* \brief The type trait indicates subclass of TVM's NDArray.
* For irrelavant classes, code = -1.
* For TVM NDArray itself, code = 0.
* All subclasses of NDArray should override code > 0.
*/
template<typename T>
struct array_type_info {
/*! \brief the value of the traits */
static const int code = -1;
};
// Overrides the type trait for tvm's NDArray.
template<>
struct array_type_info<NDArray> {
static const int code = 0;
/*!
* \brief Get mutable internal container pointer.
* \return a mutable container pointer.
*/
inline Container* get_mutable() const;
// Helper functions for FFI handling.
/*!
* \brief Construct NDArray's Data field from array handle in FFI.
* \param handle The array handle.
* \return The corresponding ObjectPtr to the constructed container object.
*
* \note We keep a special calling convention for NDArray by passing
* ContainerBase pointer in FFI.
* As a result, the argument is compatible to DLTensor*.
*/
inline static ObjectPtr<Object> FFIDataFromHandle(TVMArrayHandle handle);
/*!
* \brief DecRef resource managed by an FFI array handle.
* \param handle The array handle.
*/
inline static void FFIDecRef(TVMArrayHandle handle);
/*!
* \brief Get FFI Array handle from ndarray.
* \param nd The object with ndarray type.
* \return The result array handle.
*/
inline static TVMArrayHandle FFIGetHandle(const ObjectRef& nd);
};
/*!
......@@ -231,19 +188,14 @@ struct array_type_info<NDArray> {
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
/*!
* \brief Reference counted Container object used to back NDArray.
* \brief The container base structure
* contains all the fields except for the Object header.
*
* This object is DLTensor compatible:
* the pointer to the NDArrayContainer can be directly
* interpreted as a DLTensor*
*
* \note do not use this function directly, use NDArray.
* \note We explicitly declare this structure in order to pass
* PackedFunc argument using ContainerBase*.
*/
class NDArray::Container {
class NDArray::ContainerBase {
public:
// NOTE: the first part of this structure is the same as
// DLManagedTensor, note that, however, the deleter
// is only called when the reference counter goes to 0
/*!
* \brief The corresponding dl_tensor field.
* \note it is important that the first field is DLTensor
......@@ -259,42 +211,27 @@ class NDArray::Container {
* (e.g. reference to original memory when creating views).
*/
void* manager_ctx{nullptr};
/*!
* \brief Customized deleter
*
* \note The customized deleter is helpful to enable
* different ways of memory allocator that are not
* currently defined by the system.
*/
void (*deleter)(Container* self) = nullptr;
protected:
friend class NDArray;
friend class TVMPODValue_;
friend class TVMArgValue;
friend class TVMRetValue;
friend class RPCWrappedFunc;
/*!
* \brief Type flag used to indicate subclass.
* Default value 0 means normal NDArray::Conatainer.
*
* We can extend a more specialized NDArray::Container
* and use the array_type_code_ to indicate
* the specific array subclass.
*/
int32_t array_type_code_{0};
/*! \brief The internal reference counter */
std::atomic<int> ref_counter_{0};
/*!
* \brief The shape container,
* can be used used for shape data.
*/
std::vector<int64_t> shape_;
};
/*!
* \brief Object container class that backs NDArray.
* \note do not use this function directly, use NDArray.
*/
class NDArray::Container :
public Object,
public NDArray::ContainerBase {
public:
/*! \brief default constructor */
Container() {
// Initialize the type index.
type_index_ = Container::RuntimeTypeIndex();
dl_tensor.data = nullptr;
dl_tensor.ndim = 0;
dl_tensor.shape = nullptr;
......@@ -306,6 +243,8 @@ class NDArray::Container {
std::vector<int64_t> shape,
DLDataType dtype,
DLContext ctx) {
// Initialize the type index.
type_index_ = Container::RuntimeTypeIndex();
dl_tensor.data = data;
shape_ = std::move(shape);
dl_tensor.ndim = static_cast<int>(shape_.size());
......@@ -315,49 +254,36 @@ class NDArray::Container {
dl_tensor.byte_offset = 0;
dl_tensor.ctx = ctx;
}
/*! \brief developer function, increases reference counter */
void IncRef() {
ref_counter_.fetch_add(1, std::memory_order_relaxed);
}
/*! \brief developer function, decrease reference counter */
void DecRef() {
if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
std::atomic_thread_fence(std::memory_order_acquire);
if (this->deleter != nullptr) {
(*this->deleter)(this);
}
}
/*!
* \brief Set the deleter field.
* \param deleter The deleter.
*/
void SetDeleter(FDeleter deleter) {
deleter_ = deleter;
}
};
// implementations of inline functions
// the usages of functions are documented in place.
inline NDArray::NDArray(Container* data)
: data_(data) {
if (data != nullptr) {
data_->IncRef();
}
}
// Expose DecRef and IncRef as public function
// NOTE: they are only for developer purposes only.
using Object::DecRef;
using Object::IncRef;
inline NDArray::NDArray(const NDArray& other)
: data_(other.data_) {
if (data_ != nullptr) {
data_->IncRef();
}
}
// Information for object protocol.
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_child_slots = 0;
static constexpr const uint32_t _type_child_slots_can_overflow = true;
static constexpr const char* _type_key = "NDArray";
TVM_DECLARE_BASE_OBJECT_INFO(NDArray::Container, Object);
inline void NDArray::reset() {
if (data_ != nullptr) {
data_->DecRef();
data_ = nullptr;
}
}
protected:
friend class RPCWrappedFunc;
friend class NDArray;
};
/*! \brief return the size of data the DLTensor hold, in term of number of bytes
// implementations of inline functions
/*!
* \brief return the size of data the DLTensor hold, in term of number of bytes
*
* \param arr the input DLTensor
*
* \return number of bytes of data in the DLTensor.
*/
inline size_t GetDataSize(const DLTensor& arr) {
......@@ -371,24 +297,24 @@ inline size_t GetDataSize(const DLTensor& arr) {
inline void NDArray::CopyFrom(const DLTensor* other) {
CHECK(data_ != nullptr);
CopyFromTo(other, &(data_->dl_tensor));
CopyFromTo(other, &(get_mutable()->dl_tensor));
}
inline void NDArray::CopyFrom(const NDArray& other) {
CHECK(data_ != nullptr);
CHECK(other.data_ != nullptr);
CopyFromTo(&(other.data_->dl_tensor), &(data_->dl_tensor));
CopyFromTo(&(other.get_mutable()->dl_tensor), &(get_mutable()->dl_tensor));
}
inline void NDArray::CopyTo(DLTensor* other) const {
CHECK(data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), other);
CopyFromTo(&(get_mutable()->dl_tensor), other);
}
inline void NDArray::CopyTo(const NDArray& other) const {
CHECK(data_ != nullptr);
CHECK(other.data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor));
CopyFromTo(&(get_mutable()->dl_tensor), &(other.get_mutable()->dl_tensor));
}
inline NDArray NDArray::CopyTo(const DLContext& ctx) const {
......@@ -401,12 +327,39 @@ inline NDArray NDArray::CopyTo(const DLContext& ctx) const {
}
inline int NDArray::use_count() const {
if (data_ == nullptr) return 0;
return data_->ref_counter_.load(std::memory_order_relaxed);
return data_.use_count();
}
inline const DLTensor* NDArray::operator->() const {
return &(data_->dl_tensor);
return &(get_mutable()->dl_tensor);
}
inline NDArray::Container* NDArray::get_mutable() const {
return static_cast<NDArray::Container*>(data_.get());
}
inline ObjectPtr<Object> NDArray::FFIDataFromHandle(TVMArrayHandle handle) {
return GetObjectPtr<Object>(static_cast<NDArray::Container*>(
reinterpret_cast<NDArray::ContainerBase*>(handle)));
}
inline TVMArrayHandle NDArray::FFIGetHandle(const ObjectRef& nd) {
// NOTE: it is necessary to cast to container then to base
// so that the FFI handle uses the ContainerBase address.
return reinterpret_cast<TVMArrayHandle>(
static_cast<NDArray::ContainerBase*>(
static_cast<NDArray::Container*>(
const_cast<Object*>(nd.get()))));
}
inline void NDArray::FFIDecRef(TVMArrayHandle handle) {
static_cast<NDArray::Container*>(
reinterpret_cast<NDArray::ContainerBase*>(handle))->DecRef();
}
inline Object* TVMArrayHandleToObjectHandle(TVMArrayHandle handle) {
return static_cast<NDArray::Container*>(
reinterpret_cast<NDArray::ContainerBase*>(handle));
}
/*! \brief Magic number for NDArray file */
......
......@@ -24,10 +24,11 @@
#define TVM_RUNTIME_OBJECT_H_
#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
#include <type_traits>
#include <string>
#include <utility>
#include "c_runtime_api.h"
/*!
* \brief Whether or not use atomic reference counter.
......@@ -581,6 +582,14 @@ class ObjectRef {
return T(std::move(ref.data_));
}
/*!
* \brief Clear the object ref data field without DecRef
* after we successfully moved the field.
* \param ref The reference data.
*/
static void FFIClearAfterMove(ObjectRef* ref) {
ref->data_.data_ = nullptr;
}
/*!
* \brief Internal helper function get data_ as ObjectPtr of ObjectType.
* \note only used for internal dev purpose.
* \tparam ObjectType The corresponding object type.
......@@ -648,7 +657,7 @@ struct ObjectEqual {
return _GetOrAllocRuntimeTypeIndex(); \
} \
static const uint32_t _GetOrAllocRuntimeTypeIndex() { \
static uint32_t tidx = GetOrAllocRuntimeTypeIndex( \
static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex( \
TypeName::_type_key, \
TypeName::_type_index, \
ParentType::_GetOrAllocRuntimeTypeIndex(), \
......@@ -668,6 +677,19 @@ struct ObjectEqual {
TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define TVM_ATTRIBUTE_UNUSED
#endif
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_OBJECT_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid
/*!
* \brief Helper macro to register the object type to runtime.
* Makes sure that the runtime type table is correctly populated.
......@@ -675,7 +697,7 @@ struct ObjectEqual {
* Use this macro in the cc file for each terminal class.
*/
#define TVM_REGISTER_OBJECT_TYPE(TypeName) \
static DMLC_ATTRIBUTE_UNUSED uint32_t __make_Object_tidx ## _ ## TypeName ## __ = \
TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \
TypeName::_GetOrAllocRuntimeTypeIndex()
......@@ -691,14 +713,14 @@ struct ObjectEqual {
using ContainerType = ObjectName;
#define TVM_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \
TypeName() {} \
explicit TypeName( \
::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
: ParentType(n) {} \
ObjectName* operator->() { \
return static_cast<ObjectName*>(data_.get()); \
} \
operator bool() const { return data_ != nullptr; } \
TypeName() {} \
explicit TypeName( \
::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
: ParentType(n) {} \
ObjectName* operator->() { \
return static_cast<ObjectName*>(data_.get()); \
} \
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName;
// Implementations details below
......
......@@ -387,20 +387,22 @@ inline std::string TVMType2String(TVMType t);
#define TVM_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " \
<< TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \
/*!
* \brief Type traits to mark if a class is tvm extension type.
*
* To enable extension type in C++ must be registered via marco.
* TVM_REGISTER_EXT_TYPE(TypeName) after defining this with this traits.
*
* Extension class can be passed and returned via PackedFunc in all tvm runtime.
* Internally extension class is stored as T*.
*
* \tparam T the typename
* \brief Type traits for runtime type check during FFI conversion.
* \tparam T the type to be checked.
*/
template<typename T>
struct extension_type_info {
static const int code = 0;
struct ObjectTypeChecker {
static bool Check(const Object* ptr) {
using ContainerType = typename T::ContainerType;
if (ptr == nullptr) return true;
return ptr->IsInstance<ContainerType>();
}
static std::string TypeName() {
using ContainerType = typename T::ContainerType;
return ContainerType::_type_key;
}
};
/*!
......@@ -449,24 +451,17 @@ class TVMPODValue_ {
return static_cast<DLTensor*>(value_.v_handle);
} else {
if (type_code_ == kNull) return nullptr;
LOG(FATAL) << "Expected "
LOG(FATAL) << "Expect "
<< "DLTensor* or NDArray but get "
<< TypeCode2Str(type_code_);
return nullptr;
}
}
operator NDArray() const {
if (type_code_ == kNull) return NDArray();
if (type_code_ == kNull) return NDArray(ObjectPtr<Object>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer);
return NDArray(static_cast<NDArray::Container*>(value_.v_handle));
}
operator ObjectRef() const {
if (type_code_ == kNull) {
return ObjectRef(ObjectPtr<Object>(nullptr));
}
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
return ObjectRef(
ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
return NDArray(NDArray::FFIDataFromHandle(
static_cast<TVMArrayHandle>(value_.v_handle)));
}
operator Module() const {
if (type_code_ == kNull) {
......@@ -480,23 +475,9 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx;
}
template<typename TNDArray,
typename = typename std::enable_if<
std::is_base_of<NDArray, TNDArray>::value>::type>
TNDArray AsNDArray() const {
if (type_code_ == kNull) return TNDArray(nullptr);
auto *container = static_cast<NDArray::Container*>(value_.v_handle);
CHECK_EQ(container->array_type_code_, array_type_info<TNDArray>::code);
return TNDArray(container);
}
template<typename TObjectRef,
typename = typename std::enable_if<
std::is_class<TObjectRef>::value>::type>
inline bool IsObjectRef() const;
int type_code() const {
return type_code_;
}
/*!
* \brief return handle as specific pointer type.
* \tparam T the data type.
......@@ -506,6 +487,16 @@ class TVMPODValue_ {
T* ptr() const {
return static_cast<T*>(value_.v_handle);
}
// ObjectRef handling
template<typename TObjectRef,
typename = typename std::enable_if<
std::is_base_of<ObjectRef, TObjectRef>::value>::type>
inline bool IsObjectRef() const;
template<typename TObjectRef>
inline TObjectRef AsObjectRef() const;
// ObjectRef Specializations
inline operator tvm::Expr() const;
inline operator tvm::Integer() const;
protected:
friend class TVMArgsSetter;
......@@ -548,9 +539,11 @@ class TVMArgValue : public TVMPODValue_ {
using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator ObjectRef;
using TVMPODValue_::operator Module;
using TVMPODValue_::IsObjectRef;
using TVMPODValue_::AsObjectRef;
using TVMPODValue_::operator tvm::Expr;
using TVMPODValue_::operator tvm::Integer;
// conversion operator.
operator std::string() const {
......@@ -577,6 +570,9 @@ class TVMArgValue : public TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMType);
return value_.v_type;
}
operator DataType() const {
return DataType(operator DLDataType());
}
operator PackedFunc() const {
if (type_code_ == kNull) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
......@@ -589,16 +585,10 @@ class TVMArgValue : public TVMPODValue_ {
const TVMValue& value() const {
return value_;
}
// Deferred extension handler.
template<typename TObjectRef>
inline TObjectRef AsObjectRef() const;
template<typename T,
typename = typename std::enable_if<
std::is_class<T>::value>::type>
std::is_class<T>::value>::type>
inline operator T() const;
inline operator DataType() const;
inline operator tvm::Expr() const;
inline operator tvm::Integer() const;
};
/*!
......@@ -636,9 +626,11 @@ class TVMRetValue : public TVMPODValue_ {
using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator ObjectRef;
using TVMPODValue_::operator Module;
using TVMPODValue_::IsObjectRef;
using TVMPODValue_::AsObjectRef;
using TVMPODValue_::operator tvm::Expr;
using TVMPODValue_::operator tvm::Integer;
TVMRetValue(const TVMRetValue& other) : TVMPODValue_() {
this->Assign(other);
......@@ -660,6 +652,9 @@ class TVMRetValue : public TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMType);
return value_.v_type;
}
operator DataType() const {
return DataType(operator DLDataType());
}
operator PackedFunc() const {
if (type_code_ == kNull) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
......@@ -712,6 +707,9 @@ class TVMRetValue : public TVMPODValue_ {
value_.v_type = t;
return *this;
}
TVMRetValue& operator=(const DataType& other) {
return operator=(other.operator DLDataType());
}
TVMRetValue& operator=(bool value) {
this->SwitchToPOD(kDLInt);
value_.v_int64 = value;
......@@ -726,24 +724,20 @@ class TVMRetValue : public TVMPODValue_ {
return *this;
}
TVMRetValue& operator=(NDArray other) {
this->Clear();
type_code_ = kNDArrayContainer;
value_.v_handle = other.data_;
other.data_ = nullptr;
if (other.data_ != nullptr) {
this->Clear();
type_code_ = kNDArrayContainer;
value_.v_handle = NDArray::FFIGetHandle(other);
ObjectRef::FFIClearAfterMove(&other);
} else {
SwitchToPOD(kNull);
}
return *this;
}
TVMRetValue& operator=(ObjectRef other) {
return operator=(std::move(other.data_));
}
TVMRetValue& operator=(Module m) {
SwitchToObject(kModuleHandle, std::move(m.data_));
return *this;
}
template<typename T>
TVMRetValue& operator=(ObjectPtr<T> other) {
SwitchToObject(kObjectHandle, std::move(other));
return *this;
}
TVMRetValue& operator=(PackedFunc f) {
this->SwitchToClass(kFuncHandle, f);
return *this;
......@@ -760,14 +754,6 @@ class TVMRetValue : public TVMPODValue_ {
this->Assign(other);
return *this;
}
template<typename T,
typename = typename std::enable_if<
extension_type_info<T>::code != 0>::type>
TVMRetValue& operator=(const T& other) {
this->SwitchToClass<T>(
extension_type_info<T>::code, other);
return *this;
}
/*!
* \brief Move the value back to front-end via C API.
* This marks the current container as null.
......@@ -793,16 +779,15 @@ class TVMRetValue : public TVMPODValue_ {
type_code_ != kStr) << "TVMRetValue.value can only be used for POD data";
return value_;
}
// ObjectRef related extenstions: in tvm/packed_func_ext.h
// ObjectRef handling
template<typename TObjectRef,
typename = typename std::enable_if<
std::is_base_of<ObjectRef, TObjectRef>::value>::type>
inline TVMRetValue& operator=(TObjectRef other);
template<typename T,
typename = typename std::enable_if<
std::is_class<T>::value>::type>
inline operator T() const;
template<typename TObjectRef>
inline TObjectRef AsObjectRef() const;
// type related
inline operator DataType() const;
inline TVMRetValue& operator=(const DataType& other);
private:
template<typename T>
......@@ -829,7 +814,10 @@ class TVMRetValue : public TVMPODValue_ {
break;
}
case kObjectHandle: {
*this = other.operator ObjectRef();
// Avoid operator ObjectRef as we already know it is not NDArray/Module
SwitchToObject(
kObjectHandle, GetObjectPtr<Object>(
static_cast<Object*>(other.value_.v_handle)));
break;
}
default: {
......@@ -873,7 +861,7 @@ class TVMRetValue : public TVMPODValue_ {
case kStr: delete ptr<std::string>(); break;
case kFuncHandle: delete ptr<PackedFunc>(); break;
case kNDArrayContainer: {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
NDArray::FFIDecRef(static_cast<TVMArrayHandle>(value_.v_handle));
break;
}
case kModuleHandle: {
......@@ -905,7 +893,7 @@ inline const char* TypeCode2Str(int type_code) {
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer";
case kObjectHandle: return "ObjectCell";
case kObjectHandle: return "Object";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
......@@ -929,6 +917,10 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*)
return os;
}
inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
return os << dtype.operator DLDataType();
}
#endif
inline std::string TVMType2String(TVMType t) {
......@@ -996,10 +988,6 @@ inline TVMType String2TVMType(std::string s) {
return t;
}
inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
return os << dtype.operator DLDataType();
}
inline TVMArgValue TVMArgs::operator[](int i) const {
CHECK_LT(i, num_args)
<< "not enough argument passed, "
......@@ -1092,50 +1080,31 @@ class TVMArgsSetter {
values_[i].v_type = value;
type_codes_[i] = kTVMType;
}
void operator()(size_t i, DataType dtype) const {
operator()(i, dtype.operator DLDataType());
}
void operator()(size_t i, const char* value) const {
values_[i].v_str = value;
type_codes_[i] = kStr;
}
// setters for container type
// They must be reference(instead of const ref)
// to make sure they are alive in the tuple(instead of getting converted)
void operator()(size_t i, const std::string& value) const { // NOLINT(*)
// setters for container types
void operator()(size_t i, const std::string& value) const {
values_[i].v_str = value.c_str();
type_codes_[i] = kStr;
}
void operator()(size_t i, const TVMByteArray& value) const { // NOLINT(*)
void operator()(size_t i, const TVMByteArray& value) const {
values_[i].v_handle = const_cast<TVMByteArray*>(&value);
type_codes_[i] = kBytes;
}
void operator()(size_t i, const PackedFunc& value) const { // NOLINT(*)
void operator()(size_t i, const PackedFunc& value) const {
values_[i].v_handle = const_cast<PackedFunc*>(&value);
type_codes_[i] = kFuncHandle;
}
template<typename FType>
void operator()(size_t i, const TypedPackedFunc<FType>& value) const { // NOLINT(*)
void operator()(size_t i, const TypedPackedFunc<FType>& value) const {
operator()(i, value.packed());
}
void operator()(size_t i, const Module& value) const { // NOLINT(*)
if (value.defined()) {
values_[i].v_handle = value.data_.data_;
type_codes_[i] = kModuleHandle;
} else {
type_codes_[i] = kNull;
}
}
void operator()(size_t i, const NDArray& value) const { // NOLINT(*)
values_[i].v_handle = value.data_;
type_codes_[i] = kNDArrayContainer;
}
void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*)
if (value.defined()) {
values_[i].v_handle = value.data_.data_;
type_codes_[i] = kObjectHandle;
} else {
type_codes_[i] = kNull;
}
}
void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*)
void operator()(size_t i, const TVMRetValue& value) const {
if (value.type_code() == kStr) {
values_[i].v_str = value.ptr<std::string>()->c_str();
type_codes_[i] = kStr;
......@@ -1145,12 +1114,11 @@ class TVMArgsSetter {
type_codes_[i] = value.type_code();
}
}
// extension
template<typename T,
// ObjectRef handling
template<typename TObjectRef,
typename = typename std::enable_if<
extension_type_info<T>::code != 0>::type>
inline void operator()(size_t i, const T& value) const;
inline void operator()(size_t i, const DataType& t) const;
std::is_base_of<ObjectRef, TObjectRef>::value>::type>
inline void operator()(size_t i, const TObjectRef& value) const;
private:
/*! \brief The values fields */
......@@ -1262,57 +1230,131 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
::run(packed_, std::forward<Args>(args)...);
}
// extension and node type handling
namespace detail {
template<typename T, typename TSrc, bool is_nd>
struct TVMValueCast {
static T Apply(const TSrc* self) {
static_assert(!is_nd, "The default case accepts only non-extensions");
return self->template AsObjectRef<T>();
}
};
template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, true> {
static T Apply(const TSrc* self) {
return self->template AsNDArray<T>();
// ObjectRef related conversion handling
// Object can have three possible type codes:
// kNDArrayContainer, kModuleHandle, kObjectHandle
//
// We use type traits to eliminate un-necessary checks.
template<typename TObjectRef, typename>
inline void TVMArgsSetter::operator()(size_t i, const TObjectRef& value) const {
if (value.defined()) {
Object* ptr = value.data_.data_;
if (std::is_base_of<NDArray, TObjectRef>::value ||
(std::is_base_of<TObjectRef, NDArray>::value &&
ptr->IsInstance<NDArray::ContainerType>())) {
values_[i].v_handle = NDArray::FFIGetHandle(value);
type_codes_[i] = kNDArrayContainer;
} else if (std::is_base_of<Module, TObjectRef>::value ||
(std::is_base_of<TObjectRef, Module>::value &&
ptr->IsInstance<Module::ContainerType>())) {
values_[i].v_handle = ptr;
type_codes_[i] = kModuleHandle;
} else {
values_[i].v_handle = ptr;
type_codes_[i] = kObjectHandle;
}
} else {
type_codes_[i] = kNull;
}
};
} // namespace detail
template<typename T, typename>
inline TVMArgValue::operator T() const {
return detail::
TVMValueCast<T, TVMArgValue,
(array_type_info<T>::code > 0)>
::Apply(this);
}
template<typename T, typename>
inline TVMRetValue::operator T() const {
return detail::
TVMValueCast<T, TVMRetValue,
(array_type_info<T>::code > 0)>
::Apply(this);
template<typename TObjectRef, typename>
inline bool TVMPODValue_::IsObjectRef() const {
using ContainerType = typename TObjectRef::ContainerType;
// NOTE: the following code can be optimized by constant folding.
if (std::is_base_of<NDArray, TObjectRef>::value) {
return type_code_ == kNDArrayContainer &&
TVMArrayHandleToObjectHandle(
static_cast<TVMArrayHandle>(value_.v_handle))->IsInstance<ContainerType>();
}
if (std::is_base_of<Module, TObjectRef>::value) {
return type_code_ == kModuleHandle &&
static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
}
return
(std::is_base_of<TObjectRef, NDArray>::value && type_code_ == kNDArrayContainer) ||
(std::is_base_of<TObjectRef, Module>::value && type_code_ == kModuleHandle) ||
(type_code_ == kObjectHandle &&
ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle)));
}
// PackedFunc support
inline TVMRetValue& TVMRetValue::operator=(const DataType& t) {
return this->operator=(t.operator DLDataType());
template<typename TObjectRef>
inline TObjectRef TVMPODValue_::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
using ContainerType = typename TObjectRef::ContainerType;
if (type_code_ == kNull) return TObjectRef(ObjectPtr<Object>(nullptr));
// NOTE: the following code can be optimized by constant folding.
if (std::is_base_of<NDArray, TObjectRef>::value) {
// Casting to a sub-class of NDArray
TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer);
ObjectPtr<Object> data = NDArray::FFIDataFromHandle(
static_cast<TVMArrayHandle>(value_.v_handle));
CHECK(data->IsInstance<ContainerType>())
<< "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey();
return TObjectRef(data);
}
if (std::is_base_of<Module, TObjectRef>::value) {
// Casting to a sub-class of Module
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle));
CHECK(data->IsInstance<ContainerType>())
<< "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey();
return TObjectRef(data);
}
if (type_code_ == kObjectHandle) {
// normal object type check.
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expect " << ObjectTypeChecker<TObjectRef>::TypeName()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(GetObjectPtr<Object>(ptr));
} else if (std::is_base_of<TObjectRef, NDArray>::value &&
type_code_ == kNDArrayContainer) {
// Casting to a base class that NDArray can sub-class
ObjectPtr<Object> data = NDArray::FFIDataFromHandle(
static_cast<TVMArrayHandle>(value_.v_handle));
return TObjectRef(data);
} else if (std::is_base_of<TObjectRef, Module>::value &&
type_code_ == kModuleHandle) {
// Casting to a base class that Module can sub-class
return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
} else {
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
return TObjectRef(ObjectPtr<Object>(nullptr));
}
}
inline TVMRetValue::operator DataType() const {
return DataType(operator DLDataType());
template<typename TObjectRef, typename>
inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) {
const Object* ptr = other.get();
if (ptr != nullptr) {
if (std::is_base_of<NDArray, TObjectRef>::value ||
(std::is_base_of<TObjectRef, NDArray>::value &&
ptr->IsInstance<NDArray::ContainerType>())) {
return operator=(NDArray(std::move(other.data_)));
}
if (std::is_base_of<Module, TObjectRef>::value ||
(std::is_base_of<TObjectRef, Module>::value &&
ptr->IsInstance<Module::ContainerType>())) {
return operator=(Module(std::move(other.data_)));
}
SwitchToObject(kObjectHandle, std::move(other.data_));
} else {
SwitchToPOD(kNull);
}
return *this;
}
inline TVMArgValue::operator DataType() const {
return DataType(operator DLDataType());
template<typename T, typename>
inline TVMArgValue::operator T() const {
return AsObjectRef<T>();
}
inline void TVMArgsSetter::operator()(
size_t i, const DataType& t) const {
this->operator()(i, t.operator DLDataType());
template<typename T, typename>
inline TVMRetValue::operator T() const {
return AsObjectRef<T>();
}
inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
......
......@@ -43,9 +43,9 @@
#ifndef TVM_RUNTIME_REGISTRY_H_
#define TVM_RUNTIME_REGISTRY_H_
#include <tvm/runtime/packed_func.h>
#include <string>
#include <vector>
#include "packed_func.h"
namespace tvm {
namespace runtime {
......@@ -283,22 +283,9 @@ class Registry {
friend struct Manager;
};
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define TVM_ATTRIBUTE_UNUSED
#endif
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM
#define TVM_TYPE_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::ExtTypeVTable* __mk_ ## TVMT
/*!
* \brief Register a function globally.
* \code
......
......@@ -96,6 +96,7 @@ def config_cython():
"../3rdparty/dmlc-core/include",
"../3rdparty/dlpack/include",
],
extra_compile_args=["-std=c++11"],
library_dirs=library_dirs,
libraries=libraries,
language="c++"))
......
......@@ -20,7 +20,7 @@ from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import TVMArrayHandle, TVMNDArrayContainerHandle
from ..runtime_ctypes import TVMArrayHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
......@@ -110,12 +110,17 @@ class NDArrayBase(object):
def _make_array(handle, is_view, is_container):
global _TVM_ND_CLS
handle = ctypes.cast(handle, TVMArrayHandle)
fcreate = _CLASS_NDARRAY
if is_container and _TVM_ND_CLS:
array_type_info = ctypes.cast(handle, TVMNDArrayContainerHandle).array_type_info.value
if array_type_info > 0:
fcreate = _TVM_ND_CLS[array_type_info]
return fcreate(handle, is_view)
if is_container:
tindex = ctypes.c_uint()
check_call(_LIB.TVMArrayGetTypeIndex(handle, ctypes.byref(tindex)))
cls = _TVM_ND_CLS.get(tindex.value, _CLASS_NDARRAY)
else:
cls = _CLASS_NDARRAY
ret = cls.__new__(cls)
ret.handle = handle
ret.is_view = is_view
return ret
_TVM_COMPATS = ()
......@@ -129,9 +134,9 @@ def _reg_extension(cls, fcreate):
_TVM_ND_CLS = {}
def _reg_ndarray(cls, fcreate):
def _register_ndarray(index, cls):
global _TVM_ND_CLS
_TVM_ND_CLS[cls._array_type_code] = fcreate
_TVM_ND_CLS[index] = cls
_CLASS_NDARRAY = None
......
......@@ -21,7 +21,7 @@ from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from ..node_generic import _set_class_node_base
from .ndarray import _register_ndarray, NDArrayBase
ObjectHandle = ctypes.c_void_p
......@@ -39,6 +39,9 @@ def _set_class_node(node_class):
def _register_object(index, cls):
"""register object class"""
if issubclass(cls, NDArrayBase):
_register_ndarray(index, cls)
return
OBJECT_TYPE[index] = cls
......@@ -91,6 +94,3 @@ class ObjectBase(object):
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
self.handle = handle
_set_class_node_base(ObjectBase)
......@@ -19,7 +19,7 @@ from ..base import get_last_ffi_error
from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule
from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t
from libc.stdint cimport int32_t, int64_t, uint64_t, uint32_t, uint8_t, uint16_t
import ctypes
cdef enum TVMTypeCode:
......@@ -78,14 +78,11 @@ ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
ctypedef void* ObjectHandle
ctypedef struct TVMObject:
uint32_t type_index_
int32_t ref_counter_
void (*deleter_)(TVMObject* self)
ctypedef struct TVMNDArrayContainer:
DLTensor dl_tensor
void* manager_ctx
void (*deleter)(DLManagedTensor* self)
int32_t array_type_info
ctypedef TVMNDArrayContainer* TVMNDArrayContainerHandle
ctypedef int (*TVMPackedCFunc)(
TVMValue* args,
......
......@@ -100,17 +100,34 @@ cdef class NDArrayBase:
return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter)
# Import limited object-related function from C++ side to improve the speed
# NOTE: can only use POD-C compatible object in FFI.
cdef extern from "tvm/runtime/ndarray.h" namespace "tvm::runtime":
cdef void* TVMArrayHandleToObjectHandle(DLTensorHandle handle)
cdef c_make_array(void* chandle, is_view, is_container):
global _TVM_ND_CLS
cdef int32_t array_type_info
fcreate = _CLASS_NDARRAY
if is_container and len(_TVM_ND_CLS) > 0:
array_type_info = (<TVMNDArrayContainerHandle>chandle).array_type_info
if array_type_info > 0:
fcreate = _TVM_ND_CLS[array_type_info]
ret = fcreate(None, is_view)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
return ret
if is_container:
tindex = (
<TVMObject*>TVMArrayHandleToObjectHandle(<DLTensorHandle>chandle)).type_index_
if tindex < len(_TVM_ND_CLS):
cls = _TVM_ND_CLS[tindex]
if cls is not None:
ret = cls.__new__(cls)
else:
ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY)
else:
ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
(<NDArrayBase>ret).c_is_view = <int>is_view
return ret
else:
ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
(<NDArrayBase>ret).c_is_view = <int>is_view
return ret
cdef _TVM_COMPATS = ()
......@@ -123,11 +140,16 @@ def _reg_extension(cls, fcreate):
if fcreate:
_TVM_EXT_RET[cls._tvm_tcode] = fcreate
cdef _TVM_ND_CLS = {}
cdef list _TVM_ND_CLS = []
def _reg_ndarray(cls, fcreate):
cdef _register_ndarray(int index, object cls):
"""register object class"""
global _TVM_ND_CLS
_TVM_ND_CLS[cls._array_type_code] = fcreate
while len(_TVM_ND_CLS) <= index:
_TVM_ND_CLS.append(None)
_TVM_ND_CLS[index] = cls
def _make_array(handle, is_view, is_container):
cdef unsigned long long ptr
......
......@@ -16,12 +16,15 @@
# under the License.
"""Maps object type to its constructor"""
from ..node_generic import _set_class_node_base
OBJECT_TYPE = []
cdef list OBJECT_TYPE = []
def _register_object(int index, object cls):
"""register object class"""
if issubclass(cls, NDArrayBase):
_register_ndarray(index, cls)
return
global OBJECT_TYPE
while len(OBJECT_TYPE) <= index:
OBJECT_TYPE.append(None)
OBJECT_TYPE[index] = cls
......@@ -31,14 +34,13 @@ cdef inline object make_ret_object(void* chandle):
global OBJECT_TYPE
global _CLASS_NODE
cdef unsigned tindex
cdef list object_type
cdef object cls
cdef object handle
object_type = OBJECT_TYPE
handle = ctypes_handle(chandle)
CALL(TVMObjectGetTypeIndex(chandle, &tindex))
if tindex < len(object_type):
cls = object_type[tindex]
if tindex < len(OBJECT_TYPE):
cls = OBJECT_TYPE[tindex]
if cls is not None:
obj = cls.__new__(cls)
else:
......@@ -99,6 +101,3 @@ cdef class ObjectBase:
(<FunctionBase>fconstructor).chandle,
kObjectHandle, args, &chandle)
self.chandle = chandle
_set_class_node_base(ObjectBase)
......@@ -22,6 +22,7 @@ from __future__ import absolute_import
import sys
import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
from .node_generic import _set_class_objects
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......@@ -32,15 +33,21 @@ try:
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import convert_to_tvm_func
else:
from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import convert_to_tvm_func
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.object import ObjectBase as _ObjectBase
from ._ctypes.function import convert_to_tvm_func
FunctionHandle = ctypes.c_void_p
......@@ -325,3 +332,4 @@ def _init_api_prefix(module_name, prefix):
setattr(target_module, ff.__name__, ff)
_set_class_function(Function)
_set_class_objects((_ObjectBase, _NDArrayBase, ModuleBase))
......@@ -35,16 +35,16 @@ try:
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import _reg_extension, _reg_ndarray
from ._cy3.core import _reg_extension
else:
from ._cy2.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import _reg_extension, _reg_ndarray
from ._cy2.core import _reg_extension
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.ndarray import _reg_extension, _reg_ndarray
from ._ctypes.ndarray import _reg_extension
def context(dev_type, dev_id=0):
......@@ -348,13 +348,8 @@ def register_extension(cls, fcreate=None):
def _tvm_handle(self):
return self.handle.value
"""
if issubclass(cls, _NDArrayBase):
assert fcreate is not None
assert hasattr(cls, "_array_type_code")
_reg_ndarray(cls, fcreate)
else:
assert hasattr(cls, "_tvm_tcode")
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
assert hasattr(cls, "_tvm_tcode")
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
return cls
......@@ -23,11 +23,11 @@ from .. import _api_internal
from .base import string_types
# Node base class
_CLASS_NODE_BASE = None
_CLASS_OBJECTS = None
def _set_class_node_base(cls):
global _CLASS_NODE_BASE
_CLASS_NODE_BASE = cls
def _set_class_objects(cls):
global _CLASS_OBJECTS
_CLASS_OBJECTS = cls
def _scalar_type_inference(value):
......@@ -67,7 +67,7 @@ def convert_to_node(value):
node : Node
The corresponding node value.
"""
if isinstance(value, _CLASS_NODE_BASE):
if isinstance(value, _CLASS_OBJECTS):
return value
if isinstance(value, bool):
return const(value, 'uint1x1')
......@@ -81,7 +81,7 @@ def convert_to_node(value):
if isinstance(value, dict):
vlist = []
for item in value.items():
if (not isinstance(item[0], _CLASS_NODE_BASE) and
if (not isinstance(item[0], _CLASS_OBJECTS) and
not isinstance(item[0], string_types)):
raise ValueError("key of map must already been a container type")
vlist.append(item[0])
......
......@@ -271,12 +271,3 @@ class TVMArray(ctypes.Structure):
("byte_offset", ctypes.c_uint64)]
TVMArrayHandle = ctypes.POINTER(TVMArray)
class TVMNDArrayContainer(ctypes.Structure):
"""TVM NDArray::Container"""
_fields_ = [("dl_tensor", TVMArray),
("manager_ctx", ctypes.c_void_p),
("deleter", ctypes.c_void_p),
("array_type_info", ctypes.c_int32)]
TVMNDArrayContainerHandle = ctypes.POINTER(TVMNDArrayContainer)
......@@ -27,7 +27,10 @@ from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack
from ._ffi.ndarray import _set_class_ndarray
from ._ffi.ndarray import register_extension
from ._ffi.object import register_object
@register_object
class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
......
......@@ -67,7 +67,7 @@ TVM_REGISTER_API("_Array")
}
auto node = make_node<ArrayNode>();
node->data = std::move(data);
*ret = runtime::ObjectRef(node);
*ret = Array<ObjectRef>(node);
});
TVM_REGISTER_API("_ArrayGetItem")
......@@ -100,28 +100,28 @@ TVM_REGISTER_API("_Map")
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kStr)
<< "key of str map need to be str";
CHECK(args[i + 1].type_code() == kObjectHandle)
CHECK(args[i + 1].IsObjectRef<ObjectRef>())
<< "value of the map to be NodeRef";
data.emplace(std::make_pair(args[i].operator std::string(),
args[i + 1].operator ObjectRef()));
}
auto node = make_node<StrMapNode>();
node->data = std::move(data);
*ret = node;
*ret = Map<ObjectRef, ObjectRef>(node);
} else {
// Container node.
MapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kObjectHandle)
<< "key of str map need to be str";
CHECK(args[i + 1].type_code() == kObjectHandle)
CHECK(args[i].IsObjectRef<ObjectRef>())
<< "key of str map need to be object";
CHECK(args[i + 1].IsObjectRef<ObjectRef>())
<< "value of map to be NodeRef";
data.emplace(std::make_pair(args[i].operator ObjectRef(),
args[i + 1].operator ObjectRef()));
}
auto node = make_node<MapNode>();
node->data = std::move(data);
*ret = node;
*ret = Map<ObjectRef, ObjectRef>(node);
}
});
......@@ -191,7 +191,7 @@ TVM_REGISTER_API("_MapItems")
rkvs->data.push_back(kv.first);
rkvs->data.push_back(kv.second);
}
*ret = rkvs;
*ret = Array<ObjectRef>(rkvs);
} else {
auto* n = static_cast<const StrMapNode*>(ptr);
auto rkvs = make_node<ArrayNode>();
......@@ -199,7 +199,7 @@ TVM_REGISTER_API("_MapItems")
rkvs->data.push_back(ir::StringImm::make(kv.first));
rkvs->data.push_back(kv.second);
}
*ret = rkvs;
*ret = Array<ObjectRef>(rkvs);
}
});
......
......@@ -27,8 +27,12 @@
#include <tvm/runtime/device_api.h>
#include "runtime_base.h"
// deleter for arrays used by DLPack exporter
extern "C" void NDArrayDLPackDeleter(DLManagedTensor* tensor);
extern "C" {
// C-mangled dlpack deleter.
static void TVMNDArrayDLPackDeleter(DLManagedTensor* tensor);
// helper function to get NDArray's type index, only used by ctypes.
TVM_DLL int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex);
}
namespace tvm {
namespace runtime {
......@@ -53,8 +57,8 @@ inline size_t GetDataAlignment(const DLTensor& arr) {
struct NDArray::Internal {
// Default deleter for the container
static void DefaultDeleter(NDArray::Container* ptr) {
using tvm::runtime::NDArray;
static void DefaultDeleter(Object* ptr_obj) {
auto* ptr = static_cast<NDArray::Container*>(ptr_obj);
if (ptr->manager_ctx != nullptr) {
static_cast<NDArray::Container*>(ptr->manager_ctx)->DecRef();
} else if (ptr->dl_tensor.data != nullptr) {
......@@ -68,7 +72,8 @@ struct NDArray::Internal {
// that are not allocated inside of TVM.
// This enables us to create NDArray from memory allocated by other
// frameworks that are DLPack compatible
static void DLPackDeleter(NDArray::Container* ptr) {
static void DLPackDeleter(Object* ptr_obj) {
auto* ptr = static_cast<NDArray::Container*>(ptr_obj);
DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx);
if (tensor->deleter != nullptr) {
(*tensor->deleter)(tensor);
......@@ -81,12 +86,13 @@ struct NDArray::Internal {
DLDataType dtype,
DLContext ctx) {
VerifyDataType(dtype);
// critical zone
// critical zone: construct header
NDArray::Container* data = new NDArray::Container();
data->deleter = DefaultDeleter;
NDArray ret(data);
ret.data_ = data;
data->SetDeleter(DefaultDeleter);
// RAII now in effect
NDArray ret(GetObjectPtr<Object>(data));
// setup shape
data->shape_ = std::move(shape);
data->dl_tensor.shape = dmlc::BeginPtr(data->shape_);
......@@ -98,45 +104,57 @@ struct NDArray::Internal {
return ret;
}
// Implementation of API function
static DLTensor* MoveAsDLTensor(NDArray arr) {
DLTensor* tensor = const_cast<DLTensor*>(arr.operator->());
CHECK(reinterpret_cast<DLTensor*>(arr.data_) == tensor);
arr.data_ = nullptr;
return tensor;
static DLTensor* MoveToFFIHandle(NDArray arr) {
DLTensor* handle = NDArray::FFIGetHandle(arr);
ObjectRef::FFIClearAfterMove(&arr);
return handle;
}
static void FFIDecRef(TVMArrayHandle tensor) {
NDArray::FFIDecRef(tensor);
}
// Container to DLManagedTensor
static DLManagedTensor* ToDLPack(TVMArrayHandle handle) {
auto* from = static_cast<NDArray::Container*>(
reinterpret_cast<NDArray::ContainerBase*>(handle));
return ToDLPack(from);
}
static DLManagedTensor* ToDLPack(NDArray::Container* from) {
CHECK(from != nullptr);
DLManagedTensor* ret = new DLManagedTensor();
ret->dl_tensor = from->dl_tensor;
ret->manager_ctx = from;
from->IncRef();
ret->deleter = NDArrayDLPackDeleter;
ret->deleter = TVMNDArrayDLPackDeleter;
return ret;
}
// Delete dlpack object.
static void NDArrayDLPackDeleter(DLManagedTensor* tensor) {
static_cast<NDArray::Container*>(tensor->manager_ctx)->DecRef();
delete tensor;
}
};
NDArray NDArray::CreateView(std::vector<int64_t> shape,
DLDataType dtype) {
NDArray NDArray::CreateView(std::vector<int64_t> shape, DLDataType dtype) {
CHECK(data_ != nullptr);
CHECK(data_->dl_tensor.strides == nullptr)
CHECK(get_mutable()->dl_tensor.strides == nullptr)
<< "Can only create view for compact tensor";
NDArray ret = Internal::Create(shape, dtype, data_->dl_tensor.ctx);
ret.data_->dl_tensor.byte_offset =
this->data_->dl_tensor.byte_offset;
size_t curr_size = GetDataSize(this->data_->dl_tensor);
size_t view_size = GetDataSize(ret.data_->dl_tensor);
NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.ctx);
ret.get_mutable()->dl_tensor.byte_offset =
this->get_mutable()->dl_tensor.byte_offset;
size_t curr_size = GetDataSize(this->get_mutable()->dl_tensor);
size_t view_size = GetDataSize(ret.get_mutable()->dl_tensor);
CHECK_LE(view_size, curr_size)
<< "Tries to create a view that has bigger memory than current one";
// increase ref count
this->data_->IncRef();
ret.data_->manager_ctx = this->data_;
ret.data_->dl_tensor.data = this->data_->dl_tensor.data;
get_mutable()->IncRef();
ret.get_mutable()->manager_ctx = get_mutable();
ret.get_mutable()->dl_tensor.data = get_mutable()->dl_tensor.data;
return ret;
}
DLManagedTensor* NDArray::ToDLPack() const {
return Internal::ToDLPack(data_);
return Internal::ToDLPack(get_mutable());
}
NDArray NDArray::Empty(std::vector<int64_t> shape,
......@@ -144,9 +162,9 @@ NDArray NDArray::Empty(std::vector<int64_t> shape,
DLContext ctx) {
NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor);
size_t alignment = GetDataAlignment(ret.data_->dl_tensor);
ret.data_->dl_tensor.data =
size_t size = GetDataSize(ret.get_mutable()->dl_tensor);
size_t alignment = GetDataAlignment(ret.get_mutable()->dl_tensor);
ret.get_mutable()->dl_tensor.data =
DeviceAPI::Get(ret->ctx)->AllocDataSpace(
ret->ctx, size, alignment, ret->dtype);
return ret;
......@@ -154,10 +172,12 @@ NDArray NDArray::Empty(std::vector<int64_t> shape,
NDArray NDArray::FromDLPack(DLManagedTensor* tensor) {
NDArray::Container* data = new NDArray::Container();
data->deleter = Internal::DLPackDeleter;
// construct header
data->SetDeleter(Internal::DLPackDeleter);
// fill up content.
data->manager_ctx = tensor;
data->dl_tensor = tensor->dl_tensor;
return NDArray(data);
return NDArray(GetObjectPtr<Object>(data));
}
void NDArray::CopyFromTo(const DLTensor* from,
......@@ -184,17 +204,24 @@ void NDArray::CopyFromTo(const DLTensor* from,
}
std::vector<int64_t> NDArray::Shape() const {
return data_->shape_;
return get_mutable()->shape_;
}
TVM_REGISTER_OBJECT_TYPE(NDArray::Container);
} // namespace runtime
} // namespace tvm
using namespace tvm::runtime;
void NDArrayDLPackDeleter(DLManagedTensor* tensor) {
static_cast<NDArray::Container*>(tensor->manager_ctx)->DecRef();
delete tensor;
void TVMNDArrayDLPackDeleter(DLManagedTensor* tensor) {
NDArray::Internal::NDArrayDLPackDeleter(tensor);
}
int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex) {
API_BEGIN();
*out_tindex = TVMArrayHandleToObjectHandle(handle)->type_index();
API_END();
}
int TVMArrayAlloc(const tvm_index_t* shape,
......@@ -213,14 +240,14 @@ int TVMArrayAlloc(const tvm_index_t* shape,
DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
*out = NDArray::Internal::MoveAsDLTensor(
*out = NDArray::Internal::MoveToFFIHandle(
NDArray::Empty(std::vector<int64_t>(shape, shape + ndim), dtype, ctx));
API_END();
}
int TVMArrayFree(TVMArrayHandle handle) {
API_BEGIN();
reinterpret_cast<NDArray::Container*>(handle)->DecRef();
NDArray::Internal::FFIDecRef(handle);
API_END();
}
......@@ -235,14 +262,14 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
int TVMArrayFromDLPack(DLManagedTensor* from,
TVMArrayHandle* out) {
API_BEGIN();
*out = NDArray::Internal::MoveAsDLTensor(NDArray::FromDLPack(from));
*out = NDArray::Internal::MoveToFFIHandle(NDArray::FromDLPack(from));
API_END();
}
int TVMArrayToDLPack(TVMArrayHandle from,
DLManagedTensor** out) {
API_BEGIN();
*out = NDArray::Internal::ToDLPack(reinterpret_cast<NDArray::Container*>(from));
*out = NDArray::Internal::ToDLPack(from);
API_END();
}
......
......@@ -59,7 +59,8 @@ class RPCWrappedFunc {
const TVMArgValue& arg);
// deleter of RPC remote array
static void RemoteNDArrayDeleter(NDArray::Container* ptr) {
static void RemoteNDArrayDeleter(Object* obj) {
auto* ptr = static_cast<NDArray::Container*>(obj);
RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data);
space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx);
delete space;
......@@ -71,12 +72,12 @@ class RPCWrappedFunc {
void* nd_handle) {
NDArray::Container* data = new NDArray::Container();
data->manager_ctx = nd_handle;
data->deleter = RemoteNDArrayDeleter;
data->SetDeleter(RemoteNDArrayDeleter);
RemoteSpace* space = new RemoteSpace();
space->sess = sess;
space->data = tensor->data;
data->dl_tensor.data = space;
NDArray ret(data);
NDArray ret(GetObjectPtr<Object>(data));
// RAII now in effect
data->shape_ = std::vector<int64_t>(
tensor->shape, tensor->shape + tensor->ndim);
......
......@@ -787,9 +787,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
TVMValue ret_value_pack[2];
int ret_tcode_pack[2];
rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]);
NDArray::Container* nd = static_cast<NDArray::Container*>(ret_value_pack[0].v_handle);
ret_value_pack[1].v_handle = nd;
ret_value_pack[1].v_handle = ret_value_pack[0].v_handle;
ret_tcode_pack[1] = kHandle;
SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true);
} else {
......@@ -1190,7 +1188,8 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) {
void* handle = args[0];
static_cast<NDArray::Container*>(handle)->DecRef();
static_cast<NDArray::Container*>(
reinterpret_cast<NDArray::ContainerBase*>(handle))->DecRef();
}
void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) {
......
......@@ -31,7 +31,8 @@ namespace tvm {
namespace runtime {
namespace vm {
static void BufferDeleter(NDArray::Container* ptr) {
static void BufferDeleter(Object* obj) {
auto* ptr = static_cast<NDArray::Container*>(obj);
CHECK(ptr->manager_ctx != nullptr);
Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
MemoryManager::Global()->GetAllocator(buffer->ctx)->
......@@ -40,7 +41,8 @@ static void BufferDeleter(NDArray::Container* ptr) {
delete ptr;
}
void StorageObj::Deleter(NDArray::Container* ptr) {
void StorageObj::Deleter(Object* obj) {
auto* ptr = static_cast<NDArray::Container*>(obj);
// When invoking AllocNDArray we don't own the underlying allocation
// and should not delete the buffer, but instead let it be reclaimed
// by the storage object's destructor.
......@@ -77,16 +79,23 @@ NDArray StorageObj::AllocNDArray(size_t offset, std::vector<int64_t> shape, DLDa
// TODO(@jroesch): generalize later to non-overlapping allocations.
CHECK_EQ(offset, 0u);
VerifyDataType(dtype);
// crtical zone: allocate header, cannot throw
NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, this->buffer.ctx);
container->deleter = StorageObj::Deleter;
container->SetDeleter(StorageObj::Deleter);
size_t needed_size = GetDataSize(container->dl_tensor);
// TODO(@jroesch): generalize later to non-overlapping allocations.
CHECK(needed_size == this->buffer.size)
<< "size mistmatch required " << needed_size << " found " << this->buffer.size;
this->IncRef();
container->manager_ctx = reinterpret_cast<void*>(this);
container->dl_tensor.data = this->buffer.data;
return NDArray(container);
NDArray ret(GetObjectPtr<Object>(container));
// RAII in effect, now run the check.
// TODO(@jroesch): generalize later to non-overlapping allocations.
CHECK(needed_size == this->buffer.size)
<< "size mistmatch required " << needed_size << " found " << this->buffer.size;
return ret;
}
MemoryManager* MemoryManager::Global() {
......@@ -108,14 +117,14 @@ Allocator* MemoryManager::GetAllocator(TVMContext ctx) {
NDArray Allocator::Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) {
VerifyDataType(dtype);
NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, ctx);
container->deleter = BufferDeleter;
container->SetDeleter(BufferDeleter);
size_t size = GetDataSize(container->dl_tensor);
size_t alignment = GetDataAlignment(container->dl_tensor);
Buffer *buffer = new Buffer;
*buffer = this->Alloc(size, alignment, dtype);
container->manager_ctx = reinterpret_cast<void*>(buffer);
container->dl_tensor.data = buffer->data;
return NDArray(container);
return NDArray(GetObjectPtr<Object>(container));
}
} // namespace vm
......
......@@ -120,7 +120,7 @@ class StorageObj : public Object {
DLDataType dtype);
/*! \brief The deleter for an NDArray when allocated from underlying storage. */
static void Deleter(NDArray::Container* ptr);
static void Deleter(Object* ptr);
~StorageObj() {
auto alloc = MemoryManager::Global()->GetAllocator(buffer.ctx);
......
......@@ -22,6 +22,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir.h>
TEST(PackedFunc, Basic) {
......@@ -178,6 +179,69 @@ TEST(TypedPackedFunc, HighOrder) {
CHECK_EQ(f1(3), 4);
}
TEST(PackedFunc, ObjectConversion) {
using namespace tvm;
using namespace tvm::runtime;
TVMRetValue rv;
auto x = NDArray::Empty(
{}, String2TVMType("float32"),
TVMContext{kDLCPU, 0});
// assign null
rv = ObjectRef();
CHECK_EQ(rv.type_code(), kNull);
// Can assign NDArray to ret type
rv = x;
CHECK_EQ(rv.type_code(), kNDArrayContainer);
// Even if we assign base type it still shows as NDArray
rv = ObjectRef(x);
CHECK_EQ(rv.type_code(), kNDArrayContainer);
// Check convert back
CHECK(rv.operator NDArray().same_as(x));
CHECK(rv.operator ObjectRef().same_as(x));
CHECK(!rv.IsObjectRef<Expr>());
auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args[0].type_code(), kNDArrayContainer);
CHECK(args[0].operator NDArray().same_as(x));
CHECK(args[0].operator ObjectRef().same_as(x));
CHECK(args[1].operator ObjectRef().get() == nullptr);
CHECK(args[1].operator NDArray().get() == nullptr);
CHECK(args[1].operator Module().get() == nullptr);
CHECK(args[1].operator Array<NDArray>().get() == nullptr);
CHECK(!args[0].IsObjectRef<Expr>());
});
pf1(x, ObjectRef());
pf1(ObjectRef(x), NDArray());
// testcases for modules
auto* pf = tvm::runtime::Registry::Get("module.source_module_create");
CHECK(pf != nullptr);
Module m = (*pf)("", "xyz");
rv = m;
CHECK_EQ(rv.type_code(), kModuleHandle);
// Even if we assign base type it still shows as NDArray
rv = ObjectRef(m);
CHECK_EQ(rv.type_code(), kModuleHandle);
// Check convert back
CHECK(rv.operator Module().same_as(m));
CHECK(rv.operator ObjectRef().same_as(m));
CHECK(!rv.IsObjectRef<NDArray>());
auto pf2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args[0].type_code(), kModuleHandle);
CHECK(args[0].operator Module().same_as(m));
CHECK(args[0].operator ObjectRef().same_as(m));
CHECK(args[1].operator ObjectRef().get() == nullptr);
CHECK(args[1].operator NDArray().get() == nullptr);
CHECK(args[1].operator Module().get() == nullptr);
CHECK(!args[0].IsObjectRef<Expr>());
});
pf2(m, ObjectRef());
pf2(ObjectRef(m), Module());
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np
def test_array():
a = tvm.convert([1,2,3])
......@@ -71,6 +72,14 @@ def test_in_container():
assert tvm.make.StringImm('a') in arr
assert 'd' not in arr
def test_ndarray_container():
x = tvm.nd.array([1,2,3])
arr = tvm.convert([x, x])
assert arr[0].same_as(x)
assert arr[1].same_as(x)
assert isinstance(arr[0], tvm.nd.NDArray)
if __name__ == "__main__":
test_str_map()
test_array()
......@@ -78,3 +87,4 @@ if __name__ == "__main__":
test_array_save_load_json()
test_map_save_load_json()
test_in_container()
test_ndarray_container()
......@@ -32,7 +32,7 @@ def gen_engine_header():
#include <vector>
class Engine {
};
#endif
'''
header_file = header_file_dir_path.relpath("gcc_engine.h")
......@@ -45,7 +45,7 @@ def generate_engine_module():
#include <tvm/runtime/c_runtime_api.h>
#include <dlpack/dlpack.h>
#include "gcc_engine.h"
extern "C" void gcc_1_(float* gcc_input4, float* gcc_input5,
float* gcc_input6, float* gcc_input7, float* out) {
Engine engine;
......
......@@ -35,7 +35,8 @@ rm -rf lib
make
cd ../..
python3 -m pytest -v apps/extension/tests
TVM_FFI=cython python3 -m pytest -v apps/extension/tests
TVM_FFI=ctypes python3 -m pytest -v apps/extension/tests
TVM_FFI=ctypes python3 -m pytest -v tests/python/integration
TVM_FFI=ctypes python3 -m pytest -v tests/python/contrib
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment