From c3df7726ac8d25529368aa49e5bc0541d22cd996 Mon Sep 17 00:00:00 2001
From: Wenhao Hu <fumihwh@gmail.com>
Date: Wed, 27 Jun 2018 01:03:52 +0900
Subject: [PATCH] support t attr in onnx (#1300)

---
 nnvm/python/nnvm/frontend/onnx.py | 32 ++++++++++++++++++++++---------
 1 file changed, 23 insertions(+), 9 deletions(-)

diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py
index 129c5089c..d92a856d7 100644
--- a/nnvm/python/nnvm/frontend/onnx.py
+++ b/nnvm/python/nnvm/frontend/onnx.py
@@ -571,13 +571,20 @@ class GraphProto(object):
             op_name = node.op_type
             attr = self._parse_attr(node.attribute)
             inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
-            op = self._convert_operator(op_name, inputs, attr, opset)
-            node_output = self._fix_outputs(op_name, node.output)
-            assert len(node_output) == len(op.list_output_names()), (
-                "Number of output mismatch {} vs {} in {}.".format(
-                    len(node_output), len(op.list_output_names()), op_name))
-            for k, i in zip(list(node_output), range(len(node_output))):
-                self._nodes[k] = op[i]
+            if op_name == "Constant":
+                t_proto = self._parse_attr(node.attribute)["value"]
+                self._num_param += 1
+                self._params[node.output[0]] = self._parse_array(t_proto)
+                self._nodes[node.output[0]] = _sym.Variable(name=node.output[0],
+                                                            shape=list(t_proto.dims))
+            else:
+                op = self._convert_operator(op_name, inputs, attr, opset)
+                node_output = self._fix_outputs(op_name, node.output)
+                assert len(node_output) == len(op.list_output_names()), (
+                    "Number of output mismatch {} vs {} in {}.".format(
+                        len(node_output), len(op.list_output_names()), op_name))
+                for k, i in zip(list(node_output), range(len(node_output))):
+                    self._nodes[k] = op[i]
         # now return the outputs
         out = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
         if len(out) > 1:
@@ -615,11 +622,18 @@ class GraphProto(object):
                 if list(getattr(a, f)):
                     assert a.name not in attrs, "Only one type of attr is allowed"
                     attrs[a.name] = tuple(getattr(a, f))
-            for f in ['t', 'g']:
+            for f in ['t']:
+                if a.HasField(f):
+                    attrs[a.name] = getattr(a, f)
+            for f in ['tensors']:
+                if list(getattr(a, f)):
+                    assert a.name not in attrs, "Only one type of attr is allowed"
+                    attrs[a.name] = tuple(getattr(a, f))
+            for f in ['g']:
                 if a.HasField(f):
                     raise NotImplementedError(
                         "Filed {} is not supported in nnvm.".format(f))
-            for f in ['tensors', 'graphs']:
+            for f in ['graphs']:
                 if list(getattr(a, f)):
                     raise NotImplementedError(
                         "Filed {} is not supported in nnvm.".format(f))
-- 
GitLab