From ff26cd68d041221f9829f794d1e53b974bae20ac Mon Sep 17 00:00:00 2001
From: tqchen <tianqi.tchen@gmail.com>
Date: Fri, 6 Jan 2017 14:18:45 -0800
Subject: [PATCH] Fix Tile, add a few more test cases on bound inference

---
 src/lang/schedule.cc                 |  2 +-
 src/schedule/bound.cc                | 11 +++++--
 src/schedule/int_set.cc              |  4 ++-
 tests/python/test_bound_inference.py | 44 +++++++++++++++++++++++++---
 tests/python/test_schedule.py        | 13 +++++++-
 5 files changed, 65 insertions(+), 9 deletions(-)

diff --git a/src/lang/schedule.cc b/src/lang/schedule.cc
index aa7c5b51f..b5d4429eb 100644
--- a/src/lang/schedule.cc
+++ b/src/lang/schedule.cc
@@ -154,7 +154,7 @@ Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent,
                          Expr x_factor, Expr y_factor) { // NOLINT(*)
   split(x_parent, p_x_outer, p_x_inner, x_factor);
   split(y_parent, p_y_outer, p_y_inner, y_factor);
-  reorder(Array<IterVar>({*p_x_inner, *p_y_inner, *p_x_outer, *p_y_outer}));
+  reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner}));
   return *this;
 }
 
diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc
index 329d118e6..7d3f25d61 100644
--- a/src/schedule/bound.cc
+++ b/src/schedule/bound.cc
@@ -165,8 +165,15 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) {
     {"shared", 1},
     {"local", 2}
   };
-
-  return scope_rank.at(scope) <= scope_rank.at(iv->thread_tag);
+  static std::unordered_map<std::string, int> thread_tag_rank{
+    {"gridIdx.x", 0},
+    {"gridIdx.y", 0},
+    {"gridIdx.z", 0},
+    {"threadIdx.x", 1},
+    {"threadIdx.y", 1},
+    {"threadIdx.z", 1}
+  };
+  return scope_rank.at(scope) <= thread_tag_rank.at(iv->thread_tag);
 }
 
 void InferBound(
diff --git a/src/schedule/int_set.cc b/src/schedule/int_set.cc
index b5e632e6e..6a770b32f 100644
--- a/src/schedule/int_set.cc
+++ b/src/schedule/int_set.cc
@@ -220,6 +220,8 @@ void PassUp(const SplitNode* s,
     *parent = IntSet::make_range(dom_map.at(s->parent));
     return;
   }
+  CHECK(outer.defined());
+  CHECK(inner.defined());
   // copy construct
   auto n = std::make_shared<IntSetNode>(*(inner.operator->()));
 
@@ -228,7 +230,6 @@ void PassUp(const SplitNode* s,
     n->base = Range::make_with_min_extent(
         AsNumber(outer) * s->factor + inner->base->min,
         inner->base->extent);
-    *parent = IntSet(n);
   } else {
     // default use all domains in the data.
     n->domain.push_back(outer->base);
@@ -238,6 +239,7 @@ void PassUp(const SplitNode* s,
       n->stride.push_back(outer->stride[i] * s->factor);
     }
   }
+  *parent = IntSet(n);
 }
 
 void PassUp(const FuseNode* s,
diff --git a/tests/python/test_bound_inference.py b/tests/python/test_bound_inference.py
index 6de6e44e5..fb169e603 100644
--- a/tests/python/test_bound_inference.py
+++ b/tests/python/test_bound_inference.py
@@ -1,6 +1,6 @@
 import tvm
 
-def test_bound_inference():
+def test_bound1():
     m = tvm.Var('m')
     l = tvm.Var('l')
     A = tvm.placeholder((m, l), name='A')
@@ -12,8 +12,42 @@ def test_bound_inference():
     sA1.compute_at(sA2, xo)
     bounds = tvm.schedule.InferBound(sA2)
     assert isinstance(bounds, tvm.collections.Map)
-    print(bounds[A1.op.dim_var[0]])
-    print(bounds[A1.op.dim_var[1]])
+    assert(bounds[A1.op.dim_var[0]].extent.value == 8)
+
+def test_bound2():
+    m = tvm.Var('m')
+    l = tvm.Var('l')
+    A = tvm.placeholder((m, l), name='A')
+    A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
+    A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
+    sA1 = tvm.Schedule(A1.op)
+    sA2 = tvm.Schedule(A2.op)
+    xo, yo, xi, yi = sA2.tile(A2.op.dim_var[0], A2.op.dim_var[1], 8, 8)
+    sA1.compute_at(sA2, yo)
+    bounds = tvm.schedule.InferBound(sA2)
+    assert isinstance(bounds, tvm.collections.Map)
+    assert(bounds[A1.op.dim_var[0]].extent.value == 8)
+    assert(bounds[A1.op.dim_var[1]].extent.value == 8)
+
+def test_bound3():
+    m = tvm.Var('m')
+    l = tvm.Var('l')
+    A = tvm.placeholder((m, l), name='A')
+    A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
+    A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
+    sA1 = tvm.Schedule(A1.op, scope="shared")
+    sA2 = tvm.Schedule(A2.op)
+    thread_x = tvm.IterVar((0, 16), thread_tag="threadIdx.x")
+    xo, xi = sA2.split(A2.op.dim_var[0], 32)
+    xi0, xi1 = sA2.split(xi, outer=thread_x)
+    yo, yi = sA2.split(A2.op.dim_var[1], 16)
+    sA2.reorder(xo, xi0, yo, xi1, yi)
+    sA1.compute_at(sA2, yo)
+
+    bounds = tvm.schedule.InferBound(sA2)
+    assert isinstance(bounds, tvm.collections.Map)
+    assert(bounds[A1.op.dim_var[0]].extent.value==32)
+    assert(bounds[A1.op.dim_var[1]].extent.value==16)
 
 
 def test_create_read_graph():
@@ -31,5 +65,7 @@ def test_create_read_graph():
 
 
 if __name__ == "__main__":
-    test_bound_inference()
+    test_bound3()
+    test_bound1()
+    test_bound2()
     test_create_read_graph()
diff --git a/tests/python/test_schedule.py b/tests/python/test_schedule.py
index b08b5f6fb..efdeab9a6 100644
--- a/tests/python/test_schedule.py
+++ b/tests/python/test_schedule.py
@@ -34,6 +34,16 @@ def test_reorder():
     sch_T.reorder(*order)
     assert tuple(sch_T.leaf_iter_vars) == order
 
+def test_split():
+    m = tvm.Var('m')
+    A = tvm.placeholder((m,), name='A')
+    T = tvm.compute((m,), lambda i: A[i])
+
+    sT = tvm.Schedule(T.op)
+    xo, xi = sT.split(T.op.dim_var[0], factor=10)
+    assert tuple(sT.leaf_iter_vars) == (xo, xi)
+
+
 def test_tile():
     m = tvm.Var('m')
     n = tvm.Var('n')
@@ -42,9 +52,10 @@ def test_tile():
 
     sch_T = tvm.Schedule(T.op, scope="shared")
     xo, yo, xi, yi = sch_T.tile(T.op.dim_var[0], T.op.dim_var[1], x_factor=10, y_factor=5)
-    assert tuple(sch_T.leaf_iter_vars) == (xi, yi, xo, yo)
+    assert tuple(sch_T.leaf_iter_vars) == (xo, yo, xi, yi)
 
 if __name__ == "__main__":
     test_schedule_create()
     test_reorder()
     test_tile()
+    test_split()
-- 
GitLab