object.h 28.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
/*!
 * \file tvm/runtime/object.h
 * \brief A managed object in the TVM runtime.
 */
#ifndef TVM_RUNTIME_OBJECT_H_
#define TVM_RUNTIME_OBJECT_H_

26
#include <dmlc/logging.h>
27
#include <tvm/runtime/c_runtime_api.h>
28 29
#include <type_traits>
#include <string>
30
#include <utility>
31

32 33 34 35 36 37 38 39 40 41 42 43 44
/*!
 * \brief Whether or not use atomic reference counter.
 *  If the reference counter is not atomic,
 *  an object cannot be owned by multiple threads.
 *  We can, however, move an object across threads
 */
#ifndef TVM_OBJECT_ATOMIC_REF_COUNTER
#define TVM_OBJECT_ATOMIC_REF_COUNTER 1
#endif

#if TVM_OBJECT_ATOMIC_REF_COUNTER
#include <atomic>
#endif  // TVM_OBJECT_ATOMIC_REF_COUNTER
45 46 47 48

namespace tvm {
namespace runtime {

49 50 51 52
/*! \brief list of the type index. */
enum TypeIndex  {
  /*! \brief Root object type. */
  kRoot = 0,
53 54 55
  kClosure = 1,
  kVMADT = 2,
  kRuntimeModule = 3,
56 57 58
  kStaticIndexEnd,
  /*! \brief Type index is allocated during runtime. */
  kDynamic = kStaticIndexEnd
59 60
};

61 62 63 64 65 66 67 68
/*!
 * \brief base class of all object containers.
 *
 * Sub-class of objects should declare the following static constexpr fields:
 *
 * - _type_index:
 *      Static type index of the object, if assigned to TypeIndex::kDynamic
 *      the type index will be assigned during runtime.
69
 *      Runtime type index can be accessed by ObjectType::TypeIndex();
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
 * - _type_key:
 *       The unique string identifier of tyep type.
 * - _type_final:
 *       Whether the type is terminal type(there is no subclass of the type in the object system).
 *       This field is automatically set by marco TVM_DECLARE_FINAL_OBJECT_INFO
 *       It is still OK to sub-class a terminal object type T and construct it using make_object.
 *       But IsInstance check will only show that the object type is T(instead of the sub-class).
 *
 * The following two fields are necessary for base classes that can be sub-classed.
 *
 * - _type_child_slots:
 *       Number of reserved type index slots for child classes.
 *       Used for runtime optimization for type checking in IsInstance.
 *       If an object's type_index is within range of [type_index, type_index + _type_child_slots]
 *       Then the object can be quickly decided as sub-class of the current object class.
 *       If not, a fallback mechanism is used to check the global type table.
 *       Recommendation: set to estimate number of children needed.
 * - _type_child_slots_can_overflow:
 *       Whether we can add additional child classes even if the number of child classes
 *       exceeds the _type_child_slots. A fallback mechanism to check global type table will be used.
 *       Recommendation: set to false for optimal runtime speed if we know exact number of children.
 *
 * Two macros are used to declare helper functions in the object:
 * - Use TVM_DECLARE_BASE_OBJECT_INFO for object classes that can be sub-classed.
 * - Use TVM_DECLARE_FINAL_OBJECT_INFO for object classes that cannot be sub-classed.
 *
 * New objects can be created using make_object function.
 * Which will automatically populate the type_index and deleter of the object.
 *
 * \sa make_object
 * \sa ObjectPtr
 * \sa ObjectRef
 *
 * \code
 *
 *  // Create a base object
 *  class BaseObj : public Object {
 *   public:
 *    // object fields
 *    int field0;
 *
 *    // object properties
 *    static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
 *    static constexpr const char* _type_key = "test.BaseObj";
 *    TVM_DECLARE_BASE_OBJECT_INFO(BaseObj, Object);
 *  };
 *
 *  class ObjLeaf : public ObjBase {
 *   public:
 *    // fields
 *    int child_field0;
 *    // object properties
 *    static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
 *    static constexpr const char* _type_key = "test.LeafObj";
 *    TVM_DECLARE_BASE_OBJECT_INFO(LeaffObj, Object);
 *  };
 *
 *  // The following code should be put into a cc file.
 *  TVM_REGISTER_OBJECT_TYPE(ObjBase);
 *  TVM_REGISTER_OBJECT_TYPE(ObjLeaf);
 *
 *  // Usage example.
 *  void TestObjects() {
 *    // create an object
 *    ObjectRef leaf_ref(make_object<LeafObj>());
 *    // cast to a specific instance
 *    const LeafObj* leaf_ptr = leaf_ref.as<LeafObj>();
 *    CHECK(leaf_ptr != nullptr);
 *    // can also cast to the base class.
 *    CHECK(leaf_ref.as<BaseObj>() != nullptr);
 *  }
 *
 * \endcode
 */
class Object {
145 146
 public:
  /*!
147 148
   * \brief Object deleter
   * \param self pointer to the Object.
149
   */
150
  typedef void (*FDeleter)(Object* self);
151
  /*! \return The internal runtime type index of the object. */
152 153 154 155
  uint32_t type_index() const {
    return type_index_;
  }
  /*!
156 157 158 159 160 161 162 163 164 165 166 167 168
   * \return the type key of the object.
   * \note this operation is expensive, can be used for error reporting.
   */
  std::string GetTypeKey() const {
    return TypeIndex2Key(type_index_);
  }
  /*!
   * \return A hash value of the return of GetTypeKey.
   */
  size_t GetTypeKeyHash() const {
    return TypeIndex2KeyHash(type_index_);
  }
  /*!
169 170 171
   * Check if the object is an instance of TargetType.
   * \tparam TargetType The target type to be checked.
   * \return Whether the target type is true.
172
   */
173 174 175
  template<typename TargetType>
  inline bool IsInstance() const;

176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
  /*!
   * \brief Get the type key of the corresponding index from runtime.
   * \param tindex The type index.
   * \return the result.
   */
  TVM_DLL static std::string TypeIndex2Key(uint32_t tindex);
  /*!
   * \brief Get the type key hash of the corresponding index from runtime.
   * \param tindex The type index.
   * \return the related key-hash.
   */
  TVM_DLL static size_t TypeIndex2KeyHash(uint32_t tindex);
  /*!
   * \brief Get the type index of the corresponding key from runtime.
   * \param key The type key.
   * \return the result.
   */
193
  TVM_DLL static uint32_t TypeKey2Index(const std::string& key);
194

195 196 197 198 199 200 201
#if TVM_OBJECT_ATOMIC_REF_COUNTER
  using RefCounterType = std::atomic<int32_t>;
#else
  using RefCounterType = int32_t;
#endif

  static constexpr const char* _type_key = "Object";
202

203
  static uint32_t _GetOrAllocRuntimeTypeIndex() {
204
    return TypeIndex::kRoot;
205
  }
206
  static uint32_t RuntimeTypeIndex() {
207
    return TypeIndex::kRoot;
208 209
  }

210 211 212 213
  // Default object type properties for sub-classes
  static constexpr bool _type_final = false;
  static constexpr uint32_t _type_child_slots = 0;
  static constexpr bool _type_child_slots_can_overflow = true;
214 215 216
  // member information
  static constexpr bool _type_has_method_visit_attrs = true;
  static constexpr bool _type_has_method_sequal_reduce = false;
217
  static constexpr bool _type_has_method_shash_reduce = false;
218 219 220 221 222
  // NOTE: the following field is not type index of Object
  // but was intended to be used by sub-classes as default value.
  // The type index of Object is TypeIndex::kRoot
  static constexpr uint32_t _type_index = TypeIndex::kDynamic;

223

224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
  // Default constructor and copy constructor
  Object() {}
  // Override the copy and assign constructors to do nothing.
  // This is to make sure only contents, but not deleter and ref_counter
  // are copied when a child class copies itself.
  // This will enable us to use make_object<ObjectClass>(*obj_ptr)
  // to copy an existing object.
  Object(const Object& other) {  // NOLINT(*)
  }
  Object(Object&& other) {  // NOLINT(*)
  }
  Object& operator=(const Object& other) {  //NOLINT(*)
    return *this;
  }
  Object& operator=(Object&& other) {  //NOLINT(*)
    return *this;
  }
241

242 243 244 245 246 247
 protected:
  // The fields of the base object cell.
  /*! \brief Type index(tag) that indicates the type of the object. */
  uint32_t type_index_{0};
  /*! \brief The internal reference counter */
  RefCounterType ref_counter_{0};
248
  /*!
249 250 251
   * \brief deleter of this object to enable customized allocation.
   * If the deleter is nullptr, no deletion will be performed.
   * The creator of the object must always set the deleter field properly.
252
   */
253 254 255 256 257
  FDeleter deleter_ = nullptr;
  // Invariant checks.
  static_assert(sizeof(int32_t) == sizeof(RefCounterType) &&
                alignof(int32_t) == sizeof(RefCounterType),
                "RefCounter ABI check.");
258 259

  /*!
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
   * \brief Get the type index using type key.
   *
   *  When the function is first time called for a type,
   *  it will register the type to the type table in the runtime.
   *  If the static_tindex is TypeIndex::kDynamic, the function will
   *  allocate a runtime type index.
   *  Otherwise, we will populate the type table and return the static index.
   *
   * \param key the type key.
   * \param static_tindex The current _type_index field.
   *                      can be TypeIndex::kDynamic.
   * \param parent_tindex The index of the parent.
   * \param type_child_slots Number of slots reserved for its children.
   * \param type_child_slots_can_overflow Whether to allow child to overflow the slots.
   * \return The allocated type index.
275
   */
276
  TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex(
277
      const std::string& key,
278 279 280 281
      uint32_t static_tindex,
      uint32_t parent_tindex,
      uint32_t type_child_slots,
      bool type_child_slots_can_overflow);
282

283 284 285
  // reference counter related operations
  /*! \brief developer function, increases reference counter. */
  inline void IncRef();
286
  /*!
287 288
   * \brief developer function, decrease reference counter.
   * \note The deleter will be called when ref_counter_ becomes zero.
289
   */
290
  inline void DecRef();
291 292

 private:
293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
  /*!
   * \return The usage count of the cell.
   * \note We use stl style naming to be consistent with known API in shared_ptr.
   */
  inline int use_count() const;
  /*!
   * \brief Check of this object is derived from the parent.
   * \param parent_tindex The parent type index.
   * \return The derivation results.
   */
  TVM_DLL bool DerivedFrom(uint32_t parent_tindex) const;
  // friend classes
  template<typename>
  friend class ObjAllocatorBase;
  template<typename>
308
  friend class ObjectPtr;
309
  friend class TVMRetValue;
310
  friend class ObjectInternal;
311 312 313
};

/*!
314 315 316 317
 * \brief Get a reference type from a raw object ptr type
 *
 *  It is always important to get a reference type
 *  if we want to return a value as reference or keep
318
 *  the object alive beyond the scope of the function.
319
 *
320
 * \param ptr The object pointer
321
 * \tparam RefType The reference type
322
 * \tparam ObjectType The object type
323 324
 * \return The corresponding RefType
 */
325 326
template <typename RelayRefType, typename ObjectType>
inline RelayRefType GetRef(const ObjectType* ptr);
327 328 329 330 331 332 333 334 335 336 337 338 339

/*!
 * \brief Downcast a base reference type to a more specific type.
 *
 * \param ref The inptut reference
 * \return The corresponding SubRef.
 * \tparam SubRef The target specific reference type.
 * \tparam BaseRef the current reference type.
 */
template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref);

/*!
340 341
 * \brief A custom smart pointer for Object.
 * \tparam T the content data type.
342
 * \sa make_object
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386
 */
template <typename T>
class ObjectPtr {
 public:
  /*! \brief default constructor */
  ObjectPtr() {}
  /*! \brief default constructor */
  ObjectPtr(std::nullptr_t) {}  // NOLINT(*)
  /*!
   * \brief copy constructor
   * \param other The value to be moved
   */
  ObjectPtr(const ObjectPtr<T>& other)  // NOLINT(*)
      : ObjectPtr(other.data_) {}
  /*!
   * \brief copy constructor
   * \param other The value to be moved
   */
  template <typename U>
  ObjectPtr(const ObjectPtr<U>& other)  // NOLINT(*)
      : ObjectPtr(other.data_) {
    static_assert(std::is_base_of<T, U>::value,
                  "can only assign of child class ObjectPtr to parent");
  }
  /*!
   * \brief move constructor
   * \param other The value to be moved
   */
  ObjectPtr(ObjectPtr<T>&& other)  // NOLINT(*)
      : data_(other.data_) {
    other.data_ = nullptr;
  }
  /*!
   * \brief move constructor
   * \param other The value to be moved
   */
  template <typename Y>
  ObjectPtr(ObjectPtr<Y>&& other)  // NOLINT(*)
      : data_(other.data_) {
    static_assert(std::is_base_of<T, Y>::value,
                  "can only assign of child class ObjectPtr to parent");
    other.data_ = nullptr;
  }
  /*! \brief destructor */
387 388 389
  ~ObjectPtr() {
    this->reset();
  }
390 391 392 393 394 395 396 397 398 399
  /*!
   * \brief Swap this array with another Object
   * \param other The other Object
   */
  void swap(ObjectPtr<T>& other) {  // NOLINT(*)
    std::swap(data_, other.data_);
  }
  /*!
   * \return Get the content of the pointer
   */
400 401 402
  T* get() const {
    return static_cast<T*>(data_);
  }
403 404 405
  /*!
   * \return The pointer
   */
406 407 408
  T* operator->() const {
    return get();
  }
409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
  /*!
   * \return The reference
   */
  T& operator*() const {  // NOLINT(*)
    return *get();
  }
  /*!
   * \brief copy assignmemt
   * \param other The value to be assigned.
   * \return reference to self.
   */
  ObjectPtr<T>& operator=(const ObjectPtr<T>& other) {  // NOLINT(*)
    // takes in plane operator to enable copy elison.
    // copy-and-swap idiom
    ObjectPtr(other).swap(*this);  // NOLINT(*)
    return *this;
  }
  /*!
   * \brief move assignmemt
   * \param other The value to be assigned.
   * \return reference to self.
   */
  ObjectPtr<T>& operator=(ObjectPtr<T>&& other) {  // NOLINT(*)
    // copy-and-swap idiom
    ObjectPtr(std::move(other)).swap(*this);  // NOLINT(*)
    return *this;
  }
  /*! \brief reset the content of ptr to be nullptr */
  void reset() {
    if (data_ != nullptr) {
      data_->DecRef();
      data_ = nullptr;
    }
  }
  /*! \return The use count of the ptr, for debug purposes */
444 445 446
  int use_count() const {
    return data_ != nullptr ? data_->use_count() : 0;
  }
447
  /*! \return whether the reference is unique */
448 449 450
  bool unique() const {
    return data_ != nullptr && data_->use_count() == 1;
  }
451
  /*! \return Whether two ObjectPtr do not equal each other */
452 453 454
  bool operator==(const ObjectPtr<T>& other) const {
    return data_ == other.data_;
  }
455
  /*! \return Whether two ObjectPtr equals each other */
456 457 458
  bool operator!=(const ObjectPtr<T>& other) const {
    return data_ != other.data_;
  }
459
  /*! \return Whether the pointer is nullptr */
460 461 462
  bool operator==(std::nullptr_t null) const {
    return data_ == nullptr;
  }
463
  /*! \return Whether the pointer is not nullptr */
464 465
  bool operator!=(std::nullptr_t null) const {
    return data_ != nullptr;
466 467 468 469
  }

 private:
  /*! \brief internal pointer field */
470
  Object* data_{nullptr};
471
  /*!
472
   * \brief constructor from Object
473
   * \param data The data pointer
474
   */
475
  explicit ObjectPtr(Object* data) : data_(data) {
476 477 478 479
    if (data != nullptr) {
      data_->IncRef();
    }
  }
480 481 482
  // friend classes
  friend class Object;
  friend class ObjectRef;
483
  friend struct ObjectHash;
484
  template<typename>
485
  friend class ObjectPtr;
486 487
  template<typename>
  friend class ObjAllocatorBase;
488
  friend class TVMPODValue_;
489
  friend class TVMArgsSetter;
490
  friend class TVMRetValue;
491
  friend class TVMArgValue;
492 493
  template <typename RelayRefType, typename ObjType>
  friend RelayRefType GetRef(const ObjType* ptr);
494 495
  template <typename BaseType, typename ObjType>
  friend ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr);
496 497
};

498 499 500 501 502 503 504
/*! \brief Base class of all object reference */
class ObjectRef {
 public:
  /*! \brief default constructor */
  ObjectRef() = default;
  /*! \brief Constructor from existing object ptr */
  explicit ObjectRef(ObjectPtr<Object> data) : data_(data) {}
505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522
  /*!
   * \brief Comparator
   * \param other Another object ref.
   * \return the compare result.
   */
  bool same_as(const ObjectRef& other) const {
    return data_ == other.data_;
  }
  /*!
   * \brief Comparator
   * \param other Another object ref.
   * \return the compare result.
   */
  bool operator==(const ObjectRef& other) const {
    return data_ == other.data_;
  }
  /*!
   * \brief Comparator
523
   * \param other Another object ref.
524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540
   * \return the compare result.
   */
  bool operator!=(const ObjectRef& other) const {
    return data_ != other.data_;
  }
  /*!
   * \brief Comparator
   * \param other Another object ref by address.
   * \return the compare result.
   */
  bool operator<(const ObjectRef& other) const {
    return data_.get() < other.data_.get();
  }
  /*! \return whether the expression is null */
  bool defined() const {
    return data_ != nullptr;
  }
541
  /*! \return the internal object pointer */
542 543 544
  const Object* get() const {
    return data_.get();
  }
545
  /*! \return the internal object pointer */
546 547 548 549 550 551 552
  const Object* operator->() const {
    return get();
  }
  /*! \return whether the reference is unique */
  bool unique() const {
    return data_.unique();
  }
553 554 555 556 557 558 559 560 561 562 563 564 565 566
  /*!
   * \brief Try to downcast the internal Object to a
   *  raw pointer of a corresponding type.
   *
   *  The function will return a nullptr if the cast failed.
   *
   * if (const Add *add = node_ref.As<Add>()) {
   *   // This is an add node
   * }
   * \tparam ObjectType the target type, must be a subtype of Object/
   */
  template <typename ObjectType>
  inline const ObjectType* as() const;

567
  /*! \brief type indicate the container type. */
568 569 570 571 572
  using ContainerType = Object;

 protected:
  /*! \brief Internal pointer that backs the reference. */
  ObjectPtr<Object> data_;
573 574 575 576 577 578 579 580 581 582 583 584 585 586 587
  /*! \return return a mutable internal ptr, can be used by sub-classes. */
  Object* get_mutable() const {
    return data_.get();
  }
  /*!
   * \brief Internal helper function downcast a ref without check.
   * \note Only used for internal dev purposes.
   * \tparam T The target reference type.
   * \return The casted result.
   */
  template<typename T>
  static T DowncastNoCheck(ObjectRef ref) {
    return T(std::move(ref.data_));
  }
  /*!
588 589 590 591 592 593 594 595
   * \brief Clear the object ref data field without DecRef
   *        after we successfully moved the field.
   * \param ref The reference data.
   */
  static void FFIClearAfterMove(ObjectRef* ref) {
    ref->data_.data_ = nullptr;
  }
  /*!
596 597 598 599 600 601 602 603 604
   * \brief Internal helper function get data_ as ObjectPtr of ObjectType.
   * \note only used for internal dev purpose.
   * \tparam ObjectType The corresponding object type.
   * \return the corresponding type.
   */
  template<typename ObjectType>
  static ObjectPtr<ObjectType> GetDataPtr(const ObjectRef& ref) {
    return ObjectPtr<ObjectType>(ref.data_.data_);
  }
605
  // friend classes.
606
  friend struct ObjectHash;
607 608
  friend class TVMRetValue;
  friend class TVMArgsSetter;
609 610
  template <typename SubRef, typename BaseRef>
  friend SubRef Downcast(BaseRef ref);
611
};
612

613 614 615 616 617 618 619 620 621 622
/*!
 * \brief Get an object ptr type from a raw object ptr.
 *
 * \param ptr The object pointer
 * \tparam BaseType The reference type
 * \tparam ObjectType The object type
 * \return The corresponding RefType
 */
template <typename BaseType, typename ObjectType>
inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr);
623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649

/*! \brief ObjectRef hash functor */
struct ObjectHash {
  size_t operator()(const ObjectRef& a) const {
    return operator()(a.data_);
  }

  template<typename T>
  size_t operator()(const ObjectPtr<T>& a) const {
    return std::hash<Object*>()(a.get());
  }
};


/*! \brief ObjectRef equal functor */
struct ObjectEqual {
  bool operator()(const ObjectRef& a, const ObjectRef& b) const {
    return a.same_as(b);
  }

  template<typename T>
  size_t operator()(const ObjectPtr<T>& a, const ObjectPtr<T>& b) const {
    return a == b;
  }
};


650
/*!
651 652 653 654 655
 * \brief helper macro to declare a base object type that can be inheritated.
 * \param TypeName The name of the current type.
 * \param ParentType The name of the ParentType
 */
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)              \
656
  static_assert(!ParentType::_type_final, "ParentObj maked as final");  \
657
  static uint32_t RuntimeTypeIndex()  {                                 \
658 659
    if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
      return TypeName::_type_index;                                     \
660
    }                                                                   \
661 662
    return _GetOrAllocRuntimeTypeIndex();                               \
  }                                                                     \
663
  static uint32_t _GetOrAllocRuntimeTypeIndex()  {                      \
664
    static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex(          \
665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683
        TypeName::_type_key,                                            \
        TypeName::_type_index,                                          \
        ParentType::_GetOrAllocRuntimeTypeIndex(),                      \
        TypeName::_type_child_slots,                                    \
        TypeName::_type_child_slots_can_overflow);                      \
    return tidx;                                                        \
  }                                                                     \

/*!
 * \brief helper macro to declare type information in a final class.
  * \param TypeName The name of the current type.
  * \param ParentType The name of the ParentType
  */
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)             \
  static const constexpr bool _type_final = true;                       \
  static const constexpr int _type_child_slots = 0;                     \
  TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)                    \


684 685 686 687 688 689 690 691 692 693 694 695 696
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define TVM_ATTRIBUTE_UNUSED
#endif

#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)

#define TVM_OBJECT_REG_VAR_DEF                              \
  static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid

697 698 699
/*!
 * \brief Helper macro to register the object type to runtime.
 *  Makes sure that the runtime type table is correctly populated.
700
 *
701
 *  Use this macro in the cc file for each terminal class.
702
 */
703
#define TVM_REGISTER_OBJECT_TYPE(TypeName)                              \
704
  TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) =                 \
705
      TypeName::_GetOrAllocRuntimeTypeIndex()
706

707 708 709 710 711 712
/*
 * \brief Define object reference methods.
 * \param TypeName The object type name
 * \param ParentType The parent type of the objectref
 * \param ObjectName The type name of the object.
 */
713 714 715 716 717 718 719 720 721
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
  TypeName() {}                                                         \
  explicit TypeName(                                                    \
      ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n)              \
      : ParentType(n) {}                                                \
  const ObjectName* operator->() const {                                \
    return static_cast<const ObjectName*>(data_.get());                 \
  }                                                                     \
  using ContainerType = ObjectName;
722

723 724 725 726 727 728 729 730 731
/*
 * \brief Define object reference methods of whose content is mutable.
 * \param TypeName The object type name
 * \param ParentType The parent type of the objectref
 * \param ObjectName The type name of the object.
 * \note We recommend making objects immutable when possible.
 *       This macro is only reserved for objects that stores runtime states.
 */
#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
732 733 734 735
  TypeName() {}                                                         \
  explicit TypeName(                                                    \
      ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n)              \
      : ParentType(n) {}                                                \
736
  ObjectName* operator->() const {                                      \
737 738
    return static_cast<ObjectName*>(data_.get());                       \
  }                                                                     \
739
  using ContainerType = ObjectName;
740

741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769
/*!
 * \brief Define CopyOnWrite function in an ObjectRef.
 * \param ObjectName The Type of the Node.
 *
 *  CopyOnWrite will generate a unique copy of the internal node.
 *  The node will be copied if it is referenced by multiple places.
 *  The function returns the raw pointer to the node to allow modification
 *  of the content.
 *
 * \code
 *
 *  MyCOWObjectRef ref, ref2;
 *  ref2 = ref;
 *  ref.CopyOnWrite()->value = new_value;
 *  assert(ref2->value == old_value);
 *  assert(ref->value == new_value);
 *
 * \endcode
 */
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)                    \
  ObjectName* CopyOnWrite() {                                           \
      CHECK(data_ != nullptr);                                          \
      if (!data_.unique())  {                                           \
        auto n = make_object<ObjectName>(*(operator->()));              \
        ObjectPtr<Object>(std::move(n)).swap(data_);                    \
      }                                                                 \
      return static_cast<ObjectName*>(data_.get());                     \
    }

770 771 772
// Implementations details below
// Object reference counting.
#if TVM_OBJECT_ATOMIC_REF_COUNTER
773

774 775 776
inline void Object::IncRef() {
  ref_counter_.fetch_add(1, std::memory_order_relaxed);
}
777

778 779 780 781 782 783 784 785
inline void Object::DecRef() {
  if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
    std::atomic_thread_fence(std::memory_order_acquire);
    if (this->deleter_ != nullptr) {
      (*this->deleter_)(this);
    }
  }
}
786

787 788 789 790 791 792 793 794
inline int Object::use_count() const {
  return ref_counter_.load(std::memory_order_relaxed);
}

#else

inline void Object::IncRef() {
  ++ref_counter_;
795 796
}

797
inline void Object::DecRef() {
798
  if (--ref_counter_ == 0) {
799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821
    if (this->deleter_ != nullptr) {
      (*this->deleter_)(this);
    }
  }
}

inline int Object::use_count() const {
  return ref_counter_;
}

#endif  // TVM_OBJECT_ATOMIC_REF_COUNTER

template<typename TargetType>
inline bool Object::IsInstance() const {
  const Object* self = this;
  // NOTE: the following code can be optimized by
  // compiler dead-code elimination for already known constants.
  if (self != nullptr) {
    // Everything is a subclass of object.
    if (std::is_same<TargetType, Object>::value) return true;
    if (TargetType::_type_final) {
      // if the target type is a final type
      // then we only need to check the equivalence.
822
      return self->type_index_ == TargetType::RuntimeTypeIndex();
823 824 825
    } else {
      // if target type is a non-leaf type
      // Check if type index falls into the range of reserved slots.
826
      uint32_t begin = TargetType::RuntimeTypeIndex();
827 828 829 830 831 832 833 834 835
      // The condition will be optimized by constant-folding.
      if (TargetType::_type_child_slots != 0) {
        uint32_t end = begin + TargetType::_type_child_slots;
        if (self->type_index_ >= begin && self->type_index_ < end) return true;
      } else {
        if (self->type_index_ == begin) return true;
      }
      if (!TargetType::_type_child_slots_can_overflow) return false;
      // Invariance: parent index is always smaller than the child.
836
      if (self->type_index_ < TargetType::RuntimeTypeIndex()) return false;
837
      // The rare slower-path, check type hierachy.
838
      return self->DerivedFrom(TargetType::RuntimeTypeIndex());
839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854
    }
  } else {
    return false;
  }
}


template <typename ObjectType>
inline const ObjectType* ObjectRef::as() const {
  if (data_ != nullptr &&
      data_->IsInstance<ObjectType>()) {
    return static_cast<ObjectType*>(data_.get());
  } else {
    return nullptr;
  }
}
855

856 857 858
template <typename RelayRefType, typename ObjType>
inline RelayRefType GetRef(const ObjType* ptr) {
  static_assert(std::is_base_of<typename RelayRefType::ContainerType, ObjType>::value,
859
                "Can only cast to the ref of same container type");
860
  return RelayRefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
861 862
}

863 864 865 866 867 868 869
template <typename BaseType, typename ObjType>
inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {
  static_assert(std::is_base_of<BaseType, ObjType>::value,
                "Can only cast to the ref of same container type");
  return ObjectPtr<BaseType>(static_cast<Object*>(ptr));
}

870 871 872 873 874 875 876 877
template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref) {
  CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
      << "Downcast from " << ref->GetTypeKey() << " to "
      << SubRef::ContainerType::_type_key << " failed.";
  return SubRef(std::move(ref.data_));
}

878 879
}  // namespace runtime
}  // namespace tvm
880

881
#endif  // TVM_RUNTIME_OBJECT_H_