From 42b189cbc040be496ce9af7d2f59270ca586edf8 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng <mercy_zheng@sjtu.edu.cn>
Date: Mon, 18 Jun 2018 03:38:23 +0800
Subject: [PATCH] [NNVM][TESTING] Add two testing symbols: dqn and dcgan
 (#1294)

---
 nnvm/python/nnvm/testing/__init__.py          |  2 +
 nnvm/python/nnvm/testing/dcgan.py             | 90 +++++++++++++++++++
 nnvm/python/nnvm/testing/dqn.py               | 71 +++++++++++++++
 .../frontend/mxnet/model_zoo/__init__.py      | 10 ++-
 .../python/frontend/mxnet/model_zoo/dcgan.py  | 63 +++++++++++++
 .../python/frontend/mxnet/model_zoo/dqn.py    | 27 ++++++
 .../tests/python/frontend/mxnet/test_graph.py | 14 +++
 7 files changed, 276 insertions(+), 1 deletion(-)
 create mode 100644 nnvm/python/nnvm/testing/dcgan.py
 create mode 100644 nnvm/python/nnvm/testing/dqn.py
 create mode 100644 nnvm/tests/python/frontend/mxnet/model_zoo/dcgan.py
 create mode 100644 nnvm/tests/python/frontend/mxnet/model_zoo/dqn.py

diff --git a/nnvm/python/nnvm/testing/__init__.py b/nnvm/python/nnvm/testing/__init__.py
index 56d5a9a48..2700aea10 100644
--- a/nnvm/python/nnvm/testing/__init__.py
+++ b/nnvm/python/nnvm/testing/__init__.py
@@ -7,4 +7,6 @@ from . import mobilenet
 from . import mlp
 from . import resnet
 from . import vgg
+from . import dcgan
+from . import dqn
 from . import yolo2_detection
diff --git a/nnvm/python/nnvm/testing/dcgan.py b/nnvm/python/nnvm/testing/dcgan.py
new file mode 100644
index 000000000..421699ad4
--- /dev/null
+++ b/nnvm/python/nnvm/testing/dcgan.py
@@ -0,0 +1,90 @@
+# pylint: disable=unused-argument
+"""
+Symbol of the generator of DCGAN
+
+Adopted from:
+https://github.com/tqchen/mxnet-gan/blob/master/mxgan/generator.py
+
+Reference:
+Radford, Alec, Luke Metz, and Soumith Chintala.
+"Unsupervised representation learning with deep convolutional generative adversarial networks."
+arXiv preprint arXiv:1511.06434 (2015).
+"""
+from .. import symbol as sym
+from . utils import create_workload
+
+def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)):
+    """a deconv layer that enlarges the feature map"""
+    target_shape = (oshape[-2], oshape[-1])
+
+    pad_y = (kshape[0] - 1) // 2
+    pad_x = (kshape[1] - 1) // 2
+    adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0]
+    adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1]
+
+    net = sym.conv2d_transpose(data,
+                               kernel_size=kshape,
+                               strides=stride,
+                               channels=oshape[0],
+                               padding=(pad_y, pad_x),
+                               output_padding=(adj_y, adj_x),
+                               use_bias=False,
+                               name=name)
+    return net
+
+def deconv2d_bn_relu(data, prefix, **kwargs):
+    """a block of deconv + batch norm + relu"""
+    eps = 1e-5 + 1e-12
+    net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
+    net = sym.batch_norm(net, epsilon=eps, name="%s_bn" % prefix)
+    net = sym.relu(net, name="%s_act" % prefix)
+    return net
+
+def get_symbol(oshape, ngf=128, code=None):
+    """get symbol of dcgan generator"""
+    assert oshape[-1] == 32, "Only support 32x32 image"
+    assert oshape[-2] == 32, "Only support 32x32 image"
+
+    code = sym.Variable("data") if code is None else code
+    net = sym.dense(code, name="g1", units=4*4*ngf*4, use_bias=False)
+    net = sym.relu(net)
+    # 4 x 4
+    net = sym.reshape(net, shape=(-1, ngf * 4, 4, 4))
+    # 8 x 8
+    net = deconv2d_bn_relu(
+        net, ishape=(ngf * 4, 4, 4), oshape=(ngf * 2, 8, 8), kshape=(4, 4), prefix="g2")
+    # 16x16
+    net = deconv2d_bn_relu(
+        net, ishape=(ngf * 2, 8, 8), oshape=(ngf, 16, 16), kshape=(4, 4), prefix="g3")
+    # 32x32
+    net = deconv2d(
+        net, ishape=(ngf, 16, 16), oshape=oshape[-3:], kshape=(4, 4), name="g4_deconv")
+    net = sym.tanh(net)
+    return net
+
+
+def get_workload(batch_size, oshape=(3, 32, 32), ngf=128, random_len=100, dtype="float32"):
+    """Get benchmark workload for a DCGAN generator
+
+    Parameters
+    ----------
+    batch_size : int
+        The batch size used in the model
+    oshape : tuple, optional
+        The shape of output image, layout="CHW"
+    ngf: int, optional
+        The number of final feature maps in the generator
+    random_len : int, optional
+        The length of random input
+    dtype : str, optional
+        The data type
+
+    Returns
+    -------
+    net : nnvm.symbol
+        The computational graph
+    params : dict of str to NDArray
+        The parameters.
+    """
+    net = get_symbol(oshape=oshape, ngf=ngf)
+    return create_workload(net, batch_size, (random_len, ), dtype)
diff --git a/nnvm/python/nnvm/testing/dqn.py b/nnvm/python/nnvm/testing/dqn.py
new file mode 100644
index 000000000..b04475efa
--- /dev/null
+++ b/nnvm/python/nnvm/testing/dqn.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Symbol of Nature DQN
+
+Reference:
+Mnih, Volodymyr, et al. "Human-level control through deep reinforcement learning."
+Nature 518.7540 (2015): 529.
+"""
+
+from .. import symbol as sym
+from . utils import create_workload
+
+def get_symbol(num_actions=18):
+    """get symbol of nature dqn"""
+    data = sym.Variable(name='data')
+    net = sym.conv2d(data, kernel_size=(8, 8), strides=(4, 4), padding=(0, 0),
+                     channels=32, name='conv1')
+    net = sym.relu(net, name='relu1')
+    net = sym.conv2d(net, kernel_size=(4, 4), strides=(2, 2), padding=(0, 0),
+                     channels=64, name='conv2')
+    net = sym.relu(net, name='relu2')
+    net = sym.conv2d(net, kernel_size=(3, 3), strides=(1, 1), padding=(0, 0),
+                     channels=64, name='conv3')
+    net = sym.relu(net, name='relu3')
+    net = sym.flatten(net, name='flatten')
+    net = sym.dense(net, units=512, name='fc4')
+    net = sym.relu(net, name='relu4')
+    net = sym.dense(net, units=num_actions, name='fc5')
+
+    return net
+
+
+def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"):
+    """Get benchmark workload for a Deep Q Network
+
+    Parameters
+    ----------
+    batch_size : int
+        The batch size used in the model
+    num_actions : int, optional
+        Number of actions
+    image_shape : tuple, optional
+        The input image shape
+    dtype : str, optional
+        The data type
+
+    Returns
+    -------
+    net : nnvm.symbol
+        The computational graph
+    params : dict of str to NDArray
+        The parameters.
+    """
+    net = get_symbol(num_actions=num_actions)
+    return create_workload(net, batch_size, image_shape, dtype)
diff --git a/nnvm/tests/python/frontend/mxnet/model_zoo/__init__.py b/nnvm/tests/python/frontend/mxnet/model_zoo/__init__.py
index c39e2b214..6c3d07ffc 100644
--- a/nnvm/tests/python/frontend/mxnet/model_zoo/__init__.py
+++ b/nnvm/tests/python/frontend/mxnet/model_zoo/__init__.py
@@ -1,6 +1,6 @@
 """MXNet and NNVM model zoo."""
 from __future__ import absolute_import
-from . import mlp, resnet, vgg
+from . import mlp, resnet, vgg, dqn, dcgan
 import nnvm.testing
 
 __all__ = ['mx_mlp', 'nnvm_mlp', 'mx_resnet', 'nnvm_resnet', 'mx_vgg', 'nnvm_vgg']
@@ -26,3 +26,11 @@ for num_layer in [11, 13, 16, 19]:
     mx_vgg[num_layer] = vgg.get_symbol(_num_class, num_layer)
     nnvm_vgg[num_layer] = nnvm.testing.vgg.get_workload(
         1, _num_class, num_layers=num_layer)[0]
+
+# dqn
+mx_dqn = dqn.get_symbol()
+nnvm_dqn = nnvm.testing.dqn.get_workload(1)[0]
+
+# dcgan generator
+mx_dcgan = dcgan.get_symbol()
+nnvm_dcgan = nnvm.testing.dcgan.get_workload(1)[0]
diff --git a/nnvm/tests/python/frontend/mxnet/model_zoo/dcgan.py b/nnvm/tests/python/frontend/mxnet/model_zoo/dcgan.py
new file mode 100644
index 000000000..98133d369
--- /dev/null
+++ b/nnvm/tests/python/frontend/mxnet/model_zoo/dcgan.py
@@ -0,0 +1,63 @@
+# pylint: disable=unused-argument
+"""
+The MXNet symbol of DCGAN generator
+
+Adopted from:
+https://github.com/tqchen/mxnet-gan/blob/master/mxgan/generator.py
+
+Reference:
+Radford, Alec, Luke Metz, and Soumith Chintala.
+"Unsupervised representation learning with deep convolutional generative adversarial networks."
+arXiv preprint arXiv:1511.06434 (2015).
+"""
+
+import mxnet as mx
+
+def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)):
+    """a deconv layer that enlarges the feature map"""
+    target_shape = (oshape[-2], oshape[-1])
+    pad_y = (kshape[0] - 1) // 2
+    pad_x = (kshape[1] - 1) // 2
+    adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0]
+    adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1]
+
+    net = mx.sym.Deconvolution(data,
+                               kernel=kshape,
+                               stride=stride,
+                               pad=(pad_y, pad_x),
+                               adj=(adj_y, adj_x),
+                               num_filter=oshape[0],
+                               no_bias=True,
+                               name=name)
+    return net
+
+def deconv2d_bn_relu(data, prefix, **kwargs):
+    """a block of deconv + batch norm + relu"""
+    eps = 1e-5 + 1e-12
+
+    net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
+    net = mx.sym.BatchNorm(net, eps=eps, name="%s_bn" % prefix)
+    net = mx.sym.Activation(net, name="%s_act" % prefix, act_type='relu')
+    return net
+
+def get_symbol(oshape=(3, 32, 32), ngf=128, code=None):
+    """get symbol of dcgan generator"""
+    assert oshape[-1] == 32, "Only support 32x32 image"
+    assert oshape[-2] == 32, "Only support 32x32 image"
+
+    code = mx.sym.Variable("data") if code is None else code
+    net = mx.sym.FullyConnected(code, name="g1", num_hidden=4*4*ngf*4, no_bias=True, flatten=False)
+    net = mx.sym.Activation(net, act_type='relu')
+    # 4 x 4
+    net = mx.sym.reshape(net, shape=(-1, ngf * 4, 4, 4))
+    # 8 x 8
+    net = deconv2d_bn_relu(
+        net, ishape=(ngf * 4, 4, 4), oshape=(ngf * 2, 8, 8), kshape=(4, 4), prefix="g2")
+    # 16x16
+    net = deconv2d_bn_relu(
+        net, ishape=(ngf * 2, 8, 8), oshape=(ngf, 16, 16), kshape=(4, 4), prefix="g3")
+    # 32x32
+    net = deconv2d(
+        net, ishape=(ngf, 16, 16), oshape=oshape[-3:], kshape=(4, 4), name="g4_deconv")
+    net = mx.sym.Activation(net, act_type='tanh')
+    return net
diff --git a/nnvm/tests/python/frontend/mxnet/model_zoo/dqn.py b/nnvm/tests/python/frontend/mxnet/model_zoo/dqn.py
new file mode 100644
index 000000000..e037511ef
--- /dev/null
+++ b/nnvm/tests/python/frontend/mxnet/model_zoo/dqn.py
@@ -0,0 +1,27 @@
+"""
+The mxnet symbol of Nature DQN
+
+Reference:
+Mnih, Volodymyr, et al.
+"Human-level control through deep reinforcement learning."
+Nature 518.7540 (2015): 529.
+"""
+
+import mxnet as mx
+
+def get_symbol(num_action=18):
+    data = mx.sym.Variable(name='data')
+    net = mx.sym.Convolution(data, kernel=(8, 8), stride=(4, 4),
+                             num_filter=32, name='conv1')
+    net = mx.sym.Activation(net, act_type='relu', name='relu1')
+    net = mx.sym.Convolution(net, kernel=(4, 4), stride=(2, 2),
+                             num_filter=64, name='conv2')
+    net = mx.sym.Activation(net, act_type='relu', name='relu2')
+    net = mx.sym.Convolution(net, kernel=(3, 3), stride=(1, 1),
+                             num_filter=64, name='conv3')
+    net = mx.sym.Activation(net, act_type='relu', name='relu3')
+    net = mx.sym.FullyConnected(net, num_hidden=512, name='fc4')
+    net = mx.sym.Activation(net, act_type='relu', name='relu4')
+    net = mx.sym.FullyConnected(net, num_hidden=num_action, name='fc5', flatten=False)
+
+    return net
diff --git a/nnvm/tests/python/frontend/mxnet/test_graph.py b/nnvm/tests/python/frontend/mxnet/test_graph.py
index bbbc42db1..cd14f9c7c 100644
--- a/nnvm/tests/python/frontend/mxnet/test_graph.py
+++ b/nnvm/tests/python/frontend/mxnet/test_graph.py
@@ -32,6 +32,18 @@ def test_resnet():
         nnvm_sym = model_zoo.nnvm_resnet[n]
         compare_graph(from_mx_sym, nnvm_sym)
 
+def test_dqn():
+    mx_sym = model_zoo.mx_dqn
+    from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
+    nnvm_sym = model_zoo.nnvm_dqn
+    compare_graph(from_mx_sym, nnvm_sym)
+
+def test_dcgan():
+    mx_sym = model_zoo.mx_dcgan
+    from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
+    nnvm_sym = model_zoo.nnvm_dcgan
+    compare_graph(from_mx_sym, nnvm_sym)
+
 def test_multi_outputs():
     def compose(F, **kwargs):
         x = F.sym.Variable('x')
@@ -48,3 +60,5 @@ if __name__ == '__main__':
     test_vgg()
     test_resnet()
     test_multi_outputs()
+    test_dqn()
+    test_dcgan()
-- 
GitLab