From b11f2a0495541cb348ae89093fd233d78eefec6e Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Wed, 22 Aug 2018 11:18:05 -0700
Subject: [PATCH] [ATTRS] change AttrFiledInfo->Node (#1634)

---
 include/tvm/attrs.h | 45 +++++++++++++++++++++++++++++----------------
 src/lang/attrs.cc   |  2 +-
 2 files changed, 30 insertions(+), 17 deletions(-)

diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h
index aed6b1ff7..3e5169ba0 100644
--- a/include/tvm/attrs.h
+++ b/include/tvm/attrs.h
@@ -69,15 +69,27 @@ struct AttrError : public dmlc::Error {
 /*!
  * \brief Information about attribute fields in string representations.
  */
-struct AttrFieldInfo {
+class AttrFieldInfoNode : public Node {
+ public:
   /*! \brief name of the field */
   std::string name;
   /*! \brief type docstring information in str. */
   std::string type_info;
   /*! \brief detailed description of the type */
   std::string description;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("name", &name);
+    v->Visit("type_info", &type_info);
+    v->Visit("description", &description);
+  }
+  static constexpr const char* _type_key = "AttrFieldInfo";
+  TVM_DECLARE_NODE_TYPE_INFO(AttrFieldInfoNode, Node);
 };
 
+/*! \brief AttrFieldInfo */
+TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode);
+
 /*!
  * \brief Base class of all attribute class
  * \note Do not subclass AttrBaseNode directly,
@@ -104,7 +116,7 @@ class BaseAttrsNode : public Node {
    * \brief Get the field information about the
    * \note This function throws when the required a field is not present.
    */
-  TVM_DLL virtual std::vector<AttrFieldInfo> ListFieldInfo() const = 0;
+  TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0;
   /*!
    * \brief Initialize the attributes by arguments.
    * \param kwargs The key value pairs for initialization.
@@ -159,7 +171,7 @@ class DictAttrsNode : public BaseAttrsNode {
   // implementations
   void VisitAttrs(AttrVisitor* v) final;
   void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
-  std::vector<AttrFieldInfo> ListFieldInfo() const final;
+  Array<AttrFieldInfo> ListFieldInfo() const final;
   // type info
   static constexpr const char* _type_key = "DictAttrs";
   TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode);
@@ -430,7 +442,7 @@ class AttrDocEntry {
  public:
   using TSelf = AttrDocEntry;
 
-  explicit AttrDocEntry(AttrFieldInfo* info)
+  explicit AttrDocEntry(std::shared_ptr<AttrFieldInfoNode> info)
       : info_(info) {
   }
   TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
@@ -454,21 +466,22 @@ class AttrDocEntry {
   }
 
  private:
-  AttrFieldInfo* info_;
+  std::shared_ptr<AttrFieldInfoNode> info_;
 };
 
 class AttrDocVisitor {
  public:
   template<typename T>
   AttrDocEntry operator()(const char* key, T* v) {
-    AttrFieldInfo info;
-    info.name = key;
-    info.type_info = TypeName<T>::value;
-    fields_.emplace_back(std::move(info));
-    return AttrDocEntry(&(fields_.back()));
+    std::shared_ptr<AttrFieldInfoNode> info
+        = std::make_shared<AttrFieldInfoNode>();
+    info->name = key;
+    info->type_info = TypeName<T>::value;
+    fields_.push_back(AttrFieldInfo(info));
+    return AttrDocEntry(info);
   }
 
-  std::vector<AttrFieldInfo> fields_;
+  Array<AttrFieldInfo> fields_;
 };
 
 class AttrExistVisitor {
@@ -557,7 +570,7 @@ class AttrsNode : public BaseAttrsNode {
     }
   }
 
-  std::vector<AttrFieldInfo> ListFieldInfo() const final {
+  Array<AttrFieldInfo> ListFieldInfo() const final {
     detail::AttrDocVisitor visitor;
     self()->__VisitAttrs__(visitor);
     return visitor.fields_;
@@ -580,11 +593,11 @@ inline void BaseAttrsNode::InitBySeq(Args&& ...args) {
 }
 
 inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*)
-  std::vector<AttrFieldInfo> entry = this->ListFieldInfo();
+  Array<AttrFieldInfo> entry = this->ListFieldInfo();
   for (AttrFieldInfo info : entry) {
-    os << info.name << " : " << info.type_info << '\n';
-    if (info.description.length() != 0) {
-      os << "    " << info.description << '\n';
+    os << info->name << " : " << info->type_info << '\n';
+    if (info->description.length() != 0) {
+      os << "    " << info->description << '\n';
     }
   }
 }
diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc
index 49a91983e..0d8d1f3c9 100644
--- a/src/lang/attrs.cc
+++ b/src/lang/attrs.cc
@@ -25,7 +25,7 @@ void DictAttrsNode::InitByPackedArgs(
   }
 }
 
-std::vector<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
+Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
   return {};
 }
 
-- 
GitLab