From e5d92e1b96c1fbc175be741f522c030f2d91a613 Mon Sep 17 00:00:00 2001
From: Dominic Symes <36929632+dominicsymes@users.noreply.github.com>
Date: Mon, 24 Dec 2018 21:08:45 +0000
Subject: [PATCH] [FRONTEND][TENSORFLOW] Bugfix (#2326)

---
 nnvm/python/nnvm/frontend/tensorflow.py               | 7 +++++--
 nnvm/tests/python/frontend/tensorflow/test_forward.py | 1 +
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py
index 10f23a49b..47aca3816 100644
--- a/nnvm/python/nnvm/frontend/tensorflow.py
+++ b/nnvm/python/nnvm/frontend/tensorflow.py
@@ -646,6 +646,9 @@ def _stridedSlice():
                 pass
             else:
                 final_output.append(out_shape[gather_index])
+        # Prevent 0-dim tensors which are not accepted by nnvm
+        if not final_output:
+            final_output.append(1)
         return _sym.reshape(out, shape=tuple(final_output))
     return _impl
 
@@ -1187,8 +1190,8 @@ class GraphProto(object):
                 raise NotImplementedError( \
                     "Please freeze the graph with add_shapes=True")
             self._outputs_are_0d[node.name] = [ \
-                not shape if isinstance(shape, list) else False \
-                for shape in self._output_shapes[node.name]]
+                not tshape if isinstance(tshape, list) else False \
+                for tshape in self._output_shapes[node.name]]
 
             if node.op == "Placeholder":
                 self._nodes[node.name] = _sym.Variable(name=node.name,
diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py
index ed3d0272b..5b8f11695 100644
--- a/nnvm/tests/python/frontend/tensorflow/test_forward.py
+++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py
@@ -463,6 +463,7 @@ def test_forward_stridedslice():
     _test_stridedslice((3, 4, 5, 4, 5, 6), [1, 2, 0, -3], [4, 5, 3, 3], [2, 2, 1, 1],
                        'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5,
                        end_mask=8)
+    _test_stridedslice((1), [0], [1], [1], 'float32', shrink_axis_mask=1)
 
 
 #######################################################################
-- 
GitLab