From 5bf1cbda61ef3ca2d0fd1e0ae10acdc8f9013dd5 Mon Sep 17 00:00:00 2001 From: reminisce <wujun.nju@gmail.com> Date: Fri, 5 Oct 2018 19:00:53 -0700 Subject: [PATCH] Fix saveload json bug (#1831) --- nnvm/src/pass/saveload_json.cc | 3 +++ .../python/unittest/test_pass_saveload_json.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) create mode 100644 nnvm/tests/python/unittest/test_pass_saveload_json.py diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 195d49bfb..f1acb9721 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -218,6 +218,9 @@ std::shared_ptr<Symbol> JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) // rebuild attribute parser if (!no_parse && n.node->op() != nullptr && n.node->op()->attr_parser != nullptr) { n.node->op()->attr_parser(&(n.node->attrs)); + } else if (!no_parse && n.node->is_variable()) { + n.node->attrs.parsed = + Symbol::CreateVariable(n.node->attrs.name).outputs[0].node->attrs.parsed; } for (const JSONGraph &subgraph : n.subgraphs) { // The "no_parse" option here, is to be compatible with diff --git a/nnvm/tests/python/unittest/test_pass_saveload_json.py b/nnvm/tests/python/unittest/test_pass_saveload_json.py new file mode 100644 index 000000000..7b5f5ea68 --- /dev/null +++ b/nnvm/tests/python/unittest/test_pass_saveload_json.py @@ -0,0 +1,17 @@ +import nnvm +from tvm.contrib import util + + +def test_variable_node_parsed(): + sym = nnvm.sym.Variable('data') + tempdir = util.tempdir() + json_filename = 'test_nnvm_symbol.json' + with open(tempdir.relpath(json_filename), 'w') as fo: + fo.write(nnvm.graph.create(sym).json()) + sym_str = open(tempdir.relpath(json_filename), 'r').read() + sym = nnvm.graph.load_json(sym_str).symbol() + sym = nnvm.sym.relu(sym) + + +if __name__ == '__main__': + test_variable_node_parsed() -- GitLab