diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index deae3112bf5f89c72e4ada15afca37159723b5dd..2f190ab71f4d22d133e1e08b881a9a4a2bb73dea 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -188,12 +188,15 @@ def _reshape(inputs, attrs): return _get_nnvm_op(op_name)(*inputs, **new_attrs) def _split(inputs, attrs): - if _parse_bool_str(attrs, 'squeeze_axis'): - _raise_not_supported('squeeze_axis', 'split') op_name, new_attrs = 'split', {} + axis = attrs.get('axis', 1) new_attrs['indices_or_sections'] = _required_attr(attrs, 'num_outputs') - new_attrs['axis'] = attrs.get('axis', 1) - return _get_nnvm_op(op_name)(*inputs, **new_attrs) + new_attrs['axis'] = axis + outputs = _get_nnvm_op(op_name)(*inputs, **new_attrs) + if _parse_bool_str(attrs, 'squeeze_axis'): + squeeze_attrs = {'axis': axis} + outputs = _sym.Group([_get_nnvm_op('squeeze')(o, **squeeze_attrs) for o in outputs]) + return outputs def _softmax_activation(inputs, attrs): op_name, new_attrs = 'softmax', {} diff --git a/nnvm/tests/python/frontend/mxnet/test_forward.py b/nnvm/tests/python/frontend/mxnet/test_forward.py index d0930d75cce19a5d22410b7ceb0352a7b120947f..e6b6dffa17b0b871793988116d57934cdae7f0fd 100644 --- a/nnvm/tests/python/frontend/mxnet/test_forward.py +++ b/nnvm/tests/python/frontend/mxnet/test_forward.py @@ -126,6 +126,16 @@ def test_forward_clip(): mx_sym = mx.sym.clip(data, a_min=0, a_max=1) verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) +def test_forward_split(): + data = mx.sym.var('data') + mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=False) + verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 1, 2, 1)) + +def test_forward_split_squeeze(): + data = mx.sym.var('data') + mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True) + verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1)) + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -136,3 +146,5 @@ if __name__ == '__main__': test_forward_softrelu() test_forward_fc_flatten() test_forward_clip() + test_forward_split() + test_forward_split_squeeze()