From 3ac944394cd9bce023cd12c64a5078291925a8d9 Mon Sep 17 00:00:00 2001 From: ziheng <ziheng@apache.org> Date: Sun, 9 Apr 2017 20:31:32 -0700 Subject: [PATCH] [PASS] Support for partition loops with thread_axis (#81) * [PASS] Support for partition loops with thread_axis * Add check for AttrStmt.attr_key --- src/pass/loop_partition.cc | 38 +++++++++++++------ .../unittest/test_pass_loop_partition.py | 20 ++++++++++ 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 34a317824..3a8f30e7d 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -39,27 +39,43 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) { class PartitionFinder : public IRVisitor { public: - explicit PartitionFinder(VarExpr loop_var, + explicit PartitionFinder(VarExpr current_var, const std::unordered_map<const Variable*, IntSet>& dom_map) - : target_var_(loop_var), out_vars_(dom_map.size()), hint_map_(dom_map) { + : current_var_(current_var), out_vars_(dom_map.size()), hint_map_(dom_map) { for (const auto& kv : dom_map) out_vars_.insert(kv.first); } void Visit_(const For* op) { if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return; - hint_map_.insert({op->loop_var.get(), - IntSet::interval(op->min, op->min + op->extent - 1)}); - relax_map_.insert({op->loop_var.get(), - IntSet::interval(op->min, op->min + op->extent - 1)}); + const Variable* var = op->loop_var.get(); + hint_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)}); + relax_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)}); IRVisitor::Visit_(op); - relax_map_.erase(op->loop_var.get()); - hint_map_.erase(op->loop_var.get()); + relax_map_.erase(var); + hint_map_.erase(var); + } + + void Visit_(const AttrStmt* op) { + // handle thread_axis + if (op->attr_key == attr::thread_extent) { + const IterVarNode* thread_axis = op->node.as<IterVarNode>(); + CHECK(thread_axis); + const Variable* var = thread_axis->var.get(); + IntSet dom = IntSet::range(Range(make_zero(op->value.type()), op->value)); + hint_map_.insert({var, dom}); + relax_map_.insert({var, dom}); + IRVisitor::Visit_(op); + relax_map_.erase(var); + hint_map_.erase(var); + } else { + IRVisitor::Visit_(op); + } } void Visit_(const IfThenElse* op) { - if (ExprUseVars(op->condition, std::unordered_set<const Variable*>({target_var_.get()}))) { - IntSet interval = DeduceBound(target_var_, op->condition, hint_map_, relax_map_); + if (ExprUseVars(op->condition, std::unordered_set<const Variable*>({current_var_.get()}))) { + IntSet interval = DeduceBound(current_var_, op->condition, hint_map_, relax_map_); partitions[op->condition.get()] = Partition{op->condition, interval}; } else { IRVisitor::Visit_(op); @@ -69,7 +85,7 @@ class PartitionFinder : public IRVisitor { std::unordered_map<const Node*, Partition> partitions; private: - VarExpr target_var_; + VarExpr current_var_; std::unordered_set<const Variable*> out_vars_; std::unordered_map<const Variable*, IntSet> hint_map_; std::unordered_map<const Variable*, IntSet> relax_map_; diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index fd0662c8d..9a3c6bbdd 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -53,8 +53,28 @@ def test_multi_if(): assert('if' not in str(stmt.body.first)) print(stmt) +def test_thread_axis(): + m = tvm.Var('m') + l = tvm.Var('l') + A = tvm.placeholder((m, l), name='A') + B = tvm.compute((m, l), lambda i, j: A[i, j] + 3, name='B') + + s = tvm.Schedule(B.op) + + s[B].set_scope("shared") + num_thread = 16 + xo, xi = s[B].split(B.op.axis[0], 32) + xi0, xi1 = s[B].split(xi, nparts=num_thread) + s[B].bind(xi0, tvm.thread_axis("threadIdx.x")) + + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + stmt_ = tvm.ir_pass.LoopPartition(stmt) + assert('if' not in str(stmt_.body.body.body.first)) + print(stmt_) if __name__ == "__main__": test_basic() test_multi_loop() test_multi_if() + test_thread_axis() -- GitLab