Unverified Commit b11f2a04 by Tianqi Chen Committed by GitHub

[ATTRS] change AttrFiledInfo->Node (#1634)

parent d060e919
...@@ -69,15 +69,27 @@ struct AttrError : public dmlc::Error { ...@@ -69,15 +69,27 @@ struct AttrError : public dmlc::Error {
/*! /*!
* \brief Information about attribute fields in string representations. * \brief Information about attribute fields in string representations.
*/ */
struct AttrFieldInfo { class AttrFieldInfoNode : public Node {
public:
/*! \brief name of the field */ /*! \brief name of the field */
std::string name; std::string name;
/*! \brief type docstring information in str. */ /*! \brief type docstring information in str. */
std::string type_info; std::string type_info;
/*! \brief detailed description of the type */ /*! \brief detailed description of the type */
std::string description; std::string description;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("type_info", &type_info);
v->Visit("description", &description);
}
static constexpr const char* _type_key = "AttrFieldInfo";
TVM_DECLARE_NODE_TYPE_INFO(AttrFieldInfoNode, Node);
}; };
/*! \brief AttrFieldInfo */
TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode);
/*! /*!
* \brief Base class of all attribute class * \brief Base class of all attribute class
* \note Do not subclass AttrBaseNode directly, * \note Do not subclass AttrBaseNode directly,
...@@ -104,7 +116,7 @@ class BaseAttrsNode : public Node { ...@@ -104,7 +116,7 @@ class BaseAttrsNode : public Node {
* \brief Get the field information about the * \brief Get the field information about the
* \note This function throws when the required a field is not present. * \note This function throws when the required a field is not present.
*/ */
TVM_DLL virtual std::vector<AttrFieldInfo> ListFieldInfo() const = 0; TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0;
/*! /*!
* \brief Initialize the attributes by arguments. * \brief Initialize the attributes by arguments.
* \param kwargs The key value pairs for initialization. * \param kwargs The key value pairs for initialization.
...@@ -159,7 +171,7 @@ class DictAttrsNode : public BaseAttrsNode { ...@@ -159,7 +171,7 @@ class DictAttrsNode : public BaseAttrsNode {
// implementations // implementations
void VisitAttrs(AttrVisitor* v) final; void VisitAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
std::vector<AttrFieldInfo> ListFieldInfo() const final; Array<AttrFieldInfo> ListFieldInfo() const final;
// type info // type info
static constexpr const char* _type_key = "DictAttrs"; static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode); TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode);
...@@ -430,7 +442,7 @@ class AttrDocEntry { ...@@ -430,7 +442,7 @@ class AttrDocEntry {
public: public:
using TSelf = AttrDocEntry; using TSelf = AttrDocEntry;
explicit AttrDocEntry(AttrFieldInfo* info) explicit AttrDocEntry(std::shared_ptr<AttrFieldInfoNode> info)
: info_(info) { : info_(info) {
} }
TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
...@@ -454,21 +466,22 @@ class AttrDocEntry { ...@@ -454,21 +466,22 @@ class AttrDocEntry {
} }
private: private:
AttrFieldInfo* info_; std::shared_ptr<AttrFieldInfoNode> info_;
}; };
class AttrDocVisitor { class AttrDocVisitor {
public: public:
template<typename T> template<typename T>
AttrDocEntry operator()(const char* key, T* v) { AttrDocEntry operator()(const char* key, T* v) {
AttrFieldInfo info; std::shared_ptr<AttrFieldInfoNode> info
info.name = key; = std::make_shared<AttrFieldInfoNode>();
info.type_info = TypeName<T>::value; info->name = key;
fields_.emplace_back(std::move(info)); info->type_info = TypeName<T>::value;
return AttrDocEntry(&(fields_.back())); fields_.push_back(AttrFieldInfo(info));
return AttrDocEntry(info);
} }
std::vector<AttrFieldInfo> fields_; Array<AttrFieldInfo> fields_;
}; };
class AttrExistVisitor { class AttrExistVisitor {
...@@ -557,7 +570,7 @@ class AttrsNode : public BaseAttrsNode { ...@@ -557,7 +570,7 @@ class AttrsNode : public BaseAttrsNode {
} }
} }
std::vector<AttrFieldInfo> ListFieldInfo() const final { Array<AttrFieldInfo> ListFieldInfo() const final {
detail::AttrDocVisitor visitor; detail::AttrDocVisitor visitor;
self()->__VisitAttrs__(visitor); self()->__VisitAttrs__(visitor);
return visitor.fields_; return visitor.fields_;
...@@ -580,11 +593,11 @@ inline void BaseAttrsNode::InitBySeq(Args&& ...args) { ...@@ -580,11 +593,11 @@ inline void BaseAttrsNode::InitBySeq(Args&& ...args) {
} }
inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*) inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*)
std::vector<AttrFieldInfo> entry = this->ListFieldInfo(); Array<AttrFieldInfo> entry = this->ListFieldInfo();
for (AttrFieldInfo info : entry) { for (AttrFieldInfo info : entry) {
os << info.name << " : " << info.type_info << '\n'; os << info->name << " : " << info->type_info << '\n';
if (info.description.length() != 0) { if (info->description.length() != 0) {
os << " " << info.description << '\n'; os << " " << info->description << '\n';
} }
} }
} }
......
...@@ -25,7 +25,7 @@ void DictAttrsNode::InitByPackedArgs( ...@@ -25,7 +25,7 @@ void DictAttrsNode::InitByPackedArgs(
} }
} }
std::vector<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const { Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
return {}; return {};
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment