Skip to content
Snippets Groups Projects
Commit 7e7c24e1 authored by tqchen's avatar tqchen
Browse files

temp checkin

parent 8e04361c
No related branches found
No related tags found
No related merge requests found
......@@ -27,6 +27,7 @@ using Halide::abs;
using Halide::select;
using Halide::Expr;
using Halide::VarExpr;
using Halide::IR::FunctionBaseNode;
using Halide::Internal::Stmt;
......
......@@ -45,6 +45,49 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Max = "Max";
static constexpr const char* Min = "Min";
};
// Reuse IR node defintiion from HalideIR
using Halide::Internal::IntImm;
using Halide::Internal::UIntImm;
using Halide::Internal::FloatImm;
using Halide::Internal::StringImm;
using Halide::Internal::Cast;
using Halide::Internal::Variable;
using Halide::Internal::Add;
using Halide::Internal::Sub;
using Halide::Internal::Mul;
using Halide::Internal::Div;
using Halide::Internal::Mod;
using Halide::Internal::Min;
using Halide::Internal::Max;
using Halide::Internal::EQ;
using Halide::Internal::NE;
using Halide::Internal::LT;
using Halide::Internal::LE;
using Halide::Internal::GT;
using Halide::Internal::GE;
using Halide::Internal::And;
using Halide::Internal::Or;
using Halide::Internal::Not;
using Halide::Internal::Select;
using Halide::Internal::Load;
using Halide::Internal::Ramp;
using Halide::Internal::Broadcast;
using Halide::Internal::Call;
using Halide::Internal::Let;
using Halide::Internal::LetStmt;
using Halide::Internal::AssertStmt;
using Halide::Internal::ProducerConsumer;
using Halide::Internal::For;
using Halide::Internal::Store;
using Halide::Internal::Provide;
using Halide::Internal::Allocate;
using Halide::Internal::Free;
using Halide::Internal::Realize;
using Halide::Internal::Block;
using Halide::Internal::IfThenElse;
using Halide::Internal::Evaluate;
} // namespace ir
} // namespace tvm
......
......@@ -14,11 +14,35 @@ namespace tvm {
namespace ir {
/*!
* \brief Substitute occurance of IRNode in expr
* \param replacements The replacement rule of substitution
* \param expr The expression to be substituted.
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
*
* \param ir The root of the IR DAG.
* \return Whether IR is in SSA form.
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr);
bool VerifySSA(const IRNodeRef& ir);
/*!
* \brief Convert a IR node to be SSA form.
* \param stmt The source statement to be converted.
* \return The converted form.
*/
Stmt ConvertSSA(const Stmt stmt);
/*!
* \brief inline all calls of f in stmt.
*
* \param f The function reference to be inlined
* \param args The arguments variable of the function.
* \param body The defintion body of the function.
* \param stmt The statement to apply inline optimization.
* \return The result stmt
*
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
Stmt InlineSSA(FunctionRef f, const std::vector<Var>& args, Expr body, Stmt stmt);
} // namespace ir
} // namespace tvm
......
......@@ -10,29 +10,128 @@
namespace tvm {
namespace ir {
namespace {
// visitor to implement apply
class IRSubstitute : public IRMutator {
struct SetVarDef {
// get var definition from node
using FType = IRFunctor<const Variable*(const IRNodeRef&)>;
static FGetVarDef& vtable_get_var_def() { // NOLINT(*)
static FGetVarDef inst; return inst;
}
static FSetVarExpr& vtable_set_var_expr() { // NOLINT(*)
static FSetVarExpr inst; return inst;
}
static FSetVarStmt& vtable_set_var_expr() { // NOLINT(*)
static FSetVarStmt inst; return inst;
}
};
// return a new node to
using FSetVarExpr = IRFunctor<Expr (const IRNodeRef&, VarExpr)>;
// return a new node to
using FSetVarStmt = IRFunctor<Expr (const IRNodeRef&, VarExpr)>;
inline const Variable* GetVarDef(const IRNodeRef& n) {
if (n.as<Let>()) {
return n.as<Let>()->var.get();
} else if (n.as<LetStmt>()) {
return n.as<LetStmt>()->var.get();
} else if (n.as<For>()) {
return n.as<For>()->loop_var.get();
} else if (n.as<Allocate>()) {
return n.as<Allocate>()->buffer_var.get();
} else {
return nullptr;
}
}
inline Expr ResetVar(const Expr& n, VarExpr var) {
if (n.as<Let>()) {
std::shared_ptr<Let> x = std::make_shared<Let>(*n.as<Let>());
x->var = var;
return Expr(x);
} else if (n.as<Allocate>()) {
}
}
inline Stmt ResetVarDef(const Stmt& n, VarExpr var) {
if (n.as<LetStmt>()) {
std::shared_ptr<LetStmt> x = std::make_shared<LetStmt>(*n.as<Let>());
x->var = var;
return Expr(x);
} else if (n.as<For>()) {
std::shared_ptr<For> x = std::make_shared<For>(*n.as<Let>());
x->loop_var = var;
return Expr(x);
} else {
LOG(FATAL) << "not reached";
}
}
class IRVerifySSA : public IRVisitor {
public:
Expr Mutate(Expr expr) final {
const IRNode* v = expr.get();
bool is_ssa{true};
std::unordered_set<const Variable*> defined;
void Visit(const IRNodeRef& n) final {
if (!is_ssa) return;
const Variable* v = GetVarDef(n);
if (v != nullptr) {
auto it = replacements_.find(v);
if (it != replacements_.end()) {
return it->second;
if (defined.count(v) != 0) {
is_ssa = false; return;
} else {
defined.insert(v);
}
}
return IRMutator::Mutate(expr);
IRVisitor::Visit(n);
}
explicit IRSubstitute(const std::unordered_map<const IRNode*, Expr>& replacements)
: replacements_(replacements) {}
};
private:
const std::unordered_map<const IRNode*, Expr>& replacements_;
class IRConvertSSA : public IRMutator {
public:
Expr Mutate(Expr expr) final {
static const auto& f = IRConvertSSA::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::Mutate(expr));
}
Stmt Mutate(Stmt stmt) final {
static const auto& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::Mutate(stmt));
}
using FConvertExpr = IRFunctor<Expr(const IRNodeRef&, const Expr&, IRConvertSSA *)>;
using FConvertStmt = IRFunctor<Stmt(const IRNodeRef&, const Expr&, IRConvertSSA *)>;
std::unordered_map<const Variable*, std::vector<VarExpr> > scope;
std::unordered_set<const Variable*> defined;
};
temple<>
TVM_STATIC_IR_FUNCTOR(IRConvertSSA, vtable_expr)
.set_dispatch<Let>([](const Let* op, const Expr& e, IRConvertSSA* m) {
VarExpr var = op->var;
if (m->defined.count(var.get()) != 0) {
var = Variable::make(var->type, var->name_hint);
}
// insert scope before recursion.
m->scope[var.get()].push_back(var);
Expr new_expr = Mutate(e);
m->scope[var.get()].pop_back();
if (!var.same_as(op->var)) {
std::shared_ptr<Let> x = std::make_shared<Let>(*new_expr.as<Let>());
x->var = var;
return Expr(x);
} else {
return new_expr;
}
});
} // namespace
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr) {
return IRSubstitute(replacements).Mutate(expr);
bool VerifySSA(const IRNodeRef& ir) {
IRVerifySSA v;
v.Visit(ir);
return v.is_ssa;
}
} // namespace ir
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment