From 5bf1cbda61ef3ca2d0fd1e0ae10acdc8f9013dd5 Mon Sep 17 00:00:00 2001
From: reminisce <wujun.nju@gmail.com>
Date: Fri, 5 Oct 2018 19:00:53 -0700
Subject: [PATCH] Fix saveload json bug (#1831)

---
 nnvm/src/pass/saveload_json.cc                  |  3 +++
 .../python/unittest/test_pass_saveload_json.py  | 17 +++++++++++++++++
 2 files changed, 20 insertions(+)
 create mode 100644 nnvm/tests/python/unittest/test_pass_saveload_json.py

diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc
index 195d49bfb..f1acb9721 100644
--- a/nnvm/src/pass/saveload_json.cc
+++ b/nnvm/src/pass/saveload_json.cc
@@ -218,6 +218,9 @@ std::shared_ptr<Symbol> JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse)
     // rebuild attribute parser
     if (!no_parse && n.node->op() != nullptr && n.node->op()->attr_parser != nullptr) {
       n.node->op()->attr_parser(&(n.node->attrs));
+    } else if (!no_parse && n.node->is_variable()) {
+      n.node->attrs.parsed =
+        Symbol::CreateVariable(n.node->attrs.name).outputs[0].node->attrs.parsed;
     }
     for (const JSONGraph &subgraph : n.subgraphs) {
       // The "no_parse" option here, is to be compatible with
diff --git a/nnvm/tests/python/unittest/test_pass_saveload_json.py b/nnvm/tests/python/unittest/test_pass_saveload_json.py
new file mode 100644
index 000000000..7b5f5ea68
--- /dev/null
+++ b/nnvm/tests/python/unittest/test_pass_saveload_json.py
@@ -0,0 +1,17 @@
+import nnvm
+from tvm.contrib import util
+
+
+def test_variable_node_parsed():
+    sym = nnvm.sym.Variable('data')
+    tempdir = util.tempdir()
+    json_filename = 'test_nnvm_symbol.json'
+    with open(tempdir.relpath(json_filename), 'w') as fo:
+        fo.write(nnvm.graph.create(sym).json())
+    sym_str = open(tempdir.relpath(json_filename), 'r').read()
+    sym = nnvm.graph.load_json(sym_str).symbol()
+    sym = nnvm.sym.relu(sym)
+
+
+if __name__ == '__main__':
+    test_variable_node_parsed()
-- 
GitLab