From 38f03f1f78f88a115829444c6e790b68417250ae Mon Sep 17 00:00:00 2001 From: tqchen <tianqi.tchen@gmail.com> Date: Fri, 18 Nov 2016 21:30:15 -0800 Subject: [PATCH] SSA Pass --- HalideIR | 2 +- Makefile | 1 + include/tvm/expr.h | 2 + include/tvm/ir_pass.h | 2 +- src/pass/ir_pass.cc | 138 -------------- src/pass/ssa.cc | 171 ++++++++++++++++++ tests/cpp/{ir_pass_test.cc => ir_ssa_test.cc} | 28 +-- 7 files changed, 191 insertions(+), 153 deletions(-) delete mode 100644 src/pass/ir_pass.cc create mode 100644 src/pass/ssa.cc rename tests/cpp/{ir_pass_test.cc => ir_ssa_test.cc} (52%) diff --git a/HalideIR b/HalideIR index 4becbde67..24a7c0357 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit 4becbde67c8aa565941b02648cea90f50211f8dc +Subproject commit 24a7c0357a6a8db5db782d320aad7f706ebe8507 diff --git a/Makefile b/Makefile index 7daddbd95..d2f8bd71b 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,4 @@ +export CXX=g++ export LDFLAGS = -pthread -lm export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\ -Iinclude -Idmlc-core/include -IHalideIR/src -fPIC diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 0be106346..24e792fea 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -27,7 +27,9 @@ using Halide::abs; using Halide::select; using Halide::Expr; + using Halide::VarExpr; +using Halide::IR::FunctionRef; using Halide::IR::FunctionBaseNode; using Halide::Internal::Stmt; diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 3baca6bae..c02e57565 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -28,7 +28,7 @@ bool VerifySSA(const IRNodeRef& ir); * \param stmt The source statement to be converted. * \return The converted form. */ -Stmt ConvertSSA(const Stmt stmt); +Stmt ConvertSSA(const Stmt& stmt); /*! * \brief inline all calls of f in stmt. diff --git a/src/pass/ir_pass.cc b/src/pass/ir_pass.cc deleted file mode 100644 index 9abf04bd2..000000000 --- a/src/pass/ir_pass.cc +++ /dev/null @@ -1,138 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file ir_pass.cc - */ -#include <tvm/ir.h> -#include <tvm/ir_visitor.h> -#include <tvm/ir_mutator.h> -#include <unordered_set> - -namespace tvm { -namespace ir { -namespace { - -struct SetVarDef { - // get var definition from node - using FType = IRFunctor<const Variable*(const IRNodeRef&)>; - static FGetVarDef& vtable_get_var_def() { // NOLINT(*) - static FGetVarDef inst; return inst; - } - static FSetVarExpr& vtable_set_var_expr() { // NOLINT(*) - static FSetVarExpr inst; return inst; - } - static FSetVarStmt& vtable_set_var_expr() { // NOLINT(*) - static FSetVarStmt inst; return inst; - } -}; - - // return a new node to - using FSetVarExpr = IRFunctor<Expr (const IRNodeRef&, VarExpr)>; - // return a new node to - using FSetVarStmt = IRFunctor<Expr (const IRNodeRef&, VarExpr)>; - -inline const Variable* GetVarDef(const IRNodeRef& n) { - if (n.as<Let>()) { - return n.as<Let>()->var.get(); - } else if (n.as<LetStmt>()) { - return n.as<LetStmt>()->var.get(); - } else if (n.as<For>()) { - return n.as<For>()->loop_var.get(); - } else if (n.as<Allocate>()) { - return n.as<Allocate>()->buffer_var.get(); - } else { - return nullptr; - } -} - -inline Expr ResetVar(const Expr& n, VarExpr var) { - if (n.as<Let>()) { - std::shared_ptr<Let> x = std::make_shared<Let>(*n.as<Let>()); - x->var = var; - return Expr(x); - } else if (n.as<Allocate>()) { - } -} - -inline Stmt ResetVarDef(const Stmt& n, VarExpr var) { - if (n.as<LetStmt>()) { - std::shared_ptr<LetStmt> x = std::make_shared<LetStmt>(*n.as<Let>()); - x->var = var; - return Expr(x); - } else if (n.as<For>()) { - std::shared_ptr<For> x = std::make_shared<For>(*n.as<Let>()); - x->loop_var = var; - return Expr(x); - } else { - LOG(FATAL) << "not reached"; - } -} - -class IRVerifySSA : public IRVisitor { - public: - bool is_ssa{true}; - std::unordered_set<const Variable*> defined; - - void Visit(const IRNodeRef& n) final { - if (!is_ssa) return; - const Variable* v = GetVarDef(n); - if (v != nullptr) { - if (defined.count(v) != 0) { - is_ssa = false; return; - } else { - defined.insert(v); - } - } - IRVisitor::Visit(n); - } -}; - -class IRConvertSSA : public IRMutator { - public: - Expr Mutate(Expr expr) final { - static const auto& f = IRConvertSSA::vtable_expr(); - return (f.can_dispatch(expr) ? - f(expr, expr, this) : IRMutator::Mutate(expr)); - } - Stmt Mutate(Stmt stmt) final { - static const auto& f = IRMutatorExample::vtable_stmt(); - return (f.can_dispatch(stmt) ? - f(stmt, stmt, this) : IRMutator::Mutate(stmt)); - } - using FConvertExpr = IRFunctor<Expr(const IRNodeRef&, const Expr&, IRConvertSSA *)>; - using FConvertStmt = IRFunctor<Stmt(const IRNodeRef&, const Expr&, IRConvertSSA *)>; - std::unordered_map<const Variable*, std::vector<VarExpr> > scope; - std::unordered_set<const Variable*> defined; -}; - -temple<> - -TVM_STATIC_IR_FUNCTOR(IRConvertSSA, vtable_expr) -.set_dispatch<Let>([](const Let* op, const Expr& e, IRConvertSSA* m) { - VarExpr var = op->var; - if (m->defined.count(var.get()) != 0) { - var = Variable::make(var->type, var->name_hint); - } - // insert scope before recursion. - m->scope[var.get()].push_back(var); - Expr new_expr = Mutate(e); - m->scope[var.get()].pop_back(); - - if (!var.same_as(op->var)) { - std::shared_ptr<Let> x = std::make_shared<Let>(*new_expr.as<Let>()); - x->var = var; - return Expr(x); - } else { - return new_expr; - } - }); - -} // namespace - -bool VerifySSA(const IRNodeRef& ir) { - IRVerifySSA v; - v.Visit(ir); - return v.is_ssa; -} - -} // namespace ir -} // namespace tvm diff --git a/src/pass/ssa.cc b/src/pass/ssa.cc new file mode 100644 index 000000000..556626418 --- /dev/null +++ b/src/pass/ssa.cc @@ -0,0 +1,171 @@ +/*! + * Copyright (c) 2016 by Contributors + * SSA related checks and pass. + * \file ssa.cc + */ +#include <tvm/ir.h> +#include <tvm/ir_visitor.h> +#include <tvm/ir_mutator.h> +#include <tvm/ir_pass.h> +#include <unordered_set> +#include <unordered_map> +#include <vector> + +namespace tvm { +namespace ir { +namespace { + +// global functor to get var definition from +struct FGetVarDef { + using FType = IRFunctor<VarExpr (const IRNodeRef&)>; + static FType& vtable() { // NOLINT(*) + static FType inst; return inst; + } +}; +TVM_STATIC_IR_FUNCTOR(FGetVarDef, vtable) +.set_dispatch<Let>([](const Let* op) { + return op->var; + }) +.set_dispatch<LetStmt>([](const LetStmt* op) { + return op->var; + }) +.set_dispatch<For>([](const For* op) { + return op->loop_var; + }) +.set_dispatch<Allocate>([](const Allocate* op) { + return op->buffer_var; + }); + +struct FSetVarDef { + using FTypeExpr = IRFunctor<Expr (const IRNodeRef&, VarExpr)>; + using FTypeStmt = IRFunctor<Stmt (const IRNodeRef&, VarExpr)>; + static FTypeExpr& vtable_expr() { // NOLINT(*) + static FTypeExpr inst; return inst; + } + static FTypeStmt& vtable_stmt() { // NOLINT(*) + static FTypeStmt inst; return inst; + } +}; +TVM_STATIC_IR_FUNCTOR(FSetVarDef, vtable_expr) +.set_dispatch<Let>([](const Let* op, VarExpr var) { + std::shared_ptr<Let> x = std::make_shared<Let>(*op); + x->var = var; + return Expr(x); + }); + +TVM_STATIC_IR_FUNCTOR(FSetVarDef, vtable_stmt) +.set_dispatch<LetStmt>([](const LetStmt* op, VarExpr var) { + std::shared_ptr<LetStmt> x = std::make_shared<LetStmt>(*op); + x->var = var; + return Stmt(x); + }) +.set_dispatch<For>([](const For* op, VarExpr var) { + std::shared_ptr<For> x = std::make_shared<For>(*op); + x->loop_var = var; + return Stmt(x); + }); + +class IRVerifySSA : public IRVisitor { + public: + bool is_ssa{true}; + + void Visit(const IRNodeRef& n) final { + if (!is_ssa) return; + static auto& fget_var_def = FGetVarDef::vtable(); + if (fget_var_def.can_dispatch(n)) { + VarExpr v = fget_var_def(n); + if (defined_.count(v.get()) != 0) { + is_ssa = false; return; + } else { + defined_[v.get()] = 1; + } + } + IRVisitor::Visit(n); + } + + private: + std::unordered_map<const Variable*, int> defined_; +}; + +class IRConvertSSA : public IRMutator { + public: + Expr Mutate(Expr expr) final { + static auto& fget_var_def = FGetVarDef::vtable(); + static auto& fset_var_def = FSetVarDef::vtable_expr(); + if (fget_var_def.can_dispatch(expr)) { + VarExpr v = fget_var_def(expr); + VarExpr new_var = v; + if (defined_.count(v.get()) != 0) { + CHECK(expr.as<Allocate>() == nullptr) + << "One allocation in two places, cannot rename buffer in allocate"; + new_var = Variable::make(v->type, v->name_hint); + } else { + defined_.insert(v.get()); + } + scope_[v.get()].push_back(new_var); + Expr new_expr = IRMutator::Mutate(expr); + scope_[v.get()].pop_back(); + + if (!new_var.same_as(v)) { + return fset_var_def(new_expr, new_var); + } else { + return new_expr; + } + } else if (expr.as<Variable>()) { + const Variable* v = expr.as<Variable>(); + if (scope_.count(v) != 0) { + return scope_[v].back(); + } else { + return expr; + } + } else { + Expr e = IRMutator::Mutate(expr); + return e; + + } + } + + Stmt Mutate(Stmt stmt) final { + static auto& fget_var_def = FGetVarDef::vtable(); + static auto& fset_var_def = FSetVarDef::vtable_stmt(); + if (fget_var_def.can_dispatch(stmt)) { + VarExpr v = fget_var_def(stmt); + VarExpr new_var = v; + if (defined_.count(v.get()) != 0) { + new_var = Variable::make(v->type, v->name_hint); + } else { + defined_.insert(v.get()); + } + scope_[v.get()].push_back(new_var); + Stmt new_stmt = IRMutator::Mutate(stmt); + scope_[v.get()].pop_back(); + + if (!new_var.same_as(v)) { + return fset_var_def(new_stmt, new_var); + } else { + return new_stmt; + } + } else { + return IRMutator::Mutate(stmt); + } + } + + private: + std::unordered_map<const Variable*, std::vector<VarExpr> > scope_; + std::unordered_set<const Variable*> defined_; +}; + +} // namespace + +bool VerifySSA(const IRNodeRef& ir) { + IRVerifySSA v; + v.Visit(ir); + return v.is_ssa; +} + +Stmt ConvertSSA(const Stmt& stmt) { + return IRConvertSSA().Mutate(stmt); +} + +} // namespace ir +} // namespace tvm diff --git a/tests/cpp/ir_pass_test.cc b/tests/cpp/ir_ssa_test.cc similarity index 52% rename from tests/cpp/ir_pass_test.cc rename to tests/cpp/ir_ssa_test.cc index 4397cfa12..0f0f9e6da 100644 --- a/tests/cpp/ir_pass_test.cc +++ b/tests/cpp/ir_ssa_test.cc @@ -3,23 +3,25 @@ #include <tvm/tvm.h> #include <tvm/ir_pass.h> -TEST(IRPass, Substitute) { + +TEST(IRSSA, Convert) { + using namespace Halide::Internal; + using namespace tvm; + Var x("x"), y; + Expr let = Let::make(x, 1, x + 1); + + auto z = let + let; + CHECK(!ir::VerifySSA(z)); + auto z_ssa = ir::ConvertSSA(Evaluate::make(z)); + CHECK(ir::VerifySSA(z_ssa)); +} + +TEST(IRSSA, Basic) { using namespace Halide::Internal; using namespace tvm; Var x("x"), y; auto z = x + y; - { - auto zz = ir::Substitute({{y.get(), 11}}, z); - std::ostringstream os; - os << zz; - CHECK(os.str() == "(x + 11)"); - } - { - auto zz = ir::Substitute({{z.get(), 11}}, z); - std::ostringstream os; - os << zz; - CHECK(os.str() == "11"); - } + CHECK(ir::VerifySSA(z)); } int main(int argc, char ** argv) { -- GitLab