From 27205e36fda1e4f9432f01883f21970424d3c45b Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Wed, 12 Apr 2017 21:24:38 -0700 Subject: [PATCH] [BUGFIX] Thread related bound (#86) --- src/schedule/bound.cc | 3 + .../unittest/test_schedule_bound_inference.py | 58 +++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index e40f32078..203ce2870 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -148,6 +148,9 @@ void InferRootBound(const Stage& stage, << "call schedule.normalize to achieve this."; if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) { relax_set[iv->var.get()] = IntSet::range(vrange); + if (ctx.bind_map.count(iv)) { + relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange); + } } } CHECK(found_attach || stage_attach.size() == 0) diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index b5defceba..57f328546 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -175,6 +175,63 @@ def test_bound_nest_thread(): assert(bounds[A2.op.axis[0]].extent.value==32) assert(bounds[A3.op.axis[0]].extent == m) +def test_gemm_bound(): + nn = 1024 + n = tvm.convert(nn) + A = tvm.placeholder((n, n), name='A') + B = tvm.placeholder((n, n), name='B') + k = tvm.reduce_axis((0, n), name='k') + C = tvm.compute( + (n, n), + lambda ii, jj: tvm.sum(A[ii, k] * B[jj, k], axis=k), + name='CC') + # schedule + s = tvm.Schedule(C.op) + xtile, ytile = 32, 32 + scale = 8 + num_thread = 8 + block_factor = scale * num_thread + block_x = tvm.thread_axis("blockIdx.x") + thread_x = tvm.thread_axis("threadIdx.x") + block_y = tvm.thread_axis("blockIdx.y") + thread_y = tvm.thread_axis("threadIdx.y") + + CC = s.cache_write(C, "local") + AA = s.cache_read(A, "shared", [CC]) + BB = s.cache_read(B, "shared", [CC]) + by, yi = s[C].split(C.op.axis[0], factor=block_factor) + bx, xi = s[C].split(C.op.axis[1], factor=block_factor) + s[C].reorder(by, bx, yi, xi) + s[C].bind(by, block_y) + s[C].bind(bx, block_x) + ty, yi = s[C].split(yi, nparts=num_thread) + tx, xi = s[C].split(xi, nparts=num_thread) + s[C].reorder(ty, tx, yi, xi) + s[C].bind(ty, thread_y) + s[C].bind(tx, thread_x) + yo, xo = CC.op.axis + s[CC].reorder(k, yo, xo) + + s[CC].compute_at(s[C], tx) + s[AA].compute_at(s[CC], k) + s[BB].compute_at(s[CC], k) + + ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread) + tx, xi = s[AA].split(xi, nparts=num_thread) + s[AA].bind(ty, thread_y) + s[AA].bind(tx, thread_x) + + ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread) + tx, xi = s[BB].split(xi, nparts=num_thread) + s[BB].bind(ty, thread_y) + s[BB].bind(tx, thread_x) + s.normalize() + bounds = tvm.schedule.InferBound(s) + assert(bounds[BB.op.axis[0]].extent.value==64) + assert(bounds[AA.op.axis[0]].extent.value==64) + assert(bounds[CC.op.axis[0]].extent.value == 8) + assert(bounds[CC.op.axis[1]].extent.value == 8) + if __name__ == "__main__": test_bound_nest_thread() @@ -187,3 +244,4 @@ if __name__ == "__main__": test_bound_blur() test_bound_conv1d() test_bound2() + test_gemm_bound() -- GitLab