/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /* * \file src/runtime/object.cc * \brief Object type management system. */ #include <dmlc/logging.h> #include <tvm/runtime/object.h> #include <mutex> #include <string> #include <vector> #include <utility> #include <unordered_map> #include "object_internal.h" #include "runtime_base.h" namespace tvm { namespace runtime { /*! \brief Type information */ struct TypeInfo { /*! \brief The current index. */ uint32_t index{0}; /*! \brief Index of the parent in the type hierachy */ uint32_t parent_index{0}; // NOTE: the indices in [index, index + num_reserved_slots) are // reserved for the child-class of this type. /*! \brief Total number of slots reserved for the type and its children. */ uint32_t num_slots{0}; /*! \brief number of allocated child slots. */ uint32_t allocated_slots{0}; /*! \brief Whether child can overflow. */ bool child_slots_can_overflow{true}; /*! \brief name of the type. */ std::string name; /*! \brief hash of the name */ size_t name_hash{0}; }; /*! * \brief Type context that manages the type hierachy information. */ class TypeContext { public: // NOTE: this is a relatively slow path for child checking // Most types are already checked by the fast-path via reserved slot checking. bool DerivedFrom(uint32_t child_tindex, uint32_t parent_tindex) { // invariance: child's type index is always bigger than its parent. if (child_tindex < parent_tindex) return false; if (child_tindex == parent_tindex) return true; { std::lock_guard<std::mutex> lock(mutex_); CHECK_LT(child_tindex, type_table_.size()); while (child_tindex > parent_tindex) { child_tindex = type_table_[child_tindex].parent_index; } } return child_tindex == parent_tindex; } uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey, uint32_t static_tindex, uint32_t parent_tindex, uint32_t num_child_slots, bool child_slots_can_overflow) { std::lock_guard<std::mutex> lock(mutex_); auto it = type_key2index_.find(skey); if (it != type_key2index_.end()) { return it->second; } // try to allocate from parent's type table. CHECK_LT(parent_tindex, type_table_.size()); TypeInfo& pinfo = type_table_[parent_tindex]; CHECK_EQ(pinfo.index, parent_tindex); // if parent cannot overflow, then this class cannot. if (!pinfo.child_slots_can_overflow) { child_slots_can_overflow = false; } // total number of slots include the type itself. uint32_t num_slots = num_child_slots + 1; uint32_t allocated_tindex; if (static_tindex != TypeIndex::kDynamic) { // statically assigned type allocated_tindex = static_tindex; CHECK_LT(static_tindex, type_table_.size()); CHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U) << "Conflicting static index " << static_tindex << " between " << type_table_[allocated_tindex].name << " and " << skey; } else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) { // allocate the slot from parent's reserved pool allocated_tindex = parent_tindex + pinfo.allocated_slots; // update parent's state pinfo.allocated_slots += num_slots; } else { CHECK(pinfo.child_slots_can_overflow) << "Reach maximum number of sub-classes for " << pinfo.name; // allocate new entries. allocated_tindex = type_counter_; type_counter_ += num_slots; CHECK_LE(type_table_.size(), allocated_tindex); type_table_.resize(allocated_tindex + 1, TypeInfo()); } CHECK_GT(allocated_tindex, parent_tindex); // initialize the slot. type_table_[allocated_tindex].index = allocated_tindex; type_table_[allocated_tindex].parent_index = parent_tindex; type_table_[allocated_tindex].num_slots = num_slots; type_table_[allocated_tindex].allocated_slots = 1; type_table_[allocated_tindex].child_slots_can_overflow = child_slots_can_overflow; type_table_[allocated_tindex].name = skey; type_table_[allocated_tindex].name_hash = std::hash<std::string>()(skey); // update the key2index mapping. type_key2index_[skey] = allocated_tindex; return allocated_tindex; } std::string TypeIndex2Key(uint32_t tindex) { std::lock_guard<std::mutex> lock(mutex_); CHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0) << "Unknown type index " << tindex; return type_table_[tindex].name; } size_t TypeIndex2KeyHash(uint32_t tindex) { std::lock_guard<std::mutex> lock(mutex_); CHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0) << "Unknown type index " << tindex; return type_table_[tindex].name_hash; } uint32_t TypeKey2Index(const std::string& skey) { auto it = type_key2index_.find(skey); CHECK(it != type_key2index_.end()) << "Cannot find type " << skey << ". Did you forget to register the node by TVM_REGISTER_NODE_TYPE ?"; return it->second; } static TypeContext* Global() { static TypeContext inst; return &inst; } private: TypeContext() { type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo()); } // mutex to avoid registration from multiple threads. std::mutex mutex_; std::atomic<uint32_t> type_counter_{TypeIndex::kStaticIndexEnd}; std::vector<TypeInfo> type_table_; std::unordered_map<std::string, uint32_t> type_key2index_; }; uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex, uint32_t parent_tindex, uint32_t num_child_slots, bool child_slots_can_overflow) { return TypeContext::Global()->GetOrAllocRuntimeTypeIndex( key, static_tindex, parent_tindex, num_child_slots, child_slots_can_overflow); } bool Object::DerivedFrom(uint32_t parent_tindex) const { return TypeContext::Global()->DerivedFrom( this->type_index_, parent_tindex); } std::string Object::TypeIndex2Key(uint32_t tindex) { return TypeContext::Global()->TypeIndex2Key(tindex); } size_t Object::TypeIndex2KeyHash(uint32_t tindex) { return TypeContext::Global()->TypeIndex2KeyHash(tindex); } uint32_t Object::TypeKey2Index(const std::string& key) { return TypeContext::Global()->TypeKey2Index(key); } } // namespace runtime } // namespace tvm int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) { API_BEGIN(); CHECK(obj != nullptr); out_tindex[0] = static_cast<tvm::runtime::Object*>(obj)->type_index(); API_END(); } int TVMObjectFree(TVMObjectHandle obj) { API_BEGIN(); tvm::runtime::ObjectInternal::ObjectFree(obj); API_END(); } int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { API_BEGIN(); out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index( type_key); API_END(); }