/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tvm/runtime/object.h * \brief A managed object in the TVM runtime. */ #ifndef TVM_RUNTIME_OBJECT_H_ #define TVM_RUNTIME_OBJECT_H_ #include <type_traits> #include <string> #include <utility> #include "c_runtime_api.h" /*! * \brief Whether or not use atomic reference counter. * If the reference counter is not atomic, * an object cannot be owned by multiple threads. * We can, however, move an object across threads */ #ifndef TVM_OBJECT_ATOMIC_REF_COUNTER #define TVM_OBJECT_ATOMIC_REF_COUNTER 1 #endif #if TVM_OBJECT_ATOMIC_REF_COUNTER #include <atomic> #endif // TVM_OBJECT_ATOMIC_REF_COUNTER namespace tvm { namespace runtime { /*! \brief list of the type index. */ enum TypeIndex { /*! \brief Root object type. */ kRoot = 0, kVMTensor = 1, kVMClosure = 2, kVMDatatype = 3, kStaticIndexEnd, /*! \brief Type index is allocated during runtime. */ kDynamic = kStaticIndexEnd }; /*! * \brief base class of all object containers. * * Sub-class of objects should declare the following static constexpr fields: * * - _type_index: * Static type index of the object, if assigned to TypeIndex::kDynamic * the type index will be assigned during runtime. * Runtime type index can be accessed by ObjectType::type_index(); * - _type_key: * The unique string identifier of tyep type. * - _type_final: * Whether the type is terminal type(there is no subclass of the type in the object system). * This field is automatically set by marco TVM_DECLARE_FINAL_OBJECT_INFO * It is still OK to sub-class a terminal object type T and construct it using make_object. * But IsInstance check will only show that the object type is T(instead of the sub-class). * * The following two fields are necessary for base classes that can be sub-classed. * * - _type_child_slots: * Number of reserved type index slots for child classes. * Used for runtime optimization for type checking in IsInstance. * If an object's type_index is within range of [type_index, type_index + _type_child_slots] * Then the object can be quickly decided as sub-class of the current object class. * If not, a fallback mechanism is used to check the global type table. * Recommendation: set to estimate number of children needed. * - _type_child_slots_can_overflow: * Whether we can add additional child classes even if the number of child classes * exceeds the _type_child_slots. A fallback mechanism to check global type table will be used. * Recommendation: set to false for optimal runtime speed if we know exact number of children. * * Two macros are used to declare helper functions in the object: * - Use TVM_DECLARE_BASE_OBJECT_INFO for object classes that can be sub-classed. * - Use TVM_DECLARE_FINAL_OBJECT_INFO for object classes that cannot be sub-classed. * * New objects can be created using make_object function. * Which will automatically populate the type_index and deleter of the object. * * \sa make_object * \sa ObjectPtr * \sa ObjectRef * * \code * * // Create a base object * class BaseObj : public Object { * public: * // object fields * int field0; * * // object properties * static constexpr const uint32_t _type_index = TypeIndex::kDynamic; * static constexpr const char* _type_key = "test.BaseObj"; * TVM_DECLARE_BASE_OBJECT_INFO(BaseObj, Object); * }; * * class ObjLeaf : public ObjBase { * public: * // fields * int child_field0; * // object properties * static constexpr const uint32_t _type_index = TypeIndex::kDynamic; * static constexpr const char* _type_key = "test.LeafObj"; * TVM_DECLARE_BASE_OBJECT_INFO(LeaffObj, Object); * }; * * // The following code should be put into a cc file. * TVM_REGISTER_OBJECT_TYPE(ObjBase); * TVM_REGISTER_OBJECT_TYPE(ObjLeaf); * * // Usage example. * void TestObjects() { * // create an object * ObjectRef leaf_ref(make_object<LeafObj>()); * // cast to a specific instance * const LeafObj* leaf_ptr = leaf_ref.as<LeafObj>(); * CHECK(leaf_ptr != nullptr); * // can also cast to the base class. * CHECK(leaf_ref.as<BaseObj>() != nullptr); * } * * \endcode */ class Object { public: /*! * \brief Object deleter * \param self pointer to the Object. */ typedef void (*FDeleter)(Object* self); /*! \return The internal type index of the object. */ uint32_t type_index() const { return type_index_; } /*! * Check if the object is an instance of TargetType. * \tparam TargetType The target type to be checked. * \return Whether the target type is true. */ template<typename TargetType> inline bool IsInstance() const; #if TVM_OBJECT_ATOMIC_REF_COUNTER using RefCounterType = std::atomic<int32_t>; #else using RefCounterType = int32_t; #endif // Object type properties static constexpr const char* _type_key = "Object"; static constexpr bool _type_final = false; static constexpr uint32_t _type_child_slots = 0; static constexpr bool _type_child_slots_can_overflow = true; static const uint32_t _GetOrAllocRuntimeTypeIndex() { return 0; } protected: // The fields of the base object cell. /*! \brief Type index(tag) that indicates the type of the object. */ uint32_t type_index_{0}; /*! \brief The internal reference counter */ RefCounterType ref_counter_{0}; /*! * \brief deleter of this object to enable customized allocation. * If the deleter is nullptr, no deletion will be performed. * The creator of the object must always set the deleter field properly. */ FDeleter deleter_ = nullptr; // Invariant checks. static_assert(sizeof(int32_t) == sizeof(RefCounterType) && alignof(int32_t) == sizeof(RefCounterType), "RefCounter ABI check."); /*! * \brief Get the type index using type key. * * When the function is first time called for a type, * it will register the type to the type table in the runtime. * If the static_tindex is TypeIndex::kDynamic, the function will * allocate a runtime type index. * Otherwise, we will populate the type table and return the static index. * * \param key the type key. * \param static_tindex The current _type_index field. * can be TypeIndex::kDynamic. * \param parent_tindex The index of the parent. * \param type_child_slots Number of slots reserved for its children. * \param type_child_slots_can_overflow Whether to allow child to overflow the slots. * \return The allocated type index. */ TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex( const char* key, uint32_t static_tindex, uint32_t parent_tindex, uint32_t type_child_slots, bool type_child_slots_can_overflow); /*! * \brief Get the type key of the corresponding index from runtime. * \param tindex The type index. */ TVM_DLL static std::string TypeIndex2Key(uint32_t tindex); /*! * \brief Get the type index of the corresponding key from runtime. * \param key The type key. */ TVM_DLL static uint32_t TypeKey2Index(const char* key); private: // reference counter related operations /*! \brief developer function, increases reference counter. */ inline void IncRef(); /*! * \brief developer function, decrease reference counter. * \note The deleter will be called when ref_counter_ becomes zero. */ inline void DecRef(); /*! * \return The usage count of the cell. * \note We use stl style naming to be consistent with known API in shared_ptr. */ inline int use_count() const; /*! * \brief Check of this object is derived from the parent. * \param parent_tindex The parent type index. * \return The derivation results. */ TVM_DLL bool DerivedFrom(uint32_t parent_tindex) const; // friend classes template<typename> friend class ObjAllocatorBase; template<typename> friend class ObjectPtr; friend class TVMRetValue; friend class TVMObjectCAPI; }; /*! * \brief A custom smart pointer for Object. * \tparam T the content data type. * \sa make_object */ template <typename T> class ObjectPtr { public: /*! \brief default constructor */ ObjectPtr() {} /*! \brief default constructor */ ObjectPtr(std::nullptr_t) {} // NOLINT(*) /*! * \brief copy constructor * \param other The value to be moved */ ObjectPtr(const ObjectPtr<T>& other) // NOLINT(*) : ObjectPtr(other.data_) {} /*! * \brief copy constructor * \param other The value to be moved */ template <typename U> ObjectPtr(const ObjectPtr<U>& other) // NOLINT(*) : ObjectPtr(other.data_) { static_assert(std::is_base_of<T, U>::value, "can only assign of child class ObjectPtr to parent"); } /*! * \brief move constructor * \param other The value to be moved */ ObjectPtr(ObjectPtr<T>&& other) // NOLINT(*) : data_(other.data_) { other.data_ = nullptr; } /*! * \brief move constructor * \param other The value to be moved */ template <typename Y> ObjectPtr(ObjectPtr<Y>&& other) // NOLINT(*) : data_(other.data_) { static_assert(std::is_base_of<T, Y>::value, "can only assign of child class ObjectPtr to parent"); other.data_ = nullptr; } /*! \brief destructor */ ~ObjectPtr() { this->reset(); } /*! * \brief Swap this array with another Object * \param other The other Object */ void swap(ObjectPtr<T>& other) { // NOLINT(*) std::swap(data_, other.data_); } /*! * \return Get the content of the pointer */ T* get() const { return static_cast<T*>(data_); } /*! * \return The pointer */ T* operator->() const { return get(); } /*! * \return The reference */ T& operator*() const { // NOLINT(*) return *get(); } /*! * \brief copy assignmemt * \param other The value to be assigned. * \return reference to self. */ ObjectPtr<T>& operator=(const ObjectPtr<T>& other) { // NOLINT(*) // takes in plane operator to enable copy elison. // copy-and-swap idiom ObjectPtr(other).swap(*this); // NOLINT(*) return *this; } /*! * \brief move assignmemt * \param other The value to be assigned. * \return reference to self. */ ObjectPtr<T>& operator=(ObjectPtr<T>&& other) { // NOLINT(*) // copy-and-swap idiom ObjectPtr(std::move(other)).swap(*this); // NOLINT(*) return *this; } /*! \brief reset the content of ptr to be nullptr */ void reset() { if (data_ != nullptr) { data_->DecRef(); data_ = nullptr; } } /*! \return The use count of the ptr, for debug purposes */ int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } /*! \return whether the reference is unique */ bool unique() const { return data_ != nullptr && data_->use_count() == 1; } /*! \return Whether two ObjectPtr do not equal each other */ bool operator==(const ObjectPtr<T>& other) const { return data_ == other.data_; } /*! \return Whether two ObjectPtr equals each other */ bool operator!=(const ObjectPtr<T>& other) const { return data_ != other.data_; } /*! \return Whether the pointer is nullptr */ bool operator==(std::nullptr_t null) const { return data_ == nullptr; } /*! \return Whether the pointer is not nullptr */ bool operator!=(std::nullptr_t null) const { return data_ != nullptr; } private: /*! \brief internal pointer field */ Object* data_{nullptr}; /*! * \brief constructor from NodeBase * \param data The data pointer */ explicit ObjectPtr(Object* data) : data_(data) { if (data != nullptr) { data_->IncRef(); } } // friend classes friend class Object; friend class ObjectRef; template<typename> friend class ObjectPtr; template<typename> friend class ObjAllocatorBase; friend class TVMPODValue_; friend class TVMArgsSetter; friend class TVMRetValue; }; /*! \brief Base class of all object reference */ class ObjectRef { public: /*! \brief default constructor */ ObjectRef() = default; /*! \brief Constructor from existing object ptr */ explicit ObjectRef(ObjectPtr<Object> data) : data_(data) {} /*! \return the internal object pointer */ inline const Object* get() const; /*! \return the internal node pointer */ inline const Object* operator->() const; /*! * \brief Try to downcast the internal Object to a * raw pointer of a corresponding type. * * The function will return a nullptr if the cast failed. * * if (const Add *add = node_ref.As<Add>()) { * // This is an add node * } * \tparam ObjectType the target type, must be a subtype of Object/ */ template <typename ObjectType> inline const ObjectType* as() const; /*! \brief type indicate the container type */ using ContainerType = Object; protected: /*! \brief Internal pointer that backs the reference. */ ObjectPtr<Object> data_; // friend classes. friend class TVMRetValue; friend class TVMArgsSetter; }; /*! * \brief helper macro to declare a base object type that can be inheritated. * \param TypeName The name of the current type. * \param ParentType The name of the ParentType */ #define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ static const uint32_t type_index() { \ if (_type_index != TypeIndex::kDynamic) return _type_index; \ return _GetOrAllocRuntimeTypeIndex(); \ } \ static const uint32_t _GetOrAllocRuntimeTypeIndex() { \ static uint32_t tidx = GetOrAllocRuntimeTypeIndex( \ TypeName::_type_key, \ TypeName::_type_index, \ ParentType::_GetOrAllocRuntimeTypeIndex(), \ TypeName::_type_child_slots, \ TypeName::_type_child_slots_can_overflow); \ return tidx; \ } \ /*! * \brief helper macro to declare type information in a final class. * \param TypeName The name of the current type. * \param ParentType The name of the ParentType */ #define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ static const constexpr bool _type_final = true; \ static const constexpr int _type_child_slots = 0; \ TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ /*! * \brief Helper macro to register the object type to runtime. * Makes sure that the runtime type table is correctly populated. * * Use this macro in the cc file for each terminal class. */ #define TVM_REGISTER_OBJECT_TYPE(TypeName) \ static DMLC_ATTRIBUTE_UNUSED uint32_t __make_Object_tidx ## _ ## TypeName ## __ = \ TypeName::_GetOrAllocRuntimeTypeIndex() #define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ TypeName() {} \ explicit TypeName( \ ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ : ParentType(n) {} \ const ObjectName* operator->() const { \ return static_cast<const ObjectName*>(data_.get()); \ } \ operator bool() const { return data_ != nullptr; } \ using ContainerType = ObjectName; // Implementations details below // Object reference counting. #if TVM_OBJECT_ATOMIC_REF_COUNTER inline void Object::IncRef() { ref_counter_.fetch_add(1, std::memory_order_relaxed); } inline void Object::DecRef() { if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { std::atomic_thread_fence(std::memory_order_acquire); if (this->deleter_ != nullptr) { (*this->deleter_)(this); } } } inline int Object::use_count() const { return ref_counter_.load(std::memory_order_relaxed); } #else inline void Object::IncRef() { ++ref_counter_; } inline void Object::DecRef() { if (--ref_counter == 0) { if (this->deleter_ != nullptr) { (*this->deleter_)(this); } } } inline int Object::use_count() const { return ref_counter_; } #endif // TVM_OBJECT_ATOMIC_REF_COUNTER template<typename TargetType> inline bool Object::IsInstance() const { const Object* self = this; // NOTE: the following code can be optimized by // compiler dead-code elimination for already known constants. if (self != nullptr) { // Everything is a subclass of object. if (std::is_same<TargetType, Object>::value) return true; if (TargetType::_type_final) { // if the target type is a final type // then we only need to check the equivalence. return self->type_index_ == TargetType::type_index(); } else { // if target type is a non-leaf type // Check if type index falls into the range of reserved slots. uint32_t begin = TargetType::type_index(); // The condition will be optimized by constant-folding. if (TargetType::_type_child_slots != 0) { uint32_t end = begin + TargetType::_type_child_slots; if (self->type_index_ >= begin && self->type_index_ < end) return true; } else { if (self->type_index_ == begin) return true; } if (!TargetType::_type_child_slots_can_overflow) return false; // Invariance: parent index is always smaller than the child. if (self->type_index_ < TargetType::type_index()) return false; // The rare slower-path, check type hierachy. return self->DerivedFrom(TargetType::type_index()); } } else { return false; } } inline const Object* ObjectRef::get() const { return data_.data_; } inline const Object* ObjectRef::operator->() const { return get(); } template <typename ObjectType> inline const ObjectType* ObjectRef::as() const { if (data_ != nullptr && data_->IsInstance<ObjectType>()) { return static_cast<ObjectType*>(data_.get()); } else { return nullptr; } } } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_OBJECT_H_