From 27205e36fda1e4f9432f01883f21970424d3c45b Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Wed, 12 Apr 2017 21:24:38 -0700
Subject: [PATCH] [BUGFIX] Thread related bound (#86)

---
 src/schedule/bound.cc                         |  3 +
 .../unittest/test_schedule_bound_inference.py | 58 +++++++++++++++++++
 2 files changed, 61 insertions(+)

diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc
index e40f32078..203ce2870 100644
--- a/src/schedule/bound.cc
+++ b/src/schedule/bound.cc
@@ -148,6 +148,9 @@ void InferRootBound(const Stage& stage,
           << "call schedule.normalize to achieve this.";
       if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
         relax_set[iv->var.get()] = IntSet::range(vrange);
+        if (ctx.bind_map.count(iv)) {
+          relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange);
+        }
       }
     }
     CHECK(found_attach || stage_attach.size() == 0)
diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py
index b5defceba..57f328546 100644
--- a/tests/python/unittest/test_schedule_bound_inference.py
+++ b/tests/python/unittest/test_schedule_bound_inference.py
@@ -175,6 +175,63 @@ def test_bound_nest_thread():
     assert(bounds[A2.op.axis[0]].extent.value==32)
     assert(bounds[A3.op.axis[0]].extent == m)
 
+def test_gemm_bound():
+    nn = 1024
+    n = tvm.convert(nn)
+    A = tvm.placeholder((n, n), name='A')
+    B = tvm.placeholder((n, n), name='B')
+    k = tvm.reduce_axis((0, n), name='k')
+    C = tvm.compute(
+        (n, n),
+        lambda ii, jj: tvm.sum(A[ii, k] * B[jj, k], axis=k),
+        name='CC')
+    # schedule
+    s = tvm.Schedule(C.op)
+    xtile, ytile = 32, 32
+    scale = 8
+    num_thread = 8
+    block_factor = scale * num_thread
+    block_x = tvm.thread_axis("blockIdx.x")
+    thread_x = tvm.thread_axis("threadIdx.x")
+    block_y = tvm.thread_axis("blockIdx.y")
+    thread_y = tvm.thread_axis("threadIdx.y")
+
+    CC = s.cache_write(C, "local")
+    AA = s.cache_read(A, "shared", [CC])
+    BB = s.cache_read(B, "shared", [CC])
+    by, yi = s[C].split(C.op.axis[0], factor=block_factor)
+    bx, xi = s[C].split(C.op.axis[1], factor=block_factor)
+    s[C].reorder(by, bx, yi, xi)
+    s[C].bind(by, block_y)
+    s[C].bind(bx, block_x)
+    ty, yi = s[C].split(yi, nparts=num_thread)
+    tx, xi = s[C].split(xi, nparts=num_thread)
+    s[C].reorder(ty, tx, yi, xi)
+    s[C].bind(ty, thread_y)
+    s[C].bind(tx, thread_x)
+    yo, xo = CC.op.axis
+    s[CC].reorder(k, yo, xo)
+
+    s[CC].compute_at(s[C], tx)
+    s[AA].compute_at(s[CC], k)
+    s[BB].compute_at(s[CC], k)
+
+    ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread)
+    tx, xi = s[AA].split(xi, nparts=num_thread)
+    s[AA].bind(ty, thread_y)
+    s[AA].bind(tx, thread_x)
+
+    ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread)
+    tx, xi = s[BB].split(xi, nparts=num_thread)
+    s[BB].bind(ty, thread_y)
+    s[BB].bind(tx, thread_x)
+    s.normalize()
+    bounds = tvm.schedule.InferBound(s)
+    assert(bounds[BB.op.axis[0]].extent.value==64)
+    assert(bounds[AA.op.axis[0]].extent.value==64)
+    assert(bounds[CC.op.axis[0]].extent.value == 8)
+    assert(bounds[CC.op.axis[1]].extent.value == 8)
+
 
 if __name__ == "__main__":
     test_bound_nest_thread()
@@ -187,3 +244,4 @@ if __name__ == "__main__":
     test_bound_blur()
     test_bound_conv1d()
     test_bound2()
+    test_gemm_bound()
-- 
GitLab