diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 195d49bfb9b4d892595ac03edecd7db4bafa7b99..f1acb972158dba288ffda495d73f08dcd26ae48f 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -218,6 +218,9 @@ std::shared_ptr<Symbol> JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) // rebuild attribute parser if (!no_parse && n.node->op() != nullptr && n.node->op()->attr_parser != nullptr) { n.node->op()->attr_parser(&(n.node->attrs)); + } else if (!no_parse && n.node->is_variable()) { + n.node->attrs.parsed = + Symbol::CreateVariable(n.node->attrs.name).outputs[0].node->attrs.parsed; } for (const JSONGraph &subgraph : n.subgraphs) { // The "no_parse" option here, is to be compatible with diff --git a/nnvm/tests/python/unittest/test_pass_saveload_json.py b/nnvm/tests/python/unittest/test_pass_saveload_json.py new file mode 100644 index 0000000000000000000000000000000000000000..7b5f5ea6867a6a94d437e911a515b7ce25f22fee --- /dev/null +++ b/nnvm/tests/python/unittest/test_pass_saveload_json.py @@ -0,0 +1,17 @@ +import nnvm +from tvm.contrib import util + + +def test_variable_node_parsed(): + sym = nnvm.sym.Variable('data') + tempdir = util.tempdir() + json_filename = 'test_nnvm_symbol.json' + with open(tempdir.relpath(json_filename), 'w') as fo: + fo.write(nnvm.graph.create(sym).json()) + sym_str = open(tempdir.relpath(json_filename), 'r').read() + sym = nnvm.graph.load_json(sym_str).symbol() + sym = nnvm.sym.relu(sym) + + +if __name__ == '__main__': + test_variable_node_parsed()