From 558cf098c0d4f80f8539ec2e1de9b4ce50a91916 Mon Sep 17 00:00:00 2001
From: Da Zheng <zhengda1936@gmail.com>
Date: Tue, 12 Jun 2018 15:46:48 -0700
Subject: [PATCH] add support for subgraphs. (#1221)

* add support for subgraphs.

* fix.

* fix.

* Fix compilation error

* Fix compilation error

* add comments.

* update comments.

* Sanity check on subgraphs when creating IndexedGraph

* avoid the overhead of sanity check.

* Stop using non-recursive DFS

* Trigger CI

* trigger CI
---
 nnvm/include/nnvm/node.h          |  16 ++++
 nnvm/include/nnvm/op_attr_types.h |  12 +++
 nnvm/src/core/graph.cc            |  40 ++++++++-
 nnvm/src/core/symbolic.cc         |  81 +++++++++++++----
 nnvm/src/pass/saveload_json.cc    | 144 +++++++++++++++++++-----------
 5 files changed, 219 insertions(+), 74 deletions(-)

diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h
index 15db77ee6..57afb0c55 100644
--- a/nnvm/include/nnvm/node.h
+++ b/nnvm/include/nnvm/node.h
@@ -18,6 +18,7 @@ namespace nnvm {
 
 // Forward declare node.
 class Node;
+class Symbol;
 
 /*!
  * \brief we always used NodePtr for a reference pointer
@@ -90,6 +91,21 @@ struct NodeAttrs {
    * The object can be used to quickly access attributes.
    */
   any parsed;
+  /*!
+   * \brief Some operators take graphs as input. These operators include
+   * control flow operators and high-order functions.
+   * These graphs don't change when the operators are invoked for different
+   * mini-batches. In this sense, the subgraphs are kind of similar to
+   * the parameters and show be kept as node attributes.
+   *
+   * Users need to make sure the subgraphs are disjoint with the main graph.
+   * If a graph shares nodes with subgraphs, loading the graph from LoadJSON
+   * may generate a graph that has a different structure from the original graph
+   * (some of the nodes are duplicated). If nodes are shared between two graphs,
+   * shared nodes might be executed multiple times, which can be a problem for
+   * stateful operators.
+   */
+  std::vector<std::shared_ptr<Symbol> > subgraphs;
 };
 
 /*!
diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h
index e58e9ceb3..b7f6be408 100644
--- a/nnvm/include/nnvm/op_attr_types.h
+++ b/nnvm/include/nnvm/op_attr_types.h
@@ -202,6 +202,18 @@ using FCorrectLayout = std::function<bool(
     const std::vector<Layout> *last_ilayouts,
     std::vector<Layout> *olayouts)>;
 
+/*!
+ * \brief Get a list of inputs that represent graphs instead of data.
+ * Normally, input symbols are considered as data to the operator. However,
+ * control flow operators and high-order functions need to interpret symbols
+ * as graphs.
+ * \param attrs The attributes of this node.
+ * \return a list of input index that are interpreted as symbols by the operator.
+ *
+ * \note Register under "FInputGraph".
+ */
+using FInputGraph = std::function<std::vector<uint32_t>(const NodeAttrs& attrs)>;
+
 }  // namespace nnvm
 
 #endif  // NNVM_OP_ATTR_TYPES_H_
diff --git a/nnvm/src/core/graph.cc b/nnvm/src/core/graph.cc
index 62c7085c1..b8bcae70f 100644
--- a/nnvm/src/core/graph.cc
+++ b/nnvm/src/core/graph.cc
@@ -16,15 +16,51 @@ const IndexedGraph& Graph::indexed_graph() const {
   return *indexed_graph_;
 }
 
+// a subgraph should not refer to any nodes with higher level
+// where "level" refers to the nested depth of the subgraph
+// e.g. the main graph is level 0
+// subgraphs of the main graph is level 1
+// subgraphs of the subgraphs of the main graph is level 2
+static void SubgraphSanityCheck(const std::vector<std::shared_ptr<Symbol>> &subgraphs) {
+  std::vector<const std::vector<nnvm::NodeEntry>*> curr_level;
+  std::vector<const std::vector<nnvm::NodeEntry>*> next_level;
+  std::unordered_map<nnvm::Node*, uint32_t> node2level;
+  for (auto &subgraph : subgraphs)
+    next_level.push_back(&subgraph->outputs);
+  for (uint32_t level = 0; !next_level.empty(); ++level) {
+    curr_level.swap(next_level);
+    next_level.clear();
+    for (const std::vector<NodeEntry> *graph_ptr : curr_level) {
+      const std::vector<NodeEntry> &graph = *graph_ptr;
+      DFSVisit(graph, [&next_level, &node2level, level](const NodePtr& n) {
+        nnvm::Node *node = n.get();
+        // if the node is visited, but on a different level, then check failed
+        // if check failed here or before, we stop doing anything, but raise an error
+        CHECK(!node2level.count(node) || node2level[node] == level)
+          << "A subgraph should not depend on the outputs of nodes on higher levels";
+        // otherwise, this node belongs to the current level
+        node2level[node] = level;
+        // subgraphs of current node belongs to next level
+        for (const auto& subgraph : n->attrs.subgraphs) {
+          next_level.push_back(&subgraph->outputs);
+        }
+      });
+    }
+  }
+}
+
 // implement constructor from graph
 IndexedGraph::IndexedGraph(const Graph &g) {
   entry_rptr_.push_back(0);
   std::vector<size_t> inputs_rptr{0}, control_rptr{0};
+  std::vector<std::shared_ptr<Symbol>> subgraphs;
 
-  DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr]
+  DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs]
              (const NodePtr& n) {
       CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
       uint32_t nid = static_cast<uint32_t>(nodes_.size());
+      for (const auto &subgraph : n->attrs.subgraphs)
+        subgraphs.push_back(subgraph);
       // nodes_
       IndexedGraph::Node new_node;
       new_node.source = n.get();
@@ -53,6 +89,8 @@ IndexedGraph::IndexedGraph(const Graph &g) {
       }
       control_rptr.push_back(control_deps_.size());
   });
+  if (!subgraphs.empty())
+    SubgraphSanityCheck(subgraphs);
 
   for (const auto& e : g.outputs) {
     outputs_.emplace_back(NodeEntry{
diff --git a/nnvm/src/core/symbolic.cc b/nnvm/src/core/symbolic.cc
index 2a2f5be50..927dd2b70 100644
--- a/nnvm/src/core/symbolic.cc
+++ b/nnvm/src/core/symbolic.cc
@@ -267,14 +267,36 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
                      const std::string& name) {
   static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
   static auto& fset_attrs = Op::GetAttr<FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose");
+  static auto& fgraph = Op::GetAttr<FInputGraph>("FInputGraph");
+
+  // The arguments that contain graphs.
+  Node* n = outputs[0].node.get();
+  FInputGraph fng = fgraph.get(n->op(), nullptr);
+  std::vector<uint32_t> garg_idx;
+  if (fng != nullptr)
+    garg_idx = fng(n->attrs);
+
+  // The names of the arguments that contain graphs.
+  FListInputNames name_fn = flist_inputs.get(n->op(), nullptr);
+  auto arg_names = (name_fn == nullptr) ? std::vector<std::string>{"data"} : name_fn(n->attrs);
+  std::vector<std::string> garg_names(garg_idx.size());
+  for (size_t i = 0; i < garg_idx.size(); i++) {
+    size_t idx = garg_idx[i];
+    if (idx < arg_names.size())
+      garg_names[i] = arg_names[idx];
+  }
 
   // parameter check.
   for (size_t i = 0; i < args.size(); ++i) {
-    CHECK_EQ(args[i]->outputs.size(), 1U)
+    // If the argument isn't a graph, it should have only one output.
+    if (garg_idx.empty() || std::find(garg_idx.begin(), garg_idx.end(), i) == garg_idx.end())
+      CHECK_EQ(args[i]->outputs.size(), 1U)
         << "Argument " << i << " is a tuple, single value is required";
   }
   for (const auto& kv : kwargs) {
-    CHECK_EQ(kv.second->outputs.size(), 1U)
+    if (garg_names.empty()
+        || std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end())
+      CHECK_EQ(kv.second->outputs.size(), 1U)
         << "Keyword Argument " << kv.first << " is a tuple, single value is required";
   }
   // assign new name
@@ -282,28 +304,49 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
 
   // Atomic functor composition.
   if (IsAtomic(outputs)) {
-    Node* n = outputs[0].node.get();
     uint32_t n_req = n->num_inputs();
+    std::vector<const Symbol *> arg_vec(args.begin(), args.end());
+    std::unordered_map<std::string, const Symbol*> kwarg_map(kwargs.begin(), kwargs.end());
+    // If one of the input arguments is a graph, we need to remove it from the
+    // list.
+    if (fng != nullptr) {
+      std::vector<uint32_t> idxes = fng(n->attrs);
+      for (auto idx : idxes) {
+        const Symbol *sym;
+        if (idx < arg_vec.size()) {
+          sym = arg_vec[idx];
+          arg_vec.erase(arg_vec.begin() + idx);
+        } else {
+          auto it = kwarg_map.find(arg_names[idx]);
+          CHECK(it != kwarg_map.end());
+          sym = it->second;
+          kwarg_map.erase(it);
+        }
+
+        if (n_req != kVarg)
+          n_req--;
+        arg_names.erase(arg_names.begin() + idx);
+        n->attrs.subgraphs.push_back(std::make_shared<Symbol>(*sym));
+      }
+    }
 
     if (n_req != kVarg) {
       n->inputs.resize(n_req);
-      CHECK_LE(args.size(), n_req)
+      CHECK_LE(arg_vec.size(), n_req)
           << "Incorrect number of arguments, requires " << n_req
-          << ", provided " << args.size();
-      for (size_t i = 0; i < args.size(); ++i) {
-        n->inputs[i] = args[i]->outputs[0];
+          << ", provided " << arg_vec.size();
+      for (size_t i = 0; i < arg_vec.size(); ++i) {
+        n->inputs[i] = arg_vec[i]->outputs[0];
       }
       // switch to keyword argument matching
-      if (args.size() != n_req) {
-        FListInputNames fn = flist_inputs.get(n->op(), nullptr);
-        auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs);
+      if (arg_vec.size() != n_req) {
         if (arg_names.size() != n_req) {
           LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op()->name;
         }
         size_t nmatched = 0;
-        for (size_t i = args.size(); i < n_req; ++i) {
-          auto it = kwargs.find(arg_names[i]);
-          if (it != kwargs.end() && it->first == arg_names[i]) {
+        for (size_t i = arg_vec.size(); i < n_req; ++i) {
+          auto it = kwarg_map.find(arg_names[i]);
+          if (it != kwarg_map.end() && it->first == arg_names[i]) {
             n->inputs[i] = it->second->outputs[0];
             ++nmatched;
           } else {
@@ -314,18 +357,18 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
           }
         }
 
-        if (nmatched != kwargs.size()) {
+        if (nmatched != kwarg_map.size()) {
           n->inputs.clear();
-          std::vector<std::string> keys = GetKeys(kwargs);
-          array_view<std::string> view(dmlc::BeginPtr(arg_names) + args.size(),
+          std::vector<std::string> keys = GetKeys(kwarg_map);
+          array_view<std::string> view(dmlc::BeginPtr(arg_names) + arg_vec.size(),
                                        dmlc::BeginPtr(arg_names) + arg_names.size());
           KeywordArgumentMismatch("Symbol.Compose", keys, view);
         }
       }
     } else {
-      CHECK_EQ(kwargs.size(), 0U) << "Variable length function do not accept kwargs";
-      n->inputs.reserve(args.size());
-      for (const Symbol* s : args) {
+      CHECK_EQ(kwarg_map.size(), 0U) << "Variable length function do not accept kwargs";
+      n->inputs.reserve(arg_vec.size());
+      for (const Symbol* s : arg_vec) {
         n->inputs.push_back(s->outputs[0]);
       }
     }
diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc
index 3170d245c..195d49bfb 100644
--- a/nnvm/src/pass/saveload_json.cc
+++ b/nnvm/src/pass/saveload_json.cc
@@ -29,6 +29,11 @@ namespace nnvm {
 namespace pass {
 namespace {
 
+// JSONNode represents an nnvm::Node in JSON
+struct JSONNode;
+// JSONGraph represents an nnvm::Graph or nnvm::Symbol in JSON
+struct JSONGraph;
+
 // auxiliary node structure for serialization.
 struct JSONNode {
   // the node entry structure in serialized format
@@ -36,6 +41,10 @@ struct JSONNode {
     uint32_t node_id;
     uint32_t index;
     uint32_t version;
+    Entry() = default;
+    Entry(uint32_t node_id, uint32_t index, uint32_t version):
+      node_id(node_id), index(index), version(version) {
+    }
     void Save(dmlc::JSONWriter *writer) const {
       writer->BeginArray(false);
       writer->WriteArrayItem(node_id);
@@ -64,6 +73,8 @@ struct JSONNode {
   std::vector<Entry> inputs;
   // control flow dependencies
   std::vector<uint32_t> control_deps;
+  // subgraphs
+  std::vector<JSONGraph> subgraphs;
 
   // function to save JSON node.
   void Save(dmlc::JSONWriter *writer) const {
@@ -85,6 +96,9 @@ struct JSONNode {
     if (control_deps.size() != 0) {
       writer->WriteObjectKeyValue("control_deps", control_deps);
     }
+    if (subgraphs.size() != 0) {
+      writer->WriteObjectKeyValue("subgraphs", subgraphs);
+    }
     writer->EndObject();
   }
 
@@ -99,6 +113,7 @@ struct JSONNode {
     helper.DeclareOptionalField("attrs", &(node->attrs.dict));
     helper.DeclareOptionalField("attr", &(node->attrs.dict));
     helper.DeclareOptionalField("control_deps", &control_deps);
+    helper.DeclareOptionalField("subgraphs", &subgraphs);
     // backward compatible code with mxnet graph.
     int backward_source_id;
     std::unordered_map<std::string, std::string> param;
@@ -154,86 +169,107 @@ struct JSONGraph {
   }
 };
 
-// Load a graph from JSON file.
-Graph LoadJSON(Graph src) {
-  CHECK_NE(src.attrs.count("json"), 0U)
-      << "Load JSON require json to be presented.";
-  const std::string &json_str =
-      nnvm::get<std::string>(*src.attrs.at("json"));
-  bool no_parse = false;
-  if (src.attrs.count("load_json_no_parse")) {
-    no_parse = nnvm::get<bool>(*src.attrs.at("load_json_no_parse"));
+void Symbol2JSONGraph(std::shared_ptr<Symbol> src, JSONGraph *jgraph) {
+  std::unordered_map<Node*, uint32_t> node2index;
+  jgraph->node_row_ptr.push_back(0);
+  DFSVisit(src->outputs, [&node2index, jgraph](const NodePtr& n) {
+    uint32_t nid = static_cast<uint32_t>(jgraph->nodes.size());
+    node2index[n.get()] = nid;
+    if (n->is_variable()) {
+      jgraph->arg_nodes.push_back(nid);
+    }
+    JSONNode jnode;
+    jnode.node = n;
+    jnode.inputs.reserve(n->inputs.size());
+    for (const NodeEntry& e : n->inputs) {
+      jnode.inputs.emplace_back(node2index.at(e.node.get()), e.index, e.version);
+    }
+    for (const NodePtr& c : n->control_deps) {
+      jnode.control_deps.push_back(node2index.at(c.get()));
+    }
+    jgraph->node_row_ptr.push_back(jgraph->node_row_ptr.back() + n->num_outputs());
+    jgraph->nodes.emplace_back(std::move(jnode));
+  });
+  for (const NodeEntry& e : src->outputs) {
+    jgraph->heads.emplace_back(node2index.at(e.node.get()), e.index, e.version);
   }
-  std::istringstream is(json_str);
-  dmlc::JSONReader reader(&is);
-  JSONGraph jgraph;
-  // load in json graph.
-  jgraph.Load(&reader);
-  // connects the nodes
-  for (JSONNode &n : jgraph.nodes) {
+  // recursively construct subgraphs
+  for (JSONNode &jnode : jgraph->nodes) {
+    // construct jnode's subgraphs
+    const std::vector<std::shared_ptr<Symbol>> &subgraphs = jnode.node->attrs.subgraphs;
+    std::vector<JSONGraph> &jsubgraphs = jnode.subgraphs;
+    jsubgraphs.resize(subgraphs.size());
+    for (uint32_t i = 0; i < subgraphs.size(); ++i) {
+      Symbol2JSONGraph(subgraphs[i], &jsubgraphs[i]);
+    }
+  }
+}
+
+std::shared_ptr<Symbol> JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) {
+  for (const JSONNode &n : jgraph.nodes) {
     n.node->inputs.reserve(n.inputs.size());
     for (const JSONNode::Entry &e : n.inputs) {
-      n.node->inputs.emplace_back(
-          NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
+      n.node->inputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
     }
     n.node->control_deps.reserve(n.control_deps.size());
     for (uint32_t nid : n.control_deps) {
       n.node->control_deps.push_back(jgraph.nodes[nid].node);
     }
     // rebuild attribute parser
-    if (!no_parse && n.node->op() != nullptr &&
-        n.node->op()->attr_parser != nullptr) {
+    if (!no_parse && n.node->op() != nullptr && n.node->op()->attr_parser != nullptr) {
       n.node->op()->attr_parser(&(n.node->attrs));
     }
+    for (const JSONGraph &subgraph : n.subgraphs) {
+      // The "no_parse" option here, is to be compatible with
+      // commit cfd3075e85807dcd8f9534c37e053583dee87524
+      // (https://github.com/apache/incubator-mxnet/tree/cfd3075e85807dcd8f9534c37e053583dee87524),
+      // where the parsing of main graph is deferred until
+      // incubator-mxnet/src/nnvm/legacy_json_util.cc:UpgradeJSON_Parse
+      n.node->attrs.subgraphs.push_back(JSONGraph2Symbol(subgraph, false));
+    }
   }
-  // consistent check
+  // consistency check
   for (uint32_t nid : jgraph.arg_nodes) {
     CHECK(jgraph.nodes[nid].node->is_variable());
   }
+  std::shared_ptr<Symbol> symbol = std::make_shared<Symbol>();
+  symbol->outputs.reserve(jgraph.heads.size());
+  for (const JSONNode::Entry &e : jgraph.heads) {
+    symbol->outputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
+  }
+  return symbol;
+}
 
+// Load a graph from JSON file.
+Graph LoadJSON(Graph src) {
+  CHECK_NE(src.attrs.count("json"), 0U)
+      << "Load JSON require json to be presented.";
+  const std::string &json_str =
+      nnvm::get<std::string>(*src.attrs.at("json"));
+  bool no_parse = false;
+  if (src.attrs.count("load_json_no_parse")) {
+    no_parse = nnvm::get<bool>(*src.attrs.at("load_json_no_parse"));
+  }
+  std::istringstream is(json_str);
+  dmlc::JSONReader reader(&is);
+  JSONGraph jgraph;
+  // load in json graph.
+  jgraph.Load(&reader);
+  std::shared_ptr<Symbol> symbol = JSONGraph2Symbol(jgraph, no_parse);
   // return the graph
   Graph ret;
   ret.attrs = std::move(jgraph.attrs);
-  ret.outputs.reserve(jgraph.heads.size());
-  for (const JSONNode::Entry &e : jgraph.heads) {
-    ret.outputs.emplace_back(
-        NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
-  }
+  ret.outputs = symbol->outputs;
   return ret;
 }
 
 // save a graph to json
 Graph SaveJSON(Graph src) {
+  std::shared_ptr<Symbol> src_symbol = std::make_shared<Symbol>();
+  src_symbol->outputs = src.outputs;
   JSONGraph jgraph;
+  Symbol2JSONGraph(src_symbol, &jgraph);
   jgraph.attrs = src.attrs;
-  std::unordered_map<Node*, uint32_t> node2index;
-  jgraph.node_row_ptr.push_back(0);
-  DFSVisit(src.outputs, [&node2index, &jgraph](const NodePtr& n) {
-      uint32_t nid = static_cast<uint32_t>(jgraph.nodes.size());
-      node2index[n.get()] = nid;
-      if (n->is_variable()) {
-        jgraph.arg_nodes.push_back(nid);
-      }
-      JSONNode jnode;
-      jnode.node = n;
-      jnode.inputs.reserve(n->inputs.size());
-      for (const NodeEntry& e : n->inputs) {
-        jnode.inputs.emplace_back(
-            JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version});
-      }
-      for (const NodePtr& c : n->control_deps) {
-        jnode.control_deps.push_back(node2index.at(c.get()));
-      }
-      jgraph.node_row_ptr.push_back(
-          jgraph.node_row_ptr.back() + n->num_outputs());
-      jgraph.nodes.emplace_back(std::move(jnode));
-    });
-
-  for (const NodeEntry& e : src.outputs) {
-    jgraph.heads.push_back(
-        JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version});
-  }
-
   std::ostringstream os;
   dmlc::JSONWriter writer(&os);
   jgraph.Save(&writer);
-- 
GitLab