Skip to content
Snippets Groups Projects
Commit e94e2450 authored by Yida Wang's avatar Yida Wang Committed by Tianqi Chen
Browse files

remove the pragma primitives for better performance when the threads are binded (#949)

parent 413e2b7a
No related branches found
No related tags found
No related merge requests found
......@@ -77,9 +77,6 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou
batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis
parallel_axis = s[A1].fuse(ic_chunk, ih)
s[A1].parallel(parallel_axis)
s[A1].pragma(batch, "parallel_launch_point")
s[A1].pragma(parallel_axis, "parallel_stride_pattern")
s[A1].pragma(batch, "parallel_barrier_when_finish")
# schedule kernel pack
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[W].op.axis
......@@ -88,9 +85,6 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou
s[W].vectorize(oc_block)
parallel_axis = s[W].fuse(oc_chunk, oh)
s[W].parallel(parallel_axis)
s[W].pragma(parallel_axis, "parallel_launch_point")
s[W].pragma(parallel_axis, "parallel_stride_pattern")
s[W].pragma(parallel_axis, "parallel_barrier_when_finish")
C, O0, O = conv_out, output, last
CC = s.cache_write(C, 'global')
......@@ -128,8 +122,5 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
s[O].pragma(batch, "parallel_launch_point")
s[O].pragma(parallel_axis, "parallel_stride_pattern")
s[O].pragma(batch, "parallel_barrier_when_finish")
return s
......@@ -90,9 +90,6 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou
batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis
parallel_axis = s[A1].fuse(ic_chunk, ih)
s[A1].parallel(parallel_axis)
s[A1].pragma(batch, "parallel_launch_point")
s[A1].pragma(parallel_axis, "parallel_stride_pattern")
s[A1].pragma(batch, "parallel_barrier_when_finish")
# schedule kernel pack
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[W].op.axis
......@@ -101,9 +98,6 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou
s[W].vectorize(oc_block)
parallel_axis = s[W].fuse(oc_chunk, oh)
s[W].parallel(parallel_axis)
s[W].pragma(parallel_axis, "parallel_launch_point")
s[W].pragma(parallel_axis, "parallel_stride_pattern")
s[W].pragma(parallel_axis, "parallel_barrier_when_finish")
# schedule conv
C, O0, O = conv_out, output, last
......@@ -144,8 +138,5 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
s[O].pragma(batch, "parallel_launch_point")
s[O].pragma(parallel_axis, "parallel_stride_pattern")
s[O].pragma(batch, "parallel_barrier_when_finish")
return s
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment