From 80e4bc029c59bef7ae3a05f76c93c25ff010f009 Mon Sep 17 00:00:00 2001
From: Tatsuya Nishiyama <nishiyama.tatsuya0@gmail.com>
Date: Thu, 21 Jun 2018 01:58:07 +0900
Subject: [PATCH] [FRONTEND][MXNET] Add squeeze_axis support to split operator
 (#1288)

---
 nnvm/python/nnvm/frontend/mxnet.py               | 11 +++++++----
 nnvm/tests/python/frontend/mxnet/test_forward.py | 12 ++++++++++++
 2 files changed, 19 insertions(+), 4 deletions(-)

diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py
index deae3112b..2f190ab71 100644
--- a/nnvm/python/nnvm/frontend/mxnet.py
+++ b/nnvm/python/nnvm/frontend/mxnet.py
@@ -188,12 +188,15 @@ def _reshape(inputs, attrs):
     return _get_nnvm_op(op_name)(*inputs, **new_attrs)
 
 def _split(inputs, attrs):
-    if _parse_bool_str(attrs, 'squeeze_axis'):
-        _raise_not_supported('squeeze_axis', 'split')
     op_name, new_attrs = 'split', {}
+    axis = attrs.get('axis', 1)
     new_attrs['indices_or_sections'] = _required_attr(attrs, 'num_outputs')
-    new_attrs['axis'] = attrs.get('axis', 1)
-    return _get_nnvm_op(op_name)(*inputs, **new_attrs)
+    new_attrs['axis'] = axis
+    outputs = _get_nnvm_op(op_name)(*inputs, **new_attrs)
+    if _parse_bool_str(attrs, 'squeeze_axis'):
+        squeeze_attrs = {'axis': axis}
+        outputs = _sym.Group([_get_nnvm_op('squeeze')(o, **squeeze_attrs) for o in outputs])
+    return outputs
 
 def _softmax_activation(inputs, attrs):
     op_name, new_attrs = 'softmax', {}
diff --git a/nnvm/tests/python/frontend/mxnet/test_forward.py b/nnvm/tests/python/frontend/mxnet/test_forward.py
index d0930d75c..e6b6dffa1 100644
--- a/nnvm/tests/python/frontend/mxnet/test_forward.py
+++ b/nnvm/tests/python/frontend/mxnet/test_forward.py
@@ -126,6 +126,16 @@ def test_forward_clip():
     mx_sym = mx.sym.clip(data, a_min=0, a_max=1)
     verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
 
+def test_forward_split():
+    data = mx.sym.var('data')
+    mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=False)
+    verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 1, 2, 1))
+
+def test_forward_split_squeeze():
+    data = mx.sym.var('data')
+    mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True)
+    verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1))
+
 if __name__ == '__main__':
     test_forward_mlp()
     test_forward_vgg()
@@ -136,3 +146,5 @@ if __name__ == '__main__':
     test_forward_softrelu()
     test_forward_fc_flatten()
     test_forward_clip()
+    test_forward_split()
+    test_forward_split_squeeze()
-- 
GitLab