From 5c410c4c31aff07b063d3c95123261af003ab33d Mon Sep 17 00:00:00 2001
From: Wuwei Lin <vincentl13x@gmail.com>
Date: Sun, 11 Nov 2018 00:52:22 +0800
Subject: [PATCH] [TOPI][CUDA] int8 group conv2d  (#2075)

---
 nnvm/python/nnvm/top/nn.py                    |   5 +
 python/tvm/autotvm/task/nnvm_integration.py   |  14 +-
 topi/python/topi/cuda/__init__.py             |   3 +-
 topi/python/topi/cuda/group_conv2d_nchw.py    | 308 ++++++++++++++++++
 topi/python/topi/generic/nn.py                |  19 ++
 topi/python/topi/nn/conv2d.py                 |  77 +++++
 .../python/topi/testing/conv2d_nchw_python.py |  37 ++-
 topi/tests/python/common.py                   |  15 +
 topi/tests/python/test_topi_conv2d_int8.py    |  13 +-
 topi/tests/python/test_topi_group_conv2d.py   | 215 ++++++++++++
 10 files changed, 690 insertions(+), 16 deletions(-)
 create mode 100644 topi/python/topi/cuda/group_conv2d_nchw.py
 create mode 100644 topi/tests/python/test_topi_group_conv2d.py

diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py
index 03ffb46a5..34dd2303f 100644
--- a/nnvm/python/nnvm/top/nn.py
+++ b/nnvm/python/nnvm/top/nn.py
@@ -108,6 +108,9 @@ def compute_conv2d(attrs, inputs, _):
          groups == channels:
         out = topi.nn.depthwise_conv2d_nchw(
             inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
+    elif layout == "NCHW":
+        out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups,
+                                        out_dtype=out_dtype)
     elif layout == "NHWC" and \
          kernel_layout == "HWOI" and \
          groups == get_const_int(inputs[0].shape[3]) and \
@@ -143,6 +146,8 @@ def schedule_conv2d(attrs, outs, target):
             return topi.generic.schedule_depthwise_conv2d_nchw(outs)
         elif groups == channels and layout == "NHWC" and kernel_layout == "HWOI":
             return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
+        elif layout == "NCHW":
+            return topi.generic.schedule_group_conv2d_nchw(outs)
         else:
             raise ValueError("No compatible schedule")
 
diff --git a/python/tvm/autotvm/task/nnvm_integration.py b/python/tvm/autotvm/task/nnvm_integration.py
index 80b62229a..6a07194a5 100644
--- a/python/tvm/autotvm/task/nnvm_integration.py
+++ b/python/tvm/autotvm/task/nnvm_integration.py
@@ -58,7 +58,8 @@ class TaskExtractEnv:
         # NOTE: To add more symbols, you only need to change the following lists
         # nnvm symbol -> topi compute
         self.symbol2topi = {
-            nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw],
+            nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
+                              topi.nn.group_conv2d_nchw],
             nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
             nnvm.sym.dense: [topi.nn.dense],
         }
@@ -67,6 +68,7 @@ class TaskExtractEnv:
         self.topi_to_task = {
             topi.nn.conv2d: "topi_nn_conv2d",
             topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw",
+            topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
             topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
             topi.nn.dense: "topi_nn_dense",
         }
@@ -76,6 +78,7 @@ class TaskExtractEnv:
                              topi.generic.schedule_conv2d_nhwc],
             topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw,
                                             topi.generic.schedule_depthwise_conv2d_nhwc],
+            topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
             topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
             topi.nn.dense: [topi.generic.schedule_dense],
         }
@@ -143,6 +146,15 @@ class TaskExtractEnv:
             s = topi.generic.schedule_depthwise_conv2d_nchw([C])
             return s, [A, W, C]
 
+        @register("topi_nn_group_conv2d_nchw")
+        def _topi_nn_group_conv2d_nchw(*args, **kwargs):
+            assert not kwargs, "Do not support kwargs in template function call"
+            args = deserialize_args(args)
+            A, W = args[:2]
+            C = topi.nn.group_conv2d_nchw(*args, **kwargs)
+            s = topi.generic.schedule_group_conv2d_nchw([C])
+            return s, [A, W, C]
+
         @register("topi_nn_conv2d_transpose_nchw")
         def _topi_nn_conv2d_transpose_nchw(*args, **kwargs):
             assert not kwargs, "Do not support kwargs in template function call"
diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py
index e1db2c6fd..28d2eb258 100644
--- a/topi/python/topi/cuda/__init__.py
+++ b/topi/python/topi/cuda/__init__.py
@@ -2,10 +2,11 @@
 """CUDA specific declaration and schedules."""
 from __future__ import absolute_import as _abs
 
-from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw
+from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw, group_conv2d_nchw
 from .conv2d_hwcn import schedule_conv2d_hwcn
 from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc
 from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
+from .group_conv2d_nchw import schedule_conv2d_nchw_cuda
 from .reduction import schedule_reduce
 from .softmax import schedule_softmax
 from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
diff --git a/topi/python/topi/cuda/group_conv2d_nchw.py b/topi/python/topi/cuda/group_conv2d_nchw.py
new file mode 100644
index 000000000..739691131
--- /dev/null
+++ b/topi/python/topi/cuda/group_conv2d_nchw.py
@@ -0,0 +1,308 @@
+# pylint: disable=invalid-name
+"""The template for cuda group_conv2d_nchw"""
+import tvm
+from tvm import autotvm
+
+from .injective import _schedule_injective
+from .tensor_intrin import dp4a
+from ..nn.pad import pad
+from ..nn.util import get_pad_tuple
+from ..util import traverse_inline, get_const_tuple, get_const_int
+from .. import nn, generic
+
+
+@autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], ['direct', 'int8'])
+def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups,
+                           out_dtype='float32'):
+    """Group convolution operator in NCHW layout.
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        4-D with shape [batch, in_channel, in_height, in_width] or
+        5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
+
+    kernel : tvm.Tensor
+        4-D with shape [num_filter, in_channel // groups, filter_height, filter_width] or
+        6-D with shape [num_filter_chunk, in_channel_chunk // groups, filter_height,
+        filter_width, num_filter_block, in_channel_block]
+
+    stride : int or a list/tuple of two ints
+        Stride size, or [stride_height, stride_width]
+
+    padding : int or str
+        Padding size, or ['VALID', 'SAME']
+
+    dilation : int or a list/tuple of two ints
+        dilation size, or [dilation_height, dilation_width]
+
+    groups : int
+        number of groups
+
+    out_dtype : str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    Output : tvm.Tensor
+        5-D with shape [batch, out_channel, out_height, out_width, out_channel_block]
+    """
+    ic_block_factor = 4
+    oc_block_factor = 4
+
+    pre_computed = len(kernel.shape) == 6
+    if not pre_computed:
+        batch, channels, height, width = get_const_tuple(data.shape)
+        out_channels, in_channels, kernel_h, kernel_w = get_const_tuple(
+            kernel.shape)
+
+        assert channels % groups == 0, "input channels must divide group size"
+        assert out_channels % groups == 0, "output channels must divide group size"
+        assert channels % ic_block_factor == 0, \
+            "Number of input channels per group must divide {}".format(ic_block_factor)
+        assert out_channels % 4 == 0, \
+            "Number of output channels per group must divide {}".format(oc_block_factor)
+
+        packed_data = tvm.compute((batch, channels // ic_block_factor, height, width,
+                                   ic_block_factor),
+                                  lambda n, c, h, w, vc: data[n, c*ic_block_factor + vc, h, w],
+                                  name="packed_data")
+        packed_kernel = tvm.compute(
+            (out_channels // oc_block_factor, in_channels // ic_block_factor, kernel_h, kernel_w,
+             oc_block_factor, ic_block_factor),
+            lambda oc_chunk, ic_chunk, kh, kw, oc_block, ic_block:
+            kernel[oc_chunk * oc_block_factor + oc_block,
+                   ic_chunk * ic_block_factor + ic_block, kh, kw],
+            name="packed_kernel")
+    else:
+        packed_data = data
+        packed_kernel = kernel
+
+    batch, ic_chunk, in_height, in_width, _ = get_const_tuple(
+        packed_data.shape)
+    oc_chunk, _, kernel_h, kernel_w, oc_block, ic_block = get_const_tuple(
+        packed_kernel.shape)
+
+    if isinstance(stride, int):
+        stride_h = stride_w = stride
+    else:
+        stride_h, stride_w = stride
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (kernel_h, kernel_w))
+    # compute graph
+    pad_before = [0, 0, pad_top, pad_left, 0]
+    pad_after = [0, 0, pad_down, pad_right, 0]
+    pad_data = pad(packed_data, pad_before, pad_after, name="pad_data")
+
+    # compute the output shape
+    out_height = (in_height - (kernel_h - 1) * dilation_h -
+                  1 + pad_top + pad_down) // stride_h + 1
+    out_width = (in_width - (kernel_w - 1) * dilation_w -
+                 1 + pad_left + pad_right) // stride_w + 1
+
+    oshape = (batch, oc_chunk, out_height, out_width, oc_block)
+
+    icc = tvm.reduce_axis((0, ic_chunk // groups), name='ic_chunk')
+    icb = tvm.reduce_axis((0, ic_block_factor), name='ic_block')
+    kh = tvm.reduce_axis((0, kernel_h), name='kh')
+    kw = tvm.reduce_axis((0, kernel_w), name='kw')
+
+    conv = tvm.compute(oshape, lambda n, occ, oh, ow, ocb:
+                       tvm.sum(pad_data[n, occ//(oc_chunk//groups)*(ic_chunk//groups)+icc,
+                                        oh*stride_h+kh*dilation_h, ow*stride_w+kw*dilation_w, icb]
+                               .astype('int32') *
+                               packed_kernel[occ, icc,
+                                             kh, kw, ocb, icb]
+                               .astype('int32'),
+                               axis=[icc, kh, kw, icb]))
+
+    output = tvm.compute(oshape, lambda *index: conv(*index).astype(out_dtype),
+                         tag='group_conv2d_NCHWc_int8')
+    num_flop = batch * oc_chunk * oc_block * out_height * out_width * \
+        ic_chunk * ic_block * kernel_h * kernel_w * 2 // groups
+    cfg.add_flop(num_flop)
+
+    return output
+
+
+_dp4a = dp4a('shared', 'shared', 'local')
+
+
+def schedule_group_conv2d_NCHWc_int8(cfg, s, output):
+    """Schedule group conv2d int8 NCHWc template"""
+    workload = output.op.attrs["workload"]
+    groups = get_const_int(workload[6])
+
+    conv = output.op.input_tensors[0]
+    packed_data, packed_kernel = conv.op.input_tensors
+
+    if isinstance(packed_data.op, tvm.tensor.ComputeOp) and "pad" in packed_data.op.tag:
+        pad_data = packed_data
+        packed_data = pad_data.op.input_tensors[0]
+    else:
+        pad_data = packed_data
+
+    if autotvm.GLOBAL_SCOPE.in_tuning:
+        # skip this part during tuning to make records accurate
+        # this part will be pre-computed during NNVM's pre-compute optimization pass
+        s[packed_data].pragma(s[packed_data].op.axis[0], "debug_skip_region")
+        s[packed_kernel].pragma(
+            s[packed_kernel].op.axis[0], "debug_skip_region")
+    else:
+        if isinstance(packed_kernel.op, tvm.tensor.ComputeOp) and\
+                packed_kernel.name == 'packed_kernel':
+            # data and kernel are not pre-computed, schedule layout transform here
+            _schedule_injective(packed_data.op, s)
+            _schedule_injective(packed_kernel.op, s)
+
+    if pad_data != packed_data:
+        s[pad_data].compute_inline()
+
+    # create cache stage
+    AA = s.cache_read(pad_data, 'shared', [conv])
+    WW = s.cache_read(packed_kernel, 'shared', [conv])
+
+    s[conv].set_scope('local')
+
+    # handle bias
+    if output.op not in s.outputs:
+        s[output].compute_inline()
+        output = s.outputs[0].output(0)
+
+    oc_chunk = get_const_int(output.shape[1])
+    # tile and bind spatial axes
+    n, f, y, x, c = s[output].op.axis
+    cfg.define_split("tile_n", n, num_outputs=4)
+    cfg.define_split("tile_g", cfg.axis(groups), num_outputs=2)
+    cfg.define_split("tile_f", cfg.axis(oc_chunk // groups), num_outputs=4)
+    cfg.define_split("tile_y", y, num_outputs=4)
+    cfg.define_split("tile_x", x, num_outputs=4)
+
+    # this is the scope to attach global config inside this kernel
+    kernel_scope, n = s[output].split(n, nparts=1)
+
+    g, f = s[output].split(f, nparts=groups)
+    s[output].bind(n, tvm.thread_axis('blockIdx.z'))
+    bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
+    bg, vg = cfg["tile_g"].apply(s, output, g)
+    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
+    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
+    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
+
+    s[output].reorder(bn, bg, bf, by, bx, vn, vg, vf, vy,
+                      vx, tn, tf, ty, tx, ni, fi, yi, xi)
+    s[output].bind(bn, tvm.thread_axis("blockIdx.z"))
+    s[output].bind(s[output].fuse(bg, bf), tvm.thread_axis("blockIdx.y"))
+    s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x"))
+    s[output].bind(vn, tvm.thread_axis("vthread"))
+    s[output].bind(vg, tvm.thread_axis("vthread"))
+    s[output].bind(vf, tvm.thread_axis("vthread"))
+    s[output].bind(vy, tvm.thread_axis("vthread"))
+    s[output].bind(vx, tvm.thread_axis("vthread"))
+    cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
+    if cfg["fuse_yx"].val:
+        s[output].bind(tn, tvm.thread_axis("threadIdx.z"))
+        s[output].bind(tf, tvm.thread_axis("threadIdx.y"))
+        tyx = s[output].fuse(ty, tx)
+        s[output].bind(tyx, tvm.thread_axis("threadIdx.x"))
+        s[conv].compute_at(s[output], tyx)
+
+        # number of threads
+        n_tz = cfg["tile_n"].size[2]
+        n_ty = cfg["tile_f"].size[2]
+        n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
+    else:
+        s[output].bind(tn, tvm.thread_axis("threadIdx.z"))
+        s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z"))
+        s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
+        s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
+        s[conv].compute_at(s[output], tx)
+
+        # number of threads
+        n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
+        n_ty = cfg["tile_y"].size[2]
+        n_tx = cfg["tile_x"].size[2]
+
+    # tile and bind reduction axes
+    n, f, y, x, c = s[conv].op.axis
+    rc, ry, rx, rc_block = s[conv].op.reduce_axis
+    cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=2)
+    cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=2)
+    cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=2)
+    rco, rci = cfg['tile_rc'].apply(s, conv, rc)
+    ryo, ryi = cfg['tile_ry'].apply(s, conv, ry)
+    rxo, rxi = cfg['tile_rx'].apply(s, conv, rx)
+
+    s[conv].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x, c, rc_block)
+    _, rc_block = s[conv].split(rc_block, factor=4)
+    s[conv].tensorize(rc_block, _dp4a)
+
+    s[AA].compute_at(s[conv], rxo)
+    s[WW].compute_at(s[conv], rxo)
+
+    # cooperative fetching
+    for load in [AA, WW]:
+        c = s[load].op.axis[-1]
+        c_outer, c = s[load].split(c, factor=4)
+        s[load].vectorize(c)
+        fused = s[load].op.axis[:-1] + [c_outer]
+        fused = s[load].fuse(*fused)
+
+        fused, tx = s[load].split(fused, factor=n_tx)
+        fused, ty = s[load].split(fused, factor=n_ty)
+        fused, tz = s[load].split(fused, factor=n_tz)
+        s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
+        s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
+        s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
+
+    # double buffer
+    cfg.define_knob('AA_double_buffer', [0, 1])
+    cfg.define_knob('WW_double_buffer', [0, 1])
+    if cfg['AA_double_buffer'].val:
+        s[AA].double_buffer()
+    if cfg['WW_double_buffer'].val:
+        s[WW].double_buffer()
+
+    # unroll
+    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
+    s[output].pragma(kernel_scope, 'auto_unroll_max_step',
+                     cfg['auto_unroll_max_step'].val)
+    s[output].pragma(kernel_scope, 'unroll_explicit', False)
+
+    return s
+
+
+@autotvm.register_topi_schedule(generic.schedule_group_conv2d_nchw,
+                                ["cuda", "gpu"], ["direct", "int8"])
+def schedule_conv2d_nchw_cuda(cfg, outs):
+    """TOPI schedule callback of group conv2d for cuda gpu
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    outs: Array of Tensor
+        The computation graph description of conv2d
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for group conv2d.
+    """
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == "group_conv2d_NCHWc_int8":
+            schedule_group_conv2d_NCHWc_int8(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py
index a48b85638..0f4b51b81 100644
--- a/topi/python/topi/generic/nn.py
+++ b/topi/python/topi/generic/nn.py
@@ -173,6 +173,25 @@ def schedule_depthwise_conv2d_nhwc(outs):
     """
     return _default_schedule(outs, False)
 
+
+@tvm.target.generic_func
+def schedule_group_conv2d_nchw(outs):
+    """Schedule for conv2d_nchw
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of group_conv2d_nchw
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
+
+
 @tvm.target.generic_func
 def schedule_bitserial_conv2d_nchw(outs):
     """Schedule for bitserial_conv2d_nchw
diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py
index 2b8888652..d4b9393c1 100644
--- a/topi/python/topi/nn/conv2d.py
+++ b/topi/python/topi/nn/conv2d.py
@@ -403,3 +403,80 @@ def conv2d_winograd_without_weight_transform(input, filter, strides, padding, di
         4-D with shape [batch, out_height, out_width, out_channel]
     """
     raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform")
+
+
+@tvm.target.generic_func
+def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtype=None):
+    """Group convolution operator in NCHW layout.
+
+    Parameters
+    ----------
+    Input : tvm.Tensor
+        4-D with shape [batch, in_channel, in_height, in_width]
+
+    Filter : tvm.Tensor
+        4-D with shape [num_filter, in_channel // groups, filter_height, filter_width]
+
+    stride : int or a list/tuple of two ints
+        Stride size, or [stride_height, stride_width]
+
+    padding : int or str
+        Padding size, or ['VALID', 'SAME']
+
+    dilation : int or a list/tuple of two ints
+        dilation size, or [dilation_height, dilation_width]
+
+    groups : int
+        number of groups
+
+    out_dtype : str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    Output : tvm.Tensor
+        4-D with shape [batch, out_channel, out_height, out_width]
+    """
+    if out_dtype is None:
+        out_dtype = Input.dtype
+    assert isinstance(stride, int) or len(stride) == 2
+    assert isinstance(dilation, int) or len(dilation) == 2
+    if isinstance(stride, int):
+        stride_h = stride_w = stride
+    else:
+        stride_h, stride_w = stride
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    batch, in_channel, in_height, in_width = get_const_tuple(Input.shape)
+    num_filter, _, kernel_h, kernel_w = get_const_tuple(Filter.shape)
+
+    assert in_channel % groups == 0, "input channels must divide group size"
+    assert num_filter % groups == 0, "output channels must divide group size"
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (kernel_h, kernel_w))
+    # compute the output shape
+    out_channel = num_filter
+    out_height = simplify(
+        (in_height - (kernel_h - 1) * dilation_h - 1 + pad_top + pad_down) // stride_h + 1)
+    out_width = simplify(
+        (in_width - (kernel_w - 1) * dilation_w - 1 + pad_left + pad_right) // stride_w + 1)
+    # compute graph
+    pad_before = [0, 0, pad_top, pad_left]
+    pad_after = [0, 0, pad_down, pad_right]
+    temp = pad(Input, pad_before, pad_after, name="pad_temp")
+    rc = tvm.reduce_axis((0, in_channel // groups), name='rc')
+    ry = tvm.reduce_axis((0, kernel_h), name='ry')
+    rx = tvm.reduce_axis((0, kernel_w), name='rx')
+    return tvm.compute(
+        (batch, out_channel, out_height, out_width),
+        lambda nn, ff, yy, xx: tvm.sum(
+            temp[nn, ff // (num_filter//groups) * (in_channel//groups) + rc,
+                 yy * stride_h + ry * dilation_h,
+                 xx * stride_w + rx * dilation_w].astype(out_dtype) *
+            Filter[ff, rc, ry, rx].astype(out_dtype),
+            axis=[rc, ry, rx]), tag="conv2d_nchw")
diff --git a/topi/python/topi/testing/conv2d_nchw_python.py b/topi/python/topi/testing/conv2d_nchw_python.py
index 4a40d02d2..7d2aa0d0f 100644
--- a/topi/python/topi/testing/conv2d_nchw_python.py
+++ b/topi/python/topi/testing/conv2d_nchw_python.py
@@ -4,8 +4,8 @@ import numpy as np
 import scipy.signal
 
 
-def conv2d_nchw_python(a_np, w_np, stride, padding):
-    """Convolution operator in HWCN layout.
+def _conv2d_nchw_python(a_np, w_np, stride, padding):
+    """Convolution operator in NCHW layout.
 
     Parameters
     ----------
@@ -66,3 +66,36 @@ def conv2d_nchw_python(a_np, w_np, stride, padding):
                     apad, np.rot90(np.rot90(w_np[f, c])), mode='valid')
                 b_np[n, f] += out[::stride_h, ::stride_w]
     return b_np
+
+
+def conv2d_nchw_python(a_np, w_np, stride, padding, groups=1):
+    """Convolution operator in NCHW layout.
+
+    Parameters
+    ----------
+    a_np : numpy.ndarray
+        4-D with shape [batch, in_channel, in_height, in_width]
+
+    w_np : numpy.ndarray
+        4-D with shape [num_filter, in_channel // groups, filter_height, filter_width]
+
+    stride : int or a list/tuple of two ints
+        Stride size, or [stride_height, stride_width]
+
+    padding : int or str or a list/tuple of two ints
+        Padding size, or ['VALID', 'SAME'], or [pad_height, pad_width]
+
+    groups : int
+        Number of groups
+
+    Returns
+    -------
+    b_np : np.ndarray
+        4-D with shape [batch, out_channel, out_height, out_width]
+    """
+    a_slices = np.array_split(a_np, groups, axis=1)
+    w_slices = np.array_split(w_np, groups, axis=0)
+    b_slices = [_conv2d_nchw_python(a_slice, w_slice, stride, padding)
+                for a_slice, w_slice in zip(a_slices, w_slices)]
+    b_np = np.concatenate(b_slices, axis=1)
+    return b_np
diff --git a/topi/tests/python/common.py b/topi/tests/python/common.py
index 763db5f86..f34f3b331 100644
--- a/topi/tests/python/common.py
+++ b/topi/tests/python/common.py
@@ -1,5 +1,9 @@
 """Common utility for topi test"""
 
+from tvm import autotvm
+from tvm.autotvm.task.space import FallbackConfigEntity
+
+
 def get_all_backend():
     """return all supported target
 
@@ -10,3 +14,14 @@ def get_all_backend():
     """
     return ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx',
             'llvm -device=arm_cpu', 'opencl -device=mali', 'aocl_sw_emu']
+
+
+class NCHWcInt8Fallback(autotvm.FallbackContext):
+    def _query_inside(self, target, workload):
+        key = (target, workload)
+        if key in self.memory:
+            return self.memory[key]
+        cfg = FallbackConfigEntity()
+        cfg.template_key = 'int8'
+        self.memory[key] = cfg
+        return cfg
diff --git a/topi/tests/python/test_topi_conv2d_int8.py b/topi/tests/python/test_topi_conv2d_int8.py
index fd5e91eed..272a72f82 100644
--- a/topi/tests/python/test_topi_conv2d_int8.py
+++ b/topi/tests/python/test_topi_conv2d_int8.py
@@ -9,7 +9,7 @@ import topi.testing
 from tvm.contrib.pickle_memoize import memoize
 from topi.util import get_const_tuple
 
-from common import get_all_backend
+from common import get_all_backend, NCHWcInt8Fallback
 
 oc_block_factor = 4
 
@@ -88,17 +88,6 @@ def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, str
         check_device(device)
 
 
-class NCHWcInt8Fallback(autotvm.FallbackContext):
-    def _query_inside(self, target, workload):
-        key = (target, workload)
-        if key in self.memory:
-            return self.memory[key]
-        cfg = FallbackConfigEntity()
-        cfg.template_key = 'int8'
-        self.memory[key] = cfg
-        return cfg
-
-
 def test_conv2d_nchw():
     with NCHWcInt8Fallback():
         # ResNet18 workloads where channels in / out are multiple of oc_block_factor
diff --git a/topi/tests/python/test_topi_group_conv2d.py b/topi/tests/python/test_topi_group_conv2d.py
new file mode 100644
index 000000000..c1ff656fc
--- /dev/null
+++ b/topi/tests/python/test_topi_group_conv2d.py
@@ -0,0 +1,215 @@
+"""Example code to do group convolution."""
+
+import numpy as np
+import tvm
+from tvm import autotvm
+from tvm.autotvm.task.space import FallbackConfigEntity
+import topi
+import topi.testing
+from tvm.contrib.pickle_memoize import memoize
+from topi.util import get_const_tuple
+
+from common import get_all_backend, NCHWcInt8Fallback
+
+
+def verify_group_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups, add_bias=False, add_relu=False):
+    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d)" %
+        (batch, in_channel, in_size, num_filter,
+         kernel, stride, padding, dilation, groups))
+
+    in_height = in_width = in_size
+
+    A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
+    W = tvm.placeholder((num_filter, in_channel // groups, kernel, kernel), name='W')
+    bias = tvm.placeholder((num_filter, 1, 1), name='bias')
+
+    a_shape = get_const_tuple(A.shape)
+    w_shape = get_const_tuple(W.shape)
+    bias_shape = get_const_tuple(bias.shape)
+    dtype = A.dtype
+
+    @memoize("topi.tests.test_topi_group_conv2d.verify_group_conv2d_nchw")
+    def get_ref_data():
+        a_np = np.random.uniform(size=a_shape).astype(dtype)
+        w_np = np.random.uniform(size=w_shape).astype(dtype)
+        b_np = np.random.uniform(size=bias_shape).astype(dtype)
+        dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+        c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding, groups).astype(dtype)
+
+        if add_bias:
+            b_np = np.random.uniform(size=bias_shape).astype(dtype)
+            c_np += b_np
+        if add_relu:
+            c_np = np.maximum(c_np, 0)
+
+        return a_np, w_np, b_np, c_np
+
+    a_np, w_np, b_np, c_np = get_ref_data()
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version):
+            print("Skip because int8 intrinsics are not available")
+            return
+
+        print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            C = topi.nn.group_conv2d_nchw(A, W, stride, padding, dilation, groups, out_dtype=dtype)
+            if add_bias:
+                C = topi.add(C, bias)
+            if add_relu:
+                C = topi.nn.relu(C)
+            s = topi.generic.schedule_group_conv2d_nchw([C])
+
+        a = tvm.nd.array(a_np, ctx)
+        w = tvm.nd.array(w_np, ctx)
+        b = tvm.nd.array(b_np, ctx)
+        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
+        if add_bias:
+            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" %\
+                (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups))
+            func(a, w, b, c)
+        else:
+            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" % \
+            (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups))
+            func(a, w, c)
+        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+
+    for device in ["llvm"]:
+        check_device(device)
+
+
+oc_block_factor = 4
+
+
+def verify_group_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups, add_bias=False, add_relu=False):
+    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d)" %
+        (batch, in_channel, in_size, num_filter,
+         kernel, stride, padding, dilation, groups))
+
+    in_height = in_width = in_size
+
+    A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='int8')
+    W = tvm.placeholder((num_filter, in_channel // groups, kernel, kernel), name='W', dtype='int8')
+    bias = tvm.placeholder((num_filter // oc_block_factor, 1, 1, oc_block_factor), name='bias',
+                            dtype='int8')
+
+    a_shape = get_const_tuple(A.shape)
+    w_shape = get_const_tuple(W.shape)
+    bias_shape = get_const_tuple(bias.shape)
+    dtype = A.dtype
+
+    @memoize("topi.tests.test_topi_group_conv2d.verify_group_conv2d_NCHWc_int8")
+    def get_ref_data():
+        a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
+        w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
+        b_np = np.random.uniform(size=bias_shape).astype(dtype)
+        dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+        c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding, groups).astype(dtype)
+
+        # convert to NCHWc
+        _, _, out_height, out_width = c_np.shape
+        c_np = c_np.reshape((batch, num_filter // oc_block_factor, oc_block_factor, \
+                out_height, out_width)).transpose(0, 1, 3, 4, 2)
+
+        if add_bias:
+            b_np = np.random.uniform(size=bias_shape).astype(dtype)
+            c_np += b_np
+        if add_relu:
+            c_np = np.maximum(c_np, 0)
+
+        return a_np, w_np, b_np, c_np
+
+    a_np, w_np, b_np, c_np = get_ref_data()
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version):
+            print("Skip because int8 intrinsics are not available")
+            return
+
+        print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            C = topi.nn.group_conv2d_nchw(A, W, stride, padding, dilation, groups, out_dtype=dtype)
+            if add_bias:
+                C = topi.add(C, bias)
+            if add_relu:
+                C = topi.nn.relu(C)
+            s = topi.generic.schedule_group_conv2d_nchw([C])
+
+        a = tvm.nd.array(a_np, ctx)
+        w = tvm.nd.array(w_np, ctx)
+        b = tvm.nd.array(b_np, ctx)
+        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
+        if add_bias:
+            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" %\
+                (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups))
+            func(a, w, b, c)
+        else:
+            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" % \
+            (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups))
+            func(a, w, c)
+        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+
+    for device in ["cuda"]:
+        check_device(device)
+
+
+def test_group_conv2d_nchw():
+    # ResNeXt-50 workload
+    verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 1, 32)
+    verify_group_conv2d_nchw(1, 256, 56, 256, 3, 2, 1, 1, 32)
+    verify_group_conv2d_nchw(1, 256, 28, 256, 3, 1, 1, 1, 32)
+    verify_group_conv2d_nchw(1, 512, 28, 512, 3, 2, 1, 1, 32)
+    verify_group_conv2d_nchw(1, 512, 14, 512, 3, 1, 1, 1, 32)
+    verify_group_conv2d_nchw(1, 1024, 14, 1024, 3, 2, 1, 1, 32)
+    verify_group_conv2d_nchw(1, 1024, 7, 1024, 3, 1, 1, 1, 32)
+
+    # bias, relu
+    verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True)
+    verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 1, 32, add_bias=True)
+    verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True,
+                             add_bias=True)
+
+    # dilation
+    verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 2, 32)
+
+    # batch size
+    verify_group_conv2d_nchw(2, 128, 56, 128, 3, 1, 1, 1, 32)
+    verify_group_conv2d_nchw(9, 128, 56, 128, 3, 1, 1, 1, 32)
+
+
+
+def test_group_conv2d_NCHWc_int8():
+    with NCHWcInt8Fallback():
+        # ResNeXt-50 workload
+        verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32)
+        verify_group_conv2d_NCHWc_int8(1, 256, 56, 256, 3, 2, 1, 1, 32)
+        verify_group_conv2d_NCHWc_int8(1, 256, 28, 256, 3, 1, 1, 1, 32)
+        verify_group_conv2d_NCHWc_int8(1, 512, 28, 512, 3, 2, 1, 1, 32)
+        verify_group_conv2d_NCHWc_int8(1, 512, 14, 512, 3, 1, 1, 1, 32)
+        verify_group_conv2d_NCHWc_int8(1, 1024, 14, 1024, 3, 2, 1, 1, 32)
+        verify_group_conv2d_NCHWc_int8(1, 1024, 7, 1024, 3, 1, 1, 1, 32)
+
+        # bias, relu
+        verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True)
+        verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_bias=True)
+        verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True,
+                                       add_bias=True)
+        # dilation
+        verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 2, 32)
+
+        # batch size
+        verify_group_conv2d_NCHWc_int8(2, 128, 56, 128, 3, 1, 1, 1, 32)
+        verify_group_conv2d_NCHWc_int8(9, 128, 56, 128, 3, 1, 1, 1, 32)
+
+
+if __name__ == "__main__":
+    test_group_conv2d_nchw()
+    test_group_conv2d_NCHWc_int8()
-- 
GitLab