From 03b09f749f8bb586e77d8f3ad428484756f1b99f Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Mon, 15 May 2017 17:32:13 -0700 Subject: [PATCH] [PASS] Improve SSA conversion, add forbid list in loop-par (#142) --- python/tvm/build.py | 4 +- src/pass/loop_partition.cc | 19 +- src/pass/ssa.cc | 237 ++++++++++++---------- tests/python/unittest/test_build_lower.py | 20 ++ 4 files changed, 167 insertions(+), 113 deletions(-) create mode 100644 tests/python/unittest/test_build_lower.py diff --git a/python/tvm/build.py b/python/tvm/build.py index 273c16a07..a592bb314 100644 --- a/python/tvm/build.py +++ b/python/tvm/build.py @@ -90,10 +90,10 @@ def lower(sch, sch = sch.normalize() bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) - if not simple_mode: - stmt = ir_pass.LoopPartition(stmt) stmt = ir_pass.StorageFlatten(stmt, binds) stmt = ir_pass.CanonicalSimplify(stmt) + if not simple_mode: + stmt = ir_pass.LoopPartition(stmt) stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.InjectVirtualThread(stmt) stmt = ir_pass.StorageRewrite(stmt) diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index c4567b6ae..1e524583b 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -52,7 +52,7 @@ class CandidateSelector : public IRVisitor { const Variable* var = op->loop_var.get(); record_.insert({var, false}); IRVisitor::Visit_(op); - if (record_.at(var)) { + if (record_.at(var) && !no_split_) { candidates.insert(op); } record_.erase(var); @@ -70,7 +70,7 @@ class CandidateSelector : public IRVisitor { if ((scope.rank == 0) && !is_const(op->value)) { record_.insert({var.get(), false}); IRVisitor::Visit_(op); - if (record_.at(var.get())) { + if (record_.at(var.get()) && !no_split_) { candidates.insert(op); } record_.erase(var.get()); @@ -80,11 +80,25 @@ class CandidateSelector : public IRVisitor { IRVisitor::Visit_(op); } + void Visit_(const Block* op) { + bool temp = no_split_; + this->Visit(op->first); + // erase the no split state of first when visit rest. + std::swap(temp, no_split_); + this->Visit(op->rest); + // restore the no split flag. + no_split_ = no_split_ || temp; + } + void Visit_(const Call* op) { if (op->is_intrinsic(Call::likely)) { in_likely_ = true; IRVisitor::Visit_(op); in_likely_ = false; + } else if (op->is_intrinsic(intrinsic::tvm_thread_allreduce)) { + // no split if the body contains allreduce. + no_split_ = true; + return; } else { IRVisitor::Visit_(op); } @@ -100,6 +114,7 @@ class CandidateSelector : public IRVisitor { private: bool in_likely_; + bool no_split_{false}; std::unordered_map<const Variable*, VarIsUsed> record_; }; diff --git a/src/pass/ssa.cc b/src/pass/ssa.cc index 9b01e24fe..ed12160af 100644 --- a/src/pass/ssa.cc +++ b/src/pass/ssa.cc @@ -1,6 +1,8 @@ /*! * Copyright (c) 2016 by Contributors * SSA related checks and pass. + * + * SSA requires each varaible to be only defined once. * \file ssa.cc */ #include <tvm/ir.h> @@ -14,138 +16,155 @@ namespace tvm { namespace ir { namespace { - -// global functor to get var definition from -struct FGetVarDef { - using FType = IRFunctor<VarExpr (const NodeRef&)>; - static FType& vtable() { // NOLINT(*) - static FType inst; return inst; - } -}; -TVM_STATIC_IR_FUNCTOR(FGetVarDef, vtable) -.set_dispatch<Let>([](const Let* op) { - return op->var; - }) -.set_dispatch<LetStmt>([](const LetStmt* op) { - return op->var; - }) -.set_dispatch<For>([](const For* op) { - return op->loop_var; - }) -.set_dispatch<Allocate>([](const Allocate* op) { - return op->buffer_var; - }); - -struct FSetVarDef { - using FTypeExpr = IRFunctor<Expr (const NodeRef&, VarExpr)>; - using FTypeStmt = IRFunctor<Stmt (const NodeRef&, VarExpr)>; - static FTypeExpr& vtable_expr() { // NOLINT(*) - static FTypeExpr inst; return inst; - } - static FTypeStmt& vtable_stmt() { // NOLINT(*) - static FTypeStmt inst; return inst; - } -}; -TVM_STATIC_IR_FUNCTOR(FSetVarDef, vtable_expr) -.set_dispatch<Let>([](const Let* op, VarExpr var) { - std::shared_ptr<Let> x = std::make_shared<Let>(*op); - x->var = var; - return Expr(x); - }); - -TVM_STATIC_IR_FUNCTOR(FSetVarDef, vtable_stmt) -.set_dispatch<LetStmt>([](const LetStmt* op, VarExpr var) { - std::shared_ptr<LetStmt> x = std::make_shared<LetStmt>(*op); - x->var = var; - return Stmt(x); - }) -.set_dispatch<For>([](const For* op, VarExpr var) { - std::shared_ptr<For> x = std::make_shared<For>(*op); - x->loop_var = var; - return Stmt(x); - }); - -class IRVerifySSA : public IRVisitor { +class IRVerifySSA final : public IRVisitor { public: bool is_ssa{true}; void Visit(const NodeRef& n) final { if (!is_ssa) return; - static auto& fget_var_def = FGetVarDef::vtable(); - if (fget_var_def.can_dispatch(n)) { - VarExpr v = fget_var_def(n); - if (defined_.count(v.get()) != 0) { - is_ssa = false; return; - } else { - defined_[v.get()] = 1; - } - } IRVisitor::Visit(n); } + void Visit_(const Let* op) final { + MarkDef(op->var.get()); + IRVisitor::Visit_(op); + } + void Visit_(const LetStmt* op) final { + MarkDef(op->var.get()); + IRVisitor::Visit_(op); + } + void Visit_(const For* op) final { + MarkDef(op->loop_var.get()); + IRVisitor::Visit_(op); + } + void Visit_(const Allocate* op) final { + MarkDef(op->buffer_var.get()); + IRVisitor::Visit_(op); + } private: + void MarkDef(const Variable* v) { + if (defined_.count(v) != 0) { + is_ssa = false; return; + } else { + defined_[v] = 1; + } + } std::unordered_map<const Variable*, int> defined_; }; -class IRConvertSSA : public IRMutator { +class IRConvertSSA final : public IRMutator { public: - Expr Mutate(Expr expr) final { - static auto& fget_var_def = FGetVarDef::vtable(); - static auto& fset_var_def = FSetVarDef::vtable_expr(); - if (fget_var_def.can_dispatch(expr)) { - VarExpr v = fget_var_def(expr); - VarExpr new_var = v; - if (defined_.count(v.get()) != 0) { - CHECK(expr.as<Allocate>() == nullptr) - << "One allocation in two places, cannot rename buffer in allocate"; - new_var = Variable::make(v->type, v->name_hint); - } else { - defined_.insert(v.get()); - } + Expr Mutate_(const Variable* op, const Expr& e) final { + if (scope_.count(op)) { + return scope_[op].back(); + } else { + return e; + } + } + Expr Mutate_(const Let* op, const Expr& e) final { + const VarExpr& v = op->var; + if (defined_.count(v.get())) { + Expr value = IRMutator::Mutate(op->value); + VarExpr new_var = Variable::make(v.type(), v->name_hint); scope_[v.get()].push_back(new_var); - Expr new_expr = IRMutator::Mutate(expr); + Expr body = IRMutator::Mutate(op->body); scope_[v.get()].pop_back(); - - if (!new_var.same_as(v)) { - return fset_var_def(new_expr, new_var); - } else { - return new_expr; - } - } else if (expr.as<Variable>()) { - const Variable* v = expr.as<Variable>(); - if (scope_.count(v) != 0) { - return scope_[v].back(); - } else { - return expr; - } + return Let::make(new_var, value, body); } else { - Expr e = IRMutator::Mutate(expr); - return e; + defined_.insert(v.get()); + return IRMutator::Mutate_(op, e); } } - - Stmt Mutate(Stmt stmt) final { - static auto& fget_var_def = FGetVarDef::vtable(); - static auto& fset_var_def = FSetVarDef::vtable_stmt(); - if (fget_var_def.can_dispatch(stmt)) { - VarExpr v = fget_var_def(stmt); - VarExpr new_var = v; - if (defined_.count(v.get()) != 0) { - new_var = Variable::make(v->type, v->name_hint); - } else { - defined_.insert(v.get()); - } + Expr Mutate_(const Load* op, const Expr& e) final { + Expr expr = IRMutator::Mutate_(op, e); + op = expr.as<Load>(); + if (scope_.count(op->buffer_var.get())) { + return Load::make( + op->type, scope_[op->buffer_var.get()].back(), + op->index, op->predicate); + } else { + return expr; + } + } + Stmt Mutate_(const Store* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as<Store>(); + if (scope_.count(op->buffer_var.get())) { + return Store::make( + scope_[op->buffer_var.get()].back(), op->value, + op->index, op->predicate); + } else { + return stmt; + } + } + Stmt Mutate_(const LetStmt* op, const Stmt& s) final { + const VarExpr& v = op->var; + if (defined_.count(v.get())) { + Expr value = IRMutator::Mutate(op->value); + VarExpr new_var = Variable::make(v.type(), v->name_hint); scope_[v.get()].push_back(new_var); - Stmt new_stmt = IRMutator::Mutate(stmt); + Stmt body = IRMutator::Mutate(op->body); scope_[v.get()].pop_back(); - - if (!new_var.same_as(v)) { - return fset_var_def(new_stmt, new_var); + return LetStmt::make(new_var, value, body); + } else { + defined_.insert(v.get()); + return IRMutator::Mutate_(op, s); + } + } + Stmt Mutate_(const For* op, const Stmt& s) final { + const VarExpr& v = op->loop_var; + if (defined_.count(v.get())) { + VarExpr new_var = Variable::make(v.type(), v->name_hint); + scope_[v.get()].push_back(new_var); + Stmt stmt = IRMutator::Mutate_(op, s); + scope_[v.get()].pop_back(); + op = stmt.as<For>(); + return For::make( + new_var, op->min, op->extent, op->for_type, op->device_api, op->body); + } else { + defined_.insert(v.get()); + return IRMutator::Mutate_(op, s); + } + } + Stmt Mutate_(const Allocate* op, const Stmt& s) final { + const VarExpr& v = op->buffer_var; + if (defined_.count(v.get())) { + VarExpr new_var = Variable::make(v.type(), v->name_hint); + scope_[v.get()].push_back(new_var); + Stmt stmt = IRMutator::Mutate_(op, s); + scope_[v.get()].pop_back(); + op = stmt.as<Allocate>(); + return Allocate::make( + new_var, op->type, op->extents, op->condition, + op->body, op->new_expr, op->free_function); + } else { + defined_.insert(v.get()); + return IRMutator::Mutate_(op, s); + } + } + Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + if (const Variable* v = op->node.as<Variable>()) { + if (op->attr_key == attr::storage_scope) { + const Allocate* alloc = op->body.as<Allocate>(); + if (alloc && op->node.same_as(alloc->buffer_var)) { + Stmt new_alloc = Mutate(op->body); + if (new_alloc.same_as(op->body)) return s; + alloc = new_alloc.as<Allocate>(); + CHECK(alloc); + return AttrStmt::make( + alloc->buffer_var, op->attr_key, op->value, new_alloc); + } + } + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as<AttrStmt>(); + if (scope_.count(v) && scope_[v].size() != 0) { + return AttrStmt::make( + scope_[v].back(), op->attr_key, op->value, op->body); } else { - return new_stmt; + return stmt; } } else { - return IRMutator::Mutate(stmt); + return IRMutator::Mutate_(op, s); } } diff --git a/tests/python/unittest/test_build_lower.py b/tests/python/unittest/test_build_lower.py new file mode 100644 index 000000000..37ef2a417 --- /dev/null +++ b/tests/python/unittest/test_build_lower.py @@ -0,0 +1,20 @@ +import tvm + +def test_lower_rfactor(): + n = tvm.var("n") + m = tvm.var("m") + A = tvm.placeholder((n, m), name='A') + k = tvm.reduce_axis((0, m), "k") + B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name="B") + s = tvm.create_schedule(B.op) + ko, ki = s[B].split(B.op.reduce_axis[0], factor=16) + BF = s.rfactor(B, ki) + xo, xi = s[B].split(s[B].op.axis[0], factor=32) + s[B.op].bind(xo, tvm.thread_axis("blockIdx.x")) + s[B.op].bind(xi, tvm.thread_axis("threadIdx.y")) + s[B].bind(s[B].op.reduce_axis[0], tvm.thread_axis("threadIdx.x")) + s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) + fapi = tvm.lower(s, [A, B]) + +if __name__ == "__main__": + test_lower_rfactor() -- GitLab