diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index aed6b1ff722f853e3215690e11f5f5b30e932fbf..3e5169ba02b880f6d803075a570b7839e5b620ba 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -69,15 +69,27 @@ struct AttrError : public dmlc::Error { /*! * \brief Information about attribute fields in string representations. */ -struct AttrFieldInfo { +class AttrFieldInfoNode : public Node { + public: /*! \brief name of the field */ std::string name; /*! \brief type docstring information in str. */ std::string type_info; /*! \brief detailed description of the type */ std::string description; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("type_info", &type_info); + v->Visit("description", &description); + } + static constexpr const char* _type_key = "AttrFieldInfo"; + TVM_DECLARE_NODE_TYPE_INFO(AttrFieldInfoNode, Node); }; +/*! \brief AttrFieldInfo */ +TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode); + /*! * \brief Base class of all attribute class * \note Do not subclass AttrBaseNode directly, @@ -104,7 +116,7 @@ class BaseAttrsNode : public Node { * \brief Get the field information about the * \note This function throws when the required a field is not present. */ - TVM_DLL virtual std::vector<AttrFieldInfo> ListFieldInfo() const = 0; + TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0; /*! * \brief Initialize the attributes by arguments. * \param kwargs The key value pairs for initialization. @@ -159,7 +171,7 @@ class DictAttrsNode : public BaseAttrsNode { // implementations void VisitAttrs(AttrVisitor* v) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; - std::vector<AttrFieldInfo> ListFieldInfo() const final; + Array<AttrFieldInfo> ListFieldInfo() const final; // type info static constexpr const char* _type_key = "DictAttrs"; TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode); @@ -430,7 +442,7 @@ class AttrDocEntry { public: using TSelf = AttrDocEntry; - explicit AttrDocEntry(AttrFieldInfo* info) + explicit AttrDocEntry(std::shared_ptr<AttrFieldInfoNode> info) : info_(info) { } TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { @@ -454,21 +466,22 @@ class AttrDocEntry { } private: - AttrFieldInfo* info_; + std::shared_ptr<AttrFieldInfoNode> info_; }; class AttrDocVisitor { public: template<typename T> AttrDocEntry operator()(const char* key, T* v) { - AttrFieldInfo info; - info.name = key; - info.type_info = TypeName<T>::value; - fields_.emplace_back(std::move(info)); - return AttrDocEntry(&(fields_.back())); + std::shared_ptr<AttrFieldInfoNode> info + = std::make_shared<AttrFieldInfoNode>(); + info->name = key; + info->type_info = TypeName<T>::value; + fields_.push_back(AttrFieldInfo(info)); + return AttrDocEntry(info); } - std::vector<AttrFieldInfo> fields_; + Array<AttrFieldInfo> fields_; }; class AttrExistVisitor { @@ -557,7 +570,7 @@ class AttrsNode : public BaseAttrsNode { } } - std::vector<AttrFieldInfo> ListFieldInfo() const final { + Array<AttrFieldInfo> ListFieldInfo() const final { detail::AttrDocVisitor visitor; self()->__VisitAttrs__(visitor); return visitor.fields_; @@ -580,11 +593,11 @@ inline void BaseAttrsNode::InitBySeq(Args&& ...args) { } inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*) - std::vector<AttrFieldInfo> entry = this->ListFieldInfo(); + Array<AttrFieldInfo> entry = this->ListFieldInfo(); for (AttrFieldInfo info : entry) { - os << info.name << " : " << info.type_info << '\n'; - if (info.description.length() != 0) { - os << " " << info.description << '\n'; + os << info->name << " : " << info->type_info << '\n'; + if (info->description.length() != 0) { + os << " " << info->description << '\n'; } } } diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index 49a91983e79d02e7cba074aec20df9a6a47737d4..0d8d1f3c9ece1163e44a3be766204c437b997375 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -25,7 +25,7 @@ void DictAttrsNode::InitByPackedArgs( } } -std::vector<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const { +Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const { return {}; }