From 10b7757ab6fe6aadda4b48d6a39fef0d06daef77 Mon Sep 17 00:00:00 2001 From: Albin Joy <albin.joy@huawei.com> Date: Thu, 5 Jul 2018 09:15:30 +0530 Subject: [PATCH] [NNVM][TENSORFLOW] Fixed variable ops shape parsing issue (#1381) --- nnvm/python/nnvm/frontend/tensorflow.py | 15 +++++-- .../frontend/tensorflow/test_forward.py | 41 +++++++++++++++++++ 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index fc7aeeda7..aa00c5183 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -593,11 +593,18 @@ class GraphProto(object): raise NotImplementedError( \ "Const {} couldn't be converted to Param.".format(node.name)) - try: + attr = self._parse_attr(node.attr) + #Variable converted to Const will not have only value attr + if 'value' in attr: + tensor_value = attr['value'] self._output_shapes[node.name] = \ - [tensor_util.TensorShapeProtoToList(shape) \ - for shape in self._parse_attr(node.attr)['_output_shapes']] - except KeyError: + [tensor_util.TensorShapeProtoToList( \ + tensor_value.tensor_shape)] + elif '_output_shapes' in attr: + self._output_shapes[node.name] = \ + [tensor_util.TensorShapeProtoToList(shape) \ + for shape in self._parse_attr(node.attr)['_output_shapes']] + else: raise NotImplementedError( \ "Please freeze the graph with add_shapes=True") else: diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 1ec8222e9..5a37918f1 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -14,6 +14,8 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.core.framework import graph_pb2 import nnvm.testing.tf @@ -393,6 +395,44 @@ def test_forward_sigmoid(): _test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype('float32')) + +####################################################################### +# Variable +# -------- + +def _test_variable(data): + tf.reset_default_graph() + input_op = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + input_tensor = array_ops.reshape(input_op, data.shape) + + size = input_tensor.shape.dims[1] + with variable_scope.variable_scope("linear", reuse=None): + w = variable_scope.get_variable( + "w", shape=[size, size], dtype=input_tensor.dtype) + # pylint: disable=unused-variable + output_op = math_ops.matmul(input_tensor, w) + # pylint: enable=unused-variable + + with tf.Session() as sess: + sess.run(variables.global_variables_initializer()) + final_graph_def = tf.graph_util.convert_variables_to_constants( + sess, + sess.graph.as_graph_def(add_shapes=True), + ['MatMul'], + ) + + tf_output = run_tf_graph(sess, data, 'Placeholder:0', 'MatMul:0') + tvm_output = run_tvm_graph(final_graph_def, data, + "Placeholder", tf_output.shape, data.dtype) + + np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5) + sess.close() + +def test_forward_variable(): + """Variable type op test""" + _test_variable(np.random.uniform(size=(32, 100)).astype('float32')) + + ####################################################################### # Multi Input to graph # -------------------- @@ -503,3 +543,4 @@ if __name__ == '__main__': test_forward_inception_v3() test_forward_inception_v1() test_forward_mobilenet() + test_forward_variable() -- GitLab