/*! * Copyright (c) 2018 by Contributors * \file tvm/attrs.h * \brief TVM attribute module * * This module enables declaration of named attributes * which support default value setup and bound checking. * * \code * struct MyAttrs : public tvm::AttrsNode<MyAttrs> { * float learning_rate; * int num_hidden; * std::string name; * // declare attribute fields in header file * TVM_DECLARE_ATTRS(MyAttrs, "attrs.MyAttrs") { * TVM_ATTR_FIELD(num_hidden).set_lower_bound(1); * TVM_ATTR_FIELD(learning_rate).set_default(0.01f); * TVM_ATTR_FIELD(name).set_default("hello"); * } * }; * // register it in cc file * TVM_REGISTER_NODE_TYPE(MyAttrs); * \endcode * * \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD */ #ifndef TVM_ATTRS_H_ #define TVM_ATTRS_H_ #include <unordered_map> #include <vector> #include <type_traits> #include <string> #include "ir.h" #include "base.h" #include "packed_func_ext.h" namespace tvm { /*! * \brief Declare an attribute function. * \param ClassName The name of the class. * \param TypeKey The type key to be used by the TVM node system. */ #define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ static constexpr const char* _type_key = TypeKey; \ TVM_DECLARE_NODE_TYPE_INFO(ClassName, ::tvm::BaseAttrsNode); \ template<typename FVisit> \ void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*) /*! * \brief Declare an attribute field. * \param FieldName The field name. */ #define TVM_ATTR_FIELD(FieldName) \ __fvisit__(#FieldName, &FieldName) /*! * \brief Create a NodeRef type that represents null. * \tparam TNodeRef the type to be created. * \return A instance that will represent None. */ template<typename TNodeRef> inline TNodeRef NullValue() { return TNodeRef(NodePtr<Node>(nullptr)); } template<> inline Type NullValue<Type>() { return Type(Type::Handle, 0, 0); } /*! \brief Error thrown during attribute checking. */ struct AttrError : public dmlc::Error { /*! * \brief constructor * \param msg error message */ explicit AttrError(const std::string &msg) : dmlc::Error(msg) {} }; /*! * \brief Information about attribute fields in string representations. */ class AttrFieldInfoNode : public Node { public: /*! \brief name of the field */ std::string name; /*! \brief type docstring information in str. */ std::string type_info; /*! \brief detailed description of the type */ std::string description; void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); v->Visit("type_info", &type_info); v->Visit("description", &description); } static constexpr const char* _type_key = "AttrFieldInfo"; TVM_DECLARE_NODE_TYPE_INFO(AttrFieldInfoNode, Node); }; /*! \brief AttrFieldInfo */ TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode); /*! * \brief Base class of all attribute class * \note Do not subclass AttrBaseNode directly, * subclass AttrsNode instead. * \sa AttrsNode */ class BaseAttrsNode : public Node { public: using TVMArgs = runtime::TVMArgs; using TVMRetValue = runtime::TVMRetValue; /*! * \brief Initialize the attributes by sequence of arguments * \param args The postional arguments in the form * [key0, value0, key1, value1, ..., key_n, value_n] */ template<typename... Args> inline void InitBySeq(Args&& ...args); /*! * \brief Print readible docstring to ostream, add newline. * \param os the stream to print the docstring to. */ inline void PrintDocString(std::ostream &os) const; // NOLINT(*) /*! * \brief Get the field information about the * \note This function throws when the required a field is not present. */ TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0; /*! * \brief Initialize the attributes by arguments. * \param kwargs The key value pairs for initialization. * [key0, value0, key1, value1, ..., key_n, value_n] * \param allow_unknown Whether allow additional unknown fields. * \note This function throws when the required a field is not present. */ TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0; static constexpr const char* _type_key = "Attrs"; TVM_DECLARE_BASE_NODE_INFO(BaseAttrsNode, Node); }; /*! \brief Base attribute container for all attributes */ class Attrs : public NodeRef { public: // normal constructor Attrs() {} // construct from shared ptr. explicit Attrs(NodePtr<Node> n) : NodeRef(n) {} /*! \return The attribute node */ const BaseAttrsNode* operator->() const { return ptr(); } /*! \brief specify container node */ using ContainerType = BaseAttrsNode; private: /*! \return the internal attribute node */ const BaseAttrsNode* ptr() const { return static_cast<const BaseAttrsNode*>(node_.get()); } }; /*! * \brief Specialized attribute type that is backed by a map. * The DictAttrsNode implements the Attrs behavior, * its fields are directly accessible via object.field_name * like other normal nodes. */ class DictAttrsNode : public BaseAttrsNode { public: /*! \brief internal attrs map */ Map<std::string, NodeRef> dict; /*! * \brief Consruct a Attrs backed by DictAttrsNode. * \param dict The attributes. * \return The dict attributes. */ TVM_DLL static Attrs make(Map<std::string, NodeRef> dict); // implementations void VisitAttrs(AttrVisitor* v) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; Array<AttrFieldInfo> ListFieldInfo() const final; // type info static constexpr const char* _type_key = "DictAttrs"; TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode); }; // Namespace containing detail implementations namespace detail { using runtime::TVMArgValue; // helper entry that does nothing in set_default/bound/describe calls. struct AttrNopEntry { using TSelf = AttrNopEntry; TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } template<typename T> TSelf& set_default(DMLC_ATTRIBUTE_UNUSED T value) { return *this; } template<typename T> TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) { return *this; } template<typename T> TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) { return *this; } }; // Wrapper for normal visitor. class AttrNormalVisitor { public: explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) { } template<typename T> AttrNopEntry operator()(const char* key, T* value) { visitor_->Visit(key, value); return AttrNopEntry(); } private: AttrVisitor* visitor_; }; // helper entry that does initialization, set default. template<typename T> struct AttrInitEntry { // The attributes using TSelf = AttrInitEntry<T>; // The type key const char* type_key_; // field name const char* key_; // internal value. T* value_; // whether the value is missing. bool value_missing_{true}; // If the value is still missing in destruction time throw an error. ~AttrInitEntry() DMLC_THROW_EXCEPTION { if (value_missing_) { std::ostringstream os; os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization"; throw AttrError(os.str()); } } // override fields. // This function sets the lower bound of the attribute TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { if (this->value_missing_) return *this; const T& val = *value_; if (begin > val) { std::ostringstream os; os << type_key_ << "." << key_ << ": " << "value " << val << " is smaller than the lower bound " << begin; throw AttrError(os.str()); } return *this; } // This function sets the upper bound of the attribute TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { if (this->value_missing_) return *this; const T& val = *value_; if (val > end) { std::ostringstream os; os << type_key_ << "." << key_ << ": " << "value " << val << " is bigger than the upper bound " << end; throw AttrError(os.str()); } return *this; } // set default when TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) { if (!value_missing_) return *this; *value_ = value; value_missing_ = false; return *this; } TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } }; // Template function to allow smart conversion // from Expr types into the constants. template<typename T> inline void SetValue(T* ptr, const TVMArgValue& val) { *ptr = val.operator T(); } template<typename T> inline void SetIntValue(T* ptr, const TVMArgValue& val) { if (val.type_code() == kDLInt) { *ptr = static_cast<T>(val.value().v_int64); } else { Expr expr = val; CHECK(expr.defined()); if (const ir::IntImm* op = expr.as<ir::IntImm>()) { *ptr = static_cast<T>(op->value); } else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) { *ptr = static_cast<T>(op->value); } else { LOG(FATAL) << "Expect int value, but get " << expr->type_key(); } } } template<> inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) { if (val.type_code() == kStr) { *ptr = val.operator std::string(); } else { Expr expr = val; const ir::StringImm* op = expr.as<ir::StringImm>(); CHECK(op != nullptr); *ptr = op->value; } } template<> inline void SetValue(Type* ptr, const TVMArgValue& val) { *ptr = val.operator Type(); } template<> inline void SetValue<double>(double* ptr, const TVMArgValue& val) { if (val.type_code() == kDLFloat || val.type_code() == kDLInt) { *ptr = val.operator double(); } else { Expr expr = val; CHECK(expr.defined()); if (const ir::IntImm* op = expr.as<ir::IntImm>()) { *ptr = static_cast<double>(op->value); } else if (const ir::IntImm* op = expr.as<ir::IntImm>()) { *ptr = static_cast<double>(op->value); } else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) { *ptr = static_cast<double>(op->value); } else { LOG(FATAL) << "Expect float value, but get " << expr->type_key(); } } } template<> inline void SetValue<int>(int* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } template<> inline void SetValue<int64_t>(int64_t* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } template<> inline void SetValue<uint64_t>(uint64_t* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } template<> inline void SetValue<bool>(bool* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } // Visitor for value initialization template<typename FFind> class AttrInitVisitor { public: // Counter of number of matched attributes during visit. // This is used to decide if there is additional unmatched attributes. size_t hit_count_{0}; // constructor AttrInitVisitor(const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) { } template<typename T> AttrInitEntry<T> operator()(const char* key, T* value) { TVMArgValue val; AttrInitEntry<T> opt; opt.type_key_ = type_key_; opt.key_ = key; opt.value_ = value; if (ffind_(key, &val)) { SetValue(value, val); opt.value_missing_ = false; ++hit_count_; } else { opt.value_missing_ = true; } return opt; } private: // the type key const char* type_key_; FFind ffind_; }; template<typename FFind> inline AttrInitVisitor<FFind> CreateInitVisitor( const char* type_key, FFind ffind) { return AttrInitVisitor<FFind>(type_key, ffind); } /*! * \brief Helper struct to get the type name known to tvm. * \tparam T the type we are interested in. */ template<typename T> struct TypeName { static constexpr const char* value = T::ContainerType::_type_key; }; template<> struct TypeName<int> { static constexpr const char* value = "int"; }; template<> struct TypeName<int64_t> { static constexpr const char* value = "int64"; }; template<> struct TypeName<uint64_t> { static constexpr const char* value = "uint64_t"; }; template<> struct TypeName<Type> { static constexpr const char* value = "Type"; }; template<> struct TypeName<std::string> { static constexpr const char* value = "str"; }; template<> struct TypeName<bool> { static constexpr const char* value = "bool"; }; template<> struct TypeName<void*> { static constexpr const char* value = "handle"; }; template<> struct TypeName<double> { static constexpr const char* value = "double"; }; class AttrDocEntry { public: using TSelf = AttrDocEntry; explicit AttrDocEntry(NodePtr<AttrFieldInfoNode> info) : info_(info) { } TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { info_->description = str; return *this; } template<typename T> TSelf& set_default(DMLC_ATTRIBUTE_UNUSED T value) { std::ostringstream os; os << info_->type_info << ", default=" << value; info_->type_info = os.str(); return *this; } template<typename T> TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) { return *this; } template<typename T> TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) { return *this; } private: NodePtr<AttrFieldInfoNode> info_; }; class AttrDocVisitor { public: template<typename T> AttrDocEntry operator()(const char* key, T* v) { NodePtr<AttrFieldInfoNode> info = make_node<AttrFieldInfoNode>(); info->name = key; info->type_info = TypeName<T>::value; fields_.push_back(AttrFieldInfo(info)); return AttrDocEntry(info); } Array<AttrFieldInfo> fields_; }; class AttrExistVisitor { public: std::string key_; bool exist_{false}; template<typename T> AttrNopEntry operator()(const char* key, T* v) { if (exist_) return AttrNopEntry(); if (key == key_) exist_ = true; return AttrNopEntry(); } }; } // namespace detail /*! * \brief The base class of the all the * Use "curiously recurring template pattern". * * \tparam DerivedType The final attribute type. */ template<typename DerivedType> class AttrsNode : public BaseAttrsNode { public: void VisitAttrs(AttrVisitor* v) final { detail::AttrNormalVisitor vis(v); self()->__VisitAttrs__(vis); } void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final { CHECK_EQ(args.size() % 2, 0); const int kLinearSearchBound = 16; int hit_count = 0; // applies two stratgies to lookup if (args.size() < kLinearSearchBound) { // linear search. auto ffind = [&args](const char* key, runtime::TVMArgValue* val) { for (int i = 0; i < args.size(); i += 2) { CHECK_EQ(args.type_codes[i], kStr); if (!std::strcmp(key, args.values[i].v_str)) { *val = args[i + 1]; return true; } } return false; }; auto vis = detail::CreateInitVisitor(DerivedType::_type_key, ffind); self()->__VisitAttrs__(vis); hit_count = vis.hit_count_; } else { // construct a map then do lookup. std::unordered_map<std::string, runtime::TVMArgValue> kwargs; for (int i = 0; i < args.size(); i += 2) { CHECK_EQ(args.type_codes[i], kStr); kwargs[args[i].operator std::string()] = args[i + 1]; } auto ffind = [&kwargs](const char *key, runtime::TVMArgValue* val) { auto it = kwargs.find(key); if (it != kwargs.end()) { *val = it->second; return true; } return false; }; auto vis = detail::CreateInitVisitor(DerivedType::_type_key, ffind); self()->__VisitAttrs__(vis); hit_count = vis.hit_count_; } // error handling, slow path if (hit_count * 2 != args.size() && !allow_unknown) { for (int i = 0; i < args.size(); i += 2) { detail::AttrExistVisitor visitor; visitor.key_ = args[i].operator std::string(); self()->__VisitAttrs__(visitor); if (!visitor.exist_) { std::ostringstream os; os << DerivedType::_type_key << ": does not have field \'" << visitor.key_ << "\', Possible fields:\n"; os << "----------------\n"; this->PrintDocString(os); throw AttrError(os.str()); } } } } Array<AttrFieldInfo> ListFieldInfo() const final { detail::AttrDocVisitor visitor; self()->__VisitAttrs__(visitor); return visitor.fields_; } private: DerivedType* self() const { return const_cast<DerivedType*>( static_cast<const DerivedType*>(this)); } }; template<typename... Args> inline void BaseAttrsNode::InitBySeq(Args&& ...args) { runtime::PackedFunc pf([this](const TVMArgs& args, TVMRetValue *rv) { this->InitByPackedArgs(args); }); pf(std::forward<Args>(args)...); } inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*) Array<AttrFieldInfo> entry = this->ListFieldInfo(); for (AttrFieldInfo info : entry) { os << info->name << " : " << info->type_info << '\n'; if (info->description.length() != 0) { os << " " << info->description << '\n'; } } } } // namespace tvm #endif // TVM_ATTRS_H_