Unverified Commit b11f2a04 by Tianqi Chen Committed by GitHub

[ATTRS] change AttrFiledInfo->Node (#1634)

parent d060e919
......@@ -69,15 +69,27 @@ struct AttrError : public dmlc::Error {
/*!
* \brief Information about attribute fields in string representations.
*/
struct AttrFieldInfo {
class AttrFieldInfoNode : public Node {
public:
/*! \brief name of the field */
std::string name;
/*! \brief type docstring information in str. */
std::string type_info;
/*! \brief detailed description of the type */
std::string description;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("type_info", &type_info);
v->Visit("description", &description);
}
static constexpr const char* _type_key = "AttrFieldInfo";
TVM_DECLARE_NODE_TYPE_INFO(AttrFieldInfoNode, Node);
};
/*! \brief AttrFieldInfo */
TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode);
/*!
* \brief Base class of all attribute class
* \note Do not subclass AttrBaseNode directly,
......@@ -104,7 +116,7 @@ class BaseAttrsNode : public Node {
* \brief Get the field information about the
* \note This function throws when the required a field is not present.
*/
TVM_DLL virtual std::vector<AttrFieldInfo> ListFieldInfo() const = 0;
TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0;
/*!
* \brief Initialize the attributes by arguments.
* \param kwargs The key value pairs for initialization.
......@@ -159,7 +171,7 @@ class DictAttrsNode : public BaseAttrsNode {
// implementations
void VisitAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
std::vector<AttrFieldInfo> ListFieldInfo() const final;
Array<AttrFieldInfo> ListFieldInfo() const final;
// type info
static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode);
......@@ -430,7 +442,7 @@ class AttrDocEntry {
public:
using TSelf = AttrDocEntry;
explicit AttrDocEntry(AttrFieldInfo* info)
explicit AttrDocEntry(std::shared_ptr<AttrFieldInfoNode> info)
: info_(info) {
}
TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
......@@ -454,21 +466,22 @@ class AttrDocEntry {
}
private:
AttrFieldInfo* info_;
std::shared_ptr<AttrFieldInfoNode> info_;
};
class AttrDocVisitor {
public:
template<typename T>
AttrDocEntry operator()(const char* key, T* v) {
AttrFieldInfo info;
info.name = key;
info.type_info = TypeName<T>::value;
fields_.emplace_back(std::move(info));
return AttrDocEntry(&(fields_.back()));
std::shared_ptr<AttrFieldInfoNode> info
= std::make_shared<AttrFieldInfoNode>();
info->name = key;
info->type_info = TypeName<T>::value;
fields_.push_back(AttrFieldInfo(info));
return AttrDocEntry(info);
}
std::vector<AttrFieldInfo> fields_;
Array<AttrFieldInfo> fields_;
};
class AttrExistVisitor {
......@@ -557,7 +570,7 @@ class AttrsNode : public BaseAttrsNode {
}
}
std::vector<AttrFieldInfo> ListFieldInfo() const final {
Array<AttrFieldInfo> ListFieldInfo() const final {
detail::AttrDocVisitor visitor;
self()->__VisitAttrs__(visitor);
return visitor.fields_;
......@@ -580,11 +593,11 @@ inline void BaseAttrsNode::InitBySeq(Args&& ...args) {
}
inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*)
std::vector<AttrFieldInfo> entry = this->ListFieldInfo();
Array<AttrFieldInfo> entry = this->ListFieldInfo();
for (AttrFieldInfo info : entry) {
os << info.name << " : " << info.type_info << '\n';
if (info.description.length() != 0) {
os << " " << info.description << '\n';
os << info->name << " : " << info->type_info << '\n';
if (info->description.length() != 0) {
os << " " << info->description << '\n';
}
}
}
......
......@@ -25,7 +25,7 @@ void DictAttrsNode::InitByPackedArgs(
}
}
std::vector<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
return {};
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment