From 8a66ac230f4c92ff3b5b6fcb8dec4aa7ec8e6eb4 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Tue, 4 Jul 2017 21:53:15 -0700 Subject: [PATCH] [PASS/OP/REFACTOR] IRDeepCompare, isolate computeop part, allow fuzzy bind (#218) --- include/tvm/ir_functor_ext.h | 1 + include/tvm/ir_pass.h | 24 +- src/op/compute_op.cc | 233 +++---------- src/op/compute_op.h | 68 ++++ src/op/cross_thread_reduction.cc | 120 +++++++ src/op/op_util.cc | 12 +- src/op/op_util.h | 10 + src/pass/arg_binder.cc | 31 +- src/pass/arg_binder.h | 4 +- src/pass/ir_deep_compare.cc | 417 +++++++++++++++++++++++ src/pass/storage_flatten.cc | 2 +- tests/python/unittest/test_pass_equal.py | 48 +++ 12 files changed, 775 insertions(+), 195 deletions(-) create mode 100644 src/op/compute_op.h create mode 100644 src/op/cross_thread_reduction.cc create mode 100644 src/pass/ir_deep_compare.cc create mode 100644 tests/python/unittest/test_pass_equal.py diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index 6feb75566..55368fbea 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -137,6 +137,7 @@ class ExprFunctor<R(const Expr& n, Args...)> { virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Broadcast* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Shuffle* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const IntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 714733a19..872fca353 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -23,14 +23,6 @@ namespace tvm { namespace ir { -inline bool Equal(Expr a, Expr b) { - return Halide::Internal::equal(a, b); -} - -inline bool Equal(Stmt a, Stmt b) { - return Halide::Internal::equal(a, b); -} - inline Expr Simplify(Expr a) { return Halide::Internal::simplify(a); } @@ -39,6 +31,22 @@ inline Stmt Simplify(Stmt a) { return Halide::Internal::simplify(a); } +/*! + * \brief Deep compare lhs and rhs + * \param lhs The left operand + * \param rhs The right operand + * \return The comparison result. + */ +bool Equal(const Expr& lhs, const Expr& rhs); + +/*! + * \brief Deep compare lhs and rhs + * \param lhs The left operand + * \param rhs The right operand + * \return The comparison result. + */ +bool Equal(const Stmt& lhs, const Stmt& rhs); + /*! * \brief verifies whether the IR stmt or Expr is in SSA form. * That is: each VarExpr is defined and assigned once(in Let/For) diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index a4bb99e1b..abb83b6ec 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -9,6 +9,7 @@ #include <tvm/ir_visitor.h> #include <tvm/ir_pass.h> #include <unordered_set> +#include "./compute_op.h" #include "./op_util.h" #include "../schedule/message_passing.h" @@ -242,124 +243,6 @@ void MakeReduction(const ComputeOpNode* op, } } -Stmt Substitute(Stmt s, - const std::unordered_map<IterVar, Expr>& value_map) { - Map<Var, Expr> temp; - for (const auto& kv : value_map) { - temp.Set(kv.first->var, kv.second); - } - return ir::Substitute(s, temp); -} - -// Cross Thread reduction -bool IsCrossThreadReduction(const ComputeOpNode* self, - const Stage& stage) { - // Verify correctness of leaf nest. - int normal_red = 0, thread_red = 0; - for (IterVar iv : stage->leaf_iter_vars) { - if (iv->iter_type == kCommReduce) { - auto it = stage->iter_var_attrs.find(iv); - if (it != stage->iter_var_attrs.end() && - (*it).second->bind_thread.defined()) { - ++thread_red; - } else { - ++normal_red; - } - } else { - CHECK_EQ(thread_red, 0) - << "Cross thread reduce cannot swap with normal data axis"; - } - } - CHECK(normal_red == 0 || thread_red == 0) - << "Cannot mix normal reduction with thread reduce"; - return thread_red != 0; -} - -Stmt MakeCrossThreadReduction( - const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map) { - Array<Expr> args; - for (IterVar iv : self->axis) { - args.push_back(iv->var); - } - std::unordered_map<IterVar, Expr> value_map; - auto nest = op::MakeLoopNest( - stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map); - auto conds = op::MakeBoundCheck( - stage, dom_map, false, - std::unordered_set<IterVar>(), value_map); - - size_t size = self->body.size(); - CHECK_GT(size, 0); - std::vector<const Reduce*> reduces(size); - for (size_t i = 0; i < size; ++i) { - const Reduce* reduce = self->body[i].as<Reduce>(); - CHECK(reduce); - reduces[i] = reduce; - } - Expr cond = reduces[0]->condition; - for (Expr v : conds) { - cond = cond && v; - } - Array<Expr> freduce_args; - freduce_args.push_back(make_const(UInt(32), static_cast<uint32_t>(size))); - for (size_t i = 0; i < size; ++i) { - freduce_args.push_back(reduces[0]->source[i]); - } - freduce_args.push_back(cond); - std::vector<Var> res_handles(size); - for (size_t idx = 0; idx < size; ++idx) { - res_handles[idx] = Var("reduce_temp" + std::to_string(idx), Handle()); - freduce_args.push_back(res_handles[idx]); - } - - for (IterVar iv : stage->leaf_iter_vars) { - if (iv->iter_type == kCommReduce) { - auto it = stage->iter_var_attrs.find(iv); - if (it != stage->iter_var_attrs.end() && - (*it).second->bind_thread.defined()) { - IterVar tv = (*it).second->bind_thread; - freduce_args.push_back(tv->var); - } - } - } - // Checks for the thread. - std::vector<Expr> thread_head_check; - if (stage->store_predicate.defined()) { - thread_head_check.emplace_back(stage->store_predicate); - } - - Stmt reduce_body = Evaluate::make(Call::make( - Handle(), - ir::intrinsic::tvm_thread_allreduce, - freduce_args, Call::Intrinsic)); - reduce_body = AttrStmt::make( - reduces[0]->combiner, - attr::reduce_scope, - make_zero(Handle()), - reduce_body); - std::vector<Stmt> assigns(size); - for (size_t idx = 0; idx < size; ++idx) { - Type t = reduces[idx]->type; - assigns[idx] = Provide::make( - stage->op, idx, - Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args); - } - Stmt assign_body = Block::make(assigns); - assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body); - assign_body = MergeNest(op::MakeIfNest(conds), assign_body); - Stmt body = Block::make(reduce_body, assign_body); - for (size_t idx = size; idx != 0; --idx) { - body = Allocate::make( - res_handles[idx - 1], reduces[idx - 1]->type, {1}, const_true(), body); - body = AttrStmt::make( - res_handles[idx - 1], attr::storage_scope, StringImm::make("local"), body); - } - body = Substitute(body, value_map); - return MergeNest(nest, body); -} - // Normal computation. Stmt MakeProvide(const ComputeOpNode* op, const Tensor& t) { @@ -370,27 +253,56 @@ Stmt MakeProvide(const ComputeOpNode* op, return Provide::make(t->op, t->value_index, op->body[t->value_index], args); } -// loop nest structure for general compute -// This the the loop nest structured used in compute. -// Does not include the loop body. -struct ComputeLoopNest { - // The common number of loops between init and main - size_t num_common_loop; - // predicates for the initialize loop - std::vector<Expr> init_predicates; - // Initialization nest involved. - std::vector<std::vector<Stmt> > init_nest; - // Value map for the init code - std::unordered_map<IterVar, Expr> init_vmap; - // Predicates for the main update loop - std::vector<Expr> main_predicates; - // The general loop nest - std::vector<std::vector<Stmt> > main_nest; - // Value map for the IterVar. - std::unordered_map<IterVar, Expr> main_vmap; -}; +Stmt MakeComputeStmt(const ComputeOpNode* self, + const Stage& stage, + const std::unordered_map<IterVar, Range>& dom_map) { + // grab the nest structure + ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map); + // Normal loop structure + n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); + n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates)); + if (self->reduce_axis.size() != 0) { + // make reduction. + Stmt init, provide; + Array<Tensor> source; + for (size_t i = 0; i < self->body.size(); ++i) { + source.push_back(stage->op.output(i)); + } + MakeReduction(self, source, &init, &provide); + init = op::Substitute(init, n.init_vmap); + init = MergeNest(n.init_nest, init); + // common nest + std::vector<std::vector<Stmt> > common( + n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); + std::vector<std::vector<Stmt> > reduce( + n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end()); + provide = op::Substitute(provide, n.main_vmap); + provide = MergeNest(reduce, provide); + return MergeNest(common, Block::make(init, provide)); + } else { + std::vector<Stmt> provides; + for (size_t i = 0; i < self->body.size(); ++i) { + provides.emplace_back(MakeProvide(self, stage->op.output(i))); + } + Stmt provide = op::Substitute(Block::make(provides), n.main_vmap); + return MergeNest(n.main_nest, provide); + } +} -ComputeLoopNest MakeComputeLoopNest( +// implement the provide utility. +Stmt ComputeOpNode::BuildProvide( + const Stage& stage, + const std::unordered_map<IterVar, Range>& dom_map) const { + CHECK_EQ(stage->op.operator->(), this); + if (IsCrossThreadReduction(this, stage)) { + // specially handle cross thread reduction. + return MakeCrossThreadReduction(this, stage, dom_map); + } else { + return MakeComputeStmt(this, stage, dom_map); + } +} + +ComputeLoopNest ComputeLoopNest::make( const ComputeOpNode* self, const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map) { @@ -446,51 +358,10 @@ ComputeLoopNest MakeComputeLoopNest( e = likely(e); } } else { - ret.num_common_loop = ret.main_nest.size() - 1; + CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1); + ret.num_common_loop = stage->leaf_iter_vars.size(); } // copy elison here. return ret; } - -// implement the provide utility. -Stmt ComputeOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map) const { - CHECK_EQ(stage->op.operator->(), this); - if (IsCrossThreadReduction(this, stage)) { - // specially handle cross thread reduction. - return MakeCrossThreadReduction(this, stage, dom_map); - } - // grab the nest structure - ComputeLoopNest n = MakeComputeLoopNest(this, stage, dom_map); - // Normal loop structure - n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); - n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates)); - if (this->reduce_axis.size() != 0) { - // make reduction. - Stmt init, provide; - Array<Tensor> source; - for (size_t i = 0; i < this->body.size(); ++i) { - source.push_back(stage->op.output(i)); - } - MakeReduction(this, source, &init, &provide); - init = Substitute(init, n.init_vmap); - init = MergeNest(n.init_nest, init); - // common nest - std::vector<std::vector<Stmt> > common( - n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); - std::vector<std::vector<Stmt> > reduce( - n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end()); - provide = Substitute(provide, n.main_vmap); - provide = MergeNest(reduce, provide); - return MergeNest(common, Block::make(init, provide)); - } else { - std::vector<Stmt> provides; - for (size_t i = 0; i < this->body.size(); ++i) { - provides.emplace_back(MakeProvide(this, stage->op.output(i))); - } - Stmt provide = Substitute(Block::make(provides), n.main_vmap); - return MergeNest(n.main_nest, provide); - } -} } // namespace tvm diff --git a/src/op/compute_op.h b/src/op/compute_op.h new file mode 100644 index 000000000..79b1954a7 --- /dev/null +++ b/src/op/compute_op.h @@ -0,0 +1,68 @@ +/*! + * Copyright (c) 2017 by Contributors + * \brief Helper utilities to implement compute_op. + * \file compute_op.h + */ +#ifndef TVM_OP_COMPUTE_OP_H_ +#define TVM_OP_COMPUTE_OP_H_ + +#include <tvm/ir.h> +#include <tvm/expr.h> +#include <tvm/operation.h> +#include <vector> +#include <unordered_map> + +namespace tvm { +// loop nest structure for general compute +// This the the loop nest structured used in compute. +// Does not include the loop body. +struct ComputeLoopNest { + // The common number of loops between init and main + size_t num_common_loop; + // predicates for the initialize loop + std::vector<Expr> init_predicates; + // Initialization nest involved. + std::vector<std::vector<Stmt> > init_nest; + // Value map for the init code + std::unordered_map<IterVar, Expr> init_vmap; + // Predicates for the main update loop + std::vector<Expr> main_predicates; + // The general loop nest + std::vector<std::vector<Stmt> > main_nest; + // Value map for the IterVar. + std::unordered_map<IterVar, Expr> main_vmap; + + /*! + * \brief constructor to build ComputeOpNest + * \param self The pointer to compute op. + * \param stage The scxhedule stage. + * \param dom_map The domain map. + * \return The constructed loop nest + */ + static ComputeLoopNest make( + const ComputeOpNode* self, + const Stage& stage, + const std::unordered_map<IterVar, Range>& dom_map); +}; + +/*! + * \brief Whether compute op is a cross thread reduction structure. + * \param self The pointer to ComputeOpNode + * \param stage the schedule stage. + */ +bool IsCrossThreadReduction(const ComputeOpNode* self, + const Stage& stage); +/*! + * \brief Build body of compute for cross thread reduction pattern. + * \param self The pointer to ComputeOpNode + * \param stage The schedule stage. + * \param dom_map The domain map. + * \return The created statement. + */ +Stmt MakeCrossThreadReduction( + const ComputeOpNode* self, + const Stage& stage, + const std::unordered_map<IterVar, Range>& dom_map); +} // namespace tvm + +#endif // TVM_OP_COMPUTE_OP_H_ diff --git a/src/op/cross_thread_reduction.cc b/src/op/cross_thread_reduction.cc new file mode 100644 index 000000000..2a8091414 --- /dev/null +++ b/src/op/cross_thread_reduction.cc @@ -0,0 +1,120 @@ +/*! + * Copyright (c) 2017 by Contributors + * \brief Logics related to cross thread reduction, used by ComputeOpNode. + * \file cross_thread_reduction.cc + */ +#include <tvm/ir_pass.h> +#include "./compute_op.h" +#include "./op_util.h" + +namespace tvm { +using namespace ir; + +bool IsCrossThreadReduction(const ComputeOpNode* self, + const Stage& stage) { + // Verify correctness of leaf nest. + int normal_red = 0, thread_red = 0; + for (IterVar iv : stage->leaf_iter_vars) { + if (iv->iter_type == kCommReduce) { + auto it = stage->iter_var_attrs.find(iv); + if (it != stage->iter_var_attrs.end() && + (*it).second->bind_thread.defined()) { + ++thread_red; + } else { + ++normal_red; + } + } else { + CHECK_EQ(thread_red, 0) + << "Cross thread reduce cannot swap with normal data axis"; + } + } + CHECK(normal_red == 0 || thread_red == 0) + << "Cannot mix normal reduction with thread reduce"; + return thread_red != 0; +} + +Stmt MakeCrossThreadReduction( + const ComputeOpNode* self, + const Stage& stage, + const std::unordered_map<IterVar, Range>& dom_map) { + Array<Expr> args; + for (IterVar iv : self->axis) { + args.push_back(iv->var); + } + std::unordered_map<IterVar, Expr> value_map; + auto nest = op::MakeLoopNest( + stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map); + auto conds = op::MakeBoundCheck( + stage, dom_map, false, + std::unordered_set<IterVar>(), value_map); + + size_t size = self->body.size(); + CHECK_GT(size, 0); + std::vector<const Reduce*> reduces(size); + for (size_t i = 0; i < size; ++i) { + const Reduce* reduce = self->body[i].as<Reduce>(); + CHECK(reduce); + reduces[i] = reduce; + } + Expr cond = reduces[0]->condition; + for (Expr v : conds) { + cond = cond && v; + } + Array<Expr> freduce_args; + freduce_args.push_back(make_const(UInt(32), static_cast<uint32_t>(size))); + for (size_t i = 0; i < size; ++i) { + freduce_args.push_back(reduces[0]->source[i]); + } + freduce_args.push_back(cond); + std::vector<Var> res_handles(size); + for (size_t idx = 0; idx < size; ++idx) { + res_handles[idx] = Var("reduce_temp" + std::to_string(idx), Handle()); + freduce_args.push_back(res_handles[idx]); + } + + for (IterVar iv : stage->leaf_iter_vars) { + if (iv->iter_type == kCommReduce) { + auto it = stage->iter_var_attrs.find(iv); + if (it != stage->iter_var_attrs.end() && + (*it).second->bind_thread.defined()) { + IterVar tv = (*it).second->bind_thread; + freduce_args.push_back(tv->var); + } + } + } + // Checks for the thread. + std::vector<Expr> thread_head_check; + if (stage->store_predicate.defined()) { + thread_head_check.emplace_back(stage->store_predicate); + } + + Stmt reduce_body = Evaluate::make(Call::make( + Handle(), + ir::intrinsic::tvm_thread_allreduce, + freduce_args, Call::Intrinsic)); + reduce_body = AttrStmt::make( + reduces[0]->combiner, + attr::reduce_scope, + make_zero(Handle()), + reduce_body); + std::vector<Stmt> assigns(size); + for (size_t idx = 0; idx < size; ++idx) { + Type t = reduces[idx]->type; + assigns[idx] = Provide::make( + stage->op, idx, + Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args); + } + Stmt assign_body = Block::make(assigns); + assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body); + assign_body = MergeNest(op::MakeIfNest(conds), assign_body); + Stmt body = Block::make(reduce_body, assign_body); + for (size_t idx = size; idx != 0; --idx) { + body = Allocate::make( + res_handles[idx - 1], reduces[idx - 1]->type, {1}, const_true(), body); + body = AttrStmt::make( + res_handles[idx - 1], attr::storage_scope, StringImm::make("local"), body); + } + body = op::Substitute(body, value_map); + return MergeNest(nest, body); +} +} // namespace tvm diff --git a/src/op/op_util.cc b/src/op/op_util.cc index 628c714df..fe597a0cc 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -223,7 +223,6 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) { } - // replacer to replace tensors class TensorReplacer : public ir::IRMutator { public: @@ -263,5 +262,16 @@ Expr ReplaceTensor(Expr expr, Expr ret = repl.Mutate(expr); return repl.found ? ret : expr; } + + +Stmt Substitute(Stmt s, + const std::unordered_map<IterVar, Expr>& value_map) { + std::unordered_map<const Variable*, Expr> init; + for (const auto& kv : value_map) { + init[kv.first->var.get()] = kv.second; + } + return ir::Substitute(s, init); +} + } // namespace op } // namespace tvm diff --git a/src/op/op_util.h b/src/op/op_util.h index 914815f9a..419035b67 100644 --- a/src/op/op_util.h +++ b/src/op/op_util.h @@ -12,6 +12,7 @@ #include <unordered_set> #include <vector> #include "../pass/ir_util.h" +#include "../pass/arg_binder.h" namespace tvm { namespace op { @@ -74,6 +75,15 @@ Stmt ReplaceTensor(Stmt stmt, Expr ReplaceTensor(Expr expr, const std::unordered_map<Tensor, Tensor>& replace); +/*! + * \brief Substitute the variables of stmt by value map. + * \param stmt the statment + * \param value_map The value map. + * \return Substituted result. + */ +Stmt Substitute(Stmt stmt, + const std::unordered_map<IterVar, Expr>& value_map); + } // namespace op } // namespace tvm #endif // TVM_OP_OP_UTIL_H_ diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index 4ac7998d6..69e376260 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -75,13 +75,38 @@ void ArgBinder::BindArray(const Array<Expr>& arg, void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, - const std::string& arg_name) { + const std::string& arg_name, + bool fuzzy_match) { CHECK_EQ(arg->scope, value->scope) << "Argument " << arg_name << " Buffer bind scope mismatch"; this->Bind(arg->data, value->data, arg_name + ".data"); - this->BindArray(arg->shape, value->shape, arg_name + ".shape"); - this->BindArray(arg->strides, value->strides, arg_name + ".strides"); + if (arg->shape.size() > value->shape.size()) { + CHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch"; + size_t diff = arg->shape.size() - value->shape.size(); + for (size_t i = 0; i < diff; ++i) { + CHECK(is_one(arg->shape[i])) + << "Argument " << arg_name << " shape mismatch" + << arg->shape << " vs " << value->shape; + } + for (size_t i = 0; i < value->shape.size(); ++i) { + std::ostringstream os; + os << arg_name << ".shape[" << i << "]"; + this->Bind(arg->shape[i + diff], value->shape[i], os.str()); + } + if (arg->strides.size() != 0) { + CHECK_EQ(arg->strides.size(), arg->shape.size()); + CHECK_EQ(value->strides.size(), value->shape.size()); + for (size_t i = 0; i < value->strides.size(); ++i) { + std::ostringstream os; + os << arg_name << ".strides[" << i << "]"; + this->Bind(arg->strides[i + diff], value->strides[i], os.str()); + } + } + } else { + this->BindArray(arg->shape, value->shape, arg_name + ".shape"); + this->BindArray(arg->strides, value->strides, arg_name + ".strides"); + } this->Bind(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset"); } diff --git a/src/pass/arg_binder.h b/src/pass/arg_binder.h index 59e4eab55..6d6e6e7ca 100644 --- a/src/pass/arg_binder.h +++ b/src/pass/arg_binder.h @@ -71,10 +71,12 @@ class ArgBinder { * \param arg The argument to be binded. * \param value The target expression value * \param arg_name argument name. + * \param fuzzy_match If enabled, we allow value's dimension to be smaller than arg, as long as arg's higher dimensions are of 1. */ void BindBuffer(const Buffer& arg, const Buffer& value, - const std::string& arg_name); + const std::string& arg_name, + bool fuzzy_match); /*! * \brief Bind symbolic buffer to a DLTensor handle. * \param buffer The argument buffer to be binded. diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc new file mode 100644 index 000000000..48656a41f --- /dev/null +++ b/src/pass/ir_deep_compare.cc @@ -0,0 +1,417 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file ir_deep_compare.cc + */ +#include <tvm/ir_pass.h> +#include <tvm/ir_functor_ext.h> + +namespace tvm { +namespace ir { + +using ExprComparator = ExprFunctor<void(const Expr& n, const Expr &other)>; +using StmtComparator = StmtFunctor<void(const Stmt& n, const Stmt &other)>; + +#define DEFINE_BIOP_EXPR_CMP_(OP) \ + void VisitExpr_(const OP* op, const Expr& other) final { \ + const OP* rhs = other.as<OP>(); \ + if (CompareExpr(op->a, rhs->a) != 0) return; \ + if (CompareExpr(op->b, rhs->b) != 0) return; \ + } + +// Deep comparison to check if two IR graph are equivalent +class IRDeepCompare : + public ExprComparator, public StmtComparator { + public: + // Equality comparison + bool Equal(const Stmt& lhs, const Stmt& rhs) { + tie_def_ = true; + VisitStmt(lhs, rhs); + return order_ == 0; + } + + bool Equal(const Expr& lhs, const Expr& rhs) { + tie_def_ = true; + VisitExpr(lhs, rhs); + return order_ == 0; + } + + void VisitExpr(const Expr& n, const Expr& other) override { + if (order_ != 0) return; + if (CompareValue(n->type_index(), other->type_index()) != 0) return; + if (CompareType(n.type(), other.type()) != 0) return; + ExprComparator::VisitExpr(n, other); + } + + void VisitStmt(const Stmt& n, const Stmt& other) override { + if (order_ != 0) return; + if (CompareValue(n->type_index(), other->type_index()) != 0) return; + StmtComparator::VisitStmt(n, other); + } + // Stmt + void VisitStmt_(const LetStmt* op, const Stmt& other) final { + const LetStmt* rhs = other.as<LetStmt>(); + if (CompareExpr(op->value, rhs->value) != 0) return; + if (tie_def_) { + vmap_[op->var.get()] = rhs->var.get(); + } else { + if (CompareExpr(op->var, rhs->var) != 0) return; + } + if (CompareStmt(op->body, rhs->body) != 0) return; + } + + void VisitStmt_(const AttrStmt* op, const Stmt& other) final { + const AttrStmt* rhs = other.as<AttrStmt>(); + if (CompareString(op->attr_key, rhs->attr_key) != 0) return; + if (CompareNodeRef(op->node, rhs->node) != 0) return; + if (CompareExpr(op->value, rhs->value) != 0) return; + if (CompareStmt(op->body, rhs->body) != 0) return; + } + + void VisitStmt_(const IfThenElse* op, const Stmt& other) final { + const IfThenElse* rhs = other.as<IfThenElse>(); + if (CompareExpr(op->condition, rhs->condition) != 0) return; + if (CompareStmt(op->then_case, rhs->then_case) != 0) return; + if (CompareStmt(op->else_case, rhs->else_case) != 0) return; + } + + void VisitStmt_(const For* op, const Stmt& other) final { + const For* rhs = other.as<For>(); + if (CompareExpr(op->min, rhs->min) != 0) return; + if (CompareExpr(op->extent, rhs->extent) != 0) return; + if (tie_def_) { + vmap_[op->loop_var.get()] = rhs->loop_var.get(); + } else { + if (CompareExpr(op->loop_var, rhs->loop_var) != 0) return; + } + if (CompareStmt(op->body, rhs->body) != 0) return; + } + + void VisitStmt_(const Allocate* op, const Stmt& other) final { + const Allocate* rhs = other.as<Allocate>(); + if (tie_def_) { + vmap_[op->buffer_var.get()] = rhs->buffer_var.get(); + } else { + if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return; + } + if (CompareType(op->type, rhs->type) != 0) return; + if (CompareArray(op->extents, rhs->extents) != 0) return; + if (CompareExpr(op->condition, rhs->condition) != 0) return; + if (CompareStmt(op->body, rhs->body) != 0) return; + if (CompareExpr(op->new_expr, rhs->new_expr) != 0) return; + if (CompareString(op->free_function, rhs->free_function) != 0) return; + } + + void VisitStmt_(const Store* op, const Stmt& other) final { + const Store* rhs = other.as<Store>(); + if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return; + if (CompareExpr(op->value, rhs->value) != 0) return; + if (CompareExpr(op->index, rhs->index) != 0) return; + if (CompareExpr(op->predicate, rhs->predicate) != 0) return; + } + + void VisitStmt_(const Free* op, const Stmt& other) final { + const Free* rhs = other.as<Free>(); + if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return; + } + + void VisitStmt_(const AssertStmt* op, const Stmt& other) final { + const AssertStmt* rhs = other.as<AssertStmt>(); + if (CompareExpr(op->condition, rhs->condition) != 0) return; + if (CompareExpr(op->message, rhs->message) != 0) return; + } + + void VisitStmt_(const ProducerConsumer* op, const Stmt& other) final { + const ProducerConsumer* rhs = other.as<ProducerConsumer>(); + if (CompareNodeRef(op->func, rhs->func) != 0) return; + if (CompareValue(op->is_producer, rhs->is_producer) != 0) return; + if (CompareStmt(op->body, rhs->body) != 0) return; + } + + + void VisitStmt_(const Provide* op, const Stmt& other) final { + const Provide* rhs = other.as<Provide>(); + if (CompareNodeRef(op->func, rhs->func) != 0) return; + if (CompareValue(op->value_index, rhs->value_index) != 0) return; + if (CompareExpr(op->value, rhs->value) != 0) return; + if (CompareArray(op->args, rhs->args) != 0) return; + } + + void VisitStmt_(const Realize* op, const Stmt& other) final { + const Realize* rhs = other.as<Realize>(); + if (CompareNodeRef(op->func, rhs->func) != 0) return; + if (CompareValue(op->value_index, rhs->value_index) != 0) return; + if (CompareType(op->type, rhs->type) != 0) return; + if (CompareRegion(op->bounds, rhs->bounds) != 0) return; + if (CompareStmt(op->body, rhs->body) != 0) return; + } + + void VisitStmt_(const Prefetch* op, const Stmt& other) final { + const Prefetch* rhs = other.as<Prefetch>(); + if (CompareNodeRef(op->func, rhs->func) != 0) return; + if (CompareValue(op->value_index, rhs->value_index) != 0) return; + if (CompareType(op->type, rhs->type) != 0) return; + if (CompareRegion(op->bounds, rhs->bounds) != 0) return; + } + + void VisitStmt_(const Block* op, const Stmt& other) final { + const Block* rhs = other.as<Block>(); + if (CompareStmt(op->first, rhs->first) != 0) return; + if (CompareStmt(op->rest, rhs->rest) != 0) return; + } + + void VisitStmt_(const Evaluate* op, const Stmt& other) final { + const Evaluate* rhs = other.as<Evaluate>(); + CompareExpr(op->value, rhs->value); + } + + // Exprs + void VisitExpr_(const Variable* op, const Expr& other) final { + const Variable* rhs = other.as<Variable>(); + auto it = vmap_.find(op); + if (it != vmap_.end()) op = it->second; + if (op < rhs) { + order_ = -1; + } else if (op > rhs) { + order_ = +1; + } + } + void VisitExpr_(const Load* op, const Expr& other) final { + const Load* rhs = other.as<Load>(); + if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return; + if (CompareExpr(op->index, rhs->index) != 0) return; + if (CompareExpr(op->predicate, rhs->predicate) != 0) return; + } + + void VisitExpr_(const Let* op, const Expr& other) final { + const Let* rhs = other.as<Let>(); + if (tie_def_) { + vmap_[op->var.get()] = rhs->var.get(); + } else { + if (CompareExpr(op->var, rhs->var) != 0) return; + } + if (CompareExpr(op->value, rhs->value) != 0) return; + if (CompareExpr(op->body, rhs->body) != 0) return; + } + + void VisitExpr_(const Call* op, const Expr& other) final { + const Call* rhs = other.as<Call>(); + if (CompareString(op->name, rhs->name)) return; + if (CompareArray(op->args, rhs->args)) return; + if (CompareValue(op->call_type, rhs->call_type) != 0) return; + if (CompareNodeRef(op->func, rhs->func) != 0) return; + if (CompareValue(op->value_index, rhs->value_index) != 0) return; + } + + void VisitExpr_(const Reduce *op, const Expr& other) final { + const Reduce* rhs = other.as<Reduce>(); + if (CompareCommReducer(op->combiner, rhs->combiner) != 0) return; + if (CompareValue(op->axis.size(), rhs->axis.size()) != 0) return; + if (CompareValue(op->value_index, rhs->value_index) != 0) return; + for (size_t i = 0; i < op->axis.size(); ++i) { + if (CompareExpr(op->axis[i]->dom->min, rhs->axis[i]->dom->min) != 0) return; + if (CompareExpr(op->axis[i]->dom->extent, rhs->axis[i]->dom->extent) != 0) return; + if (tie_def_) { + vmap_[op->axis[i]->var.get()] = rhs->axis[i]->var.get(); + } else { + if (CompareExpr(op->axis[i]->var, rhs->axis[i]->var) != 0) return; + } + } + if (CompareExpr(op->condition, rhs->condition) != 0) return; + if (CompareArray(op->source, rhs->source) != 0) return; + } + + void VisitExpr_(const IntImm *op, const Expr& other) final { + CompareValue(op->value, other.as<IntImm>()->value); + } + + void VisitExpr_(const UIntImm *op, const Expr& other) final { + CompareValue(op->value, other.as<UIntImm>()->value); + } + + void VisitExpr_(const FloatImm *op, const Expr& other) final { + CompareValue(op->value, other.as<FloatImm>()->value); + } + + void VisitExpr_(const StringImm *op, const Expr& other) final { + CompareString(op->value, other.as<StringImm>()->value); + } + + void VisitExpr_(const Cast *op, const Expr& other) final { + CompareExpr(op->value, other.as<Cast>()->value); + } + + void VisitExpr_(const Not *op, const Expr& other) final { + CompareExpr(op->a, other.as<Not>()->a); + } + + void VisitExpr_(const Select *op, const Expr& other) final { + const Select* rhs = other.as<Select>(); + if (CompareExpr(op->condition, rhs->condition) != 0) return; + if (CompareExpr(op->true_value, rhs->true_value) != 0) return; + if (CompareExpr(op->false_value, rhs->false_value) != 0) return; + } + + void VisitExpr_(const Ramp *op, const Expr& other) final { + const Ramp* rhs = other.as<Ramp>(); + if (CompareExpr(op->base, rhs->base) != 0) return; + if (CompareExpr(op->stride, rhs->stride) != 0) return; + if (CompareValue(op->lanes, rhs->lanes) != 0) return; + } + + void VisitExpr_(const Broadcast *op, const Expr& other) final { + const Broadcast* rhs = other.as<Broadcast>(); + if (CompareExpr(op->value, rhs->value) != 0) return; + if (CompareValue(op->lanes, rhs->lanes) != 0) return; + } + + void VisitExpr_(const Shuffle *op, const Expr& other) final { + const Shuffle* rhs = other.as<Shuffle>(); + if (CompareArray(op->vectors, rhs->vectors) != 0) return; + if (CompareArray(op->indices, rhs->indices) != 0) return; + } + + DEFINE_BIOP_EXPR_CMP_(Add) + DEFINE_BIOP_EXPR_CMP_(Sub) + DEFINE_BIOP_EXPR_CMP_(Mul) + DEFINE_BIOP_EXPR_CMP_(Div) + DEFINE_BIOP_EXPR_CMP_(Mod) + DEFINE_BIOP_EXPR_CMP_(Min) + DEFINE_BIOP_EXPR_CMP_(Max) + DEFINE_BIOP_EXPR_CMP_(EQ) + DEFINE_BIOP_EXPR_CMP_(NE) + DEFINE_BIOP_EXPR_CMP_(LT) + DEFINE_BIOP_EXPR_CMP_(LE) + DEFINE_BIOP_EXPR_CMP_(GT) + DEFINE_BIOP_EXPR_CMP_(GE) + DEFINE_BIOP_EXPR_CMP_(And) + DEFINE_BIOP_EXPR_CMP_(Or) + + private: + int CompareExpr(const Expr& lhs, const Expr& rhs) { + if (order_ != 0) return order_; + if (!lhs.defined() && rhs.defined()) { + order_ = -1; return order_; + } + if (!rhs.defined() && lhs.defined()) { + order_ = +1; return order_; + } + VisitExpr(lhs, rhs); + return order_; + } + + int CompareStmt(const Stmt& lhs, const Stmt& rhs) { + if (order_ != 0) return order_; + if (!lhs.defined() && rhs.defined()) { + order_ = -1; return order_; + } + if (!rhs.defined() && lhs.defined()) { + order_ = +1; return order_; + } + VisitStmt(lhs, rhs); + return order_; + } + + int CompareArray(const Array<Expr>& lhs, const Array<Expr>& rhs) { + if (order_ != 0) return order_; + if (CompareValue(lhs.size(), rhs.size()) != 0) return order_; + for (size_t i = 0; i < lhs.size(); ++i) { + if (CompareExpr(lhs[i], rhs[i]) != 0) return order_; + } + return order_; + } + + int CompareRegion(const Halide::Internal::Region& lhs, + const Halide::Internal::Region& rhs) { + if (order_ != 0) return order_; + if (CompareValue(lhs.size(), rhs.size()) != 0) return order_; + for (size_t i = 0; i < lhs.size(); ++i) { + if (CompareExpr(lhs[i]->min, rhs[i]->min) != 0) return order_; + if (CompareExpr(lhs[i]->extent, rhs[i]->extent) != 0) return order_; + } + return order_; + } + + int CompareNodeRef(const NodeRef& lhs, const NodeRef& rhs) { + if (order_ != 0) return order_; + if (lhs.get() < rhs.get()) { + order_ = -1; return order_; + } + if (lhs.get() > rhs.get()) { + order_ = +1; return order_; + } + return order_; + } + + int CompareType(const Type& lhs, const Type& rhs) { + if (order_ != 0) return order_; + if (lhs == rhs) return order_; + if (CompareValue(lhs.code(), rhs.code()) != 0) return order_; + if (CompareValue(lhs.bits(), rhs.bits()) != 0) return order_; + if (CompareValue(lhs.lanes(), rhs.lanes()) != 0) return order_; + return order_; + } + + int CompareString(const std::string& lhs, const std::string& rhs) { + if (order_ != 0) return order_; + order_ = lhs.compare(rhs); + return order_; + } + + template<typename T> + int CompareValue(const T& lhs, const T& rhs) { + if (order_ != 0) return order_; + if (lhs < rhs) { + order_ = -1; return order_; + } else if (lhs > rhs) { + order_ = +1; return order_; + } + return order_; + } + + int CompareCommReducer(const CommReducer& lhs, const CommReducer& rhs) { + if (order_ != 0) return order_; + if (lhs == rhs) return order_; + if (CompareValue(lhs->lhs.size(), rhs->lhs.size()) != 0) return order_; + if (CompareValue(lhs->rhs.size(), rhs->rhs.size()) != 0) return order_; + IRDeepCompare cmp; + if (tie_def_) { + for (size_t i = 0; i < lhs->lhs.size(); ++i) { + cmp.vmap_[lhs->lhs[i].get()] = rhs->lhs[i].get(); + } + for (size_t i = 0; i < lhs->rhs.size(); ++i) { + cmp.vmap_[lhs->rhs[i].get()] = rhs->rhs[i].get(); + } + } else { + for (size_t i = 0; i < lhs->lhs.size(); ++i) { + if (CompareExpr(lhs->lhs[i], rhs->lhs[i]) != 0) return order_; + } + for (size_t i = 0; i < lhs->lhs.size(); ++i) { + if (CompareExpr(lhs->rhs[i], rhs->rhs[i]) != 0) return order_; + } + } + order_ = cmp.CompareArray(lhs->result, rhs->result); + return order_; + } + // The order flag, smaller, -1, bigger: +1, equal: 0 + int order_{0}; + // Whether tie intermediate definitions. + // This allows use to tie definitions of two variables together. + // This enables us to assert equal between (let x in x + 1), (let y in y + 1) + // However, the comparison is no longer in total order. + // Only equality/non-equality information is valid. + bool tie_def_{false}; + // varaible remap if any + std::unordered_map<const Variable*, const Variable*> vmap_; +}; + + +bool Equal(const Stmt& lhs, const Stmt& rhs) { + return IRDeepCompare().Equal(lhs, rhs); +} + +bool Equal(const Expr& lhs, const Expr& rhs) { + return IRDeepCompare().Equal(lhs, rhs); +} + +} // namespace ir +} // namespace tvm diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index a33e03496..d7ac0ac6a 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -195,7 +195,7 @@ class StorageFlattener : public IRMutator { } // start binding ArgBinder binder(&var_remap_); - binder.BindBuffer(Buffer(arr[0].node_), slice, buffer->name); + binder.BindBuffer(Buffer(arr[0].node_), slice, buffer->name, true); // Apply the remaps Stmt body = MergeNest(binder.asserts(), op->body); body = MergeNest(binder.init_nest(), body); diff --git a/tests/python/unittest/test_pass_equal.py b/tests/python/unittest/test_pass_equal.py new file mode 100644 index 000000000..1c13b82ea --- /dev/null +++ b/tests/python/unittest/test_pass_equal.py @@ -0,0 +1,48 @@ +import tvm + +def test_equal_expr(): + x = tvm.var('x') + y = tvm.var('y') + + def func1(): + return x + y + 1 + + def func2(): + return tvm.exp((x + y + 1) * y / 4) + + assert tvm.ir_pass.Equal(func1(), func1()) + assert tvm.ir_pass.Equal(func2(), func2()) + assert not tvm.ir_pass.Equal(func2(), func1()) + + +def test_equal_compute(): + x = tvm.var('x') + y = tvm.var('y') + n = 128 + A = tvm.placeholder((n, n), name='A') + B = tvm.placeholder((n, n), name='B') + ii = tvm.var('i') + jj = tvm.var('j') + + def func1(): + k = tvm.reduce_axis((0, n), name='k') + return tvm.sum(A[ii, k] * B[jj, k], axis=k) + + Ab = tvm.decl_buffer((n,), name='A') + n = tvm.var("n") + def func2(): + ib = tvm.ir_builder.create() + A = ib.buffer_ptr(Ab) + with ib.for_range(0, n, name="i") as i: + A[i] = A[i] + 1 + with ib.for_range(0, 10, name="j") as j: + A[j] = A[j] + 2 + return ib.get() + + assert tvm.ir_pass.Equal(func1(), func1()) + assert tvm.ir_pass.Equal(func2(), func2()) + + +if __name__ == "__main__": + test_equal_expr() + test_equal_compute() -- GitLab