diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 2a9b81b842309f19ef28374a756c4a57751d06f4..547fff4255956a426ec37ee22e7ed754247da473 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -4,3 +4,4 @@ from __future__ import absolute_import as _abs from . import mlp from . import resnet from . import dqn +from . import dcgan diff --git a/python/tvm/relay/testing/dcgan.py b/python/tvm/relay/testing/dcgan.py new file mode 100644 index 0000000000000000000000000000000000000000..96cd871e4122b55780f96f44023124560603d34c --- /dev/null +++ b/python/tvm/relay/testing/dcgan.py @@ -0,0 +1,96 @@ +# pylint: disable=unused-argument +""" +Net of the generator of DCGAN + +Adopted from: +https://github.com/tqchen/mxnet-gan/blob/master/mxgan/generator.py + +Reference: +Radford, Alec, Luke Metz, and Soumith Chintala. +"Unsupervised representation learning with deep convolutional generative adversarial networks." +arXiv preprint arXiv:1511.06434 (2015). +""" +from tvm import relay +from . import layers +from .init import create_workload + +def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)): + """a deconv layer that enlarges the feature map""" + target_shape = (oshape[-2], oshape[-1]) + + pad_y = (kshape[0] - 1) // 2 + pad_x = (kshape[1] - 1) // 2 + adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0] + adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1] + + net = layers.conv2d_transpose(data, + kernel_size=kshape, + strides=stride, + channels=oshape[0], + padding=(pad_y, pad_x), + output_padding=(adj_y, adj_x), + name=name) + return net + +def deconv2d_bn_relu(data, prefix, **kwargs): + """a block of deconv + batch norm + relu""" + eps = 1e-5 + 1e-12 + net = deconv2d(data, name="%s_deconv" % prefix, **kwargs) + net = layers.batch_norm_infer(net, epsilon=eps, name="batch_norm") + net = relay.nn.relu(net) + return net + +def get_net(batch_size, random_len=100, oshape=(3, 64, 64), ngf=128, code=None, dtype="float32"): + """get net of dcgan generator""" + assert oshape[-1] == 64, "Only support 64x64 image" + assert oshape[-2] == 64, "Only support 64x64 image" + + code = relay.var("data", dtype=dtype, shape=(batch_size, random_len)) if code is None else code + dense_weight = relay.var("dense_weight") + dense = relay.nn.dense(code, weight=dense_weight, units=4*4*ngf*8) + relu = relay.nn.relu(dense) + # 4 x 4 + reshape = relay.reshape(relu, newshape=(-1, ngf * 8, 4, 4)) + # 8 x 8 + dc8 = deconv2d_bn_relu( + reshape, ishape=(ngf * 8, 4, 4), oshape=(ngf * 4, 8, 8), kshape=(4, 4), prefix="g2") + # 16x16 + dc16 = deconv2d_bn_relu( + dc8, ishape=(ngf * 4, 8, 8), oshape=(ngf * 2, 16, 16), kshape=(4, 4), prefix="g3") + # 32x32 + dc32 = deconv2d_bn_relu( + dc16, ishape=(ngf * 2, 16, 16), oshape=(ngf, 32, 32), kshape=(4, 4), prefix="g4") + # 64x64 + dc64 = deconv2d( + dc32, ishape=(ngf, 32, 32), oshape=oshape[-3:], kshape=(4, 4), name="g5_deconv") + tanh = relay.tanh(dc64) + + args = relay.ir_pass.free_vars(tanh) + return relay.Function(args, tanh) + + +def get_workload(batch_size, oshape=(3, 64, 64), ngf=128, random_len=100, dtype="float32"): + """Get benchmark workload for a DCGAN generator + + Parameters + ---------- + batch_size : int + The batch size used in the model + oshape : tuple, optional + The shape of output image, layout="CHW" + ngf: int, optional + The number of final feature maps in the generator + random_len : int, optional + The length of random input + dtype : str, optional + The data type + + Returns + ------- + net : nnvm.symbol + The computational graph + params : dict of str to NDArray + The parameters. + """ + net = get_net(batch_size, random_len, oshape=oshape, ngf=ngf, dtype=dtype) + return create_workload(net) diff --git a/python/tvm/relay/testing/dqn.py b/python/tvm/relay/testing/dqn.py index 736894612e199ae98287030db185c6cfe7b5772e..034ac0a6c2e5f7a68754346f633d4ef209f0c48f 100644 --- a/python/tvm/relay/testing/dqn.py +++ b/python/tvm/relay/testing/dqn.py @@ -30,15 +30,25 @@ def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32" """get symbol of nature dqn""" data_shape = (batch_size,) + image_shape data = relay.var("data", shape=data_shape, dtype=dtype) + + conv1_bias = relay.var("conv1_bias") conv1 = layers.conv2d(data, kernel_size=(8, 8), strides=(4, 4), padding=(0, 0), channels=32, name="conv1") + conv1 = relay.nn.bias_add(conv1, conv1_bias) relu1 = relay.nn.relu(conv1) + + conv2_bias = relay.var("conv2_bias") conv2 = layers.conv2d(relu1, kernel_size=(4, 4), strides=(2, 2), padding=(0, 0), channels=64, name="conv2") + conv2 = relay.nn.bias_add(conv2, conv2_bias) relu2 = relay.nn.relu(conv2) + + conv3_bias = relay.var("conv3_bias") conv3 = layers.conv2d(relu2, kernel_size=(3, 3), strides=(1, 1), padding=(0, 0), channels=64, name="conv3") + conv3 = relay.nn.bias_add(conv3, conv3_bias) relu3 = relay.nn.relu(conv3) + bf1 = relay.nn.batch_flatten(relu3) dense1 = layers.dense_add_bias(bf1, units=512, name="dense1") relu4 = relay.nn.relu(dense1) diff --git a/python/tvm/relay/testing/layers.py b/python/tvm/relay/testing/layers.py index fc06ca229f771b8b04839d6cb52c04b7de44887a..1b279d9e72af7ecc5be48b073f3865f1bd1b2e55 100644 --- a/python/tvm/relay/testing/layers.py +++ b/python/tvm/relay/testing/layers.py @@ -80,6 +80,30 @@ def conv2d(data, weight=None, **kwargs): weight = relay.var(name + "_weight") return relay.nn.conv2d(data, weight, **kwargs) +def conv2d_transpose(data, weight=None, **kwargs): + """Wrapper of conv2d_transpose which automatically creates weights if not given. + + Parameters + ---------- + data : relay.Expr + The input expression. + + weight : relay.Expr + The weight to conv2d_transpose. + + kwargs : dict + Additional arguments. + + Returns + ------- + result : relay.Expr + The result. + """ + name = kwargs.get("name") + kwargs.pop("name") + if not weight: + weight = relay.var(name + "_weight") + return relay.nn.conv2d_transpose(data, weight, **kwargs) def dense_add_bias(data, weight=None, bias=None, **kwargs): """Wrapper of dense which automatically creates weights if not given. diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 7b2c343b084444ed884e5555eb4cad98abd5881e..fd446f9b7f03c72fa5f7f9583016ac12a6c95629 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -106,13 +106,18 @@ def test_resnet(): def test_dqn(): net, params = tvm.relay.testing.dqn.get_workload(batch_size=1) - show(net.astext()) + net.astext() + +def test_dcgan(): + net, params = tvm.relay.testing.dcgan.get_workload(batch_size=1) + net.astext() if __name__ == "__main__": do_print[0] = True test_resnet() test_mlp() test_dqn() + test_dcgan() test_func() test_env() test_meta_data()