From 97ca4031f3ec0563baa3a974b2f39d651592fdea Mon Sep 17 00:00:00 2001
From: Haichen Shen <shenhaichen@gmail.com>
Date: Thu, 20 Dec 2018 14:43:33 -0800
Subject: [PATCH] [Relay][Frontend] Add MXNet test example for relay (#2316)

* Add MXNet test example for relay
* Fix a bug in BiasAddSimplifier
---
 src/relay/pass/canonicalize_ops.cc            |   2 +-
 .../frontend/mxnet/model_zoo/__init__.py      |  46 ++++
 .../relay/frontend/mxnet/model_zoo/dcgan.py   |  66 ++++++
 .../relay/frontend/mxnet/model_zoo/dqn.py     |  27 +++
 .../frontend/mxnet/model_zoo/inception_v3.py  | 170 +++++++++++++++
 .../relay/frontend/mxnet/model_zoo/mlp.py     |  40 ++++
 .../relay/frontend/mxnet/model_zoo/resnet.py  | 199 +++++++++++++++++
 .../frontend/mxnet/model_zoo/squeezenet.py    |  76 +++++++
 .../relay/frontend/mxnet/model_zoo/vgg.py     |  85 ++++++++
 .../relay/frontend/mxnet/test_forward.py      | 206 ++++++++++++++++++
 .../python/relay/frontend/mxnet/test_graph.py |  87 ++++++++
 11 files changed, 1003 insertions(+), 1 deletion(-)
 create mode 100644 tests/python/relay/frontend/mxnet/model_zoo/__init__.py
 create mode 100644 tests/python/relay/frontend/mxnet/model_zoo/dcgan.py
 create mode 100644 tests/python/relay/frontend/mxnet/model_zoo/dqn.py
 create mode 100644 tests/python/relay/frontend/mxnet/model_zoo/inception_v3.py
 create mode 100644 tests/python/relay/frontend/mxnet/model_zoo/mlp.py
 create mode 100644 tests/python/relay/frontend/mxnet/model_zoo/resnet.py
 create mode 100644 tests/python/relay/frontend/mxnet/model_zoo/squeezenet.py
 create mode 100644 tests/python/relay/frontend/mxnet/model_zoo/vgg.py
 create mode 100644 tests/python/relay/frontend/mxnet/test_forward.py
 create mode 100644 tests/python/relay/frontend/mxnet/test_graph.py

diff --git a/src/relay/pass/canonicalize_ops.cc b/src/relay/pass/canonicalize_ops.cc
index 77cd59e2a..4482dc395 100644
--- a/src/relay/pass/canonicalize_ops.cc
+++ b/src/relay/pass/canonicalize_ops.cc
@@ -22,7 +22,7 @@ class BiasAddSimplifier : public ExprMutator {
       CHECK_EQ(call->args.size(), 2);
       const BiasAddAttrs* param = call->attrs.as<BiasAddAttrs>();
 
-      auto ttype = call->args[0]->type_as<TensorTypeNode>();
+      auto ttype = n->args[0]->type_as<TensorTypeNode>();
       size_t n_dim = ttype->shape.size();
       Expr expanded_bias = ExpandBiasToMatchAxis(call->args[1], n_dim, {param->axis});
       Expr ret = Add(call->args[0], expanded_bias);
diff --git a/tests/python/relay/frontend/mxnet/model_zoo/__init__.py b/tests/python/relay/frontend/mxnet/model_zoo/__init__.py
new file mode 100644
index 000000000..1c796f781
--- /dev/null
+++ b/tests/python/relay/frontend/mxnet/model_zoo/__init__.py
@@ -0,0 +1,46 @@
+"""MXNet and Relay model zoo."""
+from __future__ import absolute_import
+from . import mlp, resnet, vgg, dqn, dcgan, squeezenet, inception_v3
+import tvm.relay.testing
+
+_num_class = 1000
+_batch = 2
+
+# mlp fc
+mx_mlp = mlp.get_symbol(_num_class)
+relay_mlp = tvm.relay.testing.mlp.get_workload(_batch, _num_class)[0]
+
+# vgg fc
+mx_vgg = {}
+relay_vgg = {}
+for num_layers in [11, 13, 16, 19]:
+    mx_vgg[num_layers] = vgg.get_symbol(_num_class, num_layers)
+    relay_vgg[num_layers] = tvm.relay.testing.vgg.get_workload(
+        _batch, _num_class, num_layers=num_layers)[0]
+
+# resnet fc
+mx_resnet = {}
+relay_resnet = {}
+for num_layers in [18, 34, 50, 101, 152, 200, 269]:
+    mx_resnet[num_layers] = resnet.get_symbol(_num_class, num_layers, '3,224,224')
+    relay_resnet[num_layers] = tvm.relay.testing.resnet.get_workload(
+        _batch, _num_class, num_layers=num_layers)[0]
+
+# squeezenet
+mx_squeezenet = {}
+relay_squeezenet = {}
+for version in ['1.0', '1.1']:
+    mx_squeezenet[version] = squeezenet.get_symbol(version=version)
+    relay_squeezenet[version] = tvm.relay.testing.squeezenet.get_workload(_batch, version=version)[0]
+
+# inception
+mx_inception_v3 = inception_v3.get_symbol()
+relay_inception_v3 = tvm.relay.testing.inception_v3.get_workload(_batch)[0]
+
+# dqn
+mx_dqn = dqn.get_symbol()
+relay_dqn = tvm.relay.testing.dqn.get_workload(_batch)[0]
+
+# dcgan generator
+mx_dcgan = dcgan.get_symbol()
+relay_dcgan = tvm.relay.testing.dcgan.get_workload(_batch)[0]
diff --git a/tests/python/relay/frontend/mxnet/model_zoo/dcgan.py b/tests/python/relay/frontend/mxnet/model_zoo/dcgan.py
new file mode 100644
index 000000000..8af030b6b
--- /dev/null
+++ b/tests/python/relay/frontend/mxnet/model_zoo/dcgan.py
@@ -0,0 +1,66 @@
+# pylint: disable=unused-argument
+"""
+The MXNet symbol of DCGAN generator
+
+Adopted from:
+https://github.com/tqchen/mxnet-gan/blob/master/mxgan/generator.py
+
+Reference:
+Radford, Alec, Luke Metz, and Soumith Chintala.
+"Unsupervised representation learning with deep convolutional generative adversarial networks."
+arXiv preprint arXiv:1511.06434 (2015).
+"""
+
+import mxnet as mx
+
+def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)):
+    """a deconv layer that enlarges the feature map"""
+    target_shape = (oshape[-2], oshape[-1])
+    pad_y = (kshape[0] - 1) // 2
+    pad_x = (kshape[1] - 1) // 2
+    adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0]
+    adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1]
+
+    net = mx.sym.Deconvolution(data,
+                               kernel=kshape,
+                               stride=stride,
+                               pad=(pad_y, pad_x),
+                               adj=(adj_y, adj_x),
+                               num_filter=oshape[0],
+                               no_bias=True,
+                               name=name)
+    return net
+
+def deconv2d_bn_relu(data, prefix, **kwargs):
+    """a block of deconv + batch norm + relu"""
+    eps = 1e-5 + 1e-12
+
+    net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
+    net = mx.sym.BatchNorm(net, eps=eps, name="%s_bn" % prefix)
+    net = mx.sym.Activation(net, name="%s_act" % prefix, act_type='relu')
+    return net
+
+def get_symbol(oshape=(3, 64, 64), ngf=128, code=None):
+    """get symbol of dcgan generator"""
+    assert oshape[-1] == 64, "Only support 64x64 image"
+    assert oshape[-2] == 64, "Only support 64x64 image"
+
+    code = mx.sym.Variable("data") if code is None else code
+    net = mx.sym.FullyConnected(code, name="g1", num_hidden=ngf*8*4*4, no_bias=True, flatten=False)
+    net = mx.sym.Activation(net, act_type='relu')
+    # 4 x 4
+    net = mx.sym.reshape(net, shape=(-1, ngf * 8, 4, 4))
+    # 8 x 8
+    net = deconv2d_bn_relu(
+        net, ishape=(ngf * 8, 4, 4), oshape=(ngf * 4, 8, 8), kshape=(4, 4), prefix="g2")
+    # 16x16
+    net = deconv2d_bn_relu(
+        net, ishape=(ngf * 4, 8, 8), oshape=(ngf * 2, 16, 16), kshape=(4, 4), prefix="g3")
+    # 32x32
+    net = deconv2d_bn_relu(
+        net, ishape=(ngf * 2, 16, 16), oshape=(ngf, 32, 32), kshape=(4, 4), prefix="g4")
+    # 64x64
+    net = deconv2d(
+        net, ishape=(ngf, 32, 32), oshape=oshape[-3:], kshape=(4, 4), name="g5_deconv")
+    net = mx.sym.Activation(net, act_type='tanh')
+    return net
diff --git a/tests/python/relay/frontend/mxnet/model_zoo/dqn.py b/tests/python/relay/frontend/mxnet/model_zoo/dqn.py
new file mode 100644
index 000000000..e037511ef
--- /dev/null
+++ b/tests/python/relay/frontend/mxnet/model_zoo/dqn.py
@@ -0,0 +1,27 @@
+"""
+The mxnet symbol of Nature DQN
+
+Reference:
+Mnih, Volodymyr, et al.
+"Human-level control through deep reinforcement learning."
+Nature 518.7540 (2015): 529.
+"""
+
+import mxnet as mx
+
+def get_symbol(num_action=18):
+    data = mx.sym.Variable(name='data')
+    net = mx.sym.Convolution(data, kernel=(8, 8), stride=(4, 4),
+                             num_filter=32, name='conv1')
+    net = mx.sym.Activation(net, act_type='relu', name='relu1')
+    net = mx.sym.Convolution(net, kernel=(4, 4), stride=(2, 2),
+                             num_filter=64, name='conv2')
+    net = mx.sym.Activation(net, act_type='relu', name='relu2')
+    net = mx.sym.Convolution(net, kernel=(3, 3), stride=(1, 1),
+                             num_filter=64, name='conv3')
+    net = mx.sym.Activation(net, act_type='relu', name='relu3')
+    net = mx.sym.FullyConnected(net, num_hidden=512, name='fc4')
+    net = mx.sym.Activation(net, act_type='relu', name='relu4')
+    net = mx.sym.FullyConnected(net, num_hidden=num_action, name='fc5', flatten=False)
+
+    return net
diff --git a/tests/python/relay/frontend/mxnet/model_zoo/inception_v3.py b/tests/python/relay/frontend/mxnet/model_zoo/inception_v3.py
new file mode 100644
index 000000000..b8585bf05
--- /dev/null
+++ b/tests/python/relay/frontend/mxnet/model_zoo/inception_v3.py
@@ -0,0 +1,170 @@
+"""
+Inception V3, suitable for images with around 299 x 299
+
+Reference:
+Szegedy, Christian, et al. "Rethinking the Inception Architecture for Computer Vision." arXiv preprint arXiv:1512.00567 (2015).
+
+Adopted from https://github.com/apache/incubator-mxnet/blob/
+             master/example/image-classification/symbols/inception-v3.py
+"""
+import mxnet as mx
+import numpy as np
+
+def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, suffix=''):
+    conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix))
+    bn = mx.sym.BatchNorm(data=conv, eps=2e-5, name='%s%s_batchnorm' % (name, suffix))
+    act = mx.sym.Activation(data=bn, act_type='relu', name='%s%s_relu' %(name, suffix))
+    return act
+
+
+def Inception7A(data,
+                num_1x1,
+                num_3x3_red, num_3x3_1, num_3x3_2,
+                num_5x5_red, num_5x5,
+                pool, proj,
+                name):
+    tower_1x1 = Conv(data, num_1x1, name=('%s_conv' % name))
+    tower_5x5 = Conv(data, num_5x5_red, name=('%s_tower' % name), suffix='_conv')
+    tower_5x5 = Conv(tower_5x5, num_5x5, kernel=(5, 5), pad=(2, 2), name=('%s_tower' % name), suffix='_conv_1')
+    tower_3x3 = Conv(data, num_3x3_red, name=('%s_tower_1' % name), suffix='_conv')
+    tower_3x3 = Conv(tower_3x3, num_3x3_1, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1')
+    tower_3x3 = Conv(tower_3x3, num_3x3_2, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_2')
+    pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
+    cproj = Conv(pooling, proj, name=('%s_tower_2' %  name), suffix='_conv')
+    concat = mx.sym.Concat(*[tower_1x1, tower_5x5, tower_3x3, cproj], name='ch_concat_%s_chconcat' % name)
+    return concat
+
+# First Downsample
+def Inception7B(data,
+                num_3x3,
+                num_d3x3_red, num_d3x3_1, num_d3x3_2,
+                pool,
+                name):
+    tower_3x3 = Conv(data, num_3x3, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_conv' % name))
+    tower_d3x3 = Conv(data, num_d3x3_red, name=('%s_tower' % name), suffix='_conv')
+    tower_d3x3 = Conv(tower_d3x3, num_d3x3_1, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_tower' % name), suffix='_conv_1')
+    tower_d3x3 = Conv(tower_d3x3, num_d3x3_2, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_2')
+    pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(0,0), pool_type="max", name=('max_pool_%s_pool' % name))
+    concat = mx.sym.Concat(*[tower_3x3, tower_d3x3, pooling], name='ch_concat_%s_chconcat' % name)
+    return concat
+
+def Inception7C(data,
+                num_1x1,
+                num_d7_red, num_d7_1, num_d7_2,
+                num_q7_red, num_q7_1, num_q7_2, num_q7_3, num_q7_4,
+                pool, proj,
+                name):
+    tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name))
+    tower_d7 = Conv(data=data, num_filter=num_d7_red, name=('%s_tower' % name), suffix='_conv')
+    tower_d7 = Conv(data=tower_d7, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower' % name), suffix='_conv_1')
+    tower_d7 = Conv(data=tower_d7, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower' % name), suffix='_conv_2')
+    tower_q7 = Conv(data=data, num_filter=num_q7_red, name=('%s_tower_1' % name), suffix='_conv')
+    tower_q7 = Conv(data=tower_q7, num_filter=num_q7_1, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_1')
+    tower_q7 = Conv(data=tower_q7, num_filter=num_q7_2, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_2')
+    tower_q7 = Conv(data=tower_q7, num_filter=num_q7_3, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_3')
+    tower_q7 = Conv(data=tower_q7, num_filter=num_q7_4, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_4')
+    pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
+    cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' %  name), suffix='_conv')
+    # concat
+    concat = mx.sym.Concat(*[tower_1x1, tower_d7, tower_q7, cproj], name='ch_concat_%s_chconcat' % name)
+    return concat
+
+def Inception7D(data,
+                num_3x3_red, num_3x3,
+                num_d7_3x3_red, num_d7_1, num_d7_2, num_d7_3x3,
+                pool,
+                name):
+    tower_3x3 = Conv(data=data, num_filter=num_3x3_red, name=('%s_tower' % name), suffix='_conv')
+    tower_3x3 = Conv(data=tower_3x3, num_filter=num_3x3, kernel=(3, 3), pad=(0,0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_1')
+    tower_d7_3x3 = Conv(data=data, num_filter=num_d7_3x3_red, name=('%s_tower_1' % name), suffix='_conv')
+    tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_1')
+    tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_2')
+    tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_3x3, kernel=(3, 3), stride=(2, 2), name=('%s_tower_1' % name), suffix='_conv_3')
+    pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
+    # concat
+    concat = mx.sym.Concat(*[tower_3x3, tower_d7_3x3, pooling], name='ch_concat_%s_chconcat' % name)
+    return concat
+
+def Inception7E(data,
+                num_1x1,
+                num_d3_red, num_d3_1, num_d3_2,
+                num_3x3_d3_red, num_3x3, num_3x3_d3_1, num_3x3_d3_2,
+                pool, proj,
+                name):
+    tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name))
+    tower_d3 = Conv(data=data, num_filter=num_d3_red, name=('%s_tower' % name), suffix='_conv')
+    tower_d3_a = Conv(data=tower_d3, num_filter=num_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower' % name), suffix='_mixed_conv')
+    tower_d3_b = Conv(data=tower_d3, num_filter=num_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower' % name), suffix='_mixed_conv_1')
+    tower_3x3_d3 = Conv(data=data, num_filter=num_3x3_d3_red, name=('%s_tower_1' % name), suffix='_conv')
+    tower_3x3_d3 = Conv(data=tower_3x3_d3, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1')
+    tower_3x3_d3_a = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower_1' % name), suffix='_mixed_conv')
+    tower_3x3_d3_b = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower_1' % name), suffix='_mixed_conv_1')
+    pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
+    cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' %  name), suffix='_conv')
+    # concat
+    concat = mx.sym.Concat(*[tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj], name='ch_concat_%s_chconcat' % name)
+    return concat
+
+def get_symbol(num_classes=1000, **kwargs):
+    data = mx.sym.Variable(name="data")
+    # stage 1
+    conv = Conv(data, 32, kernel=(3, 3), stride=(2, 2), name="conv")
+    conv_1 = Conv(conv, 32, kernel=(3, 3), name="conv_1")
+    conv_2 = Conv(conv_1, 64, kernel=(3, 3), pad=(1, 1), name="conv_2")
+    pool = mx.sym.Pooling(data=conv_2, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool")
+    # stage 2
+    conv_3 = Conv(pool, 80, kernel=(1, 1), name="conv_3")
+    conv_4 = Conv(conv_3, 192, kernel=(3, 3), name="conv_4")
+    pool1 = mx.sym.Pooling(data=conv_4, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool1")
+
+    # # stage 3
+    in3a = Inception7A(pool1, 64,
+                       64, 96, 96,
+                       48, 64,
+                       "avg", 32, "mixed")
+    in3b = Inception7A(in3a, 64,
+                       64, 96, 96,
+                       48, 64,
+                       "avg", 64, "mixed_1")
+    in3c = Inception7A(in3b, 64,
+                       64, 96, 96,
+                       48, 64,
+                       "avg", 64, "mixed_2")
+    in3d = Inception7B(in3c, 384,
+                       64, 96, 96,
+                       "max", "mixed_3")
+    # stage 4
+    in4a = Inception7C(in3d, 192,
+                       128, 128, 192,
+                       128, 128, 128, 128, 192,
+                       "avg", 192, "mixed_4")
+    in4b = Inception7C(in4a, 192,
+                       160, 160, 192,
+                       160, 160, 160, 160, 192,
+                       "avg", 192, "mixed_5")
+    in4c = Inception7C(in4b, 192,
+                       160, 160, 192,
+                       160, 160, 160, 160, 192,
+                       "avg", 192, "mixed_6")
+    in4d = Inception7C(in4c, 192,
+                       192, 192, 192,
+                       192, 192, 192, 192, 192,
+                       "avg", 192, "mixed_7")
+    in4e = Inception7D(in4d, 192, 320,
+                       192, 192, 192, 192,
+                       "max", "mixed_8")
+    # stage 5
+    in5a = Inception7E(in4e, 320,
+                       384, 384, 384,
+                       448, 384, 384, 384,
+                       "avg", 192, "mixed_9")
+    in5b = Inception7E(in5a, 320,
+                       384, 384, 384,
+                       448, 384, 384, 384,
+                       "max", 192, "mixed_10")
+    # pool
+    pool = mx.sym.Pooling(data=in5b, kernel=(8, 8), stride=(1, 1), pool_type="avg", name="global_pool")
+    flatten = mx.sym.Flatten(data=pool, name="flatten")
+    fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1', flatten=False)
+    softmax = mx.sym.SoftmaxOutput(data=fc1, name='softmax')
+    return softmax
diff --git a/tests/python/relay/frontend/mxnet/model_zoo/mlp.py b/tests/python/relay/frontend/mxnet/model_zoo/mlp.py
new file mode 100644
index 000000000..922b20874
--- /dev/null
+++ b/tests/python/relay/frontend/mxnet/model_zoo/mlp.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+a simple multilayer perceptron
+"""
+import mxnet as mx
+
+def get_symbol(num_classes=10, **kwargs):
+    data = mx.symbol.Variable('data')
+    data = mx.sym.Flatten(data=data)
+    try:
+        fc1  = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128, flatten=False)
+        act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
+        fc2  = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64, flatten=False)
+        act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
+        fc3  = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes, flatten=False)
+        mlp  = mx.symbol.softmax(data = fc3, name = 'softmax')
+    except:
+        fc1  = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
+        act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
+        fc2  = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
+        act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
+        fc3  = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
+        mlp  = mx.symbol.softmax(data = fc3, name = 'softmax')
+    return mlp
diff --git a/tests/python/relay/frontend/mxnet/model_zoo/resnet.py b/tests/python/relay/frontend/mxnet/model_zoo/resnet.py
new file mode 100644
index 000000000..3f9a870d3
--- /dev/null
+++ b/tests/python/relay/frontend/mxnet/model_zoo/resnet.py
@@ -0,0 +1,199 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+'''
+Adapted from https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py
+Original author Wei Wu
+
+Implemented the following paper:
+
+Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
+'''
+import mxnet as mx
+import numpy as np
+
+def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False):
+    """Return ResNet Unit symbol for building ResNet
+    Parameters
+    ----------
+    data : str
+        Input data
+    num_filter : int
+        Number of output channels
+    bnf : int
+        Bottle neck channels factor with regard to num_filter
+    stride : tuple
+        Stride used in convolution
+    dim_match : Boolean
+        True means channel number between input and output is the same, otherwise means differ
+    name : str
+        Base name of the operators
+    workspace : int
+        Workspace used in convolution operator
+    """
+    if bottle_neck:
+        bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')
+        act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')
+        conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter*0.25), kernel=(1,1), stride=stride, pad=(0,0),
+                                   no_bias=True, workspace=workspace, name=name + '_conv1')
+        bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')
+        act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2')
+        conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1,1),
+                                   no_bias=True, workspace=workspace, name=name + '_conv2')
+        bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')
+        act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3')
+        conv3 = mx.sym.Convolution(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True,
+                                   workspace=workspace, name=name + '_conv3')
+        if dim_match:
+            shortcut = data
+        else:
+            shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
+                                            workspace=workspace, name=name+'_sc')
+        if memonger:
+            shortcut._set_attr(mirror_stage='True')
+        return conv3 + shortcut
+    else:
+        bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn1')
+        act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')
+        conv1 = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(3,3), stride=stride, pad=(1,1),
+                                      no_bias=True, workspace=workspace, name=name + '_conv1')
+        bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2')
+        act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2')
+        conv2 = mx.sym.Convolution(data=act2, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1),
+                                      no_bias=True, workspace=workspace, name=name + '_conv2')
+        if dim_match:
+            shortcut = data
+        else:
+            shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
+                                            workspace=workspace, name=name+'_sc')
+        if memonger:
+            shortcut._set_attr(mirror_stage='True')
+        return conv2 + shortcut
+
+def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck=True, bn_mom=0.9, workspace=256, dtype='float32', memonger=False):
+    """Return ResNet symbol of
+    Parameters
+    ----------
+    units : list
+        Number of units in each stage
+    num_stages : int
+        Number of stage
+    filter_list : list
+        Channel size of each stage
+    num_classes : int
+        Ouput size of symbol
+    dataset : str
+        Dataset type, only cifar10 and imagenet supports
+    workspace : int
+        Workspace used in convolution operator
+    dtype : str
+        Precision (float32 or float16)
+    """
+    num_unit = len(units)
+    assert(num_unit == num_stages)
+    data = mx.sym.Variable(name='data')
+    if dtype == 'float32':
+        # data = mx.sym.identity(data=data, name='id')
+        data = data
+    else:
+        if dtype == 'float16':
+            data = mx.sym.Cast(data=data, dtype=np.float16)
+    data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data')
+    (nchannel, height, width) = image_shape
+    if height <= 32:            # such as cifar10
+        body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(3, 3), stride=(1,1), pad=(1, 1),
+                                  no_bias=True, name="conv0", workspace=workspace)
+    else:                       # often expected to be 224 such as imagenet
+        body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2,2), pad=(3, 3),
+                                  no_bias=True, name="conv0", workspace=workspace)
+        body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0')
+        body = mx.sym.Activation(data=body, act_type='relu', name='relu0')
+        body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max')
+
+    for i in range(num_stages):
+        body = residual_unit(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False,
+                             name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, workspace=workspace,
+                             memonger=memonger)
+        for j in range(units[i]-1):
+            body = residual_unit(body, filter_list[i+1], (1,1), True, name='stage%d_unit%d' % (i + 1, j + 2),
+                                 bottle_neck=bottle_neck, workspace=workspace, memonger=memonger)
+    bn1 = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1')
+    relu1 = mx.sym.Activation(data=bn1, act_type='relu', name='relu1')
+    # Although kernel is not used here when global_pool=True, we should put one
+    pool1 = mx.sym.Pooling(data=relu1, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1')
+    flat = mx.sym.Flatten(data=pool1)
+    try:
+        fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1', flatten=False)
+    except:
+        fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1')
+    if dtype == 'float16':
+        fc1 = mx.sym.Cast(data=fc1, dtype=np.float32)
+    return mx.sym.softmax(data=fc1, name='softmax')
+
+def get_symbol(num_classes, num_layers, image_shape, conv_workspace=256, dtype='float32', **kwargs):
+    """
+    Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py
+    Original author Wei Wu
+    """
+    image_shape = [int(l) for l in image_shape.split(',')]
+    (nchannel, height, width) = image_shape
+    if height <= 28:
+        num_stages = 3
+        if (num_layers-2) % 9 == 0 and num_layers >= 164:
+            per_unit = [(num_layers-2)//9]
+            filter_list = [16, 64, 128, 256]
+            bottle_neck = True
+        elif (num_layers-2) % 6 == 0 and num_layers < 164:
+            per_unit = [(num_layers-2)//6]
+            filter_list = [16, 16, 32, 64]
+            bottle_neck = False
+        else:
+            raise ValueError("no experiments done on num_layers {}, you can do it yourself".format(num_layers))
+        units = per_unit * num_stages
+    else:
+        if num_layers >= 50:
+            filter_list = [64, 256, 512, 1024, 2048]
+            bottle_neck = True
+        else:
+            filter_list = [64, 64, 128, 256, 512]
+            bottle_neck = False
+        num_stages = 4
+        if num_layers == 18:
+            units = [2, 2, 2, 2]
+        elif num_layers == 34:
+            units = [3, 4, 6, 3]
+        elif num_layers == 50:
+            units = [3, 4, 6, 3]
+        elif num_layers == 101:
+            units = [3, 4, 23, 3]
+        elif num_layers == 152:
+            units = [3, 8, 36, 3]
+        elif num_layers == 200:
+            units = [3, 24, 36, 3]
+        elif num_layers == 269:
+            units = [3, 30, 48, 8]
+        else:
+            raise ValueError("no experiments done on num_layers {}, you can do it yourself".format(num_layers))
+
+    return resnet(units       = units,
+                  num_stages  = num_stages,
+                  filter_list = filter_list,
+                  num_classes = num_classes,
+                  image_shape = image_shape,
+                  bottle_neck = bottle_neck,
+                  workspace   = conv_workspace,
+                  dtype       = dtype)
diff --git a/tests/python/relay/frontend/mxnet/model_zoo/squeezenet.py b/tests/python/relay/frontend/mxnet/model_zoo/squeezenet.py
new file mode 100644
index 000000000..deb896a21
--- /dev/null
+++ b/tests/python/relay/frontend/mxnet/model_zoo/squeezenet.py
@@ -0,0 +1,76 @@
+"""
+Symbol of SqueezeNet
+
+Reference:
+Iandola, Forrest N., et al.
+"Squeezenet: Alexnet-level accuracy with 50x fewer parameters and< 0.5 mb model size." (2016).
+"""
+
+import mxnet as mx
+
+# Helpers
+def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels):
+    net = _make_fire_conv(net, squeeze_channels, 1, 0)
+
+    left = _make_fire_conv(net, expand1x1_channels, 1, 0)
+    right = _make_fire_conv(net, expand3x3_channels, 3, 1)
+    # NOTE : Assume NCHW layout here
+    net = mx.sym.concat(left, right, dim=1)
+
+    return net
+
+def _make_fire_conv(net, channels, kernel_size, padding=0):
+    net = mx.sym.Convolution(net, num_filter=channels, kernel=(kernel_size, kernel_size),
+                             pad=(padding, padding))
+    net = mx.sym.Activation(net, act_type='relu')
+    return net
+
+# Net
+def get_symbol(num_classes=1000, version='1.0', **kwargs):
+    """Get symbol of SqueezeNet
+
+    Parameters
+    ----------
+    num_classes: int
+        The number of classification results
+
+    version : str, optional
+        "1.0" or "1.1" of SqueezeNet
+    """
+    assert version in ['1.0', '1.1'], ("Unsupported SqueezeNet version {version}:"
+                                       "1.0 or 1.1 expected".format(version=version))
+    net = mx.sym.Variable("data")
+    if version == '1.0':
+        net = mx.sym.Convolution(net, num_filter=96, kernel=(7, 7), stride=(2, 2), pad=(3, 3))
+        net = mx.sym.Activation(net, act_type='relu')
+        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
+        net = _make_fire(net, 16, 64, 64)
+        net = _make_fire(net, 16, 64, 64)
+        net = _make_fire(net, 32, 128, 128)
+        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
+        net = _make_fire(net, 32, 128, 128)
+        net = _make_fire(net, 48, 192, 192)
+        net = _make_fire(net, 48, 192, 192)
+        net = _make_fire(net, 64, 256, 256)
+        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
+        net = _make_fire(net, 64, 256, 256)
+    else:
+        net = mx.sym.Convolution(net, num_filter=64, kernel=(3, 3), stride=(2, 2), pad=(1, 1))
+        net = mx.sym.Activation(net, act_type='relu')
+        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
+        net = _make_fire(net, 16, 64, 64)
+        net = _make_fire(net, 16, 64, 64)
+        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max',  stride=(2, 2))
+        net = _make_fire(net, 32, 128, 128)
+        net = _make_fire(net, 32, 128, 128)
+        net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max',  stride=(2, 2))
+        net = _make_fire(net, 48, 192, 192)
+        net = _make_fire(net, 48, 192, 192)
+        net = _make_fire(net, 64, 256, 256)
+        net = _make_fire(net, 64, 256, 256)
+    net = mx.sym.Dropout(net, p=0.5)
+    net = mx.sym.Convolution(net, num_filter=num_classes, kernel=(1, 1))
+    net = mx.sym.Activation(net, act_type='relu')
+    net = mx.sym.Pooling(data=net, global_pool=True, kernel=(13, 13), pool_type='avg')
+    net = mx.sym.flatten(net)
+    return mx.sym.softmax(net)
diff --git a/tests/python/relay/frontend/mxnet/model_zoo/vgg.py b/tests/python/relay/frontend/mxnet/model_zoo/vgg.py
new file mode 100644
index 000000000..68215bb80
--- /dev/null
+++ b/tests/python/relay/frontend/mxnet/model_zoo/vgg.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""References:
+
+Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for
+large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
+"""
+
+import mxnet as mx
+import numpy as np
+
+def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs):
+    for i, num in enumerate(layers):
+        for j in range(num):
+            internel_layer = mx.sym.Convolution(data = internel_layer, kernel=(3, 3), pad=(1, 1), num_filter=filters[i], name="conv%s_%s" %(i + 1, j + 1))
+            if batch_norm:
+                internel_layer = mx.symbol.BatchNorm(data=internel_layer, name="bn%s_%s" %(i + 1, j + 1))
+            internel_layer = mx.sym.Activation(data=internel_layer, act_type="relu", name="relu%s_%s" %(i + 1, j + 1))
+        internel_layer = mx.sym.Pooling(data=internel_layer, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool%s" %(i + 1))
+    return internel_layer
+
+def get_classifier(input_data, num_classes, **kwargs):
+    flatten = mx.sym.Flatten(data=input_data, name="flatten")
+    try:
+        fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6", flatten=False)
+        relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6")
+        drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6")
+        fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7", flatten=False)
+        relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7")
+        drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7")
+        fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8", flatten=False)
+    except:
+        fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6")
+        relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6")
+        drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6")
+        fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7")
+        relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7")
+        drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7")
+        fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8")
+    return fc8
+
+def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs):
+    """
+    Parameters
+    ----------
+    num_classes : int, default 1000
+        Number of classification classes.
+    num_layers : int
+        Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
+    batch_norm : bool, default False
+        Use batch normalization.
+    dtype: str, float32 or float16
+        Data precision.
+    """
+    vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
+                13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
+                16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
+                19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
+    if num_layers not in vgg_spec:
+        raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers))
+    layers, filters = vgg_spec[num_layers]
+    data = mx.sym.Variable(name="data")
+    if dtype == 'float16':
+        data = mx.sym.Cast(data=data, dtype=np.float16)
+    feature = get_feature(data, layers, filters, batch_norm)
+    classifier = get_classifier(feature, num_classes)
+    if dtype == 'float16':
+        classifier = mx.sym.Cast(data=classifier, dtype=np.float32)
+    symbol = mx.sym.softmax(data=classifier, name='softmax')
+    return symbol
diff --git a/tests/python/relay/frontend/mxnet/test_forward.py b/tests/python/relay/frontend/mxnet/test_forward.py
new file mode 100644
index 000000000..fcc760981
--- /dev/null
+++ b/tests/python/relay/frontend/mxnet/test_forward.py
@@ -0,0 +1,206 @@
+import numpy as np
+
+import topi
+import tvm
+from tvm.contrib import graph_runtime
+from tvm import relay
+from tvm.relay.testing.config import ctx_list
+import mxnet as mx
+from mxnet import gluon
+from mxnet.gluon.model_zoo import vision
+import model_zoo
+
+
+def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(1, 1000),
+                               gluon_impl=False, name=None, dtype='float32'):
+    """Use name different from test to avoid let nose pick it up"""
+    if gluon_impl:
+        def get_gluon_output(name, x):
+            net = vision.get_model(name)
+            net.collect_params().initialize(mx.init.Xavier())
+            net_sym = gluon.nn.SymbolBlock(outputs=net(mx.sym.var('data')),
+                                           inputs=mx.sym.var('data'),
+                                           params=net.collect_params())
+            out = net_sym(mx.nd.array(x.astype(dtype))).asnumpy()
+            return out, net_sym
+    else:
+        def get_mxnet_output(symbol, x, dtype='float32'):
+            from collections import namedtuple
+            Batch = namedtuple('Batch', ['data'])
+            mod = mx.mod.Module(symbol, label_names=None)
+            mod.bind(data_shapes=[('data', x.shape)], for_training=False)
+            mod.init_params()
+            mod.forward(Batch([mx.nd.array(x.astype(dtype))]))
+            out = mod.get_outputs()[0].asnumpy()
+            args, auxs = mod.get_params()
+            return out, args, auxs
+
+    def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'):
+        dshape = x.shape
+        shape_dict = {'data': dshape}
+        if gluon_impl:
+            new_sym, params = relay.frontend.from_mxnet(symbol, shape_dict)
+        else:
+            new_sym, params = relay.frontend.from_mxnet(symbol, shape_dict, arg_params=args, aux_params=auxs)
+
+        with relay.build_config(opt_level=3):
+            graph, lib, params = relay.build(new_sym, target, params=params)
+        m = graph_runtime.create(graph, lib, ctx)
+        # set inputs
+        m.set_input("data", tvm.nd.array(x.astype(dtype)))
+        m.set_input(**params)
+        m.run()
+        # get outputs
+        out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
+        return out.asnumpy()
+
+    # random input
+    x = np.random.uniform(size=data_shape)
+    if gluon_impl:
+        gluon_out, gluon_sym = get_gluon_output(name, x)
+        for target, ctx in ctx_list():
+            tvm_out = get_tvm_output(gluon_sym, x, None, None, target, ctx, dtype)
+            tvm.testing.assert_allclose(gluon_out, tvm_out, rtol=1e-5, atol=1e-5)
+    else:
+        mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype)
+        assert "data" not in args
+        for target, ctx in ctx_list():
+            tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
+            tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
+
+def test_forward_mlp():
+    mlp = model_zoo.mx_mlp
+    verify_mxnet_frontend_impl(mlp)
+
+def test_forward_vgg():
+    for n in [11]:
+        mx_sym = model_zoo.mx_vgg[n]
+        verify_mxnet_frontend_impl(mx_sym)
+
+def test_forward_resnet():
+    for n in [18]:
+        mx_sym = model_zoo.mx_resnet[n]
+        verify_mxnet_frontend_impl(mx_sym)
+
+def test_forward_elu():
+    data = mx.sym.var('data')
+    data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
+    mx_sym = mx.sym.LeakyReLU(data, act_type='elu')
+    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
+
+def test_forward_rrelu():
+    data = mx.sym.var('data')
+    data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
+    mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7)
+    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
+
+def test_forward_prelu():
+    data = mx.sym.var('data')
+    data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
+    mx_sym = mx.sym.LeakyReLU(data, act_type='prelu')
+    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
+
+def test_forward_softrelu():
+    data = mx.sym.var('data')
+    data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
+    mx_sym = mx.sym.Activation(data, act_type='softrelu')
+    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
+
+def test_forward_fc_flatten():
+    # test flatten=True option in mxnet 0.11.1
+    data = mx.sym.var('data')
+    try:
+        mx_sym = mx.sym.FullyConnected(data, num_hidden=100, flatten=True)
+        verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
+        mx_sym = mx.sym.FullyConnected(mx.sym.Flatten(data), num_hidden=100, flatten=False)
+        verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
+    except:
+        pass
+
+def test_forward_clip():
+    data = mx.sym.var('data')
+    data = mx.sym.concat(data, -data, dim=1)  # negative part explicity
+    mx_sym = mx.sym.clip(data, a_min=0, a_max=1)
+    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
+
+def test_forward_split():
+    data = mx.sym.var('data')
+    mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=False)
+    verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 1, 2, 1))
+
+def test_forward_split_squeeze():
+    data = mx.sym.var('data')
+    mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True)
+    verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1))
+
+def test_forward_expand_dims():
+    data = mx.sym.var('data')
+    mx_sym = mx.sym.expand_dims(data, axis=1)
+    verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 1, 3, 4))
+
+def test_forward_pooling():
+    data = mx.sym.var('data')
+    mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='avg')
+    verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))
+
+    mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max')
+    verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))
+
+def test_forward_lrn():
+    data = mx.sym.var('data')
+    mx_sym = mx.sym.LRN(data, alpha=2, beta=2, knorm=1, nsize=5)
+    verify_mxnet_frontend_impl(mx_sym, (1, 10, 24, 24), (1, 10, 24, 24))
+
+def test_forward_ones():
+    data = mx.sym.var('data')
+    ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32')
+    mx_sym = mx.sym.elemwise_add(data, ones)
+    verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
+    
+def test_forward_zeros():
+    data = mx.sym.var('data')
+    zeros = mx.sym.zeros(shape=(2, 3, 4), dtype='float32')
+    mx_sym = mx.sym.elemwise_add(data, zeros)
+    verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
+
+def test_forward_ones_like():
+    data = mx.sym.var('data')
+    mx_sym = mx.sym.ones_like(data, dtype='float32')
+    verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
+
+def test_forward_zeros_like():
+    data = mx.sym.var('data')
+    mx_sym = mx.sym.zeros_like(data, dtype='float32')
+    verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
+
+def test_forward_argmax():
+    data = mx.sym.var('data')
+    mx_sym = mx.sym.argmax(data, axis=1)
+    verify_mxnet_frontend_impl(mx_sym, (5, 3), (5,))
+
+def test_forward_argmin():
+    data = mx.sym.var('data')
+    mx_sym = mx.sym.argmin(data, axis=0)
+    verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))
+    
+if __name__ == '__main__':
+    test_forward_mlp()
+    test_forward_vgg()
+    test_forward_resnet()
+    test_forward_elu()
+    test_forward_rrelu()
+    test_forward_prelu()
+    test_forward_softrelu()
+    test_forward_fc_flatten()
+    test_forward_clip()
+    test_forward_split()
+    test_forward_split_squeeze()
+    test_forward_expand_dims()
+    test_forward_pooling()
+    test_forward_lrn()
+    test_forward_ones()
+    test_forward_zeros()
+    test_forward_ones_like()
+    test_forward_zeros_like()
+    test_forward_argmax()
+    test_forward_argmin()
diff --git a/tests/python/relay/frontend/mxnet/test_graph.py b/tests/python/relay/frontend/mxnet/test_graph.py
new file mode 100644
index 000000000..820e78242
--- /dev/null
+++ b/tests/python/relay/frontend/mxnet/test_graph.py
@@ -0,0 +1,87 @@
+import mxnet as mx
+import tvm
+from tvm import relay
+import model_zoo
+from model_zoo import _batch
+
+def test_mlp():
+    mx_sym = model_zoo.mx_mlp
+    from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 1, 28, 28)})
+    from_mx_sym = relay.ir_pass.infer_type(from_mx_sym)
+    relay_sym = model_zoo.relay_mlp
+    assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym)
+
+def test_vgg():
+    for n in [11, 13, 16, 19]:
+        mx_sym = model_zoo.mx_vgg[n]
+        from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 3, 224, 224)})
+        from_mx_sym = relay.ir_pass.infer_type(from_mx_sym)
+        relay_sym = model_zoo.relay_vgg[n]
+        assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym)
+
+def test_resnet():
+    for n in [18, 34, 50, 101, 152, 200, 269]:
+        mx_sym = model_zoo.mx_resnet[n]
+        from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 3, 224, 224)})
+        from_mx_sym = relay.ir_pass.infer_type(from_mx_sym)
+        relay_sym = model_zoo.relay_resnet[n]
+        assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym)
+
+def test_squeezenet():
+    for version in ['1.0', '1.1']:
+        mx_sym = model_zoo.mx_squeezenet[version]
+        from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 3, 224, 224)})
+        from_mx_sym = relay.ir_pass.infer_type(from_mx_sym)
+        relay_sym = model_zoo.relay_squeezenet[version]
+        assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym)
+
+def test_inception_v3():
+    mx_sym = model_zoo.mx_inception_v3
+    from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 3, 299, 299)})
+    from_mx_sym = relay.ir_pass.infer_type(from_mx_sym)
+    relay_sym = model_zoo.relay_inception_v3
+    assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym)
+
+def test_dqn():
+    mx_sym = model_zoo.mx_dqn
+    from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 4, 84, 84)})
+    from_mx_sym = relay.ir_pass.infer_type(from_mx_sym)
+    relay_sym = model_zoo.relay_dqn
+    assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym)
+
+def test_dcgan():
+    mx_sym = model_zoo.mx_dcgan
+    from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 100)})
+    from_mx_sym = relay.ir_pass.infer_type(from_mx_sym)
+    relay_sym = model_zoo.relay_dcgan
+    assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym)
+
+def test_multi_outputs():
+    def compose_mxnet(**kwargs):
+        x = mx.sym.Variable('x')
+        y = mx.sym.Variable('y')
+        z = mx.sym.split(x, **kwargs)
+        return mx.sym.broadcast_sub(mx.sym.broadcast_add(z[0], z[2]), y)
+    def compose_relay(**kwargs):
+        x = relay.var("x", shape=(_batch, 3, 224, 224))
+        y = relay.var("y", shape=(1,))
+        z = relay.split(x, **kwargs)
+        ret = z[0] + z[2] - y
+        args = relay.ir_pass.free_vars(ret)
+        return relay.Function(args, ret)
+    mx_sym = compose_mxnet(num_outputs=3, axis=1)
+    from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'x': (_batch, 3, 224, 224), 'y': (1,)})
+    from_mx_sym = relay.ir_pass.infer_type(from_mx_sym)
+    relay_sym = compose_relay(indices_or_sections=3, axis=1)
+    relay_sym = relay.ir_pass.infer_type(relay_sym)
+    assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym)
+
+if __name__ == '__main__':
+    test_mlp()
+    test_vgg()
+    test_resnet()
+    test_squeezenet()
+    test_inception_v3()
+    test_dqn()
+    test_dcgan()
+    test_multi_outputs()
-- 
GitLab