From b11f2a0495541cb348ae89093fd233d78eefec6e Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Wed, 22 Aug 2018 11:18:05 -0700 Subject: [PATCH] [ATTRS] change AttrFiledInfo->Node (#1634) --- include/tvm/attrs.h | 45 +++++++++++++++++++++++++++++---------------- src/lang/attrs.cc | 2 +- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index aed6b1ff7..3e5169ba0 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -69,15 +69,27 @@ struct AttrError : public dmlc::Error { /*! * \brief Information about attribute fields in string representations. */ -struct AttrFieldInfo { +class AttrFieldInfoNode : public Node { + public: /*! \brief name of the field */ std::string name; /*! \brief type docstring information in str. */ std::string type_info; /*! \brief detailed description of the type */ std::string description; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("type_info", &type_info); + v->Visit("description", &description); + } + static constexpr const char* _type_key = "AttrFieldInfo"; + TVM_DECLARE_NODE_TYPE_INFO(AttrFieldInfoNode, Node); }; +/*! \brief AttrFieldInfo */ +TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode); + /*! * \brief Base class of all attribute class * \note Do not subclass AttrBaseNode directly, @@ -104,7 +116,7 @@ class BaseAttrsNode : public Node { * \brief Get the field information about the * \note This function throws when the required a field is not present. */ - TVM_DLL virtual std::vector<AttrFieldInfo> ListFieldInfo() const = 0; + TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0; /*! * \brief Initialize the attributes by arguments. * \param kwargs The key value pairs for initialization. @@ -159,7 +171,7 @@ class DictAttrsNode : public BaseAttrsNode { // implementations void VisitAttrs(AttrVisitor* v) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; - std::vector<AttrFieldInfo> ListFieldInfo() const final; + Array<AttrFieldInfo> ListFieldInfo() const final; // type info static constexpr const char* _type_key = "DictAttrs"; TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode); @@ -430,7 +442,7 @@ class AttrDocEntry { public: using TSelf = AttrDocEntry; - explicit AttrDocEntry(AttrFieldInfo* info) + explicit AttrDocEntry(std::shared_ptr<AttrFieldInfoNode> info) : info_(info) { } TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { @@ -454,21 +466,22 @@ class AttrDocEntry { } private: - AttrFieldInfo* info_; + std::shared_ptr<AttrFieldInfoNode> info_; }; class AttrDocVisitor { public: template<typename T> AttrDocEntry operator()(const char* key, T* v) { - AttrFieldInfo info; - info.name = key; - info.type_info = TypeName<T>::value; - fields_.emplace_back(std::move(info)); - return AttrDocEntry(&(fields_.back())); + std::shared_ptr<AttrFieldInfoNode> info + = std::make_shared<AttrFieldInfoNode>(); + info->name = key; + info->type_info = TypeName<T>::value; + fields_.push_back(AttrFieldInfo(info)); + return AttrDocEntry(info); } - std::vector<AttrFieldInfo> fields_; + Array<AttrFieldInfo> fields_; }; class AttrExistVisitor { @@ -557,7 +570,7 @@ class AttrsNode : public BaseAttrsNode { } } - std::vector<AttrFieldInfo> ListFieldInfo() const final { + Array<AttrFieldInfo> ListFieldInfo() const final { detail::AttrDocVisitor visitor; self()->__VisitAttrs__(visitor); return visitor.fields_; @@ -580,11 +593,11 @@ inline void BaseAttrsNode::InitBySeq(Args&& ...args) { } inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*) - std::vector<AttrFieldInfo> entry = this->ListFieldInfo(); + Array<AttrFieldInfo> entry = this->ListFieldInfo(); for (AttrFieldInfo info : entry) { - os << info.name << " : " << info.type_info << '\n'; - if (info.description.length() != 0) { - os << " " << info.description << '\n'; + os << info->name << " : " << info->type_info << '\n'; + if (info->description.length() != 0) { + os << " " << info->description << '\n'; } } } diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index 49a91983e..0d8d1f3c9 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -25,7 +25,7 @@ void DictAttrsNode::InitByPackedArgs( } } -std::vector<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const { +Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const { return {}; } -- GitLab