Skip to content
Snippets Groups Projects
Commit be8de13f authored by tqchen's avatar tqchen
Browse files

Enable IRFunctor based IRMutator

parent 0a392dd0
No related branches found
No related tags found
No related merge requests found
Subproject commit ec84af1359c841df622f683048968348381e328a
Subproject commit 89b7939957d66a37dd6083ad6b09a5644e73fd8b
......@@ -36,6 +36,8 @@ class Range : public Halide::IR::Range {
* \param end The end of the range.
*/
Range(Expr begin, Expr end);
static Range make_with_min_extent(Expr min, Expr extent);
};
/*! \brief Domain is a multi-dimensional range */
......@@ -74,6 +76,8 @@ class RDomain : public NodeRef {
inline Var i0() const {
return index(0);
}
// low level constructor
static RDomain make(Array<Var> index, Domain domain);
};
/*! \brief use RDom as alias of RDomain */
......@@ -88,8 +92,8 @@ class RDomainNode : public Node {
Domain domain;
/*! \brief constructor */
RDomainNode() {}
RDomainNode(Array<Var> && index, Domain && domain)
: index(std::move(index)), domain(std::move(domain)) {
RDomainNode(Array<Var> index, Domain domain)
: index(index), domain(domain) {
}
const char* type_key() const override {
return "RDomain";
......
......@@ -8,7 +8,7 @@
#include <ir/Expr.h>
#include <ir/IROperator.h>
#include <type_traits>
#include <string>
#include "./base.h"
namespace tvm {
......@@ -28,7 +28,12 @@ using Halide::select;
using Halide::Expr;
using Halide::Internal::Stmt;
using Var = Halide::VarExpr;
class Var : public Halide::VarExpr {
public:
explicit Var(const std::string& name_hint = "v",
Type t = Int(32)) : VarExpr(name_hint, t) {}
};
} // namespace tvm
#endif // TVM_EXPR_H_
/*!
* Copyright (c) 2016 by Contributors
* \file ir_mutator.h
* \brief Defines general IRMutation pass
*/
#ifndef TVM_IR_MUTATOR_H_
#define TVM_IR_MUTATOR_H_
#include <tvm/ir_node.h>
#include "./expr.h"
namespace tvm {
namespace ir {
/*!
* \brief a base class for mutator to iterative mutate the IR
*
* This IRMutator is implemented via IRFunctor instead of Visitor Pattern.
* This enables easy extensions of possible new IRNode.
* It also makes changing return types easier.
*
* \note If you want to return a different type other than Expr and Stmt,
* Simply following the same pattern as IRMutator and create a seperate class.
* \sa IRFunctor
*/
class IRMutator {
public:
/*!
* \brief mutate expression
* \return the mutated expr
*/
virtual Expr mutate(Expr expr) {
static const FMutateExpr& f = vtable_expr();
return f(expr, expr, this);
}
/*!
* \brief mutate expression
* \return the mutated stmt
*/
virtual Stmt mutate(Stmt stmt) {
static const FMutateStmt& f = vtable_stmt();
return f(stmt, stmt, this);
}
/*! \brief destructor */
virtual ~IRMutator() {}
/*! \brief functor type of expr mutation */
using FMutateExpr = IRFunctor<Expr(const IRNodeRef&, const Expr&, IRMutator*)>;
/*! \brief functor type of stmt mutation */
using FMutateStmt = IRFunctor<Stmt(const IRNodeRef&, const Stmt&, IRMutator*)>;
/*! \return internal vtable of expr */
static FMutateExpr& vtable_expr(); // NOLINT(*)
/*! \return internal stmt of expr */
static FMutateStmt& vtable_stmt(); // NOLINT(*)
};
/*!
* \brief templatized base class of subclass of IRMutator
*
* Use "curiously recurring template pattern" to implement mutate for you.
* Child class need to declare IRMutatorBase<T>::vtable_expr and IRMutatorBase<T>::vtable_stmt
*
* \note This only implement direct subclass from IRMutator, similar code
* can be created to implement deeper subclassing when needed.
*/
class IRMutatorExample : public IRMutator {
public:
Expr mutate(Expr expr) final {
static const FMutateExpr& f = IRMutatorExample::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::mutate(expr));
}
Stmt mutate(Stmt stmt) final {
static const FMutateStmt& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::mutate(stmt));
}
// to be implemented by child class
static FMutateExpr& vtable_expr(); // NOLINT(*)
static FMutateStmt& vtable_stmt(); // NOLINT(*)
};
} // namespace ir
} // namespace tvm
#endif // TVM_IR_MUTATOR_H_
......@@ -12,6 +12,10 @@ Range::Range(Expr begin, Expr end)
// TODO(tqchen) add simplify to end - begin
}
Range Range::make_with_min_extent(Expr min, Expr extent) {
return Range(std::make_shared<Halide::IR::RangeNode>(min, extent));
}
RDomain::RDomain(Domain domain) {
std::vector<Var> index;
for (size_t i = 0; i < domain.size(); ++i) {
......@@ -24,6 +28,10 @@ RDomain::RDomain(Domain domain) {
std::move(idx), std::move(domain));
}
RDomain RDomain::make(Array<Var> index, Domain domain) {
return RDomain(std::make_shared<RDomainNode>(index, domain));
}
TVM_REGISTER_NODE_TYPE(RDomainNode);
} // namespace tvm
......@@ -20,7 +20,7 @@ namespace Internal {
using tvm::ir::Reduce;
template<>
void ExprNode<Reduce>::accept(IRVisitor *v) const {
void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Reduce do not work with IRVisitor yet";
}
......
/*!
* Copyright (c) 2016 by Contributors
* \file ir_mutator.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
namespace tvm {
namespace ir {
IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*)
static FMutateStmt inst; return inst;
}
// namespace to register the functors.
namespace {
using namespace Halide::Internal;
// const expr
inline Expr ReturnSelfExpr(const IRNodeRef&, const Expr& e, IRMutator*) {
return e;
}
inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
std::vector<Expr> new_arr(arr.size());
bool changed = false;
for (size_t i = 0; i < arr.size(); i++) {
Expr old_elem = arr[i];
Expr new_elem = m->mutate(old_elem);
if (!new_elem.same_as(old_elem)) changed = true;
new_arr[i] = new_elem;
}
if (!changed) {
return arr;
} else {
return Array<Expr>(new_arr);
}
}
inline RDomain MutateRDom(RDomain rdom, IRMutator *m) {
std::vector<Range> new_dom(rdom->domain.size());
bool changed = false;
for (size_t i = 0; i < rdom->domain.size(); i++) {
Range r = rdom->domain[i];
Expr new_min = m->mutate(r->min);
Expr new_extent = m->mutate(r->extent);
if (!r->min.same_as(new_min)) changed = true;
if (!r->extent.same_as(new_extent)) changed = true;
new_dom[i] = Range::make_with_min_extent(new_min, new_extent);
}
if (!changed) {
return rdom;
} else {
return RDomain::make(rdom->index, Domain(new_dom));
}
}
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) {
RDomain new_rdom = MutateRDom(op->rdom, m);
Expr new_source = m->mutate(op->source);
if (op->rdom.same_as(new_rdom) &&
op->source.same_as(new_source)) {
return e;
} else {
return Reduce::make(op->op, new_source, new_rdom);
}
});
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<IntImm>(ReturnSelfExpr)
.set_dispatch<UIntImm>(ReturnSelfExpr)
.set_dispatch<FloatImm>(ReturnSelfExpr)
.set_dispatch<StringImm>(ReturnSelfExpr)
.set_dispatch<Variable>(ReturnSelfExpr);
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Cast>([](const Cast* op, const Expr& e, IRMutator* m) {
Expr value = m->mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
return Cast::make(op->type, value);
}
});
// binary operator
template<typename T>
inline Expr Binary(const T* op, const Expr& e, IRMutator* m) {
Expr a = m->mutate(op->a);
Expr b = m->mutate(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
return e;
} else {
return T::make(a, b);
}
}
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Add>(Binary<Add>)
.set_dispatch<Sub>(Binary<Sub>)
.set_dispatch<Mul>(Binary<Mul>)
.set_dispatch<Div>(Binary<Div>)
.set_dispatch<Mod>(Binary<Mod>)
.set_dispatch<Min>(Binary<Min>)
.set_dispatch<Max>(Binary<Max>)
.set_dispatch<EQ>(Binary<EQ>)
.set_dispatch<NE>(Binary<NE>)
.set_dispatch<LT>(Binary<LT>)
.set_dispatch<LE>(Binary<LE>)
.set_dispatch<GT>(Binary<GT>)
.set_dispatch<GE>(Binary<GE>)
.set_dispatch<And>(Binary<And>)
.set_dispatch<Or>(Binary<Or>);
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Not>([](const Not* op, const Expr& e, IRMutator* m) {
Expr a = m->mutate(op->a);
if (a.same_as(op->a)) {
return e;
} else {
return Not::make(a);
}
})
.set_dispatch<Select>([](const Select *op, const Expr& e, IRMutator* m) {
Expr cond = m->mutate(op->condition);
Expr t = m->mutate(op->true_value);
Expr f = m->mutate(op->false_value);
if (cond.same_as(op->condition) &&
t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return e;
} else {
return Select::make(cond, t, f);
}
})
.set_dispatch<Load>([](const Load *op, const Expr& e, IRMutator* m) {
Expr index = m->mutate(op->index);
if (index.same_as(op->index)) {
return e;
} else {
return Load::make(op->type, op->buffer_var, index);
}
})
.set_dispatch<Ramp>([](const Ramp *op, const Expr& e, IRMutator* m) {
Expr base = m->mutate(op->base);
Expr stride = m->mutate(op->stride);
if (base.same_as(op->base) &&
stride.same_as(op->stride)) {
return e;
} else {
return Ramp::make(base, stride, op->lanes);
}
})
.set_dispatch<Broadcast>([](const Broadcast *op, const Expr& e, IRMutator* m) {
Expr value = m->mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
return Broadcast::make(value, op->lanes);
}
})
.set_dispatch<Call>([](const Call *op, const Expr& e, IRMutator* m) {
auto new_args = MutateArray(op->args, m);
if (op->args.same_as(new_args)) {
return e;
} else {
return Call::make(op->type, op->name, new_args, op->call_type,
op->func, op->value_index);
}
})
.set_dispatch<Let>([](const Let *op, const Expr& e, IRMutator* m) {
Expr value = m->mutate(op->value);
Expr body = m->mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return e;
} else {
return Let::make(op->var, value, body);
}
});
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.set_dispatch<LetStmt>([](const LetStmt *op, const Stmt& s, IRMutator* m) {
Expr value = m->mutate(op->value);
Stmt body = m->mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return LetStmt::make(op->var, value, body);
}
})
.set_dispatch<AssertStmt>([](const AssertStmt *op, const Stmt& s, IRMutator* m) {
Expr condition = m->mutate(op->condition);
Expr message = m->mutate(op->message);
if (condition.same_as(op->condition) && message.same_as(op->message)) {
return s;
} else {
return AssertStmt::make(condition, message);
}
})
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, const Stmt& s, IRMutator* m) {
Stmt body = m->mutate(op->body);
if (body.same_as(op->body)) {
return s;
} else {
return ProducerConsumer::make(op->func, op->is_producer, body);
}
})
.set_dispatch<For>([](const For *op, const Stmt& s, IRMutator* m) {
Expr min = m->mutate(op->min);
Expr extent = m->mutate(op->extent);
Stmt body = m->mutate(op->body);
if (min.same_as(op->min) &&
extent.same_as(op->extent) &&
body.same_as(op->body)) {
return s;
} else {
return For::make(
op->loop_var, min, extent, op->for_type, op->device_api, body);
}
})
.set_dispatch<Store>([](const Store *op, const Stmt& s, IRMutator* m) {
Expr value = m->mutate(op->value);
Expr index = m->mutate(op->index);
if (value.same_as(op->value) && index.same_as(op->index)) {
return s;
} else {
return Store::make(op->buffer_var, value, index);
}
})
.set_dispatch<Provide>([](const Provide *op, const Stmt& s, IRMutator* m) {
auto new_args = MutateArray(op->args, m);
auto new_values = MutateArray(op->values, m);
if (op->args.same_as(new_args) && op->values.same_as(new_values)) {
return s;
} else {
return Provide::make(op->func, new_values, new_args);
}
})
.set_dispatch<Allocate>([](const Allocate *op, const Stmt& s, IRMutator* m) {
std::vector<Expr> new_extents;
bool all_extents_unmodified = true;
for (size_t i = 0; i < op->extents.size(); i++) {
new_extents.push_back(m->mutate(op->extents[i]));
all_extents_unmodified &= new_extents[i].same_as(op->extents[i]);
}
Stmt body = m->mutate(op->body);
Expr condition = m->mutate(op->condition);
Expr new_expr;
if (op->new_expr.defined()) {
new_expr = m->mutate(op->new_expr);
}
if (all_extents_unmodified &&
body.same_as(op->body) &&
condition.same_as(op->condition) &&
new_expr.same_as(op->new_expr)) {
return s;
} else {
return Allocate::make(
op->buffer_var, op->type,
new_extents, condition, body,
new_expr, op->free_function);
}
})
.set_dispatch<Free>([](const Free *op, const Stmt& s, IRMutator* m) {
return s;
})
.set_dispatch<Realize>([](const Realize *op, const Stmt& s, IRMutator* m) {
Region new_bounds;
bool bounds_changed = false;
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
Expr old_min = op->bounds[i]->min;
Expr old_extent = op->bounds[i]->extent;
Expr new_min = m->mutate(old_min);
Expr new_extent = m->mutate(old_extent);
if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_extent.same_as(old_extent)) bounds_changed = true;
new_bounds.push_back(
Range::make_by_min_extent(new_min, new_extent));
}
Stmt body = m->mutate(op->body);
Expr condition = m->mutate(op->condition);
if (!bounds_changed &&
body.same_as(op->body) &&
condition.same_as(op->condition)) {
return s;
} else {
return Realize::make(op->func, op->types, new_bounds,
condition, body);
}
})
.set_dispatch<Block>([](const Block *op, const Stmt& s, IRMutator* m) {
Stmt first = m->mutate(op->first);
Stmt rest = m->mutate(op->rest);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return s;
} else {
return Block::make(first, rest);
}
})
.set_dispatch<IfThenElse>([](const IfThenElse *op, const Stmt& s, IRMutator* m) {
Expr condition = m->mutate(op->condition);
Stmt then_case = m->mutate(op->then_case);
Stmt else_case = m->mutate(op->else_case);
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(condition, then_case, else_case);
}
})
.set_dispatch<Evaluate>([](const Evaluate *op, const Stmt& s, IRMutator* m) {
Expr v = m->mutate(op->value);
if (v.same_as(op->value)) {
return s;
} else {
return Evaluate::make(v);
}
});
} // namespace
} // namespace ir
} // namespace tvm
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_mutator.h>
namespace {
using namespace tvm::ir;
using namespace Halide::Internal;
using namespace Halide;
// replace variable to constant
class IRVar2Const : public IRMutator {
public:
VarExpr var;
int int_val;
Expr mutate(Expr expr) final {
static const FMutateExpr& f = IRVar2Const::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::mutate(expr));
}
static FMutateExpr &vtable_expr();
};
// implement vtable
IRMutator::FMutateExpr &IRVar2Const::vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr)
.set_dispatch<Variable>([](const Variable* op, const Expr &e, IRMutator* m) {
IRVar2Const* vm = static_cast<IRVar2Const*>(m);
if (e.same_as(vm->var)) {
return IntImm::make(Int(32), vm->int_val);
} else {
return e;
}
});
} // namespace
TEST(IRMutator, Basic) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
IRVar2Const mu;
mu.var = y;
mu.int_val = 10;
auto zz = mu.mutate(z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "(x + 10)");
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment