From c3df7726ac8d25529368aa49e5bc0541d22cd996 Mon Sep 17 00:00:00 2001 From: Wenhao Hu <fumihwh@gmail.com> Date: Wed, 27 Jun 2018 01:03:52 +0900 Subject: [PATCH] support t attr in onnx (#1300) --- nnvm/python/nnvm/frontend/onnx.py | 32 ++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py index 129c5089c..d92a856d7 100644 --- a/nnvm/python/nnvm/frontend/onnx.py +++ b/nnvm/python/nnvm/frontend/onnx.py @@ -571,13 +571,20 @@ class GraphProto(object): op_name = node.op_type attr = self._parse_attr(node.attribute) inputs = [self._nodes[self._renames.get(i, i)] for i in node.input] - op = self._convert_operator(op_name, inputs, attr, opset) - node_output = self._fix_outputs(op_name, node.output) - assert len(node_output) == len(op.list_output_names()), ( - "Number of output mismatch {} vs {} in {}.".format( - len(node_output), len(op.list_output_names()), op_name)) - for k, i in zip(list(node_output), range(len(node_output))): - self._nodes[k] = op[i] + if op_name == "Constant": + t_proto = self._parse_attr(node.attribute)["value"] + self._num_param += 1 + self._params[node.output[0]] = self._parse_array(t_proto) + self._nodes[node.output[0]] = _sym.Variable(name=node.output[0], + shape=list(t_proto.dims)) + else: + op = self._convert_operator(op_name, inputs, attr, opset) + node_output = self._fix_outputs(op_name, node.output) + assert len(node_output) == len(op.list_output_names()), ( + "Number of output mismatch {} vs {} in {}.".format( + len(node_output), len(op.list_output_names()), op_name)) + for k, i in zip(list(node_output), range(len(node_output))): + self._nodes[k] = op[i] # now return the outputs out = [self._nodes[self._parse_value_proto(i)] for i in graph.output] if len(out) > 1: @@ -615,11 +622,18 @@ class GraphProto(object): if list(getattr(a, f)): assert a.name not in attrs, "Only one type of attr is allowed" attrs[a.name] = tuple(getattr(a, f)) - for f in ['t', 'g']: + for f in ['t']: + if a.HasField(f): + attrs[a.name] = getattr(a, f) + for f in ['tensors']: + if list(getattr(a, f)): + assert a.name not in attrs, "Only one type of attr is allowed" + attrs[a.name] = tuple(getattr(a, f)) + for f in ['g']: if a.HasField(f): raise NotImplementedError( "Filed {} is not supported in nnvm.".format(f)) - for f in ['tensors', 'graphs']: + for f in ['graphs']: if list(getattr(a, f)): raise NotImplementedError( "Filed {} is not supported in nnvm.".format(f)) -- GitLab