From 42dc24a310170577f929f648f477ca2567c8bc9a Mon Sep 17 00:00:00 2001 From: Wuwei Lin <vincentl13x@gmail.com> Date: Sat, 27 Oct 2018 11:18:24 +0800 Subject: [PATCH] [TOPI][CUDA] batched int8 conv2d (#1961) --- topi/python/topi/cuda/conv2d_int8.py | 56 ++++++++++++++-------- topi/tests/python/test_topi_conv2d_int8.py | 4 ++ 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_int8.py b/topi/python/topi/cuda/conv2d_int8.py index 053c9bc6b..9d3757c35 100644 --- a/topi/python/topi/cuda/conv2d_int8.py +++ b/topi/python/topi/cuda/conv2d_int8.py @@ -9,7 +9,7 @@ from .tensor_intrin import dp4a from ..nn.conv2d import conv2d_NCHWc_int8_prepacked from ..nn.pad import pad from ..nn.util import get_pad_tuple -from ..util import get_const_tuple, get_const_int, traverse_inline +from ..util import get_const_tuple, traverse_inline def _conv2d_NCHWc_int8_arg_to_workload(data, kernel, stride, padding, out_dtype): @@ -183,7 +183,7 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed): _schedule_injective(packed_data.op, s) _schedule_injective(packed_kernel.op, s) else: - kernel = packed_data + kernel = packed_kernel if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: s[kernel].compute_inline() @@ -191,7 +191,6 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed): if pad_data != packed_data: s[pad_data].compute_inline() - batch = get_const_int(packed_data.shape[0]) if isinstance(stride, int): stride_h = stride_w = stride else: @@ -210,33 +209,50 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed): # tile and bind spatial axes n, f, y, x, c = s[output].op.axis + cfg.define_split("tile_n", cfg.axis(n), num_outputs=4) cfg.define_split("tile_f", cfg.axis(f), num_outputs=4) cfg.define_split("tile_y", cfg.axis(y), num_outputs=4) cfg.define_split("tile_x", cfg.axis(x), num_outputs=4) + # this is the scope to attach global config inside this kernel + kernel_scope, n = s[output].split(n, nparts=1) + + bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n) bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) - # this is the scope to attach global config inside this kernel - kernel_scope, n = s[output].split(n, nparts=1) - - max_block_z = 128 - if batch > max_block_z: - _, n = s[output].split(n, factor=max_block_z) - s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi) - fused_byx = s[output].fuse(by, bx) - s[output].bind(n, tvm.thread_axis("blockIdx.z")) + s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi) + s[output].bind(bn, tvm.thread_axis("blockIdx.z")) s[output].bind(bf, tvm.thread_axis("blockIdx.y")) - s[output].bind(fused_byx, tvm.thread_axis("blockIdx.x")) + s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x")) + s[output].bind(vn, tvm.thread_axis("vthread")) s[output].bind(vf, tvm.thread_axis("vthread")) s[output].bind(vy, tvm.thread_axis("vthread")) s[output].bind(vx, tvm.thread_axis("vthread")) - s[output].bind(tf, tvm.thread_axis("threadIdx.z")) - s[output].bind(ty, tvm.thread_axis("threadIdx.y")) - s[output].bind(tx, tvm.thread_axis("threadIdx.x")) - s[conv].compute_at(s[output], tx) + cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf + if cfg["fuse_yx"].val: + s[output].bind(tn, tvm.thread_axis("threadIdx.z")) + s[output].bind(tf, tvm.thread_axis("threadIdx.y")) + tyx = s[output].fuse(ty, tx) + s[output].bind(tyx, tvm.thread_axis("threadIdx.x")) + s[conv].compute_at(s[output], tyx) + + # number of threads + n_tz = cfg["tile_n"].size[2] + n_ty = cfg["tile_f"].size[2] + n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2] + else: + s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z")) + s[output].bind(ty, tvm.thread_axis("threadIdx.y")) + s[output].bind(tx, tvm.thread_axis("threadIdx.x")) + s[conv].compute_at(s[output], tx) + + # number of threads + n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2] + n_ty = cfg["tile_y"].size[2] + n_tx = cfg["tile_x"].size[2] # tile and bind reduction axes n, f, y, x, c = s[conv].op.axis @@ -272,9 +288,9 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed): fused = s[load].fuse(n, f, y, x, oc_chunk) s[load].vectorize(c) - fused, tx = s[load].split(fused, factor=cfg["tile_x"].size[2]) - fused, ty = s[load].split(fused, factor=cfg["tile_y"].size[2]) - fused, tz = s[load].split(fused, factor=cfg["tile_f"].size[2]) + fused, tx = s[load].split(fused, factor=n_tx) + fused, ty = s[load].split(fused, factor=n_ty) + fused, tz = s[load].split(fused, factor=n_tz) s[load].bind(tz, tvm.thread_axis("threadIdx.z")) s[load].bind(ty, tvm.thread_axis("threadIdx.y")) s[load].bind(tx, tvm.thread_axis("threadIdx.x")) diff --git a/topi/tests/python/test_topi_conv2d_int8.py b/topi/tests/python/test_topi_conv2d_int8.py index af2d9e204..93a0587c6 100644 --- a/topi/tests/python/test_topi_conv2d_int8.py +++ b/topi/tests/python/test_topi_conv2d_int8.py @@ -172,6 +172,10 @@ def test_conv2d_nchw(): verify_conv2d_NCHWc_int8(1, 2048, 8, 192, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 1024, 19, 84, 3, 1, 1) + # batch > 1 + verify_conv2d_NCHWc_int8(7, 32, 149, 32, 3, 1, 0) + verify_conv2d_NCHWc_int8(8, 32, 149, 32, 3, 1, 0) + verify_conv2d_NCHWc_int8(32, 32, 149, 32, 3, 1, 0) if __name__ == "__main__": test_conv2d_nchw() -- GitLab