From 5f79521b8da4172553ba4da880d9b633b18b9d39 Mon Sep 17 00:00:00 2001
From: Yuwei Hu <huyuwei1995@gmail.com>
Date: Thu, 26 Oct 2017 00:39:00 +0800
Subject: [PATCH] [TOPI] add conv2d_transpose_nchw (#586)

---
 topi/python/topi/cuda/__init__.py             |   1 +
 .../python/topi/cuda/conv2d_transpose_nchw.py | 116 ++++++++++++++++++
 topi/python/topi/generic/nn.py                |  58 ++++++---
 topi/python/topi/nn/__init__.py               |   1 +
 topi/python/topi/nn/conv2d_transpose.py       |  63 ++++++++++
 topi/python/topi/testing/__init__.py          |   1 +
 .../python/topi/testing/conv2d_nchw_python.py |   2 +-
 .../testing/conv2d_transpose_nchw_python.py   |  51 ++++++++
 .../python/test_topi_conv2d_transpose_nchw.py |  64 ++++++++++
 9 files changed, 336 insertions(+), 21 deletions(-)
 create mode 100644 topi/python/topi/cuda/conv2d_transpose_nchw.py
 create mode 100644 topi/python/topi/nn/conv2d_transpose.py
 create mode 100644 topi/python/topi/testing/conv2d_transpose_nchw_python.py
 create mode 100644 topi/tests/python/test_topi_conv2d_transpose_nchw.py

diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py
index 186a68df4..b898dde6a 100644
--- a/topi/python/topi/cuda/__init__.py
+++ b/topi/python/topi/cuda/__init__.py
@@ -12,3 +12,4 @@ from .softmax import schedule_softmax
 from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
 from .dense import schedule_dense
 from .pooling import schedule_pool, schedule_global_pool
+from .conv2d_transpose_nchw import schedule_conv2d_transpose_nchw
diff --git a/topi/python/topi/cuda/conv2d_transpose_nchw.py b/topi/python/topi/cuda/conv2d_transpose_nchw.py
new file mode 100644
index 000000000..1e5c39973
--- /dev/null
+++ b/topi/python/topi/cuda/conv2d_transpose_nchw.py
@@ -0,0 +1,116 @@
+#pylint: disable=invalid-name
+"""Schedule for conv2d_transpose_nchw with auto fusion"""
+import tvm
+from .. import util
+from .. import tag
+from .. import generic
+from .conv2d_nchw import conv2d_224_3_64, conv2d_56_64_128, conv2d_14_256_256, conv2d_56_64_64
+
+
+def schedule_conv2d_transpose_small_batch(outs):
+    """Create schedule for tensors or return error if batch size is larger than 1"""
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def schedule(temp, Filter, Output):
+        """Schedule conv2d_transpose_nchw"""
+        block_h = util.get_const_int(Output.shape[3])
+        block_w = util.get_const_int(temp.shape[1])
+        if block_h % 48 == 0:
+            block_h = 48
+        elif block_h % 32 == 0:
+            block_h = 32
+        if block_w % 48 == 0:
+            block_w = 48
+        elif block_w % 32 == 0:
+            block_w = 32
+
+        flag = util.get_const_int(Filter.shape[0])+util.get_const_int(Filter.shape[1])
+
+        if flag > 768:
+            temp_G = s.cache_read(temp, "global", [Output])
+            s[temp_G].compute_inline()
+            i, ic, h, w = s[temp_G].op.axis
+            oic, iic = s[temp_G].split(ic, factor=4)
+            s[temp_G].reorder(i, h, w, oic, iic)
+            temp_R = s.cache_write(temp_G, "global")
+            temp_S = s.cache_read(temp_R, "shared", [temp_G])
+        elif 128 < flag < 512:
+            temp_G = s.cache_read(temp, "global", [Output])
+            s[temp_G].compute_inline()
+            i, ic, h, w = s[temp_G].op.axis
+            oic, iic = s[temp_G].split(ic, factor=4)
+            s[temp_G].reorder(i, oic, h, w, iic)
+            temp_R = s.cache_write(temp_G, "global")
+            temp_S = s.cache_read(temp_R, "shared", [temp_G])
+        elif util.get_const_int(Filter.shape[3]) == 7:
+            temp_G = s.cache_read(temp, "global", [Output])
+            s[temp_G].compute_inline()
+            i, ic, h, w = s[temp_G].op.axis
+            s[temp_G].split(w, factor=4)
+            temp_R = s.cache_write(temp_G, "global")
+            temp_S = s.cache_read(temp_R, "shared", [temp_G])
+        else:
+            s[temp].compute_inline()
+            temp_S = s.cache_read(temp, "shared", [Output])
+            temp_R = temp_S
+
+        Filter_S = s.cache_read(Filter, "shared", [Output])
+
+        if Output.op in s.outputs:
+            Out = Output
+            Out_L = s.cache_write(Out, "local")
+        else:
+            Out = outs[0].op.output(0)
+            s[Output].set_scope("local")
+            Out_L = Output
+
+        if util.get_const_int(Filter.shape[3]) == 7:
+            conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L)
+        elif 128 < flag < 512:
+            conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag)
+        elif flag >= 512:
+            conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L)
+        else:
+            conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L)
+
+    def traverse(OP):
+        # inline all one-to-one-mapping operators except the last stage (output)
+        if tag.is_broadcast(OP.tag):
+            if OP not in s.outputs:
+                s[OP].compute_inline()
+            for tensor in OP.input_tensors:
+                if tensor.op.input_tensors:
+                    traverse(tensor.op)
+        # schedule conv2d_transpose_nchw
+        if 'conv2d_transpose_nchw' in OP.tag:
+            temp = OP.input_tensors[0]
+            DilatedInput = temp.op.input_tensors[0]
+            s[DilatedInput].compute_inline()
+            Filter = OP.input_tensors[1]
+            Output = OP.output(0)
+            schedule(temp, Filter, Output)
+
+    traverse(outs[0].op)
+    return s
+
+
+@generic.schedule_conv2d_transpose_nchw.register(["cuda", "gpu"])
+def schedule_conv2d_transpose_nchw(outs):
+    """Schedule for conv2d_transpose_nchw.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+        The computation graph description of conv2d_transpose_nchw
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for conv2d_transpose_nchw.
+    """
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    batch_size = util.get_const_int(outs[0].op.output(0).shape[0])
+    if batch_size > 1:
+        raise RuntimeError("Batch size: %d is too large for this schedule" % batch_size)
+    return schedule_conv2d_transpose_small_batch(outs)
diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py
index 3a335790d..2cb64407c 100644
--- a/topi/python/topi/generic/nn.py
+++ b/topi/python/topi/generic/nn.py
@@ -8,7 +8,7 @@ def _default_schedule(outs, auto_inline):
     target = tvm.target.current_target(allow_none=False)
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     if target.target_name != "llvm":
-        raise RuntimeError("schedule_pool not registered for '%s'" % target)
+        raise RuntimeError("schedule not registered for '%s'" % target)
     s = tvm.create_schedule([x.op for x in outs])
     if auto_inline:
         x = outs[0]
@@ -19,13 +19,13 @@ def _default_schedule(outs, auto_inline):
 
 @tvm.target.generic_func
 def schedule_conv2d_nchw(outs):
-    """Schedule for conv2d nchow
+    """Schedule for conv2d_nchw
 
     Parameters
     ----------
     outs: Array of Tensor
-          The computation graph description of reduce in the format
-          of an array of tensors.
+          The computation graph description of conv2d_nchw
+          in the format of an array of tensors.
 
     Returns
     -------
@@ -35,15 +35,33 @@ def schedule_conv2d_nchw(outs):
     return _default_schedule(outs, False)
 
 
+@tvm.target.generic_func
+def schedule_conv2d_transpose_nchw(outs):
+    """Schedule for conv2d_transpose_nchw
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+        The computation graph description of conv2d_transpose_nchw
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
+
+
 @tvm.target.generic_func
 def schedule_depthwise_conv2d_nchw(outs):
-    """Schedule for conv2d nchow
+    """Schedule for depthwise_conv2d_nchw
 
     Parameters
     ----------
     outs: Array of Tensor
-          The computation graph description of reduce in the format
-          of an array of tensors.
+          The computation graph description of depthwise_conv2d_nchw
+          in the format of an array of tensors.
 
     Returns
     -------
@@ -55,12 +73,12 @@ def schedule_depthwise_conv2d_nchw(outs):
 
 @tvm.target.generic_func
 def schedule_depthwise_conv2d_nhwc(outs):
-    """Schedule for depthwise nhcw conv2
+    """Schedule for depthwise_conv2d_nhwc
     Parameters
     ----------
     outs: Array of Tensor
-          The computation graph description of reduce in the format
-          of an array of tensors.
+          The computation graph description of depthwise_conv2d_nhwc
+          in the format of an array of tensors.
 
     Returns
     -------
@@ -77,8 +95,8 @@ def schedule_reduce(outs):
     Parameters
     ----------
     outs: Array of Tensor
-          The computation graph description of reduce in the format
-          of an array of tensors.
+          The computation graph description of reduce
+          in the format of an array of tensors.
 
     Returns
     -------
@@ -95,8 +113,8 @@ def schedule_softmax(outs):
     Parameters
     ----------
     outs: Array of Tensor
-          The computation graph description of reduce in the format
-          of an array of tensors.
+          The computation graph description of softmax
+          in the format of an array of tensors.
 
     Returns
     -------
@@ -113,8 +131,8 @@ def schedule_dense(outs):
     Parameters
     ----------
     outs: Array of Tensor
-          The computation graph description of reduce in the format
-          of an array of tensors.
+          The computation graph description of dense
+          in the format of an array of tensors.
 
     Returns
     -------
@@ -131,8 +149,8 @@ def schedule_pool(outs):
     Parameters
     ----------
     outs: Array of Tensor
-          The computation graph description of reduce in the format
-          of an array of tensors.
+          The computation graph description of pool
+          in the format of an array of tensors.
 
     Returns
     -------
@@ -149,8 +167,8 @@ def schedule_global_pool(outs):
     Parameters
     ----------
     outs: Array of Tensor
-          The computation graph description of reduce in the format
-          of an array of tensors.
+          The computation graph description of global pool
+          in the format of an array of tensors.
 
     Returns
     -------
diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py
index 9e4a5bfbc..b66061082 100644
--- a/topi/python/topi/nn/__init__.py
+++ b/topi/python/topi/nn/__init__.py
@@ -12,3 +12,4 @@ from .dense import *
 from .mapping import *
 from .pooling import *
 from .softmax import *
+from .conv2d_transpose import *
diff --git a/topi/python/topi/nn/conv2d_transpose.py b/topi/python/topi/nn/conv2d_transpose.py
new file mode 100644
index 000000000..33f66d95c
--- /dev/null
+++ b/topi/python/topi/nn/conv2d_transpose.py
@@ -0,0 +1,63 @@
+# pylint: disable=invalid-name, unused-variable
+"""Transposed 2D convolution operators (sometimes called Deconvolution)."""
+from __future__ import absolute_import as _abs
+import tvm
+
+from .dilate import dilate
+from .pad import pad
+from .util import get_pad_tuple
+from ..util import simplify
+
+
+def conv2d_transpose_nchw(Input, Filter, strides, padding):
+    """Transposed 2D convolution nchw forward operator.
+
+    Parameters
+    ----------
+    Input : tvm.Tensor
+        4-D with shape [batch, in_channel, in_height, in_width]
+
+    Filter : tvm.Tensor
+        4-D with shape [num_filter, in_channel, filter_height, filter_width]
+
+    strides : tuple of two ints
+        The spatial stride along height and width
+
+    padding : int or str
+        Padding size, or ['VALID', 'SAME']
+
+    Returns
+    -------
+    Output : tvm.Tensor
+        4-D with shape [batch, out_channel, out_height, out_width]
+    """
+    batch, in_c, in_h, in_w = Input.shape
+    out_c, _, filter_h, filter_w = Filter.shape
+    stride_h, stride_w = strides
+    # dilate stage
+    DilatedInput = dilate(Input, [1, 1, stride_h, stride_w], name='DilatedInput')
+    # padding stage
+    fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
+    bpad_top = filter_h - 1 - fpad_top
+    bpad_bottom = filter_h - 1 - fpad_bottom
+    bpad_left = filter_w - 1 - fpad_left
+    bpad_right = filter_w - 1 - fpad_right
+    PaddedInput = pad(DilatedInput, \
+                        [0, 0, bpad_top, bpad_left], \
+                        [0, 0, bpad_bottom, bpad_right], \
+                        name='PaddedInput')
+    # convolution stage
+    out_c = simplify(out_c)
+    out_h = simplify((in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h)
+    out_w = simplify((in_w - 1) * stride_w - fpad_left - fpad_right + filter_w)
+    dc = tvm.reduce_axis((0, in_c), name='dc')
+    dh = tvm.reduce_axis((0, filter_h), name='dh')
+    dw = tvm.reduce_axis((0, filter_w), name='dw')
+
+    Output = tvm.compute(
+        (batch, out_c, out_h, out_w),
+        lambda b, c, h, w: tvm.sum(
+            PaddedInput[b, dc, h+dh, w+dw] * Filter[c, dc, filter_h-1-dh, filter_w-1-dw],
+            axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
+
+    return Output
diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py
index 2a1866d2f..3a43a0443 100644
--- a/topi/python/topi/testing/__init__.py
+++ b/topi/python/topi/testing/__init__.py
@@ -6,6 +6,7 @@ from __future__ import absolute_import as _abs
 
 from .conv2d_hwcn_python import conv2d_hwcn_python
 from .conv2d_nchw_python import conv2d_nchw_python
+from .conv2d_transpose_nchw_python import conv2d_transpose_nchw_python
 from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
 from .dilate_python import dilate_python
 from .softmax_python import softmax_python, log_softmax_python
diff --git a/topi/python/topi/testing/conv2d_nchw_python.py b/topi/python/topi/testing/conv2d_nchw_python.py
index 169605faa..20bcddd1d 100644
--- a/topi/python/topi/testing/conv2d_nchw_python.py
+++ b/topi/python/topi/testing/conv2d_nchw_python.py
@@ -60,5 +60,5 @@ def conv2d_nchw_python(a_np, w_np, stride, padding):
                     apad = a_np[n, c]
                 out = scipy.signal.convolve2d(
                     apad, np.rot90(np.rot90(w_np[f, c])), mode='valid')
-                b_np[n, f] += out[::stride, ::stride]
+                b_np[n, f] += out[::stride_h, ::stride_w]
     return b_np
diff --git a/topi/python/topi/testing/conv2d_transpose_nchw_python.py b/topi/python/topi/testing/conv2d_transpose_nchw_python.py
new file mode 100644
index 000000000..43af160e8
--- /dev/null
+++ b/topi/python/topi/testing/conv2d_transpose_nchw_python.py
@@ -0,0 +1,51 @@
+# pylint: disable=unused-variable
+"""Transposed convolution in python"""
+import numpy as np
+import topi
+from topi.nn.util import get_pad_tuple
+
+
+def conv2d_transpose_nchw_python(a_np, w_np, stride, padding):
+    """Transposed convolution operator in NCHW layout.
+
+    Parameters
+    ----------
+    a_np : numpy.ndarray
+        4-D with shape [batch, in_channel, in_height, in_width]
+
+    w_np : numpy.ndarray
+        4-D with shape [num_filter, in_channel, filter_height, filter_width]
+
+    stride : int or a list/tuple of two ints
+        Stride size, or [stride_height, stride_width]
+
+    padding : int or str
+        Padding size, or ['VALID', 'SAME']
+
+    Returns
+    -------
+    b_np : np.ndarray
+        4-D with shape [batch, out_channel, out_height, out_width]
+    """
+    batch, in_c, in_h, in_w = a_np.shape
+    out_c, _, filter_h, filter_w = w_np.shape
+    if isinstance(stride, int):
+        stride_h = stride_w = stride
+    else:
+        stride_h, stride_w = stride
+    # dilate stage
+    dilated_a_np = topi.testing.dilate_python(a_np, [1, 1, stride_h, stride_w])
+    # padding stage
+    fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
+    bpad_top = filter_h - 1 - fpad_top
+    bpad_bottom = filter_h - 1 - fpad_bottom
+    bpad_left = filter_w - 1 - fpad_left
+    bpad_right = filter_w - 1 - fpad_right
+    padded_a_np = np.zeros((batch, in_c, dilated_a_np.shape[2]+bpad_top+bpad_bottom, \
+        dilated_a_np.shape[3]+bpad_left+bpad_right))
+    padded_a_np[:, :, bpad_top:dilated_a_np.shape[2]+bpad_top, \
+        bpad_left:dilated_a_np.shape[3]+bpad_left] = dilated_a_np
+    # convolution stage
+    rotated_w_np = np.rot90(w_np, k=2, axes=(2, 3))
+    b_np = topi.testing.conv2d_nchw_python(padded_a_np, rotated_w_np, stride=1, padding='VALID')
+    return b_np
diff --git a/topi/tests/python/test_topi_conv2d_transpose_nchw.py b/topi/tests/python/test_topi_conv2d_transpose_nchw.py
new file mode 100644
index 000000000..738831d6c
--- /dev/null
+++ b/topi/tests/python/test_topi_conv2d_transpose_nchw.py
@@ -0,0 +1,64 @@
+"""Test code for transposed convolution."""
+import numpy as np
+import tvm
+import topi
+from tvm.contrib.pickle_memoize import memoize
+from topi.util import get_const_tuple
+
+
+def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
+    in_height = in_width = in_size
+
+    A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
+    W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
+    B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], padding)
+    C = topi.nn.relu(B)
+
+    a_shape = get_const_tuple(A.shape)
+    w_shape = get_const_tuple(W.shape)
+    dtype = A.dtype
+
+    @memoize("topi.tests.test_topi_conv2d_transpose.verify_conv2d_transpose_nchw")
+    def get_ref_data():
+        a_np = np.random.uniform(size=a_shape).astype(dtype)
+        w_np = np.random.uniform(size=w_shape).astype(dtype)
+        b_np = topi.testing.conv2d_transpose_nchw_python(a_np, w_np, stride, padding)
+        c_np = np.maximum(b_np, 0)
+        return a_np, w_np, b_np, c_np
+
+    a_np, w_np, b_np, c_np = get_ref_data()
+
+    def check_device(device):
+        if not tvm.module.enabled(device):
+            print("Skip because %s is not enabled" % device)
+            return
+        with tvm.target.create(device):
+            s1 = topi.generic.schedule_conv2d_transpose_nchw([B])
+            s2 = topi.generic.schedule_conv2d_transpose_nchw([C])
+        ctx = tvm.context(device, 0)
+        a = tvm.nd.array(a_np, ctx)
+        w = tvm.nd.array(w_np, ctx)
+        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
+        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
+        with tvm.build_config(auto_unroll_max_step=128,
+                              unroll_explicit=(device != "cuda")):
+            func1 = tvm.build(s1, [A, W, B], device)
+            func2 = tvm.build(s2, [A, W, C], device)
+            func1(a, w, b)
+            func2(a, w, c)
+            np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
+            np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+
+    for device in ['cuda', 'opencl', 'metal', 'rocm']:
+        check_device(device)
+
+
+def test_conv2d_transpose_nchw():
+    verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 1, 0)
+    verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 2, 1)
+    verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 1, 0)
+    verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 2, 1)
+
+
+if __name__ == "__main__":
+    test_conv2d_transpose_nchw()
-- 
GitLab