From 3ac944394cd9bce023cd12c64a5078291925a8d9 Mon Sep 17 00:00:00 2001
From: ziheng <ziheng@apache.org>
Date: Sun, 9 Apr 2017 20:31:32 -0700
Subject: [PATCH] [PASS] Support for partition loops with thread_axis (#81)

* [PASS] Support for partition loops with thread_axis

* Add check for AttrStmt.attr_key
---
 src/pass/loop_partition.cc                    | 38 +++++++++++++------
 .../unittest/test_pass_loop_partition.py      | 20 ++++++++++
 2 files changed, 47 insertions(+), 11 deletions(-)

diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc
index 34a317824..3a8f30e7d 100644
--- a/src/pass/loop_partition.cc
+++ b/src/pass/loop_partition.cc
@@ -39,27 +39,43 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
 
 class PartitionFinder : public IRVisitor {
  public:
-  explicit PartitionFinder(VarExpr loop_var,
+  explicit PartitionFinder(VarExpr current_var,
     const std::unordered_map<const Variable*, IntSet>& dom_map)
-      : target_var_(loop_var), out_vars_(dom_map.size()), hint_map_(dom_map) {
+      : current_var_(current_var), out_vars_(dom_map.size()), hint_map_(dom_map) {
         for (const auto& kv : dom_map) out_vars_.insert(kv.first);
       }
 
   void Visit_(const For* op) {
     if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return;
 
-    hint_map_.insert({op->loop_var.get(),
-      IntSet::interval(op->min, op->min + op->extent - 1)});
-    relax_map_.insert({op->loop_var.get(),
-      IntSet::interval(op->min, op->min + op->extent - 1)});
+    const Variable* var = op->loop_var.get();
+    hint_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
+    relax_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
     IRVisitor::Visit_(op);
-    relax_map_.erase(op->loop_var.get());
-    hint_map_.erase(op->loop_var.get());
+    relax_map_.erase(var);
+    hint_map_.erase(var);
+  }
+
+  void Visit_(const AttrStmt* op) {
+    // handle thread_axis
+    if (op->attr_key == attr::thread_extent) {
+      const IterVarNode* thread_axis = op->node.as<IterVarNode>();
+      CHECK(thread_axis);
+      const Variable* var = thread_axis->var.get();
+      IntSet dom = IntSet::range(Range(make_zero(op->value.type()), op->value));
+      hint_map_.insert({var, dom});
+      relax_map_.insert({var, dom});
+      IRVisitor::Visit_(op);
+      relax_map_.erase(var);
+      hint_map_.erase(var);
+    } else {
+      IRVisitor::Visit_(op);
+    }
   }
 
   void Visit_(const IfThenElse* op) {
-    if (ExprUseVars(op->condition, std::unordered_set<const Variable*>({target_var_.get()}))) {
-      IntSet interval = DeduceBound(target_var_, op->condition, hint_map_, relax_map_);
+    if (ExprUseVars(op->condition, std::unordered_set<const Variable*>({current_var_.get()}))) {
+      IntSet interval = DeduceBound(current_var_, op->condition, hint_map_, relax_map_);
       partitions[op->condition.get()] = Partition{op->condition, interval};
     } else {
       IRVisitor::Visit_(op);
@@ -69,7 +85,7 @@ class PartitionFinder : public IRVisitor {
   std::unordered_map<const Node*, Partition> partitions;
 
  private:
-  VarExpr target_var_;
+  VarExpr current_var_;
   std::unordered_set<const Variable*> out_vars_;
   std::unordered_map<const Variable*, IntSet> hint_map_;
   std::unordered_map<const Variable*, IntSet> relax_map_;
diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py
index fd0662c8d..9a3c6bbdd 100644
--- a/tests/python/unittest/test_pass_loop_partition.py
+++ b/tests/python/unittest/test_pass_loop_partition.py
@@ -53,8 +53,28 @@ def test_multi_if():
     assert('if' not in str(stmt.body.first))
     print(stmt)
 
+def test_thread_axis():
+    m = tvm.Var('m')
+    l = tvm.Var('l')
+    A = tvm.placeholder((m, l), name='A')
+    B = tvm.compute((m, l), lambda i, j: A[i, j] + 3, name='B')
+
+    s = tvm.Schedule(B.op)
+
+    s[B].set_scope("shared")
+    num_thread = 16
+    xo, xi = s[B].split(B.op.axis[0], 32)
+    xi0, xi1 = s[B].split(xi, nparts=num_thread)
+    s[B].bind(xi0, tvm.thread_axis("threadIdx.x"))
+
+    bounds = tvm.schedule.InferBound(s)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
+    stmt_ = tvm.ir_pass.LoopPartition(stmt)
+    assert('if' not in str(stmt_.body.body.body.first))
+    print(stmt_)
 
 if __name__ == "__main__":
     test_basic()
     test_multi_loop()
     test_multi_if()
+    test_thread_axis()
-- 
GitLab