diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 485b1417a49321733d16cfaabe1da5b00561c77f..4a0706b6d501b275e0550e7819f31cc967dd7894 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -209,10 +209,12 @@ std::shared_ptr<Symbol> JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) for (const JSONNode &n : jgraph.nodes) { n.node->inputs.reserve(n.inputs.size()); for (const JSONNode::Entry &e : n.inputs) { + CHECK(e.node_id < jgraph.nodes.size()); n.node->inputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); } n.node->control_deps.reserve(n.control_deps.size()); for (uint32_t nid : n.control_deps) { + CHECK(nid < jgraph.nodes.size()); n.node->control_deps.push_back(jgraph.nodes[nid].node); } for (const JSONGraph &subgraph : n.subgraphs) { @@ -233,11 +235,13 @@ std::shared_ptr<Symbol> JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) } // consistency check for (uint32_t nid : jgraph.arg_nodes) { + CHECK(nid < jgraph.nodes.size()); CHECK(jgraph.nodes[nid].node->is_variable()); } std::shared_ptr<Symbol> symbol = std::make_shared<Symbol>(); symbol->outputs.reserve(jgraph.heads.size()); for (const JSONNode::Entry &e : jgraph.heads) { + CHECK(e.node_id < jgraph.nodes.size()); symbol->outputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); } return symbol;