diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py index b6bd35768112fedb2015ca369d4fc2fa1f9c2008..a59a996242331f5cf47add07ca2cafb4c9766ee2 100644 --- a/topi/python/topi/cuda/conv2d_nchw.py +++ b/topi/python/topi/cuda/conv2d_nchw.py @@ -1,15 +1,17 @@ -#pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches +#pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches, line-too-long """Schedule for conv2d_nchw with auto fusion""" import tvm from .. import util from .. import tag from .. import generic -def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): +def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag): """Schedule conv2d for specific feature_in_out_filter pattern""" # scheduler params ofactor = 16 hfactor = 2 + if flag >= 96: + hfactor = 4 ow_size = util.get_const_int(Out.shape[3]) num_thread = ow_size * hfactor vthread = ofactor @@ -22,7 +24,8 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): oh, ih = s[Out].split(h, factor=hfactor) s[Out].reorder(ooc, oh, ioc, ih, w) oc = s[Out].fuse(ooc, oh) - w = s[Out].fuse(w, ih) + ow, _ = s[Out].split(w, nparts=ow_size) + w = s[Out].fuse(ow, ih) s[Out].bind(w, thread_x) s[Out].bind(ioc, thread_xz) s[Out].bind(oc, block_x) @@ -360,7 +363,11 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L): if util.get_const_int(Filter.shape[0]) == 64: opart2 = 8 ifactor = 16 - sfactor = max(1, ofactor // (opart2*2)) + if util.get_const_int(Out.shape[2]) == 224: + num_thread = 4 + wfactor = 112 + ifactor = 4 + sfactor = max(1, ofactor // (opart2*vthread)) spart = max(1, (wfactor + vthread-1) // vthread) block_x = tvm.thread_axis("blockIdx.x") @@ -368,7 +375,7 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L): block_z = tvm.thread_axis("blockIdx.z") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") thread_y = tvm.thread_axis((0, wfactor // vthread), "threadIdx.y") - thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx") + thread_xz = tvm.thread_axis((0, opart2), "vthread", name="vx") thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy") i, oc, h, w = s[Out].op.axis @@ -394,10 +401,10 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L): ic, dh, dw = s[Out_L].op.reduce_axis oic, iic = s[Out_L].split(ic, factor=ifactor) s[Out_L].reorder(oic, dh, dw, iic, h, w) - fuse_index = s[Out_L].fuse(dw, dh) fuse_index = s[Out_L].fuse(fuse_index, oic) dw = fuse_index + s[temp_S].compute_at(s[Out_L], dw) s[Filter_S].compute_at(s[Out_L], dw) @@ -421,16 +428,6 @@ def schedule_conv2d_small_batch(outs): def schedule(temp, Filter, Output): """Schedule conv2d_nchw""" - block_h = util.get_const_int(Output.shape[3]) - block_w = util.get_const_int(temp.shape[1]) - if block_h % 48 == 0: - block_h = 48 - elif block_h % 32 == 0: - block_h = 32 - if block_w % 48 == 0: - block_w = 48 - elif block_w % 32 == 0: - block_w = 32 flag = util.get_const_int(Filter.shape[0])+util.get_const_int(Filter.shape[1]) @@ -450,7 +447,7 @@ def schedule_conv2d_small_batch(outs): s[temp_G].reorder(i, oic, h, w, iic) temp_R = s.cache_write(temp_G, "global") temp_S = s.cache_read(temp_R, "shared", [temp_G]) - elif util.get_const_int(Filter.shape[3]) == 7: + elif util.get_const_int(Filter.shape[3]) == 7 or (util.get_const_int(Output.shape[2] == 224) and flag < 128): temp_G = s.cache_read(temp, "global", [Output]) s[temp_G].compute_inline() i, ic, h, w = s[temp_G].op.axis @@ -472,8 +469,8 @@ def schedule_conv2d_small_batch(outs): s[Output].set_scope("local") Out_L = Output - if util.get_const_int(Filter.shape[3]) == 7: - conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L) + if util.get_const_int(Filter.shape[3]) == 7 or (util.get_const_int(Output.shape[2] == 224) and flag < 128): + conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag) elif 128 < flag < 512: conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag) elif flag >= 512: diff --git a/topi/python/topi/cuda/conv2d_transpose_nchw.py b/topi/python/topi/cuda/conv2d_transpose_nchw.py index 82fd451f4ecb98ba2ad2920a5d2dd895fb5d1293..edd255a8fac036ca8da4fd8d48124fbe83869e29 100644 --- a/topi/python/topi/cuda/conv2d_transpose_nchw.py +++ b/topi/python/topi/cuda/conv2d_transpose_nchw.py @@ -1,4 +1,4 @@ -#pylint: disable=invalid-name +#pylint: disable=invalid-name, line-too-long """Schedule for conv2d_transpose_nchw with auto fusion""" import tvm from .. import util @@ -42,7 +42,7 @@ def schedule_conv2d_transpose_small_batch(outs): s[temp_G].reorder(i, oic, h, w, iic) temp_R = s.cache_write(temp_G, "global") temp_S = s.cache_read(temp_R, "shared", [temp_G]) - elif util.get_const_int(Filter.shape[3]) == 7: + elif util.get_const_int(Filter.shape[3]) == 7 or (util.get_const_int(Output.shape[2] == 224) and flag < 128): temp_G = s.cache_read(temp, "global", [Output]) s[temp_G].compute_inline() i, ic, h, w = s[temp_G].op.axis @@ -64,8 +64,8 @@ def schedule_conv2d_transpose_small_batch(outs): s[Output].set_scope("local") Out_L = Output - if util.get_const_int(Filter.shape[3]) == 7: - conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L) + if util.get_const_int(Filter.shape[3]) == 7 or (util.get_const_int(Output.shape[2] == 224) and flag < 128): + conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag) elif 128 < flag < 512: conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag) elif flag >= 512: