From c68945c582e7b3371cde85e52056bbe6fb3dc2fe Mon Sep 17 00:00:00 2001
From: MORITA Kazutaka <morita.kazutaka@lab.ntt.co.jp>
Date: Fri, 25 May 2018 01:05:17 +0900
Subject: [PATCH] [FRONTEND][Keras] fix reshape (#493)

---
 nnvm/python/nnvm/frontend/keras.py               | 11 ++++++-----
 nnvm/tests/python/frontend/keras/test_forward.py |  9 +++++++++
 2 files changed, 15 insertions(+), 5 deletions(-)

diff --git a/nnvm/python/nnvm/frontend/keras.py b/nnvm/python/nnvm/frontend/keras.py
index 0d51487c3..a61fee7e3 100644
--- a/nnvm/python/nnvm/frontend/keras.py
+++ b/nnvm/python/nnvm/frontend/keras.py
@@ -345,11 +345,12 @@ def _convert_concat(insym, keras_layer, _):
 
 
 def _convert_reshape(insym, keras_layer, _):
-    shape = keras_layer.shape if hasattr(keras_layer, 'shape') \
-       else keras_layer.target_shape if hasattr(keras_layer, 'target_shape') \
-       else None
-    if shape is None:
-        raise TypeError("No shape attribute in reshape layer: {}".format(keras_layer))
+    _check_data_format(keras_layer)
+    ch = keras_layer.input_shape[-1]
+    assert ch == keras_layer.target_shape[-1], \
+        "Only supports last dimension in target shape being equal to " \
+        "the channel number of input tensor."
+    shape = (-1, ch) + keras_layer.target_shape[:-1]
     return _sym.reshape(insym, shape=shape)
 
 
diff --git a/nnvm/tests/python/frontend/keras/test_forward.py b/nnvm/tests/python/frontend/keras/test_forward.py
index 0147a3e2c..c751b6443 100644
--- a/nnvm/tests/python/frontend/keras/test_forward.py
+++ b/nnvm/tests/python/frontend/keras/test_forward.py
@@ -134,6 +134,14 @@ def test_forward_relu6():
     verify_keras_frontend(keras_model)
 
 
+def test_forward_reshape():
+    data = keras.layers.Input(shape=(32,32,3))
+    x = keras.layers.Reshape(target_shape=(32,32,3))(data)
+    x = keras.layers.GlobalAveragePooling2D()(x)
+    keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model)
+
+
 def test_forward_vgg16():
     keras_model = keras.applications.vgg16.VGG16(include_top=True, weights=None,
         input_shape=(224,224,3), classes=1000)
@@ -162,6 +170,7 @@ if __name__ == '__main__':
     test_forward_separable_conv()
     test_forward_upsample()
     test_forward_relu6()
+    test_forward_reshape()
 
     test_forward_vgg16()
     test_forward_xception()
-- 
GitLab