From 42dc24a310170577f929f648f477ca2567c8bc9a Mon Sep 17 00:00:00 2001
From: Wuwei Lin <vincentl13x@gmail.com>
Date: Sat, 27 Oct 2018 11:18:24 +0800
Subject: [PATCH] [TOPI][CUDA] batched int8 conv2d (#1961)

---
 topi/python/topi/cuda/conv2d_int8.py       | 56 ++++++++++++++--------
 topi/tests/python/test_topi_conv2d_int8.py |  4 ++
 2 files changed, 40 insertions(+), 20 deletions(-)

diff --git a/topi/python/topi/cuda/conv2d_int8.py b/topi/python/topi/cuda/conv2d_int8.py
index 053c9bc6b..9d3757c35 100644
--- a/topi/python/topi/cuda/conv2d_int8.py
+++ b/topi/python/topi/cuda/conv2d_int8.py
@@ -9,7 +9,7 @@ from .tensor_intrin import dp4a
 from ..nn.conv2d import conv2d_NCHWc_int8_prepacked
 from ..nn.pad import pad
 from ..nn.util import get_pad_tuple
-from ..util import get_const_tuple, get_const_int, traverse_inline
+from ..util import get_const_tuple, traverse_inline
 
 
 def _conv2d_NCHWc_int8_arg_to_workload(data, kernel, stride, padding, out_dtype):
@@ -183,7 +183,7 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed):
             _schedule_injective(packed_data.op, s)
             _schedule_injective(packed_kernel.op, s)
     else:
-        kernel = packed_data
+        kernel = packed_kernel
 
     if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
         s[kernel].compute_inline()
@@ -191,7 +191,6 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed):
     if pad_data != packed_data:
         s[pad_data].compute_inline()
 
-    batch = get_const_int(packed_data.shape[0])
     if isinstance(stride, int):
         stride_h = stride_w = stride
     else:
@@ -210,33 +209,50 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed):
 
     # tile and bind spatial axes
     n, f, y, x, c = s[output].op.axis
+    cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
     cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
     cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
     cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
 
+    # this is the scope to attach global config inside this kernel
+    kernel_scope, n = s[output].split(n, nparts=1)
+
+    bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
     bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
     by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
     bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
 
-    # this is the scope to attach global config inside this kernel
-    kernel_scope, n = s[output].split(n, nparts=1)
-
-    max_block_z = 128
-    if batch > max_block_z:
-        _, n = s[output].split(n, factor=max_block_z)
-    s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
-    fused_byx = s[output].fuse(by, bx)
-    s[output].bind(n, tvm.thread_axis("blockIdx.z"))
+    s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi)
+    s[output].bind(bn, tvm.thread_axis("blockIdx.z"))
     s[output].bind(bf, tvm.thread_axis("blockIdx.y"))
-    s[output].bind(fused_byx, tvm.thread_axis("blockIdx.x"))
+    s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x"))
+    s[output].bind(vn, tvm.thread_axis("vthread"))
     s[output].bind(vf, tvm.thread_axis("vthread"))
     s[output].bind(vy, tvm.thread_axis("vthread"))
     s[output].bind(vx, tvm.thread_axis("vthread"))
-    s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
-    s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
-    s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
 
-    s[conv].compute_at(s[output], tx)
+    cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf
+    if cfg["fuse_yx"].val:
+        s[output].bind(tn, tvm.thread_axis("threadIdx.z"))
+        s[output].bind(tf, tvm.thread_axis("threadIdx.y"))
+        tyx = s[output].fuse(ty, tx)
+        s[output].bind(tyx, tvm.thread_axis("threadIdx.x"))
+        s[conv].compute_at(s[output], tyx)
+
+        # number of threads
+        n_tz = cfg["tile_n"].size[2]
+        n_ty = cfg["tile_f"].size[2]
+        n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
+    else:
+        s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z"))
+        s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
+        s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
+        s[conv].compute_at(s[output], tx)
+
+        # number of threads
+        n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
+        n_ty = cfg["tile_y"].size[2]
+        n_tx = cfg["tile_x"].size[2]
 
     # tile and bind reduction axes
     n, f, y, x, c = s[conv].op.axis
@@ -272,9 +288,9 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed):
             fused = s[load].fuse(n, f, y, x, oc_chunk)
             s[load].vectorize(c)
 
-        fused, tx = s[load].split(fused, factor=cfg["tile_x"].size[2])
-        fused, ty = s[load].split(fused, factor=cfg["tile_y"].size[2])
-        fused, tz = s[load].split(fused, factor=cfg["tile_f"].size[2])
+        fused, tx = s[load].split(fused, factor=n_tx)
+        fused, ty = s[load].split(fused, factor=n_ty)
+        fused, tz = s[load].split(fused, factor=n_tz)
         s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
         s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
         s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
diff --git a/topi/tests/python/test_topi_conv2d_int8.py b/topi/tests/python/test_topi_conv2d_int8.py
index af2d9e204..93a0587c6 100644
--- a/topi/tests/python/test_topi_conv2d_int8.py
+++ b/topi/tests/python/test_topi_conv2d_int8.py
@@ -172,6 +172,10 @@ def test_conv2d_nchw():
         verify_conv2d_NCHWc_int8(1, 2048,   8, 192, 1, 1, 0)
         verify_conv2d_NCHWc_int8(1, 1024,  19,  84, 3, 1, 1)
 
+        # batch > 1
+        verify_conv2d_NCHWc_int8(7,   32, 149,  32, 3, 1, 0)
+        verify_conv2d_NCHWc_int8(8,   32, 149,  32, 3, 1, 0)
+        verify_conv2d_NCHWc_int8(32,  32, 149,  32, 3, 1, 0)
 
 if __name__ == "__main__":
     test_conv2d_nchw()
-- 
GitLab