Commit a0bd3786 by Tianqi Chen Committed by Zhi

[RFC][RUNTIME] Introduce new object protocol. (#4115)

* [RUNTIME] Introduce new object protocol.

This PR introduces a new object protocol to unify the node and object.
We also updated the existing runtime::vm code to make use of the new system.

Update to the node will be done in a follow up PR.

Other changes:

- Remove object related code in json serializer as that code logic was not complete
  and we have a separate serializer for VM, can revisit later.

* address review  comment

* Fix the child slot logic
parent 68472596
......@@ -70,7 +70,7 @@ cpplint:
python3 3rdparty/dmlc-core/scripts/lint.py vta cpp vta/include vta/src
python3 3rdparty/dmlc-core/scripts/lint.py topi cpp topi/include;
python3 3rdparty/dmlc-core/scripts/lint.py nnvm cpp nnvm/include nnvm/src;
python3 3rdparty/dmlc-core/scripts/lint.py tvm cpp include src verilog\
python3 3rdparty/dmlc-core/scripts/lint.py tvm cpp include src \
examples/extension/src examples/graph_executor/src
pylint:
......
......@@ -42,7 +42,7 @@ namespace runtime {
// forward declaration
class NDArray;
// forward declaration
class Object;
class ObjectRef;
} // namespace runtime
/*!
......@@ -63,7 +63,7 @@ class TVM_DLL AttrVisitor {
virtual void Visit(const char* key, DataType* value) = 0;
virtual void Visit(const char* key, NodeRef* value) = 0;
virtual void Visit(const char* key, runtime::NDArray* value) = 0;
virtual void Visit(const char* key, runtime::Object* value) = 0;
virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/runtime/memory.h
* \brief Runtime memory management.
*/
#ifndef TVM_RUNTIME_MEMORY_H_
#define TVM_RUNTIME_MEMORY_H_
#include <utility>
#include <type_traits>
#include "object.h"
namespace tvm {
namespace runtime {
/*!
* \brief Allocate an object using default allocator.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
*/
template<typename T, typename... Args>
inline ObjectPtr<T> make_object(Args&&... args);
// Detail implementations after this
//
// The current design allows swapping the
// allocator pattern when necessary.
//
// Possible future allocator optimizations:
// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr)
// - Thread-local object pools: one pool per size and alignment requirement.
// - Can specialize by type of object to give the specific allocator to each object.
/*!
* \brief Base class of object allocators that implements make.
* Use curiously recurring template pattern.
*
* \tparam Derived The derived class.
*/
template<typename Derived>
class ObjAllocatorBase {
public:
/*!
* \tparam T The type to be allocated.
* \tparam Args The constructor signature.
* \param args The arguments.
*/
template<typename T, typename... Args>
inline ObjectPtr<T> make(Args&&... args) {
using Handler = typename Derived::template Handler<T>;
static_assert(std::is_base_of<Object, T>::value,
"make_node can only be used to create NodeBase");
T* ptr = Handler::New(static_cast<Derived*>(this),
std::forward<Args>(args)...);
ptr->type_index_ = T::type_index();
ptr->deleter_ = Handler::Deleter();
return ObjectPtr<T>(ptr);
}
};
// Simple allocator that uses new/delete.
class SimpleObjAllocator :
public ObjAllocatorBase<SimpleObjAllocator> {
public:
template<typename T>
class Handler {
public:
template<typename... Args>
static T* New(SimpleObjAllocator*, Args&&... args) {
// NOTE: the first argument is not needed for SimpleObjAllocator
// It is reserved for special allocators that needs to recycle
// the object to itself (e.g. in the case of object pool).
//
// In the case of an object pool, an allocator needs to create
// a special chunk memory that hides reference to the allocator
// and call allocator's release function in the deleter.
return new T(std::forward<Args>(args)...);
}
static Object::FDeleter Deleter() {
return Deleter_;
}
private:
static void Deleter_(Object* ptr) {
delete static_cast<T*>(ptr);
}
};
};
template<typename T, typename... Args>
inline ObjectPtr<T> make_object(Args&&... args) {
return SimpleObjAllocator().make<T>(std::forward<Args>(args)...);
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_MEMORY_H_
......@@ -16,117 +16,249 @@
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/runtime/object.h
* \brief A managed object in the TVM runtime.
*/
#ifndef TVM_RUNTIME_OBJECT_H_
#define TVM_RUNTIME_OBJECT_H_
#include <tvm/runtime/ndarray.h>
#include <memory>
#include <type_traits>
#include <string>
#include <utility>
#include <vector>
#include "c_runtime_api.h"
/*!
* \brief Whether or not use atomic reference counter.
* If the reference counter is not atomic,
* an object cannot be owned by multiple threads.
* We can, however, move an object across threads
*/
#ifndef TVM_OBJECT_ATOMIC_REF_COUNTER
#define TVM_OBJECT_ATOMIC_REF_COUNTER 1
#endif
#if TVM_OBJECT_ATOMIC_REF_COUNTER
#include <atomic>
#endif // TVM_OBJECT_ATOMIC_REF_COUNTER
namespace tvm {
namespace runtime {
template <typename T>
class ObjectPtr;
class Object;
enum struct ObjectTag {
/*! \brief The tag of a tensor. */
kTensor = 0U,
/*! \brief The tag of a closure. */
kClosure = 1U,
/*! \brief The tag of a structure. */
kDatatype = 2U,
/*! \brief list of the type index. */
enum TypeIndex {
/*! \brief Root object type. */
kRoot = 0,
kVMTensor = 1,
kVMClosure = 2,
kVMDatatype = 3,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
};
std::ostream& operator<<(std::ostream& os, const ObjectTag&);
struct ObjectCell {
public:
/*!
* \brief The type of object deleter.
* \param The self pointer to the ObjectCell.
*/
typedef void (*FDeleter)(ObjectCell* self);
/*! \brief The tag of the object.
/*!
* \brief base class of all object containers.
*
* Sub-class of objects should declare the following static constexpr fields:
*
* - _type_index:
* Static type index of the object, if assigned to TypeIndex::kDynamic
* the type index will be assigned during runtime.
* Runtime type index can be accessed by ObjectType::type_index();
* - _type_key:
* The unique string identifier of tyep type.
* - _type_final:
* Whether the type is terminal type(there is no subclass of the type in the object system).
* This field is automatically set by marco TVM_DECLARE_FINAL_OBJECT_INFO
* It is still OK to sub-class a terminal object type T and construct it using make_object.
* But IsInstance check will only show that the object type is T(instead of the sub-class).
*
* The following two fields are necessary for base classes that can be sub-classed.
*
* - _type_child_slots:
* Number of reserved type index slots for child classes.
* Used for runtime optimization for type checking in IsInstance.
* If an object's type_index is within range of [type_index, type_index + _type_child_slots]
* Then the object can be quickly decided as sub-class of the current object class.
* If not, a fallback mechanism is used to check the global type table.
* Recommendation: set to estimate number of children needed.
* - _type_child_slots_can_overflow:
* Whether we can add additional child classes even if the number of child classes
* exceeds the _type_child_slots. A fallback mechanism to check global type table will be used.
* Recommendation: set to false for optimal runtime speed if we know exact number of children.
*
* Two macros are used to declare helper functions in the object:
* - Use TVM_DECLARE_BASE_OBJECT_INFO for object classes that can be sub-classed.
* - Use TVM_DECLARE_FINAL_OBJECT_INFO for object classes that cannot be sub-classed.
*
* New objects can be created using make_object function.
* Which will automatically populate the type_index and deleter of the object.
*
* \sa make_object
* \sa ObjectPtr
* \sa ObjectRef
*
* \code
*
* // Create a base object
* class BaseObj : public Object {
* public:
* // object fields
* int field0;
*
* Describes which type of value
* is represented by this object.
* // object properties
* static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
* static constexpr const char* _type_key = "test.BaseObj";
* TVM_DECLARE_BASE_OBJECT_INFO(BaseObj, Object);
* };
*
* class ObjLeaf : public ObjBase {
* public:
* // fields
* int child_field0;
* // object properties
* static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
* static constexpr const char* _type_key = "test.LeafObj";
* TVM_DECLARE_BASE_OBJECT_INFO(LeaffObj, Object);
* };
*
* // The following code should be put into a cc file.
* TVM_REGISTER_OBJECT_TYPE(ObjBase);
* TVM_REGISTER_OBJECT_TYPE(ObjLeaf);
*
* // Usage example.
* void TestObjects() {
* // create an object
* ObjectRef leaf_ref(make_object<LeafObj>());
* // cast to a specific instance
* const LeafObj* leaf_ptr = leaf_ref.as<LeafObj>();
* CHECK(leaf_ptr != nullptr);
* // can also cast to the base class.
* CHECK(leaf_ref.as<BaseObj>() != nullptr);
* }
*
* \endcode
*/
ObjectTag tag;
class Object {
public:
/*!
* \brief Increment the reference count.
* \brief Object deleter
* \param self pointer to the Object.
*/
void IncRef() { ref_counter_.fetch_add(1, std::memory_order_relaxed); }
typedef void (*FDeleter)(Object* self);
/*! \return The internal type index of the object. */
uint32_t type_index() const {
return type_index_;
}
/*!
* \brief Decrement the reference count.
* Check if the object is an instance of TargetType.
* \tparam TargetType The target type to be checked.
* \return Whether the target type is true.
*/
void DecRef() {
if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
std::atomic_thread_fence(std::memory_order_acquire);
if (this->deleter_ != nullptr) {
(*this->deleter_)(this);
}
}
template<typename TargetType>
inline bool IsInstance() const;
#if TVM_OBJECT_ATOMIC_REF_COUNTER
using RefCounterType = std::atomic<int32_t>;
#else
using RefCounterType = int32_t;
#endif
// Object type properties
static constexpr const char* _type_key = "Object";
static constexpr bool _type_final = false;
static constexpr uint32_t _type_child_slots = 0;
static constexpr bool _type_child_slots_can_overflow = true;
static const uint32_t _GetOrAllocRuntimeTypeIndex() {
return 0;
}
protected:
// default constructor and copy constructor
ObjectCell() {}
explicit ObjectCell(ObjectTag tag) : tag(tag) {}
// override the copy and assign constructors to do nothing.
// This is to make sure only contents, but not deleter and ref_counter
// are copied when a child class copies itself.
ObjectCell(const ObjectCell& other) { // NOLINT(*)
}
ObjectCell(ObjectCell&& other) { // NOLINT(*)
}
ObjectCell& operator=(const ObjectCell& other) { // NOLINT(*)
return *this;
}
ObjectCell& operator=(ObjectCell&& other) { // NOLINT(*)
return *this;
}
private:
/*! \brief Internal reference counter */
std::atomic<int> ref_counter_{0};
// The fields of the base object cell.
/*! \brief Type index(tag) that indicates the type of the object. */
uint32_t type_index_{0};
/*! \brief The internal reference counter */
RefCounterType ref_counter_{0};
/*!
* \brief deleter of this object to enable customized allocation.
* If the deleter is nullptr, no deletion will be performed.
* The creator of the Node must always set the deleter field properly.
* The creator of the object must always set the deleter field properly.
*/
FDeleter deleter_ = nullptr;
// Invariant checks.
static_assert(sizeof(int32_t) == sizeof(RefCounterType) &&
alignof(int32_t) == sizeof(RefCounterType),
"RefCounter ABI check.");
/*!
* \brief Get the type index using type key.
*
* When the function is first time called for a type,
* it will register the type to the type table in the runtime.
* If the static_tindex is TypeIndex::kDynamic, the function will
* allocate a runtime type index.
* Otherwise, we will populate the type table and return the static index.
*
* \param key the type key.
* \param static_tindex The current _type_index field.
* can be TypeIndex::kDynamic.
* \param parent_tindex The index of the parent.
* \param type_child_slots Number of slots reserved for its children.
* \param type_child_slots_can_overflow Whether to allow child to overflow the slots.
* \return The allocated type index.
*/
TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex(
const char* key,
uint32_t static_tindex,
uint32_t parent_tindex,
uint32_t type_child_slots,
bool type_child_slots_can_overflow);
int use_count() const { return ref_counter_.load(std::memory_order_relaxed); }
/*!
* \brief Get the type key of the corresponding index from runtime.
* \param tindex The type index.
*/
TVM_DLL static std::string TypeIndex2Key(uint32_t tindex);
// friend declaration
template <typename>
friend class ObjectPtr;
/*!
* \brief Get the type index of the corresponding key from runtime.
* \param key The type key.
*/
TVM_DLL static uint32_t TypeKey2Index(const char* key);
template <typename Y, typename... Args>
friend ObjectPtr<Y> MakeObject(Args&&...);
private:
// reference counter related operations
/*! \brief developer function, increases reference counter. */
inline void IncRef();
/*!
* \brief developer function, decrease reference counter.
* \note The deleter will be called when ref_counter_ becomes zero.
*/
inline void DecRef();
/*!
* \return The usage count of the cell.
* \note We use stl style naming to be consistent with known API in shared_ptr.
*/
inline int use_count() const;
/*!
* \brief Check of this object is derived from the parent.
* \param parent_tindex The parent type index.
* \return The derivation results.
*/
TVM_DLL bool DerivedFrom(uint32_t parent_tindex) const;
// friend classes
template<typename>
friend class ObjAllocatorBase;
template<typename>
friend class ObjectPtr;
friend class TVMRetValue;
};
/*!
* \brief A custom smart pointer for Object.
* must be subclass of NodeBase
* \tparam T the content data type.
* \sa make_object
*/
template <typename T>
class ObjectPtr {
......@@ -159,7 +291,6 @@ class ObjectPtr {
: data_(other.data_) {
other.data_ = nullptr;
}
/*!
* \brief move constructor
* \param other The value to be moved
......@@ -171,10 +302,10 @@ class ObjectPtr {
"can only assign of child class ObjectPtr to parent");
other.data_ = nullptr;
}
/*! \brief destructor */
~ObjectPtr() { this->reset(); }
~ObjectPtr() {
this->reset();
}
/*!
* \brief Swap this array with another Object
* \param other The other Object
......@@ -182,24 +313,24 @@ class ObjectPtr {
void swap(ObjectPtr<T>& other) { // NOLINT(*)
std::swap(data_, other.data_);
}
/*!
* \return Get the content of the pointer
*/
T* get() const { return static_cast<T*>(data_); }
T* get() const {
return static_cast<T*>(data_);
}
/*!
* \return The pointer
*/
T* operator->() const { return get(); }
T* operator->() const {
return get();
}
/*!
* \return The reference
*/
T& operator*() const { // NOLINT(*)
return *get();
}
/*!
* \brief copy assignmemt
* \param other The value to be assigned.
......@@ -211,7 +342,6 @@ class ObjectPtr {
ObjectPtr(other).swap(*this); // NOLINT(*)
return *this;
}
/*!
* \brief move assignmemt
* \param other The value to be assigned.
......@@ -222,7 +352,6 @@ class ObjectPtr {
ObjectPtr(std::move(other)).swap(*this); // NOLINT(*)
return *this;
}
/*! \brief reset the content of ptr to be nullptr */
void reset() {
if (data_ != nullptr) {
......@@ -230,163 +359,238 @@ class ObjectPtr {
data_ = nullptr;
}
}
/*! \return The use count of the ptr, for debug purposes */
int use_count() const { return data_ != nullptr ? data_->use_count() : 0; }
int use_count() const {
return data_ != nullptr ? data_->use_count() : 0;
}
/*! \return whether the reference is unique */
bool unique() const { return data_ != nullptr && data_->use_count() == 1; }
bool unique() const {
return data_ != nullptr && data_->use_count() == 1;
}
/*! \return Whether two ObjectPtr do not equal each other */
bool operator==(const ObjectPtr<T>& other) const { return data_ == other.data_; }
bool operator==(const ObjectPtr<T>& other) const {
return data_ == other.data_;
}
/*! \return Whether two ObjectPtr equals each other */
bool operator!=(const ObjectPtr<T>& other) const { return data_ != other.data_; }
bool operator!=(const ObjectPtr<T>& other) const {
return data_ != other.data_;
}
/*! \return Whether the pointer is nullptr */
bool operator==(std::nullptr_t null) const { return data_ == nullptr; }
/*! \return Whether the pointer is not nullptr */
bool operator!=(std::nullptr_t null) const { return data_ != nullptr; }
/* ObjectPtr's support custom allocators.
*
* The below allocator represents the simplest
* possible impl. It can be easily swapped
* for customized executor's, different allocation
* strategies, and so on.
*
* See memory.h for more discussion on NodePtr's
* allocator.
*/
class StdAllocator {
public:
template <typename... Args>
static T* New(Args&&... args) {
return new T(std::forward<Args>(args)...);
bool operator==(std::nullptr_t null) const {
return data_ == nullptr;
}
static ObjectCell::FDeleter Deleter() { return Deleter_; }
private:
static void Deleter_(ObjectCell* ptr) { delete static_cast<T*>(ptr); }
};
template <typename U>
ObjectPtr<U> As() const {
auto ptr = reinterpret_cast<U*>(get());
return ObjectPtr<U>(ptr);
/*! \return Whether the pointer is not nullptr */
bool operator!=(std::nullptr_t null) const {
return data_ != nullptr;
}
private:
/*! \brief internal pointer field */
ObjectCell* data_{nullptr};
Object* data_{nullptr};
/*!
* \brief constructor from NodeBase
* \param data The node base pointer
* \param data The data pointer
*/
// TODO(jroesch): NodePtr design doesn't really work here due to the passing.
public:
explicit ObjectPtr(ObjectCell* data) : data_(data) {
explicit ObjectPtr(Object* data) : data_(data) {
if (data != nullptr) {
data_->IncRef();
}
}
private:
template <typename Y, typename... Args>
friend ObjectPtr<Y> MakeObject(Args&&...);
template <typename>
// friend classes
friend class Object;
friend class ObjectRef;
template<typename>
friend class ObjectPtr;
friend class NDArray;
template<typename>
friend class ObjAllocatorBase;
friend class TVMPODValue_;
friend class TVMArgValue;
friend class TVMArgsSetter;
friend class TVMRetValue;
friend class RPCWrappedFunc;
};
struct TensorCell;
struct DatatypeCell;
struct ClosureCell;
/*!
* \brief A managed object in the TVM runtime.
/*! \brief Base class of all object reference */
class ObjectRef {
public:
/*! \brief default constructor */
ObjectRef() = default;
/*! \brief Constructor from existing object ptr */
explicit ObjectRef(ObjectPtr<Object> data) : data_(data) {}
/*! \return the internal object pointer */
inline const Object* get() const;
/*! \return the internal node pointer */
inline const Object* operator->() const;
/*!
* \brief Try to downcast the internal Object to a
* raw pointer of a corresponding type.
*
* For example a tuple, list, closure, and so on.
* The function will return a nullptr if the cast failed.
*
* Maintains a reference count for the object.
* if (const Add *add = node_ref.As<Add>()) {
* // This is an add node
* }
* \tparam ObjectType the target type, must be a subtype of Object/
*/
class Object {
public:
ObjectPtr<ObjectCell> ptr_;
explicit Object(ObjectPtr<ObjectCell> ptr) : ptr_(ptr) {}
explicit Object(ObjectCell* ptr) : ptr_(ptr) {}
Object() : ptr_() {}
Object(const Object& obj) : ptr_(obj.ptr_) {}
ObjectCell* operator->() { return this->ptr_.operator->(); }
const ObjectCell* operator->() const { return this->ptr_.operator->(); }
/*! \brief Construct a tensor object. */
static Object Tensor(const NDArray& data);
/*! \brief Construct a datatype object. */
static Object Datatype(size_t tag, const std::vector<Object>& fields);
/*! \brief Construct a tuple object. */
static Object Tuple(const std::vector<Object>& fields);
/*! \brief Construct a closure object. */
static Object Closure(size_t func_index, const std::vector<Object>& free_vars);
ObjectPtr<TensorCell> AsTensor() const;
ObjectPtr<DatatypeCell> AsDatatype() const;
ObjectPtr<ClosureCell> AsClosure() const;
};
template <typename ObjectType>
inline const ObjectType* as() const;
/*! \brief An object containing an NDArray. */
struct TensorCell : public ObjectCell {
/*! \brief The NDArray. */
NDArray data;
explicit TensorCell(const NDArray& data) : ObjectCell(ObjectTag::kTensor), data(data) {}
};
/*! \brief type indicate the container type */
using ContainerType = Object;
/*! \brief An object representing a structure or enumeration. */
struct DatatypeCell : public ObjectCell {
/*! \brief The tag representing the constructor used. */
size_t tag;
/*! \brief The fields of the structure. */
std::vector<Object> fields;
DatatypeCell(size_t tag, const std::vector<Object>& fields)
: ObjectCell(ObjectTag::kDatatype), tag(tag), fields(fields) {}
protected:
/*! \brief Internal pointer that backs the reference. */
ObjectPtr<Object> data_;
// friend classes.
friend class TVMRetValue;
friend class TVMArgsSetter;
};
/*! \brief An object representing a closure. */
struct ClosureCell : public ObjectCell {
/*! \brief The index into the VM function table. */
size_t func_index;
/*! \brief The free variables of the closure. */
std::vector<Object> free_vars;
/*!
* \brief helper macro to declare a base object type that can be inheritated.
* \param TypeName The name of the current type.
* \param ParentType The name of the ParentType
*/
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
static const uint32_t type_index() { \
if (_type_index != TypeIndex::kDynamic) return _type_index; \
return _GetOrAllocRuntimeTypeIndex(); \
} \
static const uint32_t _GetOrAllocRuntimeTypeIndex() { \
static uint32_t tidx = GetOrAllocRuntimeTypeIndex( \
TypeName::_type_key, \
TypeName::_type_index, \
ParentType::_GetOrAllocRuntimeTypeIndex(), \
TypeName::_type_child_slots, \
TypeName::_type_child_slots_can_overflow); \
return tidx; \
} \
ClosureCell(size_t func_index, const std::vector<Object>& free_vars)
: ObjectCell(ObjectTag::kClosure), func_index(func_index), free_vars(free_vars) {}
};
/*!
* \brief helper macro to declare type information in a final class.
* \param TypeName The name of the current type.
* \param ParentType The name of the ParentType
*/
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
static const constexpr bool _type_final = true; \
static const constexpr int _type_child_slots = 0; \
TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
/*! \brief Extract the NDArray from a tensor object. */
NDArray ToNDArray(const Object& obj);
/*!
* \brief Allocate a node object.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
* \brief Helper macro to register the object type to runtime.
* Makes sure that the runtime type table is correctly populated.
*
* Use this macro in the cc file for each terminal class.
*/
template <typename T, typename... Args>
inline ObjectPtr<T> MakeObject(Args&&... args) {
using Allocator = typename ObjectPtr<T>::StdAllocator;
static_assert(std::is_base_of<ObjectCell, T>::value, "MakeObject can only be used to create ");
T* node = Allocator::New(std::forward<Args>(args)...);
node->deleter_ = Allocator::Deleter();
return ObjectPtr<T>(node);
#define TVM_REGISTER_OBJECT_TYPE(TypeName) \
static DMLC_ATTRIBUTE_UNUSED uint32_t __make_Object_tidx ## _ ## TypeName ## __ = \
TypeName::_GetOrAllocRuntimeTypeIndex()
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
TypeName() {} \
explicit TypeName( \
::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
: ParentType(n) {} \
const ObjectName* operator->() const { \
return static_cast<const ObjectName*>(data_.get()); \
} \
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName;
// Implementations details below
// Object reference counting.
#if TVM_OBJECT_ATOMIC_REF_COUNTER
inline void Object::IncRef() {
ref_counter_.fetch_add(1, std::memory_order_relaxed);
}
inline void Object::DecRef() {
if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
std::atomic_thread_fence(std::memory_order_acquire);
if (this->deleter_ != nullptr) {
(*this->deleter_)(this);
}
}
}
inline int Object::use_count() const {
return ref_counter_.load(std::memory_order_relaxed);
}
#else
inline void Object::IncRef() {
++ref_counter_;
}
inline void Object::DecRef() {
if (--ref_counter == 0) {
if (this->deleter_ != nullptr) {
(*this->deleter_)(this);
}
}
}
inline int Object::use_count() const {
return ref_counter_;
}
#endif // TVM_OBJECT_ATOMIC_REF_COUNTER
template<typename TargetType>
inline bool Object::IsInstance() const {
const Object* self = this;
// NOTE: the following code can be optimized by
// compiler dead-code elimination for already known constants.
if (self != nullptr) {
// Everything is a subclass of object.
if (std::is_same<TargetType, Object>::value) return true;
if (TargetType::_type_final) {
// if the target type is a final type
// then we only need to check the equivalence.
return self->type_index_ == TargetType::type_index();
} else {
// if target type is a non-leaf type
// Check if type index falls into the range of reserved slots.
uint32_t begin = TargetType::type_index();
// The condition will be optimized by constant-folding.
if (TargetType::_type_child_slots != 0) {
uint32_t end = begin + TargetType::_type_child_slots;
if (self->type_index_ >= begin && self->type_index_ < end) return true;
} else {
if (self->type_index_ == begin) return true;
}
if (!TargetType::_type_child_slots_can_overflow) return false;
// Invariance: parent index is always smaller than the child.
if (self->type_index_ < TargetType::type_index()) return false;
// The rare slower-path, check type hierachy.
return self->DerivedFrom(TargetType::type_index());
}
} else {
return false;
}
}
inline const Object* ObjectRef::get() const {
return data_.data_;
}
inline const Object* ObjectRef::operator->() const {
return get();
}
template <typename ObjectType>
inline const ObjectType* ObjectRef::as() const {
if (data_ != nullptr &&
data_->IsInstance<ObjectType>()) {
return static_cast<ObjectType*>(data_.get());
} else {
return nullptr;
}
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_OBJECT_H_
......@@ -489,10 +489,10 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer);
return NDArray(static_cast<NDArray::Container*>(value_.v_handle));
}
operator Object() const {
if (type_code_ == kNull) return Object();
operator ObjectRef() const {
if (type_code_ == kNull) return ObjectRef(ObjectPtr<Object>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kObjectCell);
return Object(static_cast<ObjectCell*>(value_.v_handle));
return ObjectRef(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
}
operator TVMContext() const {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
......@@ -566,7 +566,7 @@ class TVMArgValue : public TVMPODValue_ {
using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator Object;
using TVMPODValue_::operator ObjectRef;
// conversion operator.
operator std::string() const {
......@@ -662,7 +662,7 @@ class TVMRetValue : public TVMPODValue_ {
using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator Object;
using TVMPODValue_::operator ObjectRef;
TVMRetValue(const TVMRetValue& other) : TVMPODValue_() {
this->Assign(other);
}
......@@ -759,11 +759,12 @@ class TVMRetValue : public TVMPODValue_ {
other.data_ = nullptr;
return *this;
}
TVMRetValue& operator=(Object other) {
TVMRetValue& operator=(ObjectRef other) {
this->Clear();
type_code_ = kObjectCell;
value_.v_handle = other.ptr_.data_;
other.ptr_.data_ = nullptr;
// move the handle out
value_.v_handle = other.data_.data_;
other.data_.data_ = nullptr;
return *this;
}
TVMRetValue& operator=(PackedFunc f) {
......@@ -862,7 +863,7 @@ class TVMRetValue : public TVMPODValue_ {
break;
}
case kObjectCell: {
*this = other.operator Object();
*this = other.operator ObjectRef();
break;
}
default: {
......@@ -913,7 +914,7 @@ class TVMRetValue : public TVMPODValue_ {
break;
}
case kObjectCell: {
static_cast<ObjectCell*>(value_.v_handle)->DecRef();
static_cast<Object*>(value_.v_handle)->DecRef();
break;
}
}
......@@ -1161,6 +1162,10 @@ class TVMArgsSetter {
values_[i].v_handle = value.data_;
type_codes_[i] = kNDArrayContainer;
}
void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*)
values_[i].v_handle = value.data_.data_;
type_codes_[i] = kObjectCell;
}
void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*)
if (value.type_code() == kStr) {
values_[i].v_str = value.ptr<std::string>()->c_str();
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/runtime/vm.h
* \brief A virtual machine for executing Relay programs.
*/
......@@ -36,6 +35,75 @@ namespace tvm {
namespace runtime {
namespace vm {
/*! \brief An object containing an NDArray. */
class TensorObj : public Object {
public:
/*! \brief The NDArray. */
NDArray data;
static constexpr const uint32_t _type_index = TypeIndex::kVMTensor;
static constexpr const char* _type_key = "vm.Tensor";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorObj, Object);
};
/*! \brief reference to tensor. */
class Tensor : public ObjectRef {
public:
explicit Tensor(NDArray data);
TVM_DEFINE_OBJECT_REF_METHODS(Tensor, ObjectRef, TensorObj);
};
/*! \brief An object representing a structure or enumeration. */
class DatatypeObj : public Object {
public:
/*! \brief The tag representing the constructor used. */
size_t tag;
/*! \brief The fields of the structure. */
std::vector<ObjectRef> fields;
static constexpr const uint32_t _type_index = TypeIndex::kVMDatatype;
static constexpr const char* _type_key = "vm.Datatype";
TVM_DECLARE_FINAL_OBJECT_INFO(DatatypeObj, Object);
};
/*! \brief reference to data type. */
class Datatype : public ObjectRef {
public:
Datatype(size_t tag, std::vector<ObjectRef> fields);
/*!
* \brief construct a tuple object.
* \param fields The fields of the tuple.
* \return The constructed tuple type.
*/
static Datatype Tuple(std::vector<ObjectRef> fields);
TVM_DEFINE_OBJECT_REF_METHODS(Datatype, ObjectRef, DatatypeObj);
};
/*! \brief An object representing a closure. */
class ClosureObj : public Object {
public:
/*! \brief The index into the VM function table. */
size_t func_index;
/*! \brief The free variables of the closure. */
std::vector<ObjectRef> free_vars;
static constexpr const uint32_t _type_index = TypeIndex::kVMClosure;
static constexpr const char* _type_key = "vm.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureObj, Object);
};
/*! \brief reference to closure. */
class Closure : public ObjectRef {
public:
Closure(size_t func_index, std::vector<ObjectRef> free_vars);
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
};
/*! \brief Magic number for NDArray list file */
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
......@@ -348,7 +416,7 @@ struct VMFrame {
const Instruction* code;
/*! \brief Statically allocated space for objects */
std::vector<Object> register_file;
std::vector<ObjectRef> register_file;
/*! \brief Register in caller's frame to put return value */
RegName caller_return_register;
......@@ -406,8 +474,11 @@ class VirtualMachine : public runtime::ModuleNode {
*
* \note The return value will be stored in the last output_size slots of args.
*/
virtual void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
Index output_size, const std::vector<Object>& args);
virtual void InvokePacked(Index packed_index,
const PackedFunc& func,
Index arg_count,
Index output_size,
const std::vector<ObjectRef>& args);
virtual ~VirtualMachine() {}
......@@ -424,7 +495,7 @@ class VirtualMachine : public runtime::ModuleNode {
/*! \brief The current stack of call frames. */
std::vector<VMFrame> frames;
/*! \brief The global constant pool. */
std::vector<Object> constants;
std::vector<ObjectRef> constants;
/*! \brief The fuction table index of the current function. */
Index func_index;
/*! \brief The current pointer to the code section. */
......@@ -433,7 +504,7 @@ class VirtualMachine : public runtime::ModuleNode {
Index pc;
/*! \brief The special return register. */
Object return_register;
ObjectRef return_register;
/*! \brief The set of TVM contexts the VM is currently executing on. */
std::vector<TVMContext> ctxs;
......@@ -449,13 +520,13 @@ class VirtualMachine : public runtime::ModuleNode {
* \param reg The register to write to.
* \param obj The object to write to.
*/
inline void WriteRegister(RegName reg, const Object& obj);
inline void WriteRegister(RegName reg, const ObjectRef& obj);
/*! \brief Read a VM register.
* \param reg The register to read from.
* \return The read object.
*/
inline Object ReadRegister(RegName reg) const;
inline ObjectRef ReadRegister(RegName reg) const;
/*! \brief Read a VM register and cast it to int32_t
* \param reg The register to read from.
......@@ -468,15 +539,16 @@ class VirtualMachine : public runtime::ModuleNode {
* \param args The arguments to the function.
* \return The object representing the result.
*/
Object Invoke(const VMFunction& func, const std::vector<Object>& args);
ObjectRef Invoke(const VMFunction& func, const std::vector<ObjectRef>& args);
// TODO(@jroesch): I really would like this to be a global variable.
/*! \brief Invoke a VM function by name.
/*!
* \brief Invoke a VM function by name.
* \param name The function's name.
* \param args The arguments to the function.
* \return The object representing the result.
*/
Object Invoke(const std::string& name, const std::vector<Object>& args);
ObjectRef Invoke(const std::string& name, const std::vector<ObjectRef>& args);
VirtualMachine() : functions(), frames(), func_index(0), code(nullptr), pc(0) {}
......@@ -513,11 +585,10 @@ class VirtualMachine : public runtime::ModuleNode {
*
* This does not begin execution of the VM.
*/
void InvokeGlobal(const VMFunction& func, const std::vector<Object>& args);
void InvokeGlobal(const VMFunction& func, const std::vector<ObjectRef>& args);
/*! \brief The parameter name to data mapping. */
std::unordered_map<std::string, Object> params_;
std::unordered_map<std::string, ObjectRef> params_;
};
} // namespace vm
......
......@@ -44,9 +44,9 @@ except IMPORT_EXCEPT:
class ObjectTag(object):
"""Type code used in API calls"""
TENSOR = 0
CLOSURE = 1
DATATYPE = 2
TENSOR = 1
CLOSURE = 2
DATATYPE = 3
class Object(_ObjectBase):
......
......@@ -92,7 +92,7 @@ struct APIAttrGetter : public AttrVisitor {
found_ref_object = true;
}
}
void Visit(const char* key, runtime::Object* value) final {
void Visit(const char* key, runtime::ObjectRef* value) final {
if (skey == key) {
*ret = value[0];
found_ref_object = true;
......@@ -133,7 +133,7 @@ struct APIAttrDir : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final {
names->push_back(key);
}
void Visit(const char* key, runtime::Object* value) final {
void Visit(const char* key, runtime::ObjectRef* value) final {
names->push_back(key);
}
};
......
......@@ -54,7 +54,7 @@ inline Type String2Type(std::string s) {
}
using runtime::Object;
using runtime::ObjectCell;
using runtime::ObjectRef;
// indexer to index all the ndoes
class NodeIndexer : public AttrVisitor {
......@@ -63,8 +63,6 @@ class NodeIndexer : public AttrVisitor {
std::vector<Node*> node_list{nullptr};
std::unordered_map<DLTensor*, size_t> tensor_index;
std::vector<DLTensor*> tensor_list;
std::unordered_map<ObjectCell*, size_t> vm_obj_index;
std::vector<ObjectCell*> vm_obj_list;
void Visit(const char* key, double* value) final {}
void Visit(const char* key, int64_t* value) final {}
......@@ -86,12 +84,8 @@ class NodeIndexer : public AttrVisitor {
tensor_list.push_back(ptr);
}
void Visit(const char* key, Object* value) final {
ObjectCell* ptr = value->ptr_.get();
if (vm_obj_index.count(ptr)) return;
CHECK_EQ(vm_obj_index.size(), vm_obj_list.size());
vm_obj_index[ptr] = vm_obj_list.size();
vm_obj_list.push_back(ptr);
void Visit(const char* key, ObjectRef* value) final {
LOG(FATAL) << "Do not support json serialize non-node object";
}
// make index of all the children of node
......@@ -177,7 +171,6 @@ class JSONAttrGetter : public AttrVisitor {
public:
const std::unordered_map<Node*, size_t>* node_index_;
const std::unordered_map<DLTensor*, size_t>* tensor_index_;
const std::unordered_map<ObjectCell*, size_t>* vm_obj_index_;
JSONNode* node_;
void Visit(const char* key, double* value) final {
......@@ -212,9 +205,8 @@ class JSONAttrGetter : public AttrVisitor {
node_->attrs[key] = std::to_string(
tensor_index_->at(const_cast<DLTensor*>((*value).operator->())));
}
void Visit(const char* key, Object* value) final {
node_->attrs[key] = std::to_string(
vm_obj_index_->at(value->ptr_.get()));
void Visit(const char* key, ObjectRef* value) final {
LOG(FATAL) << "Do not support json serialize non-node object";
}
// Get the node
void Get(Node* node) {
......@@ -269,7 +261,6 @@ class JSONAttrSetter : public AttrVisitor {
public:
const std::vector<NodePtr<Node> >* node_list_;
const std::vector<runtime::NDArray>* tensor_list_;
const std::vector<Object>* vm_obj_list_;
JSONNode* node_;
......@@ -325,11 +316,8 @@ class JSONAttrSetter : public AttrVisitor {
CHECK_LE(index, tensor_list_->size());
*value = tensor_list_->at(index);
}
void Visit(const char* key, Object* value) final {
size_t index;
ParseValue(key, &index);
CHECK_LE(index, vm_obj_list_->size());
*value = vm_obj_list_->at(index);
void Visit(const char* key, ObjectRef* value) final {
LOG(FATAL) << "Do not support json serialize non-node object";
}
// set node to be current JSONNode
void Set(Node* node) {
......@@ -508,8 +496,8 @@ class NodeAttrSetter : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final {
*value = GetAttr(key).operator runtime::NDArray();
}
void Visit(const char* key, Object* value) final {
*value = GetAttr(key).operator Object();
void Visit(const char* key, ObjectRef* value) final {
*value = GetAttr(key).operator ObjectRef();
}
private:
......
......@@ -885,7 +885,7 @@ void VMCompiler::Compile(Module mod,
// populate constants
for (auto data : context_.constants) {
vm_->constants.push_back(Object::Tensor(data));
vm_->constants.push_back(runtime::vm::Tensor(data));
}
LibraryCodegen();
......
......@@ -102,7 +102,7 @@ void Deserializer::DeserializeConstantSection() {
for (size_t i = 0; i < size; i++) {
runtime::NDArray constant;
STREAM_CHECK(constant.Load(strm_), "constant");
runtime::Object obj = runtime::Object::Tensor(constant);
runtime::ObjectRef obj = runtime::vm::Tensor(constant);
vm_->constants.push_back(obj);
}
}
......
......@@ -98,8 +98,8 @@ std::string Serializer::Stats() const {
// Get the number of constants and the shape of each of them.
oss << " Constant shapes (# " << vm_->constants.size() << "): [";
for (const auto& it : vm_->constants) {
auto cell = it.AsTensor();
CHECK(cell.operator->());
auto* cell = it.as<runtime::vm::TensorObj>();
CHECK(cell != nullptr);
runtime::NDArray data = cell->data;
const auto& shape = data.Shape();
......@@ -175,7 +175,8 @@ void Serializer::SerializeGlobalSection() {
void Serializer::SerializeConstantSection() {
std::vector<DLTensor*> arrays;
for (const auto& obj : vm_->constants) {
auto cell = obj.AsTensor();
const auto* cell = obj.as<runtime::vm::TensorObj>();
CHECK(cell != nullptr);
runtime::NDArray data = cell->data;
arrays.push_back(const_cast<DLTensor*>(data.operator->()));
}
......
......@@ -930,7 +930,7 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final {
LOG(FATAL) << "do not allow NDarray as argument";
}
void Visit(const char* key, runtime::Object* obj) final {
void Visit(const char* key, runtime::ObjectRef* obj) final {
LOG(FATAL) << "do not allow Object as argument";
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* \file src/runtime/object.cc
* \brief Object type management system.
*/
#include <dmlc/logging.h>
#include <tvm/runtime/object.h>
#include <mutex>
#include <string>
#include <vector>
#include <unordered_map>
namespace tvm {
namespace runtime {
/*! \brief Type information */
struct TypeInfo {
/*! \brief The current index. */
uint32_t index{0};
/*! \brief Index of the parent in the type hierachy */
uint32_t parent_index{0};
// NOTE: the indices in [index, index + num_reserved_slots) are
// reserved for the child-class of this type.
/*! \brief Total number of slots reserved for the type and its children. */
uint32_t num_slots{0};
/*! \brief number of allocated child slots. */
uint32_t allocated_slots{0};
/*! \brief Whether child can overflow. */
bool child_slots_can_overflow{true};
/*! \brief name of the type. */
std::string name;
};
/*!
* \brief Type context that manages the type hierachy information.
*/
class TypeContext {
public:
// NOTE: this is a relatively slow path for child checking
// Most types are already checked by the fast-path via reserved slot checking.
bool DerivedFrom(uint32_t child_tindex, uint32_t parent_tindex) {
// invariance: child's type index is always bigger than its parent.
if (child_tindex < parent_tindex) return false;
if (child_tindex == parent_tindex) return true;
{
std::lock_guard<std::mutex> lock(mutex_);
CHECK_LT(child_tindex, type_table_.size());
while (child_tindex > parent_tindex) {
child_tindex = type_table_[child_tindex].parent_index;
}
}
return child_tindex == parent_tindex;
}
uint32_t GetOrAllocRuntimeTypeIndex(const char* key,
uint32_t static_tindex,
uint32_t parent_tindex,
uint32_t num_child_slots,
bool child_slots_can_overflow) {
std::lock_guard<std::mutex> lock(mutex_);
std::string skey = key;
auto it = type_key2index_.find(skey);
if (it != type_key2index_.end()) {
return it->second;
}
// try to allocate from parent's type table.
CHECK_LT(parent_tindex, type_table_.size());
TypeInfo& pinfo = type_table_[parent_tindex];
CHECK_EQ(pinfo.index, parent_tindex);
// if parent cannot overflow, then this class cannot.
if (!pinfo.child_slots_can_overflow) {
child_slots_can_overflow = false;
}
// total number of slots include the type itself.
uint32_t num_slots = num_child_slots + 1;
uint32_t allocated_tindex;
if (static_tindex != TypeIndex::kDynamic) {
// statically assigned type
allocated_tindex = static_tindex;
CHECK_LT(static_tindex, type_table_.size());
CHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U)
<< "Conflicting static index " << static_tindex
<< " between " << type_table_[allocated_tindex].name
<< " and "
<< key;
} else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) {
// allocate the slot from parent's reserved pool
allocated_tindex = parent_tindex + pinfo.allocated_slots;
// update parent's state
pinfo.allocated_slots += num_slots;
} else {
CHECK(pinfo.child_slots_can_overflow)
<< "Reach maximum number of sub-classes for " << pinfo.name;
// allocate new entries.
allocated_tindex = type_counter_;
type_counter_ += num_slots;
CHECK_LE(type_table_.size(), allocated_tindex);
type_table_.resize(allocated_tindex + 1, TypeInfo());
}
CHECK_GT(allocated_tindex, parent_tindex);
// initialize the slot.
type_table_[allocated_tindex].index = allocated_tindex;
type_table_[allocated_tindex].parent_index = parent_tindex;
type_table_[allocated_tindex].num_slots = num_slots;
type_table_[allocated_tindex].allocated_slots = 1;
type_table_[allocated_tindex].child_slots_can_overflow =
child_slots_can_overflow;
type_table_[allocated_tindex].name = skey;
// update the key2index mapping.
type_key2index_[skey] = allocated_tindex;
return allocated_tindex;
}
std::string TypeIndex2Key(uint32_t tindex) {
std::lock_guard<std::mutex> lock(mutex_);
CHECK(tindex < type_table_.size() &&
type_table_[tindex].allocated_slots != 0)
<< "Unknown type index " << tindex;
return type_table_[tindex].name;
}
uint32_t TypeKey2Index(const char* key) {
std::string skey = key;
auto it = type_key2index_.find(skey);
CHECK(it != type_key2index_.end())
<< "Cannot find type " << key;
return it->second;
}
static TypeContext* Global() {
static TypeContext inst;
return &inst;
}
private:
TypeContext() {
type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo());
}
// mutex to avoid registration from multiple threads.
std::mutex mutex_;
std::atomic<uint32_t> type_counter_{TypeIndex::kStaticIndexEnd};
std::vector<TypeInfo> type_table_;
std::unordered_map<std::string, uint32_t> type_key2index_;
};
uint32_t Object::GetOrAllocRuntimeTypeIndex(const char* key,
uint32_t static_tindex,
uint32_t parent_tindex,
uint32_t num_child_slots,
bool child_slots_can_overflow) {
return TypeContext::Global()->GetOrAllocRuntimeTypeIndex(
key, static_tindex, parent_tindex, num_child_slots, child_slots_can_overflow);
}
bool Object::DerivedFrom(uint32_t parent_tindex) const {
return TypeContext::Global()->DerivedFrom(
this->type_index_, parent_tindex);
}
std::string Object::TypeIndex2Key(uint32_t tindex) {
return TypeContext::Global()->TypeIndex2Key(tindex);
}
uint32_t Object::TypeKey2Index(const char* key) {
return TypeContext::Global()->TypeKey2Index(key);
}
} // namespace runtime
} // namespace tvm
......@@ -18,134 +18,110 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file object.cc
* \brief A managed object in the TVM runtime.
* \file src/runtime/vm/object.cc
* \brief VM related objects.
*/
#include <tvm/logging.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/vm.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h>
#include <iostream>
#include "../runtime_base.h"
namespace tvm {
namespace runtime {
namespace vm {
std::ostream& operator<<(std::ostream& os, const ObjectTag& tag) {
switch (tag) {
case ObjectTag::kClosure:
os << "Closure";
break;
case ObjectTag::kDatatype:
os << "Datatype";
break;
case ObjectTag::kTensor:
os << "Tensor";
break;
default:
LOG(FATAL) << "Invalid object tag: found " << static_cast<int>(tag);
}
return os;
}
Object Object::Tensor(const NDArray& data) {
ObjectPtr<ObjectCell> ptr = MakeObject<TensorCell>(data);
return Object(ptr);
}
Object Object::Datatype(size_t tag, const std::vector<Object>& fields) {
ObjectPtr<ObjectCell> ptr = MakeObject<DatatypeCell>(tag, fields);
return Object(ptr);
Tensor::Tensor(NDArray data) {
auto ptr = make_object<TensorObj>();
ptr->data = std::move(data);
data_ = std::move(ptr);
}
Object Object::Tuple(const std::vector<Object>& fields) { return Object::Datatype(0, fields); }
Object Object::Closure(size_t func_index, const std::vector<Object>& free_vars) {
ObjectPtr<ObjectCell> ptr = MakeObject<ClosureCell>(func_index, free_vars);
return Object(ptr);
Datatype::Datatype(size_t tag, std::vector<ObjectRef> fields) {
auto ptr = make_object<DatatypeObj>();
ptr->tag = tag;
ptr->fields = std::move(fields);
data_ = std::move(ptr);
}
ObjectPtr<TensorCell> Object::AsTensor() const {
CHECK(ptr_.get());
CHECK(ptr_.get()->tag == ObjectTag::kTensor);
return ptr_.As<TensorCell>();
Datatype Datatype::Tuple(std::vector<ObjectRef> fields) {
return Datatype(0, fields);
}
ObjectPtr<DatatypeCell> Object::AsDatatype() const {
CHECK(ptr_.get());
CHECK(ptr_.get()->tag == ObjectTag::kDatatype);
return ptr_.As<DatatypeCell>();
Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
auto ptr = make_object<ClosureObj>();
ptr->func_index = func_index;
ptr->free_vars = std::move(free_vars);
data_ = std::move(ptr);
}
ObjectPtr<ClosureCell> Object::AsClosure() const {
CHECK(ptr_.get());
CHECK(ptr_.get()->tag == ObjectTag::kClosure);
return ptr_.As<ClosureCell>();
}
NDArray ToNDArray(const Object& obj) {
auto tensor = obj.AsTensor();
return tensor->data;
}
TVM_REGISTER_GLOBAL("_vmobj.GetTensorData")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsTensor();
ObjectRef obj = args[0];
const auto* cell = obj.as<TensorObj>();
CHECK(cell != nullptr);
*rv = cell->data;
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeTag")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsDatatype();
*rv = static_cast<int>(cell->tag);
ObjectRef obj = args[0];
const auto* cell = obj.as<DatatypeObj>();
CHECK(cell != nullptr);
*rv = static_cast<int64_t>(cell->tag);
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeNumberOfFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsDatatype();
*rv = static_cast<int>(cell->fields.size());
ObjectRef obj = args[0];
const auto* cell = obj.as<DatatypeObj>();
CHECK(cell != nullptr);
*rv = static_cast<int64_t>(cell->fields.size());
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
ObjectRef obj = args[0];
int idx = args[1];
auto cell = obj.AsDatatype();
const auto* cell = obj.as<DatatypeObj>();
CHECK(cell != nullptr);
CHECK_LT(idx, cell->fields.size());
*rv = cell->fields[idx];
});
TVM_REGISTER_GLOBAL("_vmobj.Tensor")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = Object::Tensor(args[0]);
*rv = Tensor(args[0].operator NDArray());
});
TVM_REGISTER_GLOBAL("_vmobj.Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<Object> fields;
std::vector<ObjectRef> fields;
for (auto i = 0; i < args.size(); ++i) {
fields.push_back(args[i]);
}
*rv = Object::Tuple(fields);
*rv = Datatype::Tuple(fields);
});
TVM_REGISTER_GLOBAL("_vmobj.Datatype")
.set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
std::vector<Object> fields;
std::vector<ObjectRef> fields;
for (int i = 1; i < args.size(); i++) {
fields.push_back(args[i]);
}
*rv = Object::Datatype(tag, fields);
*rv = Datatype(tag, fields);
});
TVM_REGISTER_OBJECT_TYPE(TensorObj);
TVM_REGISTER_OBJECT_TYPE(DatatypeObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj);
} // namespace vm
} // namespace runtime
} // namespace tvm
......@@ -153,6 +129,7 @@ using namespace tvm::runtime;
int TVMGetObjectTag(TVMObjectHandle handle, int* tag) {
API_BEGIN();
*tag = static_cast<int>(static_cast<ObjectCell*>(handle)->tag);
int res = static_cast<int>(static_cast<Object*>(handle)->type_index());
*tag = res;
API_END();
}
......@@ -96,7 +96,7 @@ void VirtualMachineDebug::Init(const std::vector<TVMContext>& ctxs) {
void VirtualMachineDebug::InvokePacked(Index packed_index,
const PackedFunc& func, Index arg_count,
Index output_size,
const std::vector<Object>& args) {
const std::vector<ObjectRef>& args) {
auto ctx = VirtualMachine::GetParamsContext();
// warmup
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
......
......@@ -45,7 +45,7 @@ class VirtualMachineDebug : public VirtualMachine {
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
Index output_size, const std::vector<Object>& args) final;
Index output_size, const std::vector<ObjectRef>& args) final;
~VirtualMachineDebug() {}
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/runtime/vm/vm.cc
* \brief The Relay virtual machine.
*/
......@@ -558,12 +557,12 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) {
return os;
}
Object CopyTo(Object src, const DLContext& ctx) {
if (src->tag == ObjectTag::kTensor) {
auto tensor = ToNDArray(src);
ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) {
if (const TensorObj* obj = src.as<TensorObj>()) {
auto tensor = obj->data;
if (tensor->ctx.device_type != ctx.device_type) {
auto copy = tensor.CopyTo(ctx);
return Object::Tensor(copy);
return Tensor(copy);
} else {
return src;
}
......@@ -585,7 +584,7 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
auto ctx = this->GetParamsContext();
// Prepare the func args
std::vector<Object> func_args(param_names.size());
std::vector<ObjectRef> func_args(param_names.size());
std::vector<size_t> empty_slots;
for (size_t i = 0; i < param_names.size(); ++i) {
......@@ -599,7 +598,7 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
CHECK_EQ(empty_slots.size(), args.size() - 1)
<< "The number of provided parameters doesn't match the number of arguments";
for (int i = 1; i < args.size(); ++i) {
Object obj = CopyTo(args[i], ctx);
ObjectRef obj = CopyTo(args[i], ctx);
func_args[empty_slots[i - 1]] = obj;
}
......@@ -660,7 +659,7 @@ void VirtualMachine::LoadParams(const std::string& params) {
for (size_t i = 0; i < size; i++) {
NDArray arr;
CHECK(arr.Load(strm)) << "Invalid parameter file";
runtime::Object obj = runtime::Object::Tensor(arr);
ObjectRef obj = Tensor(arr);
auto copy = CopyTo(obj, ctx);
params_.emplace(std::make_pair(names[i], copy));
}
......@@ -682,7 +681,7 @@ Index VirtualMachine::PopFrame() {
return call_stack_size;
}
void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<Object>& args) {
void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<ObjectRef>& args) {
DLOG(INFO) << "Invoking global " << func.name << " " << args.size();
PushFrame(func.params.size(), this->pc + 1, func);
......@@ -695,7 +694,7 @@ void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<Obje
pc = 0;
}
Object VirtualMachine::Invoke(const VMFunction& func, const std::vector<Object>& args) {
ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector<ObjectRef>& args) {
DLOG(INFO) << "Executing Function: " << std::endl << func;
InvokeGlobal(func, args);
......@@ -705,7 +704,7 @@ Object VirtualMachine::Invoke(const VMFunction& func, const std::vector<Object>&
return return_register;
}
Object VirtualMachine::Invoke(const std::string& name, const std::vector<Object>& args) {
ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector<ObjectRef>& args) {
auto func_index = this->global_map[name];
DLOG(INFO) << "Invoke Global " << name << " at index " << func_index;
return Invoke(this->functions[func_index], args);
......@@ -713,11 +712,11 @@ Object VirtualMachine::Invoke(const std::string& name, const std::vector<Object>
void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
Index arg_count, Index output_size,
const std::vector<Object>& args) {
const std::vector<ObjectRef>& args) {
size_t arity = 0;
for (Index i = 0; i < arg_count; i++) {
if (args[i].ptr_->tag == ObjectTag::kDatatype) {
arity += args[i].AsDatatype()->fields.size();
if (const auto* obj = args[i].as<DatatypeObj>()) {
arity += obj->fields.size();
} else {
++arity;
}
......@@ -728,15 +727,16 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
runtime::TVMArgsSetter setter(values.data(), codes.data());
int idx = 0;
for (Index i = 0; i < arg_count; i++) {
if (args[i].ptr_->tag == ObjectTag::kDatatype) {
auto dt_cell = args[i].AsDatatype();
if (const auto* dt_cell = args[i].as<DatatypeObj>()) {
for (auto obj : dt_cell->fields) {
NDArray data = ToNDArray(obj);
setter(idx++, data);
const auto* tensor = obj.as<TensorObj>();
CHECK(tensor != nullptr);
setter(idx++, tensor->data);
}
} else {
NDArray data = ToNDArray(args[i]);
setter(idx++, data);
const auto* tensor = args[i].as<TensorObj>();
CHECK(tensor != nullptr);
setter(idx++, tensor->data);
}
}
......@@ -761,18 +761,20 @@ void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) {
}
}
inline void VirtualMachine::WriteRegister(Index r, const Object& val) {
inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) {
frames.back().register_file[r] = val;
}
inline Object VirtualMachine::ReadRegister(Index r) const {
inline ObjectRef VirtualMachine::ReadRegister(Index r) const {
return frames.back().register_file[r];
}
inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
int32_t result;
const auto& obj = ReadRegister(r);
NDArray array = ToNDArray(obj).CopyTo({kDLCPU, 0});
const auto* tensor = obj.as<TensorObj>();
CHECK(tensor != nullptr);
NDArray array = tensor->data.CopyTo({kDLCPU, 0});
if (array->dtype.bits <= 8) {
result = reinterpret_cast<int8_t*>(array->data)[0];
......@@ -798,7 +800,7 @@ void VirtualMachine::RunLoop() {
switch (instr.op) {
case Opcode::Move: {
Object from_obj;
ObjectRef from_obj;
from_obj = ReadRegister(instr.from);
WriteRegister(instr.dst, from_obj);
pc++;
......@@ -817,12 +819,12 @@ void VirtualMachine::RunLoop() {
case Opcode::LoadConsti: {
auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0});
reinterpret_cast<int64_t*>(tensor->data)[0] = instr.load_consti.val;
WriteRegister(instr.dst, Object::Tensor(tensor));
WriteRegister(instr.dst, Tensor(tensor));
pc++;
goto main_loop;
}
case Opcode::Invoke: {
std::vector<Object> args;
std::vector<ObjectRef> args;
for (Index i = 0; i < instr.num_args; ++i) {
args.push_back(ReadRegister(instr.invoke_args_registers[i]));
}
......@@ -833,7 +835,7 @@ void VirtualMachine::RunLoop() {
case Opcode::InvokePacked: {
const auto& func = packed_funcs[instr.packed_index];
const auto& arity = instr.arity;
std::vector<Object> args;
std::vector<ObjectRef> args;
for (Index i = 0; i < arity; ++i) {
args.push_back(ReadRegister(instr.packed_args[i]));
}
......@@ -847,8 +849,9 @@ void VirtualMachine::RunLoop() {
}
case Opcode::InvokeClosure: {
auto object = ReadRegister(instr.closure);
const auto& closure = object.AsClosure();
std::vector<Object> args;
const auto* closure = object.as<ClosureObj>();
std::vector<ObjectRef> args;
for (auto free_var : closure->free_vars) {
args.push_back(free_var);
}
......@@ -861,10 +864,10 @@ void VirtualMachine::RunLoop() {
}
case Opcode::GetField: {
auto object = ReadRegister(instr.object);
CHECK(object->tag == ObjectTag::kDatatype)
const auto* tuple = object.as<DatatypeObj>();
CHECK(tuple != nullptr)
<< "Object is not data type object, register " << instr.object << ", Object tag "
<< static_cast<int>(object->tag);
const auto& tuple = object.AsDatatype();
<< object->type_index();
auto field = tuple->fields[instr.field_index];
WriteRegister(instr.dst, field);
pc++;
......@@ -872,15 +875,15 @@ void VirtualMachine::RunLoop() {
}
case Opcode::GetTag: {
auto object = ReadRegister(instr.get_tag.object);
CHECK(object->tag == ObjectTag::kDatatype)
const auto* data = object.as<DatatypeObj>();
CHECK(data != nullptr)
<< "Object is not data type object, register "
<< instr.get_tag.object << ", Object tag "
<< static_cast<int>(object->tag);
const auto& data = object.AsDatatype();
<< object->type_index();
auto tag = data->tag;
auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0});
reinterpret_cast<int32_t*>(tag_tensor->data)[0] = tag;
WriteRegister(instr.dst, Object::Tensor(tag_tensor));
WriteRegister(instr.dst, Tensor(tag_tensor));
pc++;
goto main_loop;
}
......@@ -909,7 +912,7 @@ void VirtualMachine::RunLoop() {
}
auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]);
auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]);
auto obj = Object::Tensor(data);
auto obj = Tensor(data);
WriteRegister(instr.dst, obj);
pc++;
goto main_loop;
......@@ -920,7 +923,9 @@ void VirtualMachine::RunLoop() {
cpu_ctx.device_id = 0;
auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register);
NDArray shape_tensor = ToNDArray(shape_tensor_obj).CopyTo(cpu_ctx);
const auto* tensor = shape_tensor_obj.as<TensorObj>();
CHECK(tensor != nullptr);
NDArray shape_tensor = tensor->data.CopyTo(cpu_ctx);
int64_t* dims = static_cast<int64_t*>(shape_tensor->data);
auto num_dims = shape_tensor->shape[0];
......@@ -928,27 +933,27 @@ void VirtualMachine::RunLoop() {
shape.assign(dims, dims + num_dims);
auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]);
auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]);
auto obj = Object::Tensor(data);
auto obj = Tensor(data);
WriteRegister(instr.dst, obj);
pc++;
goto main_loop;
}
case Opcode::AllocDatatype: {
std::vector<Object> fields;
std::vector<ObjectRef> fields;
for (Index i = 0; i < instr.num_fields; ++i) {
fields.push_back(ReadRegister(instr.datatype_fields[i]));
}
Object obj = Object::Datatype(instr.constructor_tag, fields);
ObjectRef obj = Datatype(instr.constructor_tag, fields);
WriteRegister(instr.dst, obj);
pc++;
goto main_loop;
}
case Opcode::AllocClosure: {
std::vector<Object> free_vars;
std::vector<ObjectRef> free_vars;
for (Index i = 0; i < instr.num_freevar; i++) {
free_vars.push_back(ReadRegister(instr.free_vars[i]));
}
WriteRegister(instr.dst, Object::Closure(instr.func_index, free_vars));
WriteRegister(instr.dst, Closure(instr.func_index, free_vars));
pc++;
goto main_loop;
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
namespace tvm {
namespace test {
using namespace tvm::runtime;
class ObjBase : public Object {
public:
// dynamically allocate slow
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_child_slots = 1;
static constexpr const char* _type_key = "test.ObjBase";
TVM_DECLARE_BASE_OBJECT_INFO(ObjBase, Object);
};
class ObjA : public ObjBase {
public:
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_child_slots = 0;
static constexpr const char* _type_key = "test.ObjA";
TVM_DECLARE_BASE_OBJECT_INFO(ObjA, ObjBase);
};
class ObjB : public ObjBase {
public:
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "test.ObjB";
TVM_DECLARE_BASE_OBJECT_INFO(ObjB, ObjBase);
};
class ObjAA : public ObjA {
public:
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "test.ObjAA";
TVM_DECLARE_FINAL_OBJECT_INFO(ObjAA, ObjA);
};
TVM_REGISTER_OBJECT_TYPE(ObjBase);
TVM_REGISTER_OBJECT_TYPE(ObjA);
TVM_REGISTER_OBJECT_TYPE(ObjB);
TVM_REGISTER_OBJECT_TYPE(ObjAA);
} // namespace test
} // namespace tvm
TEST(ObjectHierachy, Basic) {
using namespace tvm::runtime;
using namespace tvm::test;
ObjectRef refA(make_object<ObjA>());
CHECK_EQ(refA->type_index(), ObjA::type_index());
CHECK(refA.as<Object>() != nullptr);
CHECK(refA.as<ObjA>() != nullptr);
CHECK(refA.as<ObjBase>() != nullptr);
CHECK(refA.as<ObjB>() == nullptr);
CHECK(refA.as<ObjAA>() == nullptr);
ObjectRef refAA(make_object<ObjAA>());
CHECK_EQ(refAA->type_index(), ObjAA::type_index());
CHECK(refAA.as<Object>() != nullptr);
CHECK(refAA.as<ObjBase>() != nullptr);
CHECK(refAA.as<ObjA>() != nullptr);
CHECK(refAA.as<ObjAA>() != nullptr);
CHECK(refAA.as<ObjB>() == nullptr);
ObjectRef refB(make_object<ObjB>());
CHECK_EQ(refB->type_index(), ObjB::type_index());
CHECK(refB.as<Object>() != nullptr);
CHECK(refB.as<ObjBase>() != nullptr);
CHECK(refB.as<ObjA>() == nullptr);
CHECK(refB.as<ObjAA>() == nullptr);
CHECK(refB.as<ObjB>() != nullptr);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment