diff --git a/HalideIR b/HalideIR index ec84af1359c841df622f683048968348381e328a..89b7939957d66a37dd6083ad6b09a5644e73fd8b 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit ec84af1359c841df622f683048968348381e328a +Subproject commit 89b7939957d66a37dd6083ad6b09a5644e73fd8b diff --git a/include/tvm/domain.h b/include/tvm/domain.h index a2c42a31f106721f917b3edfc5dde1591c877d66..634a72b97be8d82779cd23c22aaa8c45aa38d770 100644 --- a/include/tvm/domain.h +++ b/include/tvm/domain.h @@ -36,6 +36,8 @@ class Range : public Halide::IR::Range { * \param end The end of the range. */ Range(Expr begin, Expr end); + + static Range make_with_min_extent(Expr min, Expr extent); }; /*! \brief Domain is a multi-dimensional range */ @@ -74,6 +76,8 @@ class RDomain : public NodeRef { inline Var i0() const { return index(0); } + // low level constructor + static RDomain make(Array<Var> index, Domain domain); }; /*! \brief use RDom as alias of RDomain */ @@ -88,8 +92,8 @@ class RDomainNode : public Node { Domain domain; /*! \brief constructor */ RDomainNode() {} - RDomainNode(Array<Var> && index, Domain && domain) - : index(std::move(index)), domain(std::move(domain)) { + RDomainNode(Array<Var> index, Domain domain) + : index(index), domain(domain) { } const char* type_key() const override { return "RDomain"; diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 936b02bcbf8b0865534825c46e5cbcc21525c9df..3dbc1a852437ddac7e7a9aa646fd5cd86dbcb0bd 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -8,7 +8,7 @@ #include <ir/Expr.h> #include <ir/IROperator.h> -#include <type_traits> +#include <string> #include "./base.h" namespace tvm { @@ -28,7 +28,12 @@ using Halide::select; using Halide::Expr; using Halide::Internal::Stmt; -using Var = Halide::VarExpr; + +class Var : public Halide::VarExpr { + public: + explicit Var(const std::string& name_hint = "v", + Type t = Int(32)) : VarExpr(name_hint, t) {} +}; } // namespace tvm #endif // TVM_EXPR_H_ diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h new file mode 100644 index 0000000000000000000000000000000000000000..e2ada0cac6f392e9debff652dd37f5c864e3b2a1 --- /dev/null +++ b/include/tvm/ir_mutator.h @@ -0,0 +1,83 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file ir_mutator.h + * \brief Defines general IRMutation pass + */ +#ifndef TVM_IR_MUTATOR_H_ +#define TVM_IR_MUTATOR_H_ + +#include <tvm/ir_node.h> +#include "./expr.h" + +namespace tvm { +namespace ir { +/*! + * \brief a base class for mutator to iterative mutate the IR + * + * This IRMutator is implemented via IRFunctor instead of Visitor Pattern. + * This enables easy extensions of possible new IRNode. + * It also makes changing return types easier. + * + * \note If you want to return a different type other than Expr and Stmt, + * Simply following the same pattern as IRMutator and create a seperate class. + * \sa IRFunctor + */ +class IRMutator { + public: + /*! + * \brief mutate expression + * \return the mutated expr + */ + virtual Expr mutate(Expr expr) { + static const FMutateExpr& f = vtable_expr(); + return f(expr, expr, this); + } + /*! + * \brief mutate expression + * \return the mutated stmt + */ + virtual Stmt mutate(Stmt stmt) { + static const FMutateStmt& f = vtable_stmt(); + return f(stmt, stmt, this); + } + /*! \brief destructor */ + virtual ~IRMutator() {} + /*! \brief functor type of expr mutation */ + using FMutateExpr = IRFunctor<Expr(const IRNodeRef&, const Expr&, IRMutator*)>; + /*! \brief functor type of stmt mutation */ + using FMutateStmt = IRFunctor<Stmt(const IRNodeRef&, const Stmt&, IRMutator*)>; + /*! \return internal vtable of expr */ + static FMutateExpr& vtable_expr(); // NOLINT(*) + /*! \return internal stmt of expr */ + static FMutateStmt& vtable_stmt(); // NOLINT(*) +}; + +/*! + * \brief templatized base class of subclass of IRMutator + * + * Use "curiously recurring template pattern" to implement mutate for you. + * Child class need to declare IRMutatorBase<T>::vtable_expr and IRMutatorBase<T>::vtable_stmt + * + * \note This only implement direct subclass from IRMutator, similar code + * can be created to implement deeper subclassing when needed. + */ +class IRMutatorExample : public IRMutator { + public: + Expr mutate(Expr expr) final { + static const FMutateExpr& f = IRMutatorExample::vtable_expr(); + return (f.can_dispatch(expr) ? + f(expr, expr, this) : IRMutator::mutate(expr)); + } + Stmt mutate(Stmt stmt) final { + static const FMutateStmt& f = IRMutatorExample::vtable_stmt(); + return (f.can_dispatch(stmt) ? + f(stmt, stmt, this) : IRMutator::mutate(stmt)); + } + // to be implemented by child class + static FMutateExpr& vtable_expr(); // NOLINT(*) + static FMutateStmt& vtable_stmt(); // NOLINT(*) +}; + +} // namespace ir +} // namespace tvm +#endif // TVM_IR_MUTATOR_H_ diff --git a/src/lang/domain.cc b/src/lang/domain.cc index 27f2e860b2314a4522c31fc05c1666ab8a2f470e..7e88fb8d406e1cea6482c20fd093b5955758b745 100644 --- a/src/lang/domain.cc +++ b/src/lang/domain.cc @@ -12,6 +12,10 @@ Range::Range(Expr begin, Expr end) // TODO(tqchen) add simplify to end - begin } +Range Range::make_with_min_extent(Expr min, Expr extent) { + return Range(std::make_shared<Halide::IR::RangeNode>(min, extent)); +} + RDomain::RDomain(Domain domain) { std::vector<Var> index; for (size_t i = 0; i < domain.size(); ++i) { @@ -24,6 +28,10 @@ RDomain::RDomain(Domain domain) { std::move(idx), std::move(domain)); } +RDomain RDomain::make(Array<Var> index, Domain domain) { + return RDomain(std::make_shared<RDomainNode>(index, domain)); +} + TVM_REGISTER_NODE_TYPE(RDomainNode); } // namespace tvm diff --git a/src/lang/ir.cc b/src/lang/ir.cc index a702d8f5f7b117aa666e7986ee85c4ce9f3360c0..65fa69bf2d95152c02cedf7ced8f7e3e62d418d4 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -20,7 +20,7 @@ namespace Internal { using tvm::ir::Reduce; template<> -void ExprNode<Reduce>::accept(IRVisitor *v) const { +void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const { LOG(FATAL) << "Reduce do not work with IRVisitor yet"; } diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc new file mode 100644 index 0000000000000000000000000000000000000000..264f1c669b6b38349a14dbd2cd9caa583194d32a --- /dev/null +++ b/src/pass/ir_mutator.cc @@ -0,0 +1,337 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file ir_mutator.cc + */ +#include <tvm/ir.h> +#include <tvm/ir_mutator.h> + +namespace tvm { +namespace ir { + +IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*) + static FMutateExpr inst; return inst; +} + +IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*) + static FMutateStmt inst; return inst; +} + +// namespace to register the functors. +namespace { + +using namespace Halide::Internal; + +// const expr +inline Expr ReturnSelfExpr(const IRNodeRef&, const Expr& e, IRMutator*) { + return e; +} + +inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) { + std::vector<Expr> new_arr(arr.size()); + bool changed = false; + for (size_t i = 0; i < arr.size(); i++) { + Expr old_elem = arr[i]; + Expr new_elem = m->mutate(old_elem); + if (!new_elem.same_as(old_elem)) changed = true; + new_arr[i] = new_elem; + } + if (!changed) { + return arr; + } else { + return Array<Expr>(new_arr); + } +} + +inline RDomain MutateRDom(RDomain rdom, IRMutator *m) { + std::vector<Range> new_dom(rdom->domain.size()); + bool changed = false; + for (size_t i = 0; i < rdom->domain.size(); i++) { + Range r = rdom->domain[i]; + Expr new_min = m->mutate(r->min); + Expr new_extent = m->mutate(r->extent); + if (!r->min.same_as(new_min)) changed = true; + if (!r->extent.same_as(new_extent)) changed = true; + new_dom[i] = Range::make_with_min_extent(new_min, new_extent); + } + if (!changed) { + return rdom; + } else { + return RDomain::make(rdom->index, Domain(new_dom)); + } +} + +TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) +.set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) { + RDomain new_rdom = MutateRDom(op->rdom, m); + Expr new_source = m->mutate(op->source); + if (op->rdom.same_as(new_rdom) && + op->source.same_as(new_source)) { + return e; + } else { + return Reduce::make(op->op, new_source, new_rdom); + } + }); + +TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) +.set_dispatch<IntImm>(ReturnSelfExpr) +.set_dispatch<UIntImm>(ReturnSelfExpr) +.set_dispatch<FloatImm>(ReturnSelfExpr) +.set_dispatch<StringImm>(ReturnSelfExpr) +.set_dispatch<Variable>(ReturnSelfExpr); + +TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) +.set_dispatch<Cast>([](const Cast* op, const Expr& e, IRMutator* m) { + Expr value = m->mutate(op->value); + if (value.same_as(op->value)) { + return e; + } else { + return Cast::make(op->type, value); + } + }); + +// binary operator +template<typename T> +inline Expr Binary(const T* op, const Expr& e, IRMutator* m) { + Expr a = m->mutate(op->a); + Expr b = m->mutate(op->b); + if (a.same_as(op->a) && + b.same_as(op->b)) { + return e; + } else { + return T::make(a, b); + } +} + +TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) +.set_dispatch<Add>(Binary<Add>) +.set_dispatch<Sub>(Binary<Sub>) +.set_dispatch<Mul>(Binary<Mul>) +.set_dispatch<Div>(Binary<Div>) +.set_dispatch<Mod>(Binary<Mod>) +.set_dispatch<Min>(Binary<Min>) +.set_dispatch<Max>(Binary<Max>) +.set_dispatch<EQ>(Binary<EQ>) +.set_dispatch<NE>(Binary<NE>) +.set_dispatch<LT>(Binary<LT>) +.set_dispatch<LE>(Binary<LE>) +.set_dispatch<GT>(Binary<GT>) +.set_dispatch<GE>(Binary<GE>) +.set_dispatch<And>(Binary<And>) +.set_dispatch<Or>(Binary<Or>); + +TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) +.set_dispatch<Not>([](const Not* op, const Expr& e, IRMutator* m) { + Expr a = m->mutate(op->a); + if (a.same_as(op->a)) { + return e; + } else { + return Not::make(a); + } + }) +.set_dispatch<Select>([](const Select *op, const Expr& e, IRMutator* m) { + Expr cond = m->mutate(op->condition); + Expr t = m->mutate(op->true_value); + Expr f = m->mutate(op->false_value); + if (cond.same_as(op->condition) && + t.same_as(op->true_value) && + f.same_as(op->false_value)) { + return e; + } else { + return Select::make(cond, t, f); + } + }) +.set_dispatch<Load>([](const Load *op, const Expr& e, IRMutator* m) { + Expr index = m->mutate(op->index); + if (index.same_as(op->index)) { + return e; + } else { + return Load::make(op->type, op->buffer_var, index); + } + }) +.set_dispatch<Ramp>([](const Ramp *op, const Expr& e, IRMutator* m) { + Expr base = m->mutate(op->base); + Expr stride = m->mutate(op->stride); + if (base.same_as(op->base) && + stride.same_as(op->stride)) { + return e; + } else { + return Ramp::make(base, stride, op->lanes); + } + }) +.set_dispatch<Broadcast>([](const Broadcast *op, const Expr& e, IRMutator* m) { + Expr value = m->mutate(op->value); + if (value.same_as(op->value)) { + return e; + } else { + return Broadcast::make(value, op->lanes); + } + }) +.set_dispatch<Call>([](const Call *op, const Expr& e, IRMutator* m) { + auto new_args = MutateArray(op->args, m); + if (op->args.same_as(new_args)) { + return e; + } else { + return Call::make(op->type, op->name, new_args, op->call_type, + op->func, op->value_index); + } + }) +.set_dispatch<Let>([](const Let *op, const Expr& e, IRMutator* m) { + Expr value = m->mutate(op->value); + Expr body = m->mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return e; + } else { + return Let::make(op->var, value, body); + } + }); + +TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) +.set_dispatch<LetStmt>([](const LetStmt *op, const Stmt& s, IRMutator* m) { + Expr value = m->mutate(op->value); + Stmt body = m->mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return s; + } else { + return LetStmt::make(op->var, value, body); + } + }) +.set_dispatch<AssertStmt>([](const AssertStmt *op, const Stmt& s, IRMutator* m) { + Expr condition = m->mutate(op->condition); + Expr message = m->mutate(op->message); + + if (condition.same_as(op->condition) && message.same_as(op->message)) { + return s; + } else { + return AssertStmt::make(condition, message); + } + }) +.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, const Stmt& s, IRMutator* m) { + Stmt body = m->mutate(op->body); + if (body.same_as(op->body)) { + return s; + } else { + return ProducerConsumer::make(op->func, op->is_producer, body); + } + }) +.set_dispatch<For>([](const For *op, const Stmt& s, IRMutator* m) { + Expr min = m->mutate(op->min); + Expr extent = m->mutate(op->extent); + Stmt body = m->mutate(op->body); + if (min.same_as(op->min) && + extent.same_as(op->extent) && + body.same_as(op->body)) { + return s; + } else { + return For::make( + op->loop_var, min, extent, op->for_type, op->device_api, body); + } + }) +.set_dispatch<Store>([](const Store *op, const Stmt& s, IRMutator* m) { + Expr value = m->mutate(op->value); + Expr index = m->mutate(op->index); + if (value.same_as(op->value) && index.same_as(op->index)) { + return s; + } else { + return Store::make(op->buffer_var, value, index); + } + }) +.set_dispatch<Provide>([](const Provide *op, const Stmt& s, IRMutator* m) { + auto new_args = MutateArray(op->args, m); + auto new_values = MutateArray(op->values, m); + if (op->args.same_as(new_args) && op->values.same_as(new_values)) { + return s; + } else { + return Provide::make(op->func, new_values, new_args); + } + }) +.set_dispatch<Allocate>([](const Allocate *op, const Stmt& s, IRMutator* m) { + std::vector<Expr> new_extents; + bool all_extents_unmodified = true; + for (size_t i = 0; i < op->extents.size(); i++) { + new_extents.push_back(m->mutate(op->extents[i])); + all_extents_unmodified &= new_extents[i].same_as(op->extents[i]); + } + Stmt body = m->mutate(op->body); + Expr condition = m->mutate(op->condition); + Expr new_expr; + if (op->new_expr.defined()) { + new_expr = m->mutate(op->new_expr); + } + if (all_extents_unmodified && + body.same_as(op->body) && + condition.same_as(op->condition) && + new_expr.same_as(op->new_expr)) { + return s; + } else { + return Allocate::make( + op->buffer_var, op->type, + new_extents, condition, body, + new_expr, op->free_function); + } + }) +.set_dispatch<Free>([](const Free *op, const Stmt& s, IRMutator* m) { + return s; + }) +.set_dispatch<Realize>([](const Realize *op, const Stmt& s, IRMutator* m) { + Region new_bounds; + bool bounds_changed = false; + + // Mutate the bounds + for (size_t i = 0; i < op->bounds.size(); i++) { + Expr old_min = op->bounds[i]->min; + Expr old_extent = op->bounds[i]->extent; + Expr new_min = m->mutate(old_min); + Expr new_extent = m->mutate(old_extent); + if (!new_min.same_as(old_min)) bounds_changed = true; + if (!new_extent.same_as(old_extent)) bounds_changed = true; + new_bounds.push_back( + Range::make_by_min_extent(new_min, new_extent)); + } + + Stmt body = m->mutate(op->body); + Expr condition = m->mutate(op->condition); + if (!bounds_changed && + body.same_as(op->body) && + condition.same_as(op->condition)) { + return s; + } else { + return Realize::make(op->func, op->types, new_bounds, + condition, body); + } + }) +.set_dispatch<Block>([](const Block *op, const Stmt& s, IRMutator* m) { + Stmt first = m->mutate(op->first); + Stmt rest = m->mutate(op->rest); + if (first.same_as(op->first) && + rest.same_as(op->rest)) { + return s; + } else { + return Block::make(first, rest); + } + }) +.set_dispatch<IfThenElse>([](const IfThenElse *op, const Stmt& s, IRMutator* m) { + Expr condition = m->mutate(op->condition); + Stmt then_case = m->mutate(op->then_case); + Stmt else_case = m->mutate(op->else_case); + if (condition.same_as(op->condition) && + then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return s; + } else { + return IfThenElse::make(condition, then_case, else_case); + } + }) +.set_dispatch<Evaluate>([](const Evaluate *op, const Stmt& s, IRMutator* m) { + Expr v = m->mutate(op->value); + if (v.same_as(op->value)) { + return s; + } else { + return Evaluate::make(v); + } + }); + +} // namespace +} // namespace ir +} // namespace tvm diff --git a/tests/cpp/ir_mutator_test.cc b/tests/cpp/ir_mutator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..93843d7ff209d31fba9265828e07293c3ca39712 --- /dev/null +++ b/tests/cpp/ir_mutator_test.cc @@ -0,0 +1,59 @@ +#include <dmlc/logging.h> +#include <gtest/gtest.h> +#include <tvm/tvm.h> +#include <tvm/ir_mutator.h> + +namespace { +using namespace tvm::ir; +using namespace Halide::Internal; +using namespace Halide; + +// replace variable to constant +class IRVar2Const : public IRMutator { + public: + VarExpr var; + int int_val; + Expr mutate(Expr expr) final { + static const FMutateExpr& f = IRVar2Const::vtable_expr(); + return (f.can_dispatch(expr) ? + f(expr, expr, this) : IRMutator::mutate(expr)); + } + static FMutateExpr &vtable_expr(); +}; + +// implement vtable +IRMutator::FMutateExpr &IRVar2Const::vtable_expr() { // NOLINT(*) + static FMutateExpr inst; return inst; +} + +TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr) +.set_dispatch<Variable>([](const Variable* op, const Expr &e, IRMutator* m) { + IRVar2Const* vm = static_cast<IRVar2Const*>(m); + if (e.same_as(vm->var)) { + return IntImm::make(Int(32), vm->int_val); + } else { + return e; + } + }); + +} // namespace + +TEST(IRMutator, Basic) { + using namespace Halide::Internal; + using namespace tvm; + Var x("x"), y; + auto z = x + y; + IRVar2Const mu; + mu.var = y; + mu.int_val = 10; + auto zz = mu.mutate(z); + std::ostringstream os; + os << zz; + CHECK(os.str() == "(x + 10)"); +} + +int main(int argc, char ** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +}