From 10b7757ab6fe6aadda4b48d6a39fef0d06daef77 Mon Sep 17 00:00:00 2001
From: Albin Joy <albin.joy@huawei.com>
Date: Thu, 5 Jul 2018 09:15:30 +0530
Subject: [PATCH] [NNVM][TENSORFLOW] Fixed variable ops shape parsing issue
 (#1381)

---
 nnvm/python/nnvm/frontend/tensorflow.py       | 15 +++++--
 .../frontend/tensorflow/test_forward.py       | 41 +++++++++++++++++++
 2 files changed, 52 insertions(+), 4 deletions(-)

diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py
index fc7aeeda7..aa00c5183 100644
--- a/nnvm/python/nnvm/frontend/tensorflow.py
+++ b/nnvm/python/nnvm/frontend/tensorflow.py
@@ -593,11 +593,18 @@ class GraphProto(object):
                         raise NotImplementedError( \
                             "Const {} couldn't be converted to Param.".format(node.name))
 
-                try:
+                attr = self._parse_attr(node.attr)
+                #Variable converted to Const will not have only value attr
+                if 'value' in attr:
+                    tensor_value = attr['value']
                     self._output_shapes[node.name] = \
-                         [tensor_util.TensorShapeProtoToList(shape) \
-                         for shape in self._parse_attr(node.attr)['_output_shapes']]
-                except KeyError:
+                        [tensor_util.TensorShapeProtoToList( \
+                            tensor_value.tensor_shape)]
+                elif '_output_shapes' in attr:
+                    self._output_shapes[node.name] = \
+                        [tensor_util.TensorShapeProtoToList(shape) \
+                        for shape in self._parse_attr(node.attr)['_output_shapes']]
+                else:
                     raise NotImplementedError( \
                         "Please freeze the graph with add_shapes=True")
             else:
diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py
index 1ec8222e9..5a37918f1 100644
--- a/nnvm/tests/python/frontend/tensorflow/test_forward.py
+++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py
@@ -14,6 +14,8 @@ from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_array_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
 from tensorflow.core.framework import graph_pb2
 
 import nnvm.testing.tf
@@ -393,6 +395,44 @@ def test_forward_sigmoid():
 
     _test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype('float32'))
 
+
+#######################################################################
+# Variable
+# --------
+
+def _test_variable(data):
+    tf.reset_default_graph()
+    input_op = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+    input_tensor = array_ops.reshape(input_op, data.shape)
+
+    size = input_tensor.shape.dims[1]
+    with variable_scope.variable_scope("linear", reuse=None):
+        w = variable_scope.get_variable(
+            "w", shape=[size, size], dtype=input_tensor.dtype)
+    # pylint: disable=unused-variable
+    output_op = math_ops.matmul(input_tensor, w)
+    # pylint: enable=unused-variable
+
+    with tf.Session() as sess:
+        sess.run(variables.global_variables_initializer())
+        final_graph_def = tf.graph_util.convert_variables_to_constants(
+            sess,
+            sess.graph.as_graph_def(add_shapes=True),
+            ['MatMul'],
+            )
+
+        tf_output = run_tf_graph(sess, data, 'Placeholder:0', 'MatMul:0')
+        tvm_output = run_tvm_graph(final_graph_def, data,
+                                   "Placeholder", tf_output.shape, data.dtype)
+
+        np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
+        sess.close()
+
+def test_forward_variable():
+    """Variable type op test"""
+    _test_variable(np.random.uniform(size=(32, 100)).astype('float32'))
+
+
 #######################################################################
 # Multi Input to graph
 # --------------------
@@ -503,3 +543,4 @@ if __name__ == '__main__':
     test_forward_inception_v3()
     test_forward_inception_v1()
     test_forward_mobilenet()
+    test_forward_variable()
-- 
GitLab