From 31edf3f7eb4c7c2473cc13cbe74d534c70788f62 Mon Sep 17 00:00:00 2001
From: "Ehsan M. Kermani" <ehsanmo1367@gmail.com>
Date: Fri, 25 May 2018 09:26:23 -0700
Subject: [PATCH] Expose clip to frontend mxnet (#512)

---
 nnvm/python/nnvm/frontend/mxnet.py            |  7 ++
 .../python/frontend/mxnet/model_zoo/vgg.py    |  2 +-
 .../python/frontend/mxnet/test_forward.py     | 64 ++++++++++++++-----
 3 files changed, 55 insertions(+), 18 deletions(-)

diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py
index e671acbe7..3fc3ca851 100644
--- a/nnvm/python/nnvm/frontend/mxnet.py
+++ b/nnvm/python/nnvm/frontend/mxnet.py
@@ -205,6 +205,12 @@ def _upsampling(inputs, attrs):
     new_attrs = {'scale':int(scale)}
     return _get_nnvm_op('upsampling')(inputs[0], **new_attrs)
 
+def _clip(inputs, attrs):
+    op_name, new_attrs = "clip", {}
+    new_attrs['a_min'] = _required_attr(attrs, 'a_min')
+    new_attrs['a_max'] = _required_attr(attrs, 'a_max')
+    return _get_nnvm_op(op_name)(*inputs, **new_attrs)
+
 
 _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
                   '__div_symbol__', '__mul_scalar__', '__mul_symbol__',
@@ -248,6 +254,7 @@ _convert_map = {
     'reshape'       : _reshape,
     'sum_axis'      : _rename('sum'),
     'UpSampling'    : _upsampling,
+    'clip'          : _clip
 }
 
 def _convert_symbol(op_name, inputs, attrs,
diff --git a/nnvm/tests/python/frontend/mxnet/model_zoo/vgg.py b/nnvm/tests/python/frontend/mxnet/model_zoo/vgg.py
index c9243d361..68215bb80 100644
--- a/nnvm/tests/python/frontend/mxnet/model_zoo/vgg.py
+++ b/nnvm/tests/python/frontend/mxnet/model_zoo/vgg.py
@@ -71,7 +71,7 @@ def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **
                 13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
                 16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
                 19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
-    if not vgg_spec.has_key(num_layers):
+    if num_layers not in vgg_spec:
         raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers))
     layers, filters = vgg_spec[num_layers]
     data = mx.sym.Variable(name="data")
diff --git a/nnvm/tests/python/frontend/mxnet/test_forward.py b/nnvm/tests/python/frontend/mxnet/test_forward.py
index 0f9747538..fca19a693 100644
--- a/nnvm/tests/python/frontend/mxnet/test_forward.py
+++ b/nnvm/tests/python/frontend/mxnet/test_forward.py
@@ -8,24 +8,41 @@ import nnvm.compiler
 from nnvm.testing.config import ctx_list
 from nnvm import frontend
 import mxnet as mx
+from mxnet import gluon
+from mxnet.gluon.model_zoo import vision
 import model_zoo
 
 
-def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(1, 1000)):
+def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(1, 1000),
+                               gluon_impl=False, name=None):
     """Use name different from test to avoid let nose pick it up"""
-    def get_mxnet_output(symbol, x, dtype='float32'):
-        from collections import namedtuple
-        Batch = namedtuple('Batch', ['data'])
-        mod = mx.mod.Module(symbol, label_names=None)
-        mod.bind(data_shapes=[('data', x.shape)], for_training=False)
-        mod.init_params()
-        mod.forward(Batch([mx.nd.array(x.astype(dtype))]))
-        out = mod.get_outputs()[0].asnumpy()
-        args, auxs = mod.get_params()
-        return out, args, auxs
+    if gluon_impl:
+        def get_gluon_output(name, x):
+            net = vision.get_model(name)
+            net.collect_params().initialize(mx.init.Xavier())
+            net_sym = gluon.nn.SymbolBlock(outputs=net(mx.sym.var('data')),
+                                           inputs=mx.sym.var('data'),
+                                           params=net.collect_params())
+            out = net_sym(mx.nd.array(x.astype(dtype))).asnumpy()
+            return out, net_sym
+    else:
+        def get_mxnet_output(symbol, x, dtype='float32'):
+            from collections import namedtuple
+            Batch = namedtuple('Batch', ['data'])
+            mod = mx.mod.Module(symbol, label_names=None)
+            mod.bind(data_shapes=[('data', x.shape)], for_training=False)
+            mod.init_params()
+            mod.forward(Batch([mx.nd.array(x.astype(dtype))]))
+            out = mod.get_outputs()[0].asnumpy()
+            args, auxs = mod.get_params()
+            return out, args, auxs
 
     def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'):
-        new_sym, params = frontend.from_mxnet(symbol, args, auxs)
+        if gluon_impl:
+            new_sym, params = frontend.from_mxnet(symbol)
+        else:
+            new_sym, params = frontend.from_mxnet(symbol, args, auxs)
+
         dshape = x.shape
         shape_dict = {'data': dshape}
         with nnvm.compiler.build_config(opt_level=3):
@@ -42,11 +59,17 @@ def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape
     # random input
     dtype = 'float32'
     x = np.random.uniform(size=data_shape)
-    mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype)
-    assert "data" not in args
-    for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
-        np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
+    if gluon_impl:
+        gluon_out, gluon_sym = get_gluon_output(name, x)
+        for target, ctx in ctx_list():
+            tvm_out = get_tvm_output(gluon_sym, x, None, None, target, ctx, dtype)
+            np.testing.assert_allclose(gluon_out, tvm_out, rtol=1e-5, atol=1e-5)
+    else:
+        mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype)
+        assert "data" not in args
+        for target, ctx in ctx_list():
+            tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
+            np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
 
 def test_forward_mlp():
     mlp = model_zoo.mx_mlp
@@ -91,6 +114,12 @@ def test_forward_fc_flatten():
     except:
         pass
 
+def test_forward_clip():
+    data = mx.sym.var('data')
+    data = mx.sym.concat(data, -data, dim=1)  # negative part explicity
+    mx_sym = mx.sym.clip(data, a_min=0, a_max=1)
+    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
+
 if __name__ == '__main__':
     test_forward_mlp()
     test_forward_vgg()
@@ -99,3 +128,4 @@ if __name__ == '__main__':
     test_forward_rrelu()
     test_forward_softrelu()
     test_forward_fc_flatten()
+    test_forward_clip()
-- 
GitLab