diff --git a/nnvm/python/nnvm/frontend/keras.py b/nnvm/python/nnvm/frontend/keras.py
index 0d51487c355d7b6fa1033df7aca8ccd8f835f7de..a61fee7e34138e489aab371463be9f2d0448dc79 100644
--- a/nnvm/python/nnvm/frontend/keras.py
+++ b/nnvm/python/nnvm/frontend/keras.py
@@ -345,11 +345,12 @@ def _convert_concat(insym, keras_layer, _):
 
 
 def _convert_reshape(insym, keras_layer, _):
-    shape = keras_layer.shape if hasattr(keras_layer, 'shape') \
-       else keras_layer.target_shape if hasattr(keras_layer, 'target_shape') \
-       else None
-    if shape is None:
-        raise TypeError("No shape attribute in reshape layer: {}".format(keras_layer))
+    _check_data_format(keras_layer)
+    ch = keras_layer.input_shape[-1]
+    assert ch == keras_layer.target_shape[-1], \
+        "Only supports last dimension in target shape being equal to " \
+        "the channel number of input tensor."
+    shape = (-1, ch) + keras_layer.target_shape[:-1]
     return _sym.reshape(insym, shape=shape)
 
 
diff --git a/nnvm/tests/python/frontend/keras/test_forward.py b/nnvm/tests/python/frontend/keras/test_forward.py
index 0147a3e2c65456b6b9aa8c4f673c7e49081a5960..c751b64435db40ad95a7c4889ac1ea720d7d70d4 100644
--- a/nnvm/tests/python/frontend/keras/test_forward.py
+++ b/nnvm/tests/python/frontend/keras/test_forward.py
@@ -134,6 +134,14 @@ def test_forward_relu6():
     verify_keras_frontend(keras_model)
 
 
+def test_forward_reshape():
+    data = keras.layers.Input(shape=(32,32,3))
+    x = keras.layers.Reshape(target_shape=(32,32,3))(data)
+    x = keras.layers.GlobalAveragePooling2D()(x)
+    keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model)
+
+
 def test_forward_vgg16():
     keras_model = keras.applications.vgg16.VGG16(include_top=True, weights=None,
         input_shape=(224,224,3), classes=1000)
@@ -162,6 +170,7 @@ if __name__ == '__main__':
     test_forward_separable_conv()
     test_forward_upsample()
     test_forward_relu6()
+    test_forward_reshape()
 
     test_forward_vgg16()
     test_forward_xception()