Skip to content
Snippets Groups Projects
Commit 38f03f1f authored by tqchen's avatar tqchen
Browse files

SSA Pass

parent 7e7c24e1
No related branches found
No related tags found
No related merge requests found
Subproject commit 4becbde67c8aa565941b02648cea90f50211f8dc
Subproject commit 24a7c0357a6a8db5db782d320aad7f706ebe8507
export CXX=g++
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC
......
......@@ -27,7 +27,9 @@ using Halide::abs;
using Halide::select;
using Halide::Expr;
using Halide::VarExpr;
using Halide::IR::FunctionRef;
using Halide::IR::FunctionBaseNode;
using Halide::Internal::Stmt;
......
......@@ -28,7 +28,7 @@ bool VerifySSA(const IRNodeRef& ir);
* \param stmt The source statement to be converted.
* \return The converted form.
*/
Stmt ConvertSSA(const Stmt stmt);
Stmt ConvertSSA(const Stmt& stmt);
/*!
* \brief inline all calls of f in stmt.
......
/*!
* Copyright (c) 2016 by Contributors
* \file ir_pass.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
namespace tvm {
namespace ir {
namespace {
struct SetVarDef {
// get var definition from node
using FType = IRFunctor<const Variable*(const IRNodeRef&)>;
static FGetVarDef& vtable_get_var_def() { // NOLINT(*)
static FGetVarDef inst; return inst;
}
static FSetVarExpr& vtable_set_var_expr() { // NOLINT(*)
static FSetVarExpr inst; return inst;
}
static FSetVarStmt& vtable_set_var_expr() { // NOLINT(*)
static FSetVarStmt inst; return inst;
}
};
// return a new node to
using FSetVarExpr = IRFunctor<Expr (const IRNodeRef&, VarExpr)>;
// return a new node to
using FSetVarStmt = IRFunctor<Expr (const IRNodeRef&, VarExpr)>;
inline const Variable* GetVarDef(const IRNodeRef& n) {
if (n.as<Let>()) {
return n.as<Let>()->var.get();
} else if (n.as<LetStmt>()) {
return n.as<LetStmt>()->var.get();
} else if (n.as<For>()) {
return n.as<For>()->loop_var.get();
} else if (n.as<Allocate>()) {
return n.as<Allocate>()->buffer_var.get();
} else {
return nullptr;
}
}
inline Expr ResetVar(const Expr& n, VarExpr var) {
if (n.as<Let>()) {
std::shared_ptr<Let> x = std::make_shared<Let>(*n.as<Let>());
x->var = var;
return Expr(x);
} else if (n.as<Allocate>()) {
}
}
inline Stmt ResetVarDef(const Stmt& n, VarExpr var) {
if (n.as<LetStmt>()) {
std::shared_ptr<LetStmt> x = std::make_shared<LetStmt>(*n.as<Let>());
x->var = var;
return Expr(x);
} else if (n.as<For>()) {
std::shared_ptr<For> x = std::make_shared<For>(*n.as<Let>());
x->loop_var = var;
return Expr(x);
} else {
LOG(FATAL) << "not reached";
}
}
class IRVerifySSA : public IRVisitor {
public:
bool is_ssa{true};
std::unordered_set<const Variable*> defined;
void Visit(const IRNodeRef& n) final {
if (!is_ssa) return;
const Variable* v = GetVarDef(n);
if (v != nullptr) {
if (defined.count(v) != 0) {
is_ssa = false; return;
} else {
defined.insert(v);
}
}
IRVisitor::Visit(n);
}
};
class IRConvertSSA : public IRMutator {
public:
Expr Mutate(Expr expr) final {
static const auto& f = IRConvertSSA::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::Mutate(expr));
}
Stmt Mutate(Stmt stmt) final {
static const auto& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::Mutate(stmt));
}
using FConvertExpr = IRFunctor<Expr(const IRNodeRef&, const Expr&, IRConvertSSA *)>;
using FConvertStmt = IRFunctor<Stmt(const IRNodeRef&, const Expr&, IRConvertSSA *)>;
std::unordered_map<const Variable*, std::vector<VarExpr> > scope;
std::unordered_set<const Variable*> defined;
};
temple<>
TVM_STATIC_IR_FUNCTOR(IRConvertSSA, vtable_expr)
.set_dispatch<Let>([](const Let* op, const Expr& e, IRConvertSSA* m) {
VarExpr var = op->var;
if (m->defined.count(var.get()) != 0) {
var = Variable::make(var->type, var->name_hint);
}
// insert scope before recursion.
m->scope[var.get()].push_back(var);
Expr new_expr = Mutate(e);
m->scope[var.get()].pop_back();
if (!var.same_as(op->var)) {
std::shared_ptr<Let> x = std::make_shared<Let>(*new_expr.as<Let>());
x->var = var;
return Expr(x);
} else {
return new_expr;
}
});
} // namespace
bool VerifySSA(const IRNodeRef& ir) {
IRVerifySSA v;
v.Visit(ir);
return v.is_ssa;
}
} // namespace ir
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* SSA related checks and pass.
* \file ssa.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace ir {
namespace {
// global functor to get var definition from
struct FGetVarDef {
using FType = IRFunctor<VarExpr (const IRNodeRef&)>;
static FType& vtable() { // NOLINT(*)
static FType inst; return inst;
}
};
TVM_STATIC_IR_FUNCTOR(FGetVarDef, vtable)
.set_dispatch<Let>([](const Let* op) {
return op->var;
})
.set_dispatch<LetStmt>([](const LetStmt* op) {
return op->var;
})
.set_dispatch<For>([](const For* op) {
return op->loop_var;
})
.set_dispatch<Allocate>([](const Allocate* op) {
return op->buffer_var;
});
struct FSetVarDef {
using FTypeExpr = IRFunctor<Expr (const IRNodeRef&, VarExpr)>;
using FTypeStmt = IRFunctor<Stmt (const IRNodeRef&, VarExpr)>;
static FTypeExpr& vtable_expr() { // NOLINT(*)
static FTypeExpr inst; return inst;
}
static FTypeStmt& vtable_stmt() { // NOLINT(*)
static FTypeStmt inst; return inst;
}
};
TVM_STATIC_IR_FUNCTOR(FSetVarDef, vtable_expr)
.set_dispatch<Let>([](const Let* op, VarExpr var) {
std::shared_ptr<Let> x = std::make_shared<Let>(*op);
x->var = var;
return Expr(x);
});
TVM_STATIC_IR_FUNCTOR(FSetVarDef, vtable_stmt)
.set_dispatch<LetStmt>([](const LetStmt* op, VarExpr var) {
std::shared_ptr<LetStmt> x = std::make_shared<LetStmt>(*op);
x->var = var;
return Stmt(x);
})
.set_dispatch<For>([](const For* op, VarExpr var) {
std::shared_ptr<For> x = std::make_shared<For>(*op);
x->loop_var = var;
return Stmt(x);
});
class IRVerifySSA : public IRVisitor {
public:
bool is_ssa{true};
void Visit(const IRNodeRef& n) final {
if (!is_ssa) return;
static auto& fget_var_def = FGetVarDef::vtable();
if (fget_var_def.can_dispatch(n)) {
VarExpr v = fget_var_def(n);
if (defined_.count(v.get()) != 0) {
is_ssa = false; return;
} else {
defined_[v.get()] = 1;
}
}
IRVisitor::Visit(n);
}
private:
std::unordered_map<const Variable*, int> defined_;
};
class IRConvertSSA : public IRMutator {
public:
Expr Mutate(Expr expr) final {
static auto& fget_var_def = FGetVarDef::vtable();
static auto& fset_var_def = FSetVarDef::vtable_expr();
if (fget_var_def.can_dispatch(expr)) {
VarExpr v = fget_var_def(expr);
VarExpr new_var = v;
if (defined_.count(v.get()) != 0) {
CHECK(expr.as<Allocate>() == nullptr)
<< "One allocation in two places, cannot rename buffer in allocate";
new_var = Variable::make(v->type, v->name_hint);
} else {
defined_.insert(v.get());
}
scope_[v.get()].push_back(new_var);
Expr new_expr = IRMutator::Mutate(expr);
scope_[v.get()].pop_back();
if (!new_var.same_as(v)) {
return fset_var_def(new_expr, new_var);
} else {
return new_expr;
}
} else if (expr.as<Variable>()) {
const Variable* v = expr.as<Variable>();
if (scope_.count(v) != 0) {
return scope_[v].back();
} else {
return expr;
}
} else {
Expr e = IRMutator::Mutate(expr);
return e;
}
}
Stmt Mutate(Stmt stmt) final {
static auto& fget_var_def = FGetVarDef::vtable();
static auto& fset_var_def = FSetVarDef::vtable_stmt();
if (fget_var_def.can_dispatch(stmt)) {
VarExpr v = fget_var_def(stmt);
VarExpr new_var = v;
if (defined_.count(v.get()) != 0) {
new_var = Variable::make(v->type, v->name_hint);
} else {
defined_.insert(v.get());
}
scope_[v.get()].push_back(new_var);
Stmt new_stmt = IRMutator::Mutate(stmt);
scope_[v.get()].pop_back();
if (!new_var.same_as(v)) {
return fset_var_def(new_stmt, new_var);
} else {
return new_stmt;
}
} else {
return IRMutator::Mutate(stmt);
}
}
private:
std::unordered_map<const Variable*, std::vector<VarExpr> > scope_;
std::unordered_set<const Variable*> defined_;
};
} // namespace
bool VerifySSA(const IRNodeRef& ir) {
IRVerifySSA v;
v.Visit(ir);
return v.is_ssa;
}
Stmt ConvertSSA(const Stmt& stmt) {
return IRConvertSSA().Mutate(stmt);
}
} // namespace ir
} // namespace tvm
......@@ -3,23 +3,25 @@
#include <tvm/tvm.h>
#include <tvm/ir_pass.h>
TEST(IRPass, Substitute) {
TEST(IRSSA, Convert) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
Expr let = Let::make(x, 1, x + 1);
auto z = let + let;
CHECK(!ir::VerifySSA(z));
auto z_ssa = ir::ConvertSSA(Evaluate::make(z));
CHECK(ir::VerifySSA(z_ssa));
}
TEST(IRSSA, Basic) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
{
auto zz = ir::Substitute({{y.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "(x + 11)");
}
{
auto zz = ir::Substitute({{z.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "11");
}
CHECK(ir::VerifySSA(z));
}
int main(int argc, char ** argv) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment