/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tvm/ir/attrs.h * \brief Helpers for attribute objects. * * This module enables declaration of named attributes * which support default value setup and bound checking. * * \code * struct MyAttrs : public tvm::AttrsNode<MyAttrs> { * float learning_rate; * int num_hidden; * std::string name; * // declare attribute fields in header file * TVM_DECLARE_ATTRS(MyAttrs, "attrs.MyAttrs") { * TVM_ATTR_FIELD(num_hidden).set_lower_bound(1); * TVM_ATTR_FIELD(learning_rate).set_default(0.01f); * TVM_ATTR_FIELD(name).set_default("hello"); * } * }; * // register it in cc file * TVM_REGISTER_NODE_TYPE(MyAttrs); * \endcode * * \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD */ #ifndef TVM_IR_ATTRS_H_ #define TVM_IR_ATTRS_H_ #include <dmlc/common.h> #include <tvm/ir/expr.h> #include <tvm/node/structural_equal.h> #include <tvm/node/structural_hash.h> #include <tvm/runtime/packed_func.h> #include <unordered_map> #include <vector> #include <functional> #include <type_traits> #include <string> #include <utility> namespace tvm { /*! * \brief Declare an attribute function. * \param ClassName The name of the class. * \param TypeKey The type key to be used by the TVM node system. */ #define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ static constexpr const char* _type_key = TypeKey; \ TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \ template<typename FVisit> \ void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*) /*! * \brief Declare an attribute field. * \param FieldName The field name. */ #define TVM_ATTR_FIELD(FieldName) \ __fvisit__(#FieldName, &FieldName) /*! * \brief Create a NodeRef type that represents null. * \tparam TNodeRef the type to be created. * \return A instance that will represent None. */ template<typename TObjectRef> inline TObjectRef NullValue() { static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types"); return TObjectRef(ObjectPtr<Object>(nullptr)); } template<> inline DataType NullValue<DataType>() { return DataType(DataType::kHandle, 0, 0); } /*! \brief Error thrown during attribute checking. */ struct AttrError : public dmlc::Error { /*! * \brief constructor * \param msg error message */ explicit AttrError(const std::string &msg) : dmlc::Error(msg) {} }; /*! * \brief Information about attribute fields in string representations. */ class AttrFieldInfoNode : public Object { public: /*! \brief name of the field */ std::string name; /*! \brief type docstring information in str. */ std::string type_info; /*! \brief detailed description of the type */ std::string description; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("type_info", &type_info); v->Visit("description", &description); } static constexpr const char* _type_key = "AttrFieldInfo"; static constexpr bool _type_has_method_sequal_reduce = false; static constexpr bool _type_has_method_shash_reduce = false; TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object); }; /*! \brief AttrFieldInfo */ class AttrFieldInfo : public ObjectRef { public: TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode); }; /*! * \brief Base class of all attribute class * \note Do not subclass AttrBaseNode directly, * subclass AttrsNode instead. * \sa AttrsNode */ class BaseAttrsNode : public Object { public: using TVMArgs = runtime::TVMArgs; using TVMRetValue = runtime::TVMRetValue; /*! \brief virtual destructor */ virtual ~BaseAttrsNode() {} // visit function virtual void VisitAttrs(AttrVisitor* v) {} /*! * \brief Initialize the attributes by sequence of arguments * \param args The postional arguments in the form * [key0, value0, key1, value1, ..., key_n, value_n] */ template<typename... Args> inline void InitBySeq(Args&& ...args); /*! * \brief Print readible docstring to ostream, add newline. * \param os the stream to print the docstring to. */ inline void PrintDocString(std::ostream &os) const; // NOLINT(*) /*! * \brief Visit attributes that do not equal the default value. * * \note This is useful to extract fields for concise printing. * \param v The visitor */ TVM_DLL virtual void VisitNonDefaultAttrs(AttrVisitor* v) = 0; /*! * \brief Get the field information * \return The fields in the Attrs. */ TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0; /*! * \brief Initialize the attributes by arguments. * \param kwargs The key value pairs for initialization. * [key0, value0, key1, value1, ..., key_n, value_n] * \param allow_unknown Whether allow additional unknown fields. * \note This function throws when the required field is not present. */ TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; static constexpr const char* _type_key = "Attrs"; TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object); }; /*! * \brief Managed reference to BaseAttrsNode. * \sa AttrsNode, BaseAttrsNode */ class Attrs : public ObjectRef { public: TVM_DEFINE_OBJECT_REF_METHODS(Attrs, ObjectRef, BaseAttrsNode); }; /*! * \brief Specialized attribute type that is backed by a map. * The DictAttrsNode implements the Attrs behavior, * its fields are directly accessible via object.field_name * like other normal nodes. */ class DictAttrsNode : public BaseAttrsNode { public: /*! \brief internal attrs map */ Map<std::string, ObjectRef> dict; bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const { return equal(dict, other->dict); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dict); } // implementations void VisitAttrs(AttrVisitor* v) final; void VisitNonDefaultAttrs(AttrVisitor* v) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; Array<AttrFieldInfo> ListFieldInfo() const final; // type info static constexpr const char* _type_key = "DictAttrs"; TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode); }; /*! * \brief Managed reference to DictAttrsNode * \sa DictAttrsNode. */ class DictAttrs : public Attrs { public: /*! * \brief Consruct a Attrs backed by DictAttrsNode. * \param dict The attributes. * \return The dict attributes. */ TVM_DLL explicit DictAttrs(Map<std::string, ObjectRef> dict); TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode); }; // Namespace containing detail implementations namespace detail { using runtime::TVMArgValue; // helper entry that does nothing in set_default/bound/describe calls. struct AttrNopEntry { using TSelf = AttrNopEntry; TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } template<typename T> TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) { return *this; } template<typename T> TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; } template<typename T> TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; } }; // Wrapper for normal visitor. class AttrNormalVisitor { public: explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) { } template<typename T> AttrNopEntry operator()(const char* key, T* value) { visitor_->Visit(key, value); return AttrNopEntry(); } private: AttrVisitor* visitor_; }; class AttrsSEqualVisitor { public: bool result_{true}; // constructor AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal) : lhs_(lhs), rhs_(rhs), equal_(equal) { } template<typename T> AttrNopEntry operator()(const char* key, T* lhs_value) { if (!result_) return AttrNopEntry(); const T* rhs_value = reinterpret_cast<const T*>( reinterpret_cast<const char*>(rhs_) + (reinterpret_cast<const char*>(lhs_value) - reinterpret_cast<const char*>(lhs_))); if (!equal_(*lhs_value, *rhs_value)) { result_ = false; } return AttrNopEntry(); } private: const Object* lhs_; const Object* rhs_; const SEqualReducer& equal_; }; class AttrsSHashVisitor { public: explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) : hash_reducer_(hash_reducer) {} template<typename T> AttrNopEntry operator()(const char* key, T* value) { hash_reducer_(*value); return AttrNopEntry(); } private: const SHashReducer& hash_reducer_; }; // helper entry that does initialization, set default. template<typename T> struct AttrInitEntry { // The attributes using TSelf = AttrInitEntry<T>; // The type key const char* type_key_; // field name const char* key_; // internal value. T* value_; // whether the value is missing. bool value_missing_{true}; // If the value is still missing in destruction time throw an error. ~AttrInitEntry() DMLC_THROW_EXCEPTION { if (value_missing_) { std::ostringstream os; os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization"; throw AttrError(os.str()); } } // override fields. // This function sets the lower bound of the attribute TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { if (this->value_missing_) return *this; const T& val = *value_; if (begin > val) { std::ostringstream os; os << type_key_ << "." << key_ << ": " << "value " << val << " is smaller than the lower bound " << begin; throw AttrError(os.str()); } return *this; } // This function sets the upper bound of the attribute TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { if (this->value_missing_) return *this; const T& val = *value_; if (val > end) { std::ostringstream os; os << type_key_ << "." << key_ << ": " << "value " << val << " is bigger than the upper bound " << end; throw AttrError(os.str()); } return *this; } // set default when TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) { if (!value_missing_) return *this; *value_ = value; value_missing_ = false; return *this; } TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } }; // Template function to allow smart conversion // from Expr types into the constants. template<typename T> inline void SetValue(T* ptr, const TVMArgValue& val) { *ptr = val.operator T(); } template<typename T> inline void SetIntValue(T* ptr, const TVMArgValue& val) { if (val.type_code() == kDLInt) { *ptr = static_cast<T>(val.value().v_int64); } else { IntImm expr = val; *ptr = static_cast<T>(expr->value); } } template<> inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) { if (val.type_code() == kTVMStr) { *ptr = val.operator std::string(); } else { LOG(FATAL) << "Expect str"; } } template<> inline void SetValue<double>(double* ptr, const TVMArgValue& val) { if (val.type_code() == kDLFloat || val.type_code() == kDLInt) { *ptr = val.operator double(); } else { ObjectRef expr = val; CHECK(expr.defined()); if (const IntImmNode* op = expr.as<IntImmNode>()) { *ptr = static_cast<double>(op->value); } else if (const FloatImmNode* op = expr.as<FloatImmNode>()) { *ptr = static_cast<double>(op->value); } else { LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey(); } } } template<> inline void SetValue<int>(int* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } template<> inline void SetValue<int64_t>(int64_t* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } template<> inline void SetValue<uint64_t>(uint64_t* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } template<> inline void SetValue<bool>(bool* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } // Visitor for value initialization template<typename FFind> class AttrInitVisitor { public: // Counter of number of matched attributes during visit. // This is used to decide if there is additional unmatched attributes. size_t hit_count_{0}; // constructor AttrInitVisitor(const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) { } template<typename T> AttrInitEntry<T> operator()(const char* key, T* value) { TVMArgValue val; AttrInitEntry<T> opt; opt.type_key_ = type_key_; opt.key_ = key; opt.value_ = value; if (ffind_(key, &val)) { SetValue(value, val); opt.value_missing_ = false; ++hit_count_; } else { opt.value_missing_ = true; } return opt; } private: // the type key const char* type_key_; FFind ffind_; }; template<typename FFind> inline AttrInitVisitor<FFind> CreateInitVisitor( const char* type_key, FFind ffind) { return AttrInitVisitor<FFind>(type_key, ffind); } /*! * \brief Helper struct to get the type name known to tvm. * \tparam T the type we are interested in. */ template<typename T> struct TypeName { static constexpr const char* value = T::ContainerType::_type_key; }; template<> struct TypeName<int> { static constexpr const char* value = "int"; }; template<> struct TypeName<int64_t> { static constexpr const char* value = "int64"; }; template<> struct TypeName<uint64_t> { static constexpr const char* value = "uint64_t"; }; template<> struct TypeName<DataType> { static constexpr const char* value = "DataType"; }; template<> struct TypeName<std::string> { static constexpr const char* value = "str"; }; template<> struct TypeName<bool> { static constexpr const char* value = "bool"; }; template<> struct TypeName<void*> { static constexpr const char* value = "handle"; }; template<> struct TypeName<double> { static constexpr const char* value = "double"; }; class AttrDocEntry { public: using TSelf = AttrDocEntry; explicit AttrDocEntry(ObjectPtr<AttrFieldInfoNode> info) : info_(info) { } TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { info_->description = str; return *this; } template<typename T> TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) { std::ostringstream os; os << info_->type_info << ", default=" << value; info_->type_info = os.str(); return *this; } template<typename T> TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) { return *this; } template<typename T> TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) { return *this; } private: ObjectPtr<AttrFieldInfoNode> info_; }; class AttrDocVisitor { public: template<typename T> AttrDocEntry operator()(const char* key, T* v) { ObjectPtr<AttrFieldInfoNode> info = make_object<AttrFieldInfoNode>(); info->name = key; info->type_info = TypeName<T>::value; fields_.push_back(AttrFieldInfo(info)); return AttrDocEntry(info); } Array<AttrFieldInfo> fields_; }; class AttrExistVisitor { public: std::string key_; bool exist_{false}; template<typename T> AttrNopEntry operator()(const char* key, T* v) { if (exist_) return AttrNopEntry(); if (key == key_) exist_ = true; return AttrNopEntry(); } }; template<typename T> struct AttrTriggerNonDefaultEntry { using TSelf = AttrTriggerNonDefaultEntry<T>; // constructor AttrTriggerNonDefaultEntry( AttrVisitor* visitor, const char* key, T* data) : visitor_(visitor), key_(key), data_(data) {} ~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION { if (trigger_) { visitor_->Visit(key_, data_); } } TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } TSelf& set_default(const T& value) { if (tvm::StructuralEqual()(value, *data_)) { trigger_ = false; } return *this; } TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; } TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; } private: AttrVisitor* visitor_; const char * key_; T *data_; bool trigger_{true}; }; class AttrNonDefaultVisitor { public: explicit AttrNonDefaultVisitor(AttrVisitor* visitor) : visitor_(visitor) { } template<typename T> AttrTriggerNonDefaultEntry<T> operator()(const char* key, T* value) { return AttrTriggerNonDefaultEntry<T>(visitor_, key, value); } private: AttrVisitor* visitor_; }; } // namespace detail /*! * \brief The base class of the all the * Use "curiously recurring template pattern". * * \tparam DerivedType The final attribute type. */ template<typename DerivedType> class AttrsNode : public BaseAttrsNode { public: void VisitAttrs(AttrVisitor* v) { ::tvm::detail::AttrNormalVisitor vis(v); self()->__VisitAttrs__(vis); } void VisitNonDefaultAttrs(AttrVisitor* v) { ::tvm::detail::AttrNonDefaultVisitor vis(v); self()->__VisitAttrs__(vis); } void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final { CHECK_EQ(args.size() % 2, 0); const int kLinearSearchBound = 16; int hit_count = 0; // applies two stratgies to lookup if (args.size() < kLinearSearchBound) { // linear search. auto ffind = [&args](const char* key, runtime::TVMArgValue* val) { for (int i = 0; i < args.size(); i += 2) { CHECK_EQ(args.type_codes[i], kTVMStr); if (!std::strcmp(key, args.values[i].v_str)) { *val = args[i + 1]; return true; } } return false; }; auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind); self()->__VisitAttrs__(vis); hit_count = vis.hit_count_; } else { // construct a map then do lookup. std::unordered_map<std::string, runtime::TVMArgValue> kwargs; for (int i = 0; i < args.size(); i += 2) { CHECK_EQ(args.type_codes[i], kTVMStr); kwargs[args[i].operator std::string()] = args[i + 1]; } auto ffind = [&kwargs](const char *key, runtime::TVMArgValue* val) { auto it = kwargs.find(key); if (it != kwargs.end()) { *val = it->second; return true; } return false; }; auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind); self()->__VisitAttrs__(vis); hit_count = vis.hit_count_; } // error handling, slow path if (hit_count * 2 != args.size() && !allow_unknown) { for (int i = 0; i < args.size(); i += 2) { ::tvm::detail::AttrExistVisitor visitor; visitor.key_ = args[i].operator std::string(); self()->__VisitAttrs__(visitor); if (!visitor.exist_) { std::ostringstream os; os << DerivedType::_type_key << ": does not have field \'" << visitor.key_ << "\', Possible fields:\n"; os << "----------------\n"; this->PrintDocString(os); throw AttrError(os.str()); } } } } bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const { DerivedType* pself = self(); ::tvm::detail::AttrsSEqualVisitor visitor(pself, other, equal); self()->__VisitAttrs__(visitor); return visitor.result_; } void SHashReduce(SHashReducer hash_reducer) const { ::tvm::detail::AttrsSHashVisitor visitor(hash_reducer); self()->__VisitAttrs__(visitor); } Array<AttrFieldInfo> ListFieldInfo() const final { ::tvm::detail::AttrDocVisitor visitor; self()->__VisitAttrs__(visitor); return visitor.fields_; } private: DerivedType* self() const { return const_cast<DerivedType*>( static_cast<const DerivedType*>(this)); } }; template<typename... Args> inline void BaseAttrsNode::InitBySeq(Args&& ...args) { runtime::PackedFunc pf([this](const TVMArgs& args, TVMRetValue *rv) { this->InitByPackedArgs(args); }); pf(std::forward<Args>(args)...); } inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*) Array<AttrFieldInfo> entry = this->ListFieldInfo(); for (AttrFieldInfo info : entry) { os << info->name << " : " << info->type_info << '\n'; if (info->description.length() != 0) { os << " " << info->description << '\n'; } } } } // namespace tvm #endif // TVM_IR_ATTRS_H_