From 97ca4031f3ec0563baa3a974b2f39d651592fdea Mon Sep 17 00:00:00 2001 From: Haichen Shen <shenhaichen@gmail.com> Date: Thu, 20 Dec 2018 14:43:33 -0800 Subject: [PATCH] [Relay][Frontend] Add MXNet test example for relay (#2316) * Add MXNet test example for relay * Fix a bug in BiasAddSimplifier --- src/relay/pass/canonicalize_ops.cc | 2 +- .../frontend/mxnet/model_zoo/__init__.py | 46 ++++ .../relay/frontend/mxnet/model_zoo/dcgan.py | 66 ++++++ .../relay/frontend/mxnet/model_zoo/dqn.py | 27 +++ .../frontend/mxnet/model_zoo/inception_v3.py | 170 +++++++++++++++ .../relay/frontend/mxnet/model_zoo/mlp.py | 40 ++++ .../relay/frontend/mxnet/model_zoo/resnet.py | 199 +++++++++++++++++ .../frontend/mxnet/model_zoo/squeezenet.py | 76 +++++++ .../relay/frontend/mxnet/model_zoo/vgg.py | 85 ++++++++ .../relay/frontend/mxnet/test_forward.py | 206 ++++++++++++++++++ .../python/relay/frontend/mxnet/test_graph.py | 87 ++++++++ 11 files changed, 1003 insertions(+), 1 deletion(-) create mode 100644 tests/python/relay/frontend/mxnet/model_zoo/__init__.py create mode 100644 tests/python/relay/frontend/mxnet/model_zoo/dcgan.py create mode 100644 tests/python/relay/frontend/mxnet/model_zoo/dqn.py create mode 100644 tests/python/relay/frontend/mxnet/model_zoo/inception_v3.py create mode 100644 tests/python/relay/frontend/mxnet/model_zoo/mlp.py create mode 100644 tests/python/relay/frontend/mxnet/model_zoo/resnet.py create mode 100644 tests/python/relay/frontend/mxnet/model_zoo/squeezenet.py create mode 100644 tests/python/relay/frontend/mxnet/model_zoo/vgg.py create mode 100644 tests/python/relay/frontend/mxnet/test_forward.py create mode 100644 tests/python/relay/frontend/mxnet/test_graph.py diff --git a/src/relay/pass/canonicalize_ops.cc b/src/relay/pass/canonicalize_ops.cc index 77cd59e2a..4482dc395 100644 --- a/src/relay/pass/canonicalize_ops.cc +++ b/src/relay/pass/canonicalize_ops.cc @@ -22,7 +22,7 @@ class BiasAddSimplifier : public ExprMutator { CHECK_EQ(call->args.size(), 2); const BiasAddAttrs* param = call->attrs.as<BiasAddAttrs>(); - auto ttype = call->args[0]->type_as<TensorTypeNode>(); + auto ttype = n->args[0]->type_as<TensorTypeNode>(); size_t n_dim = ttype->shape.size(); Expr expanded_bias = ExpandBiasToMatchAxis(call->args[1], n_dim, {param->axis}); Expr ret = Add(call->args[0], expanded_bias); diff --git a/tests/python/relay/frontend/mxnet/model_zoo/__init__.py b/tests/python/relay/frontend/mxnet/model_zoo/__init__.py new file mode 100644 index 000000000..1c796f781 --- /dev/null +++ b/tests/python/relay/frontend/mxnet/model_zoo/__init__.py @@ -0,0 +1,46 @@ +"""MXNet and Relay model zoo.""" +from __future__ import absolute_import +from . import mlp, resnet, vgg, dqn, dcgan, squeezenet, inception_v3 +import tvm.relay.testing + +_num_class = 1000 +_batch = 2 + +# mlp fc +mx_mlp = mlp.get_symbol(_num_class) +relay_mlp = tvm.relay.testing.mlp.get_workload(_batch, _num_class)[0] + +# vgg fc +mx_vgg = {} +relay_vgg = {} +for num_layers in [11, 13, 16, 19]: + mx_vgg[num_layers] = vgg.get_symbol(_num_class, num_layers) + relay_vgg[num_layers] = tvm.relay.testing.vgg.get_workload( + _batch, _num_class, num_layers=num_layers)[0] + +# resnet fc +mx_resnet = {} +relay_resnet = {} +for num_layers in [18, 34, 50, 101, 152, 200, 269]: + mx_resnet[num_layers] = resnet.get_symbol(_num_class, num_layers, '3,224,224') + relay_resnet[num_layers] = tvm.relay.testing.resnet.get_workload( + _batch, _num_class, num_layers=num_layers)[0] + +# squeezenet +mx_squeezenet = {} +relay_squeezenet = {} +for version in ['1.0', '1.1']: + mx_squeezenet[version] = squeezenet.get_symbol(version=version) + relay_squeezenet[version] = tvm.relay.testing.squeezenet.get_workload(_batch, version=version)[0] + +# inception +mx_inception_v3 = inception_v3.get_symbol() +relay_inception_v3 = tvm.relay.testing.inception_v3.get_workload(_batch)[0] + +# dqn +mx_dqn = dqn.get_symbol() +relay_dqn = tvm.relay.testing.dqn.get_workload(_batch)[0] + +# dcgan generator +mx_dcgan = dcgan.get_symbol() +relay_dcgan = tvm.relay.testing.dcgan.get_workload(_batch)[0] diff --git a/tests/python/relay/frontend/mxnet/model_zoo/dcgan.py b/tests/python/relay/frontend/mxnet/model_zoo/dcgan.py new file mode 100644 index 000000000..8af030b6b --- /dev/null +++ b/tests/python/relay/frontend/mxnet/model_zoo/dcgan.py @@ -0,0 +1,66 @@ +# pylint: disable=unused-argument +""" +The MXNet symbol of DCGAN generator + +Adopted from: +https://github.com/tqchen/mxnet-gan/blob/master/mxgan/generator.py + +Reference: +Radford, Alec, Luke Metz, and Soumith Chintala. +"Unsupervised representation learning with deep convolutional generative adversarial networks." +arXiv preprint arXiv:1511.06434 (2015). +""" + +import mxnet as mx + +def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)): + """a deconv layer that enlarges the feature map""" + target_shape = (oshape[-2], oshape[-1]) + pad_y = (kshape[0] - 1) // 2 + pad_x = (kshape[1] - 1) // 2 + adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0] + adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1] + + net = mx.sym.Deconvolution(data, + kernel=kshape, + stride=stride, + pad=(pad_y, pad_x), + adj=(adj_y, adj_x), + num_filter=oshape[0], + no_bias=True, + name=name) + return net + +def deconv2d_bn_relu(data, prefix, **kwargs): + """a block of deconv + batch norm + relu""" + eps = 1e-5 + 1e-12 + + net = deconv2d(data, name="%s_deconv" % prefix, **kwargs) + net = mx.sym.BatchNorm(net, eps=eps, name="%s_bn" % prefix) + net = mx.sym.Activation(net, name="%s_act" % prefix, act_type='relu') + return net + +def get_symbol(oshape=(3, 64, 64), ngf=128, code=None): + """get symbol of dcgan generator""" + assert oshape[-1] == 64, "Only support 64x64 image" + assert oshape[-2] == 64, "Only support 64x64 image" + + code = mx.sym.Variable("data") if code is None else code + net = mx.sym.FullyConnected(code, name="g1", num_hidden=ngf*8*4*4, no_bias=True, flatten=False) + net = mx.sym.Activation(net, act_type='relu') + # 4 x 4 + net = mx.sym.reshape(net, shape=(-1, ngf * 8, 4, 4)) + # 8 x 8 + net = deconv2d_bn_relu( + net, ishape=(ngf * 8, 4, 4), oshape=(ngf * 4, 8, 8), kshape=(4, 4), prefix="g2") + # 16x16 + net = deconv2d_bn_relu( + net, ishape=(ngf * 4, 8, 8), oshape=(ngf * 2, 16, 16), kshape=(4, 4), prefix="g3") + # 32x32 + net = deconv2d_bn_relu( + net, ishape=(ngf * 2, 16, 16), oshape=(ngf, 32, 32), kshape=(4, 4), prefix="g4") + # 64x64 + net = deconv2d( + net, ishape=(ngf, 32, 32), oshape=oshape[-3:], kshape=(4, 4), name="g5_deconv") + net = mx.sym.Activation(net, act_type='tanh') + return net diff --git a/tests/python/relay/frontend/mxnet/model_zoo/dqn.py b/tests/python/relay/frontend/mxnet/model_zoo/dqn.py new file mode 100644 index 000000000..e037511ef --- /dev/null +++ b/tests/python/relay/frontend/mxnet/model_zoo/dqn.py @@ -0,0 +1,27 @@ +""" +The mxnet symbol of Nature DQN + +Reference: +Mnih, Volodymyr, et al. +"Human-level control through deep reinforcement learning." +Nature 518.7540 (2015): 529. +""" + +import mxnet as mx + +def get_symbol(num_action=18): + data = mx.sym.Variable(name='data') + net = mx.sym.Convolution(data, kernel=(8, 8), stride=(4, 4), + num_filter=32, name='conv1') + net = mx.sym.Activation(net, act_type='relu', name='relu1') + net = mx.sym.Convolution(net, kernel=(4, 4), stride=(2, 2), + num_filter=64, name='conv2') + net = mx.sym.Activation(net, act_type='relu', name='relu2') + net = mx.sym.Convolution(net, kernel=(3, 3), stride=(1, 1), + num_filter=64, name='conv3') + net = mx.sym.Activation(net, act_type='relu', name='relu3') + net = mx.sym.FullyConnected(net, num_hidden=512, name='fc4') + net = mx.sym.Activation(net, act_type='relu', name='relu4') + net = mx.sym.FullyConnected(net, num_hidden=num_action, name='fc5', flatten=False) + + return net diff --git a/tests/python/relay/frontend/mxnet/model_zoo/inception_v3.py b/tests/python/relay/frontend/mxnet/model_zoo/inception_v3.py new file mode 100644 index 000000000..b8585bf05 --- /dev/null +++ b/tests/python/relay/frontend/mxnet/model_zoo/inception_v3.py @@ -0,0 +1,170 @@ +""" +Inception V3, suitable for images with around 299 x 299 + +Reference: +Szegedy, Christian, et al. "Rethinking the Inception Architecture for Computer Vision." arXiv preprint arXiv:1512.00567 (2015). + +Adopted from https://github.com/apache/incubator-mxnet/blob/ + master/example/image-classification/symbols/inception-v3.py +""" +import mxnet as mx +import numpy as np + +def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, suffix=''): + conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix)) + bn = mx.sym.BatchNorm(data=conv, eps=2e-5, name='%s%s_batchnorm' % (name, suffix)) + act = mx.sym.Activation(data=bn, act_type='relu', name='%s%s_relu' %(name, suffix)) + return act + + +def Inception7A(data, + num_1x1, + num_3x3_red, num_3x3_1, num_3x3_2, + num_5x5_red, num_5x5, + pool, proj, + name): + tower_1x1 = Conv(data, num_1x1, name=('%s_conv' % name)) + tower_5x5 = Conv(data, num_5x5_red, name=('%s_tower' % name), suffix='_conv') + tower_5x5 = Conv(tower_5x5, num_5x5, kernel=(5, 5), pad=(2, 2), name=('%s_tower' % name), suffix='_conv_1') + tower_3x3 = Conv(data, num_3x3_red, name=('%s_tower_1' % name), suffix='_conv') + tower_3x3 = Conv(tower_3x3, num_3x3_1, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1') + tower_3x3 = Conv(tower_3x3, num_3x3_2, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_2') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + cproj = Conv(pooling, proj, name=('%s_tower_2' % name), suffix='_conv') + concat = mx.sym.Concat(*[tower_1x1, tower_5x5, tower_3x3, cproj], name='ch_concat_%s_chconcat' % name) + return concat + +# First Downsample +def Inception7B(data, + num_3x3, + num_d3x3_red, num_d3x3_1, num_d3x3_2, + pool, + name): + tower_3x3 = Conv(data, num_3x3, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_conv' % name)) + tower_d3x3 = Conv(data, num_d3x3_red, name=('%s_tower' % name), suffix='_conv') + tower_d3x3 = Conv(tower_d3x3, num_d3x3_1, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_tower' % name), suffix='_conv_1') + tower_d3x3 = Conv(tower_d3x3, num_d3x3_2, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_2') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(0,0), pool_type="max", name=('max_pool_%s_pool' % name)) + concat = mx.sym.Concat(*[tower_3x3, tower_d3x3, pooling], name='ch_concat_%s_chconcat' % name) + return concat + +def Inception7C(data, + num_1x1, + num_d7_red, num_d7_1, num_d7_2, + num_q7_red, num_q7_1, num_q7_2, num_q7_3, num_q7_4, + pool, proj, + name): + tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name)) + tower_d7 = Conv(data=data, num_filter=num_d7_red, name=('%s_tower' % name), suffix='_conv') + tower_d7 = Conv(data=tower_d7, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower' % name), suffix='_conv_1') + tower_d7 = Conv(data=tower_d7, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower' % name), suffix='_conv_2') + tower_q7 = Conv(data=data, num_filter=num_q7_red, name=('%s_tower_1' % name), suffix='_conv') + tower_q7 = Conv(data=tower_q7, num_filter=num_q7_1, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_1') + tower_q7 = Conv(data=tower_q7, num_filter=num_q7_2, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_2') + tower_q7 = Conv(data=tower_q7, num_filter=num_q7_3, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_3') + tower_q7 = Conv(data=tower_q7, num_filter=num_q7_4, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_4') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv') + # concat + concat = mx.sym.Concat(*[tower_1x1, tower_d7, tower_q7, cproj], name='ch_concat_%s_chconcat' % name) + return concat + +def Inception7D(data, + num_3x3_red, num_3x3, + num_d7_3x3_red, num_d7_1, num_d7_2, num_d7_3x3, + pool, + name): + tower_3x3 = Conv(data=data, num_filter=num_3x3_red, name=('%s_tower' % name), suffix='_conv') + tower_3x3 = Conv(data=tower_3x3, num_filter=num_3x3, kernel=(3, 3), pad=(0,0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_1') + tower_d7_3x3 = Conv(data=data, num_filter=num_d7_3x3_red, name=('%s_tower_1' % name), suffix='_conv') + tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_1') + tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_2') + tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_3x3, kernel=(3, 3), stride=(2, 2), name=('%s_tower_1' % name), suffix='_conv_3') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + # concat + concat = mx.sym.Concat(*[tower_3x3, tower_d7_3x3, pooling], name='ch_concat_%s_chconcat' % name) + return concat + +def Inception7E(data, + num_1x1, + num_d3_red, num_d3_1, num_d3_2, + num_3x3_d3_red, num_3x3, num_3x3_d3_1, num_3x3_d3_2, + pool, proj, + name): + tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name)) + tower_d3 = Conv(data=data, num_filter=num_d3_red, name=('%s_tower' % name), suffix='_conv') + tower_d3_a = Conv(data=tower_d3, num_filter=num_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower' % name), suffix='_mixed_conv') + tower_d3_b = Conv(data=tower_d3, num_filter=num_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower' % name), suffix='_mixed_conv_1') + tower_3x3_d3 = Conv(data=data, num_filter=num_3x3_d3_red, name=('%s_tower_1' % name), suffix='_conv') + tower_3x3_d3 = Conv(data=tower_3x3_d3, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1') + tower_3x3_d3_a = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower_1' % name), suffix='_mixed_conv') + tower_3x3_d3_b = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower_1' % name), suffix='_mixed_conv_1') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv') + # concat + concat = mx.sym.Concat(*[tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj], name='ch_concat_%s_chconcat' % name) + return concat + +def get_symbol(num_classes=1000, **kwargs): + data = mx.sym.Variable(name="data") + # stage 1 + conv = Conv(data, 32, kernel=(3, 3), stride=(2, 2), name="conv") + conv_1 = Conv(conv, 32, kernel=(3, 3), name="conv_1") + conv_2 = Conv(conv_1, 64, kernel=(3, 3), pad=(1, 1), name="conv_2") + pool = mx.sym.Pooling(data=conv_2, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool") + # stage 2 + conv_3 = Conv(pool, 80, kernel=(1, 1), name="conv_3") + conv_4 = Conv(conv_3, 192, kernel=(3, 3), name="conv_4") + pool1 = mx.sym.Pooling(data=conv_4, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool1") + + # # stage 3 + in3a = Inception7A(pool1, 64, + 64, 96, 96, + 48, 64, + "avg", 32, "mixed") + in3b = Inception7A(in3a, 64, + 64, 96, 96, + 48, 64, + "avg", 64, "mixed_1") + in3c = Inception7A(in3b, 64, + 64, 96, 96, + 48, 64, + "avg", 64, "mixed_2") + in3d = Inception7B(in3c, 384, + 64, 96, 96, + "max", "mixed_3") + # stage 4 + in4a = Inception7C(in3d, 192, + 128, 128, 192, + 128, 128, 128, 128, 192, + "avg", 192, "mixed_4") + in4b = Inception7C(in4a, 192, + 160, 160, 192, + 160, 160, 160, 160, 192, + "avg", 192, "mixed_5") + in4c = Inception7C(in4b, 192, + 160, 160, 192, + 160, 160, 160, 160, 192, + "avg", 192, "mixed_6") + in4d = Inception7C(in4c, 192, + 192, 192, 192, + 192, 192, 192, 192, 192, + "avg", 192, "mixed_7") + in4e = Inception7D(in4d, 192, 320, + 192, 192, 192, 192, + "max", "mixed_8") + # stage 5 + in5a = Inception7E(in4e, 320, + 384, 384, 384, + 448, 384, 384, 384, + "avg", 192, "mixed_9") + in5b = Inception7E(in5a, 320, + 384, 384, 384, + 448, 384, 384, 384, + "max", 192, "mixed_10") + # pool + pool = mx.sym.Pooling(data=in5b, kernel=(8, 8), stride=(1, 1), pool_type="avg", name="global_pool") + flatten = mx.sym.Flatten(data=pool, name="flatten") + fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1', flatten=False) + softmax = mx.sym.SoftmaxOutput(data=fc1, name='softmax') + return softmax diff --git a/tests/python/relay/frontend/mxnet/model_zoo/mlp.py b/tests/python/relay/frontend/mxnet/model_zoo/mlp.py new file mode 100644 index 000000000..922b20874 --- /dev/null +++ b/tests/python/relay/frontend/mxnet/model_zoo/mlp.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +a simple multilayer perceptron +""" +import mxnet as mx + +def get_symbol(num_classes=10, **kwargs): + data = mx.symbol.Variable('data') + data = mx.sym.Flatten(data=data) + try: + fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128, flatten=False) + act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") + fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64, flatten=False) + act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") + fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes, flatten=False) + mlp = mx.symbol.softmax(data = fc3, name = 'softmax') + except: + fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) + act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") + fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) + act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") + fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes) + mlp = mx.symbol.softmax(data = fc3, name = 'softmax') + return mlp diff --git a/tests/python/relay/frontend/mxnet/model_zoo/resnet.py b/tests/python/relay/frontend/mxnet/model_zoo/resnet.py new file mode 100644 index 000000000..3f9a870d3 --- /dev/null +++ b/tests/python/relay/frontend/mxnet/model_zoo/resnet.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +''' +Adapted from https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py +Original author Wei Wu + +Implemented the following paper: + +Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks" +''' +import mxnet as mx +import numpy as np + +def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False): + """Return ResNet Unit symbol for building ResNet + Parameters + ---------- + data : str + Input data + num_filter : int + Number of output channels + bnf : int + Bottle neck channels factor with regard to num_filter + stride : tuple + Stride used in convolution + dim_match : Boolean + True means channel number between input and output is the same, otherwise means differ + name : str + Base name of the operators + workspace : int + Workspace used in convolution operator + """ + if bottle_neck: + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1') + act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1') + conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter*0.25), kernel=(1,1), stride=stride, pad=(0,0), + no_bias=True, workspace=workspace, name=name + '_conv1') + bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2') + act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2') + conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1,1), + no_bias=True, workspace=workspace, name=name + '_conv2') + bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3') + act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3') + conv3 = mx.sym.Convolution(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True, + workspace=workspace, name=name + '_conv3') + if dim_match: + shortcut = data + else: + shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True, + workspace=workspace, name=name+'_sc') + if memonger: + shortcut._set_attr(mirror_stage='True') + return conv3 + shortcut + else: + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn1') + act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1') + conv1 = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(3,3), stride=stride, pad=(1,1), + no_bias=True, workspace=workspace, name=name + '_conv1') + bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2') + act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2') + conv2 = mx.sym.Convolution(data=act2, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1), + no_bias=True, workspace=workspace, name=name + '_conv2') + if dim_match: + shortcut = data + else: + shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True, + workspace=workspace, name=name+'_sc') + if memonger: + shortcut._set_attr(mirror_stage='True') + return conv2 + shortcut + +def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck=True, bn_mom=0.9, workspace=256, dtype='float32', memonger=False): + """Return ResNet symbol of + Parameters + ---------- + units : list + Number of units in each stage + num_stages : int + Number of stage + filter_list : list + Channel size of each stage + num_classes : int + Ouput size of symbol + dataset : str + Dataset type, only cifar10 and imagenet supports + workspace : int + Workspace used in convolution operator + dtype : str + Precision (float32 or float16) + """ + num_unit = len(units) + assert(num_unit == num_stages) + data = mx.sym.Variable(name='data') + if dtype == 'float32': + # data = mx.sym.identity(data=data, name='id') + data = data + else: + if dtype == 'float16': + data = mx.sym.Cast(data=data, dtype=np.float16) + data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data') + (nchannel, height, width) = image_shape + if height <= 32: # such as cifar10 + body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(3, 3), stride=(1,1), pad=(1, 1), + no_bias=True, name="conv0", workspace=workspace) + else: # often expected to be 224 such as imagenet + body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2,2), pad=(3, 3), + no_bias=True, name="conv0", workspace=workspace) + body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0') + body = mx.sym.Activation(data=body, act_type='relu', name='relu0') + body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max') + + for i in range(num_stages): + body = residual_unit(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False, + name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, workspace=workspace, + memonger=memonger) + for j in range(units[i]-1): + body = residual_unit(body, filter_list[i+1], (1,1), True, name='stage%d_unit%d' % (i + 1, j + 2), + bottle_neck=bottle_neck, workspace=workspace, memonger=memonger) + bn1 = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1') + relu1 = mx.sym.Activation(data=bn1, act_type='relu', name='relu1') + # Although kernel is not used here when global_pool=True, we should put one + pool1 = mx.sym.Pooling(data=relu1, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1') + flat = mx.sym.Flatten(data=pool1) + try: + fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1', flatten=False) + except: + fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1') + if dtype == 'float16': + fc1 = mx.sym.Cast(data=fc1, dtype=np.float32) + return mx.sym.softmax(data=fc1, name='softmax') + +def get_symbol(num_classes, num_layers, image_shape, conv_workspace=256, dtype='float32', **kwargs): + """ + Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py + Original author Wei Wu + """ + image_shape = [int(l) for l in image_shape.split(',')] + (nchannel, height, width) = image_shape + if height <= 28: + num_stages = 3 + if (num_layers-2) % 9 == 0 and num_layers >= 164: + per_unit = [(num_layers-2)//9] + filter_list = [16, 64, 128, 256] + bottle_neck = True + elif (num_layers-2) % 6 == 0 and num_layers < 164: + per_unit = [(num_layers-2)//6] + filter_list = [16, 16, 32, 64] + bottle_neck = False + else: + raise ValueError("no experiments done on num_layers {}, you can do it yourself".format(num_layers)) + units = per_unit * num_stages + else: + if num_layers >= 50: + filter_list = [64, 256, 512, 1024, 2048] + bottle_neck = True + else: + filter_list = [64, 64, 128, 256, 512] + bottle_neck = False + num_stages = 4 + if num_layers == 18: + units = [2, 2, 2, 2] + elif num_layers == 34: + units = [3, 4, 6, 3] + elif num_layers == 50: + units = [3, 4, 6, 3] + elif num_layers == 101: + units = [3, 4, 23, 3] + elif num_layers == 152: + units = [3, 8, 36, 3] + elif num_layers == 200: + units = [3, 24, 36, 3] + elif num_layers == 269: + units = [3, 30, 48, 8] + else: + raise ValueError("no experiments done on num_layers {}, you can do it yourself".format(num_layers)) + + return resnet(units = units, + num_stages = num_stages, + filter_list = filter_list, + num_classes = num_classes, + image_shape = image_shape, + bottle_neck = bottle_neck, + workspace = conv_workspace, + dtype = dtype) diff --git a/tests/python/relay/frontend/mxnet/model_zoo/squeezenet.py b/tests/python/relay/frontend/mxnet/model_zoo/squeezenet.py new file mode 100644 index 000000000..deb896a21 --- /dev/null +++ b/tests/python/relay/frontend/mxnet/model_zoo/squeezenet.py @@ -0,0 +1,76 @@ +""" +Symbol of SqueezeNet + +Reference: +Iandola, Forrest N., et al. +"Squeezenet: Alexnet-level accuracy with 50x fewer parameters and< 0.5 mb model size." (2016). +""" + +import mxnet as mx + +# Helpers +def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels): + net = _make_fire_conv(net, squeeze_channels, 1, 0) + + left = _make_fire_conv(net, expand1x1_channels, 1, 0) + right = _make_fire_conv(net, expand3x3_channels, 3, 1) + # NOTE : Assume NCHW layout here + net = mx.sym.concat(left, right, dim=1) + + return net + +def _make_fire_conv(net, channels, kernel_size, padding=0): + net = mx.sym.Convolution(net, num_filter=channels, kernel=(kernel_size, kernel_size), + pad=(padding, padding)) + net = mx.sym.Activation(net, act_type='relu') + return net + +# Net +def get_symbol(num_classes=1000, version='1.0', **kwargs): + """Get symbol of SqueezeNet + + Parameters + ---------- + num_classes: int + The number of classification results + + version : str, optional + "1.0" or "1.1" of SqueezeNet + """ + assert version in ['1.0', '1.1'], ("Unsupported SqueezeNet version {version}:" + "1.0 or 1.1 expected".format(version=version)) + net = mx.sym.Variable("data") + if version == '1.0': + net = mx.sym.Convolution(net, num_filter=96, kernel=(7, 7), stride=(2, 2), pad=(3, 3)) + net = mx.sym.Activation(net, act_type='relu') + net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2)) + net = _make_fire(net, 16, 64, 64) + net = _make_fire(net, 16, 64, 64) + net = _make_fire(net, 32, 128, 128) + net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2)) + net = _make_fire(net, 32, 128, 128) + net = _make_fire(net, 48, 192, 192) + net = _make_fire(net, 48, 192, 192) + net = _make_fire(net, 64, 256, 256) + net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2)) + net = _make_fire(net, 64, 256, 256) + else: + net = mx.sym.Convolution(net, num_filter=64, kernel=(3, 3), stride=(2, 2), pad=(1, 1)) + net = mx.sym.Activation(net, act_type='relu') + net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2)) + net = _make_fire(net, 16, 64, 64) + net = _make_fire(net, 16, 64, 64) + net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2)) + net = _make_fire(net, 32, 128, 128) + net = _make_fire(net, 32, 128, 128) + net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2)) + net = _make_fire(net, 48, 192, 192) + net = _make_fire(net, 48, 192, 192) + net = _make_fire(net, 64, 256, 256) + net = _make_fire(net, 64, 256, 256) + net = mx.sym.Dropout(net, p=0.5) + net = mx.sym.Convolution(net, num_filter=num_classes, kernel=(1, 1)) + net = mx.sym.Activation(net, act_type='relu') + net = mx.sym.Pooling(data=net, global_pool=True, kernel=(13, 13), pool_type='avg') + net = mx.sym.flatten(net) + return mx.sym.softmax(net) diff --git a/tests/python/relay/frontend/mxnet/model_zoo/vgg.py b/tests/python/relay/frontend/mxnet/model_zoo/vgg.py new file mode 100644 index 000000000..68215bb80 --- /dev/null +++ b/tests/python/relay/frontend/mxnet/model_zoo/vgg.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""References: + +Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for +large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014). +""" + +import mxnet as mx +import numpy as np + +def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs): + for i, num in enumerate(layers): + for j in range(num): + internel_layer = mx.sym.Convolution(data = internel_layer, kernel=(3, 3), pad=(1, 1), num_filter=filters[i], name="conv%s_%s" %(i + 1, j + 1)) + if batch_norm: + internel_layer = mx.symbol.BatchNorm(data=internel_layer, name="bn%s_%s" %(i + 1, j + 1)) + internel_layer = mx.sym.Activation(data=internel_layer, act_type="relu", name="relu%s_%s" %(i + 1, j + 1)) + internel_layer = mx.sym.Pooling(data=internel_layer, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool%s" %(i + 1)) + return internel_layer + +def get_classifier(input_data, num_classes, **kwargs): + flatten = mx.sym.Flatten(data=input_data, name="flatten") + try: + fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6", flatten=False) + relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6") + drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6") + fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7", flatten=False) + relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7") + drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7") + fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8", flatten=False) + except: + fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6") + relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6") + drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6") + fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7") + relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7") + drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7") + fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8") + return fc8 + +def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs): + """ + Parameters + ---------- + num_classes : int, default 1000 + Number of classification classes. + num_layers : int + Number of layers for the variant of densenet. Options are 11, 13, 16, 19. + batch_norm : bool, default False + Use batch normalization. + dtype: str, float32 or float16 + Data precision. + """ + vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]), + 13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]), + 16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]), + 19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])} + if num_layers not in vgg_spec: + raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers)) + layers, filters = vgg_spec[num_layers] + data = mx.sym.Variable(name="data") + if dtype == 'float16': + data = mx.sym.Cast(data=data, dtype=np.float16) + feature = get_feature(data, layers, filters, batch_norm) + classifier = get_classifier(feature, num_classes) + if dtype == 'float16': + classifier = mx.sym.Cast(data=classifier, dtype=np.float32) + symbol = mx.sym.softmax(data=classifier, name='softmax') + return symbol diff --git a/tests/python/relay/frontend/mxnet/test_forward.py b/tests/python/relay/frontend/mxnet/test_forward.py new file mode 100644 index 000000000..fcc760981 --- /dev/null +++ b/tests/python/relay/frontend/mxnet/test_forward.py @@ -0,0 +1,206 @@ +import numpy as np + +import topi +import tvm +from tvm.contrib import graph_runtime +from tvm import relay +from tvm.relay.testing.config import ctx_list +import mxnet as mx +from mxnet import gluon +from mxnet.gluon.model_zoo import vision +import model_zoo + + +def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(1, 1000), + gluon_impl=False, name=None, dtype='float32'): + """Use name different from test to avoid let nose pick it up""" + if gluon_impl: + def get_gluon_output(name, x): + net = vision.get_model(name) + net.collect_params().initialize(mx.init.Xavier()) + net_sym = gluon.nn.SymbolBlock(outputs=net(mx.sym.var('data')), + inputs=mx.sym.var('data'), + params=net.collect_params()) + out = net_sym(mx.nd.array(x.astype(dtype))).asnumpy() + return out, net_sym + else: + def get_mxnet_output(symbol, x, dtype='float32'): + from collections import namedtuple + Batch = namedtuple('Batch', ['data']) + mod = mx.mod.Module(symbol, label_names=None) + mod.bind(data_shapes=[('data', x.shape)], for_training=False) + mod.init_params() + mod.forward(Batch([mx.nd.array(x.astype(dtype))])) + out = mod.get_outputs()[0].asnumpy() + args, auxs = mod.get_params() + return out, args, auxs + + def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'): + dshape = x.shape + shape_dict = {'data': dshape} + if gluon_impl: + new_sym, params = relay.frontend.from_mxnet(symbol, shape_dict) + else: + new_sym, params = relay.frontend.from_mxnet(symbol, shape_dict, arg_params=args, aux_params=auxs) + + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(new_sym, target, params=params) + m = graph_runtime.create(graph, lib, ctx) + # set inputs + m.set_input("data", tvm.nd.array(x.astype(dtype))) + m.set_input(**params) + m.run() + # get outputs + out = m.get_output(0, tvm.nd.empty(out_shape, dtype)) + return out.asnumpy() + + # random input + x = np.random.uniform(size=data_shape) + if gluon_impl: + gluon_out, gluon_sym = get_gluon_output(name, x) + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(gluon_sym, x, None, None, target, ctx, dtype) + tvm.testing.assert_allclose(gluon_out, tvm_out, rtol=1e-5, atol=1e-5) + else: + mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype) + assert "data" not in args + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype) + tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5) + +def test_forward_mlp(): + mlp = model_zoo.mx_mlp + verify_mxnet_frontend_impl(mlp) + +def test_forward_vgg(): + for n in [11]: + mx_sym = model_zoo.mx_vgg[n] + verify_mxnet_frontend_impl(mx_sym) + +def test_forward_resnet(): + for n in [18]: + mx_sym = model_zoo.mx_resnet[n] + verify_mxnet_frontend_impl(mx_sym) + +def test_forward_elu(): + data = mx.sym.var('data') + data = mx.sym.concat(data, -data, dim=1) # negative part explicitly + mx_sym = mx.sym.LeakyReLU(data, act_type='elu') + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) + +def test_forward_rrelu(): + data = mx.sym.var('data') + data = mx.sym.concat(data, -data, dim=1) # negative part explicitly + mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) + +def test_forward_prelu(): + data = mx.sym.var('data') + data = mx.sym.concat(data, -data, dim=1) # negative part explicitly + mx_sym = mx.sym.LeakyReLU(data, act_type='prelu') + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) + +def test_forward_softrelu(): + data = mx.sym.var('data') + data = mx.sym.concat(data, -data, dim=1) # negative part explicitly + mx_sym = mx.sym.Activation(data, act_type='softrelu') + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) + +def test_forward_fc_flatten(): + # test flatten=True option in mxnet 0.11.1 + data = mx.sym.var('data') + try: + mx_sym = mx.sym.FullyConnected(data, num_hidden=100, flatten=True) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100)) + mx_sym = mx.sym.FullyConnected(mx.sym.Flatten(data), num_hidden=100, flatten=False) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100)) + except: + pass + +def test_forward_clip(): + data = mx.sym.var('data') + data = mx.sym.concat(data, -data, dim=1) # negative part explicity + mx_sym = mx.sym.clip(data, a_min=0, a_max=1) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) + +def test_forward_split(): + data = mx.sym.var('data') + mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=False) + verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 1, 2, 1)) + +def test_forward_split_squeeze(): + data = mx.sym.var('data') + mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True) + verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1)) + +def test_forward_expand_dims(): + data = mx.sym.var('data') + mx_sym = mx.sym.expand_dims(data, axis=1) + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 1, 3, 4)) + +def test_forward_pooling(): + data = mx.sym.var('data') + mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='avg') + verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8)) + + mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max') + verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8)) + +def test_forward_lrn(): + data = mx.sym.var('data') + mx_sym = mx.sym.LRN(data, alpha=2, beta=2, knorm=1, nsize=5) + verify_mxnet_frontend_impl(mx_sym, (1, 10, 24, 24), (1, 10, 24, 24)) + +def test_forward_ones(): + data = mx.sym.var('data') + ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32') + mx_sym = mx.sym.elemwise_add(data, ones) + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) + +def test_forward_zeros(): + data = mx.sym.var('data') + zeros = mx.sym.zeros(shape=(2, 3, 4), dtype='float32') + mx_sym = mx.sym.elemwise_add(data, zeros) + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) + +def test_forward_ones_like(): + data = mx.sym.var('data') + mx_sym = mx.sym.ones_like(data, dtype='float32') + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) + +def test_forward_zeros_like(): + data = mx.sym.var('data') + mx_sym = mx.sym.zeros_like(data, dtype='float32') + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) + +def test_forward_argmax(): + data = mx.sym.var('data') + mx_sym = mx.sym.argmax(data, axis=1) + verify_mxnet_frontend_impl(mx_sym, (5, 3), (5,)) + +def test_forward_argmin(): + data = mx.sym.var('data') + mx_sym = mx.sym.argmin(data, axis=0) + verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,)) + +if __name__ == '__main__': + test_forward_mlp() + test_forward_vgg() + test_forward_resnet() + test_forward_elu() + test_forward_rrelu() + test_forward_prelu() + test_forward_softrelu() + test_forward_fc_flatten() + test_forward_clip() + test_forward_split() + test_forward_split_squeeze() + test_forward_expand_dims() + test_forward_pooling() + test_forward_lrn() + test_forward_ones() + test_forward_zeros() + test_forward_ones_like() + test_forward_zeros_like() + test_forward_argmax() + test_forward_argmin() diff --git a/tests/python/relay/frontend/mxnet/test_graph.py b/tests/python/relay/frontend/mxnet/test_graph.py new file mode 100644 index 000000000..820e78242 --- /dev/null +++ b/tests/python/relay/frontend/mxnet/test_graph.py @@ -0,0 +1,87 @@ +import mxnet as mx +import tvm +from tvm import relay +import model_zoo +from model_zoo import _batch + +def test_mlp(): + mx_sym = model_zoo.mx_mlp + from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 1, 28, 28)}) + from_mx_sym = relay.ir_pass.infer_type(from_mx_sym) + relay_sym = model_zoo.relay_mlp + assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym) + +def test_vgg(): + for n in [11, 13, 16, 19]: + mx_sym = model_zoo.mx_vgg[n] + from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 3, 224, 224)}) + from_mx_sym = relay.ir_pass.infer_type(from_mx_sym) + relay_sym = model_zoo.relay_vgg[n] + assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym) + +def test_resnet(): + for n in [18, 34, 50, 101, 152, 200, 269]: + mx_sym = model_zoo.mx_resnet[n] + from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 3, 224, 224)}) + from_mx_sym = relay.ir_pass.infer_type(from_mx_sym) + relay_sym = model_zoo.relay_resnet[n] + assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym) + +def test_squeezenet(): + for version in ['1.0', '1.1']: + mx_sym = model_zoo.mx_squeezenet[version] + from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 3, 224, 224)}) + from_mx_sym = relay.ir_pass.infer_type(from_mx_sym) + relay_sym = model_zoo.relay_squeezenet[version] + assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym) + +def test_inception_v3(): + mx_sym = model_zoo.mx_inception_v3 + from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 3, 299, 299)}) + from_mx_sym = relay.ir_pass.infer_type(from_mx_sym) + relay_sym = model_zoo.relay_inception_v3 + assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym) + +def test_dqn(): + mx_sym = model_zoo.mx_dqn + from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 4, 84, 84)}) + from_mx_sym = relay.ir_pass.infer_type(from_mx_sym) + relay_sym = model_zoo.relay_dqn + assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym) + +def test_dcgan(): + mx_sym = model_zoo.mx_dcgan + from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 100)}) + from_mx_sym = relay.ir_pass.infer_type(from_mx_sym) + relay_sym = model_zoo.relay_dcgan + assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym) + +def test_multi_outputs(): + def compose_mxnet(**kwargs): + x = mx.sym.Variable('x') + y = mx.sym.Variable('y') + z = mx.sym.split(x, **kwargs) + return mx.sym.broadcast_sub(mx.sym.broadcast_add(z[0], z[2]), y) + def compose_relay(**kwargs): + x = relay.var("x", shape=(_batch, 3, 224, 224)) + y = relay.var("y", shape=(1,)) + z = relay.split(x, **kwargs) + ret = z[0] + z[2] - y + args = relay.ir_pass.free_vars(ret) + return relay.Function(args, ret) + mx_sym = compose_mxnet(num_outputs=3, axis=1) + from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'x': (_batch, 3, 224, 224), 'y': (1,)}) + from_mx_sym = relay.ir_pass.infer_type(from_mx_sym) + relay_sym = compose_relay(indices_or_sections=3, axis=1) + relay_sym = relay.ir_pass.infer_type(relay_sym) + assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym) + +if __name__ == '__main__': + test_mlp() + test_vgg() + test_resnet() + test_squeezenet() + test_inception_v3() + test_dqn() + test_dcgan() + test_multi_outputs() -- GitLab