Newer
Older
import numpy as np
import topi
import tvm
from tvm.contrib import graph_runtime
import nnvm.symbol as sym
import nnvm.compiler
from nnvm.testing.config import ctx_list
from nnvm import frontend
import mxnet as mx
from mxnet import gluon
from mxnet.gluon.model_zoo import vision
def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(1, 1000),
gluon_impl=False, name=None, dtype='float32'):
"""Use name different from test to avoid let nose pick it up"""
if gluon_impl:
def get_gluon_output(name, x):
net = vision.get_model(name)
net.collect_params().initialize(mx.init.Xavier())
net_sym = gluon.nn.SymbolBlock(outputs=net(mx.sym.var('data')),
inputs=mx.sym.var('data'),
params=net.collect_params())
out = net_sym(mx.nd.array(x.astype(dtype))).asnumpy()
return out, net_sym
else:
def get_mxnet_output(symbol, x, dtype='float32'):
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
mod = mx.mod.Module(symbol, label_names=None)
mod.bind(data_shapes=[('data', x.shape)], for_training=False)
mod.init_params()
mod.forward(Batch([mx.nd.array(x.astype(dtype))]))
out = mod.get_outputs()[0].asnumpy()
args, auxs = mod.get_params()
return out, args, auxs
def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'):
if gluon_impl:
new_sym, params = frontend.from_mxnet(symbol)
else:
new_sym, params = frontend.from_mxnet(symbol, args, auxs)
dshape = x.shape
shape_dict = {'data': dshape}
with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
m = graph_runtime.create(graph, lib, ctx)
m.set_input("data", tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
m.run()
out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
return out.asnumpy()
# random input
x = np.random.uniform(size=data_shape)
if gluon_impl:
gluon_out, gluon_sym = get_gluon_output(name, x)
for target, ctx in ctx_list():
tvm_out = get_tvm_output(gluon_sym, x, None, None, target, ctx, dtype)
tvm.testing.assert_allclose(gluon_out, tvm_out, rtol=1e-5, atol=1e-5)
else:
mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype)
assert "data" not in args
for target, ctx in ctx_list():
tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_mlp():
mlp = model_zoo.mx_mlp
verify_mxnet_frontend_impl(mlp)
def test_forward_vgg():
for n in [11]:
mx_sym = model_zoo.mx_vgg[n]
verify_mxnet_frontend_impl(mx_sym)
def test_forward_resnet():
for n in [18]:
mx_sym = model_zoo.mx_resnet[n]
verify_mxnet_frontend_impl(mx_sym)
def test_forward_elu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='elu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_rrelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_prelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='prelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_softrelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.Activation(data, act_type='softrelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_fc_flatten():
# test flatten=True option in mxnet 0.11.1
data = mx.sym.var('data')
try:
mx_sym = mx.sym.FullyConnected(data, num_hidden=100, flatten=True)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
mx_sym = mx.sym.FullyConnected(mx.sym.Flatten(data), num_hidden=100, flatten=False)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
except:
pass
def test_forward_clip():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicity
mx_sym = mx.sym.clip(data, a_min=0, a_max=1)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_split():
data = mx.sym.var('data')
mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=False)
verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 1, 2, 1))
def test_forward_split_squeeze():
data = mx.sym.var('data')
mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True)
verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1))
def test_forward_expand_dims():
data = mx.sym.var('data')
mx_sym = mx.sym.expand_dims(data, axis=1)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 1, 3, 4))
def test_forward_pooling():
data = mx.sym.var('data')
mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='avg')
verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))
mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max')
verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))
def test_forward_lrn():
data = mx.sym.var('data')
mx_sym = mx.sym.LRN(data, alpha=2, beta=2, knorm=1, nsize=5)
verify_mxnet_frontend_impl(mx_sym, (1, 10, 24, 24), (1, 10, 24, 24))
def test_forward_ones():
data = mx.sym.var('data')
ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32')
mx_sym = mx.sym.elemwise_add(data, ones)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_zeros():
data = mx.sym.var('data')
zeros = mx.sym.zeros(shape=(2, 3, 4), dtype='float32')
mx_sym = mx.sym.elemwise_add(data, zeros)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_ones_like():
data = mx.sym.var('data')
mx_sym = mx.sym.ones_like(data, dtype='float32')
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_zeros_like():
data = mx.sym.var('data')
mx_sym = mx.sym.zeros_like(data, dtype='float32')
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
test_forward_resnet()
test_forward_elu()
test_forward_rrelu()
test_forward_softrelu()
test_forward_fc_flatten()
test_forward_split()
test_forward_split_squeeze()
test_forward_expand_dims()