Skip to content
Snippets Groups Projects
Commit 58b2395d authored by MORITA Kazutaka's avatar MORITA Kazutaka Committed by Tianqi Chen
Browse files

[NNVM][KERAS] Fixed padding in pooling (#1635)

parent d90c1e45
No related branches found
No related tags found
No related merge requests found
...@@ -269,14 +269,12 @@ def _convert_pooling(insym, keras_layer, symtab): ...@@ -269,14 +269,12 @@ def _convert_pooling(insym, keras_layer, symtab):
'padding': [0, 0]} 'padding': [0, 0]}
if keras_layer.padding == 'valid': if keras_layer.padding == 'valid':
pass pass
# we insert a separate pad operator
elif keras_layer.padding == 'same': elif keras_layer.padding == 'same':
in_h = keras_layer.input_shape[1] in_h = keras_layer.input_shape[1]
in_w = keras_layer.input_shape[2] in_w = keras_layer.input_shape[2]
pad_t, pad_b = _get_pad_pair(in_h, pool_h, stride_h) pad_t, pad_b = _get_pad_pair(in_h, pool_h, stride_h)
pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w) pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w)
insym = _sym.pad(data=insym, pad_width=( params['padding'] = [pad_t, pad_l, pad_b, pad_r]
(0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r)))
else: else:
raise TypeError("Unsupported padding type : {}".format(keras_layer.padding)) raise TypeError("Unsupported padding type : {}".format(keras_layer.padding))
if pool_type == 'MaxPooling2D': if pool_type == 'MaxPooling2D':
......
...@@ -38,7 +38,7 @@ def verify_keras_frontend(keras_model): ...@@ -38,7 +38,7 @@ def verify_keras_frontend(keras_model):
out = m.get_output(0, tvm.nd.empty(out_shape, dtype)) out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
return out.asnumpy() return out.asnumpy()
xs = [np.random.uniform(size=shape) for shape in in_shapes] xs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes]
keras_out = get_keras_output(xs) keras_out = get_keras_output(xs)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
tvm_out = get_tvm_output([x.transpose([0,3,1,2]) for x in xs], target, ctx) tvm_out = get_tvm_output([x.transpose([0,3,1,2]) for x in xs], target, ctx)
...@@ -74,6 +74,18 @@ def test_forward_dense(): ...@@ -74,6 +74,18 @@ def test_forward_dense():
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_pool():
data = keras.layers.Input(shape=(2,2,1))
# maxpool
x = keras.layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same')(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
# avgpool
y = keras.layers.AveragePooling2D((3, 3), strides=(1, 1), padding='same')(data)
keras_model = keras.models.Model(data, y)
verify_keras_frontend(keras_model)
def test_forward_transpose_conv(): def test_forward_transpose_conv():
data = keras.layers.Input(shape=(32,32,3)) data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Conv2D(filters=10, kernel_size=(3,3), strides=(2,2), padding='same')(data) x = keras.layers.Conv2D(filters=10, kernel_size=(3,3), strides=(2,2), padding='same')(data)
...@@ -206,6 +218,7 @@ if __name__ == '__main__': ...@@ -206,6 +218,7 @@ if __name__ == '__main__':
test_forward_elemwise_add() test_forward_elemwise_add()
test_forward_activations() test_forward_activations()
test_forward_dense() test_forward_dense()
test_forward_pool()
test_forward_transpose_conv() test_forward_transpose_conv()
test_forward_separable_conv() test_forward_separable_conv()
test_forward_upsample() test_forward_upsample()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment