From 8a66ac230f4c92ff3b5b6fcb8dec4aa7ec8e6eb4 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Tue, 4 Jul 2017 21:53:15 -0700
Subject: [PATCH] [PASS/OP/REFACTOR] IRDeepCompare, isolate computeop part,
 allow fuzzy bind (#218)

---
 include/tvm/ir_functor_ext.h             |   1 +
 include/tvm/ir_pass.h                    |  24 +-
 src/op/compute_op.cc                     | 233 +++----------
 src/op/compute_op.h                      |  68 ++++
 src/op/cross_thread_reduction.cc         | 120 +++++++
 src/op/op_util.cc                        |  12 +-
 src/op/op_util.h                         |  10 +
 src/pass/arg_binder.cc                   |  31 +-
 src/pass/arg_binder.h                    |   4 +-
 src/pass/ir_deep_compare.cc              | 417 +++++++++++++++++++++++
 src/pass/storage_flatten.cc              |   2 +-
 tests/python/unittest/test_pass_equal.py |  48 +++
 12 files changed, 775 insertions(+), 195 deletions(-)
 create mode 100644 src/op/compute_op.h
 create mode 100644 src/op/cross_thread_reduction.cc
 create mode 100644 src/pass/ir_deep_compare.cc
 create mode 100644 tests/python/unittest/test_pass_equal.py

diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h
index 6feb75566..55368fbea 100644
--- a/include/tvm/ir_functor_ext.h
+++ b/include/tvm/ir_functor_ext.h
@@ -137,6 +137,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
   virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const Broadcast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const Shuffle* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const IntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h
index 714733a19..872fca353 100644
--- a/include/tvm/ir_pass.h
+++ b/include/tvm/ir_pass.h
@@ -23,14 +23,6 @@
 namespace tvm {
 namespace ir {
 
-inline bool Equal(Expr a, Expr b) {
-  return Halide::Internal::equal(a, b);
-}
-
-inline bool Equal(Stmt a, Stmt b) {
-  return Halide::Internal::equal(a, b);
-}
-
 inline Expr Simplify(Expr a) {
   return Halide::Internal::simplify(a);
 }
@@ -39,6 +31,22 @@ inline Stmt Simplify(Stmt a) {
   return Halide::Internal::simplify(a);
 }
 
+/*!
+ * \brief Deep compare lhs and rhs
+ * \param lhs The left operand
+ * \param rhs The right operand
+ * \return The comparison result.
+ */
+bool Equal(const Expr& lhs, const Expr& rhs);
+
+/*!
+ * \brief Deep compare lhs and rhs
+ * \param lhs The left operand
+ * \param rhs The right operand
+ * \return The comparison result.
+ */
+bool Equal(const Stmt& lhs, const Stmt& rhs);
+
 /*!
  * \brief verifies whether the IR stmt or Expr is in SSA form.
  *  That is: each VarExpr is defined and assigned once(in Let/For)
diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc
index a4bb99e1b..abb83b6ec 100644
--- a/src/op/compute_op.cc
+++ b/src/op/compute_op.cc
@@ -9,6 +9,7 @@
 #include <tvm/ir_visitor.h>
 #include <tvm/ir_pass.h>
 #include <unordered_set>
+#include "./compute_op.h"
 #include "./op_util.h"
 #include "../schedule/message_passing.h"
 
@@ -242,124 +243,6 @@ void MakeReduction(const ComputeOpNode* op,
   }
 }
 
-Stmt Substitute(Stmt s,
-                const std::unordered_map<IterVar, Expr>& value_map) {
-  Map<Var, Expr> temp;
-  for (const auto& kv : value_map) {
-    temp.Set(kv.first->var, kv.second);
-  }
-  return ir::Substitute(s, temp);
-}
-
-// Cross Thread reduction
-bool IsCrossThreadReduction(const ComputeOpNode* self,
-                            const Stage& stage) {
-  // Verify correctness of leaf nest.
-  int normal_red = 0, thread_red = 0;
-  for (IterVar iv : stage->leaf_iter_vars) {
-    if (iv->iter_type == kCommReduce) {
-      auto it = stage->iter_var_attrs.find(iv);
-      if (it != stage->iter_var_attrs.end() &&
-          (*it).second->bind_thread.defined()) {
-        ++thread_red;
-      } else {
-        ++normal_red;
-      }
-    } else {
-      CHECK_EQ(thread_red, 0)
-          << "Cross thread reduce cannot swap with normal data axis";
-    }
-  }
-  CHECK(normal_red == 0 || thread_red == 0)
-      << "Cannot mix normal reduction with thread reduce";
-  return thread_red != 0;
-}
-
-Stmt MakeCrossThreadReduction(
-    const ComputeOpNode* self,
-    const Stage& stage,
-    const std::unordered_map<IterVar, Range>& dom_map) {
-  Array<Expr>  args;
-  for (IterVar iv : self->axis) {
-    args.push_back(iv->var);
-  }
-  std::unordered_map<IterVar, Expr> value_map;
-  auto nest = op::MakeLoopNest(
-      stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
-  auto conds = op::MakeBoundCheck(
-      stage, dom_map, false,
-      std::unordered_set<IterVar>(), value_map);
-
-  size_t size = self->body.size();
-  CHECK_GT(size, 0);
-  std::vector<const Reduce*> reduces(size);
-  for (size_t i = 0; i < size; ++i) {
-    const Reduce* reduce = self->body[i].as<Reduce>();
-    CHECK(reduce);
-    reduces[i] = reduce;
-  }
-  Expr cond = reduces[0]->condition;
-  for (Expr v : conds) {
-    cond = cond && v;
-  }
-  Array<Expr> freduce_args;
-  freduce_args.push_back(make_const(UInt(32), static_cast<uint32_t>(size)));
-  for (size_t i = 0; i < size; ++i) {
-    freduce_args.push_back(reduces[0]->source[i]);
-  }
-  freduce_args.push_back(cond);
-  std::vector<Var> res_handles(size);
-  for (size_t idx = 0; idx < size; ++idx) {
-    res_handles[idx] = Var("reduce_temp" + std::to_string(idx), Handle());
-    freduce_args.push_back(res_handles[idx]);
-  }
-
-  for (IterVar iv : stage->leaf_iter_vars) {
-    if (iv->iter_type == kCommReduce) {
-      auto it = stage->iter_var_attrs.find(iv);
-      if (it != stage->iter_var_attrs.end() &&
-          (*it).second->bind_thread.defined()) {
-        IterVar tv = (*it).second->bind_thread;
-        freduce_args.push_back(tv->var);
-      }
-    }
-  }
-  // Checks for the thread.
-  std::vector<Expr> thread_head_check;
-  if (stage->store_predicate.defined()) {
-    thread_head_check.emplace_back(stage->store_predicate);
-  }
-
-  Stmt reduce_body = Evaluate::make(Call::make(
-      Handle(),
-      ir::intrinsic::tvm_thread_allreduce,
-      freduce_args, Call::Intrinsic));
-  reduce_body = AttrStmt::make(
-      reduces[0]->combiner,
-      attr::reduce_scope,
-      make_zero(Handle()),
-      reduce_body);
-  std::vector<Stmt> assigns(size);
-  for (size_t idx = 0; idx < size; ++idx) {
-    Type t = reduces[idx]->type;
-    assigns[idx] = Provide::make(
-      stage->op, idx,
-      Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
-  }
-  Stmt assign_body = Block::make(assigns);
-  assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
-  assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
-  Stmt body = Block::make(reduce_body, assign_body);
-  for (size_t idx = size; idx != 0; --idx) {
-    body = Allocate::make(
-      res_handles[idx - 1], reduces[idx - 1]->type, {1}, const_true(), body);
-    body = AttrStmt::make(
-      res_handles[idx - 1], attr::storage_scope, StringImm::make("local"), body);
-  }
-  body = Substitute(body, value_map);
-  return MergeNest(nest, body);
-}
-
 // Normal computation.
 Stmt MakeProvide(const ComputeOpNode* op,
                  const Tensor& t) {
@@ -370,27 +253,56 @@ Stmt MakeProvide(const ComputeOpNode* op,
   return Provide::make(t->op, t->value_index, op->body[t->value_index], args);
 }
 
-// loop nest structure for general compute
-// This the the loop nest structured used in compute.
-// Does not include the loop body.
-struct ComputeLoopNest {
-  // The common number of loops between init and main
-  size_t num_common_loop;
-  // predicates for the initialize loop
-  std::vector<Expr> init_predicates;
-  // Initialization nest involved.
-  std::vector<std::vector<Stmt> > init_nest;
-  // Value map for the init code
-  std::unordered_map<IterVar, Expr> init_vmap;
-  // Predicates for the main update loop
-  std::vector<Expr> main_predicates;
-  // The general loop nest
-  std::vector<std::vector<Stmt> > main_nest;
-  // Value map for the IterVar.
-  std::unordered_map<IterVar, Expr> main_vmap;
-};
+Stmt MakeComputeStmt(const ComputeOpNode* self,
+                     const Stage& stage,
+                     const std::unordered_map<IterVar, Range>& dom_map) {
+  // grab the nest structure
+  ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map);
+  // Normal loop structure
+  n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
+  n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
+  if (self->reduce_axis.size() != 0) {
+    // make reduction.
+    Stmt init, provide;
+    Array<Tensor> source;
+    for (size_t i = 0; i < self->body.size(); ++i) {
+      source.push_back(stage->op.output(i));
+    }
+    MakeReduction(self, source, &init, &provide);
+    init = op::Substitute(init, n.init_vmap);
+    init = MergeNest(n.init_nest, init);
+    // common nest
+    std::vector<std::vector<Stmt> > common(
+        n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
+    std::vector<std::vector<Stmt> > reduce(
+        n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
+    provide = op::Substitute(provide, n.main_vmap);
+    provide = MergeNest(reduce, provide);
+    return MergeNest(common, Block::make(init, provide));
+  } else {
+    std::vector<Stmt> provides;
+    for (size_t i = 0; i < self->body.size(); ++i) {
+      provides.emplace_back(MakeProvide(self, stage->op.output(i)));
+    }
+    Stmt provide = op::Substitute(Block::make(provides), n.main_vmap);
+    return MergeNest(n.main_nest, provide);
+  }
+}
 
-ComputeLoopNest MakeComputeLoopNest(
+// implement the provide utility.
+Stmt ComputeOpNode::BuildProvide(
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map) const {
+  CHECK_EQ(stage->op.operator->(), this);
+  if (IsCrossThreadReduction(this, stage)) {
+    // specially handle cross thread reduction.
+    return MakeCrossThreadReduction(this, stage, dom_map);
+  } else {
+    return MakeComputeStmt(this, stage, dom_map);
+  }
+}
+
+ComputeLoopNest ComputeLoopNest::make(
     const ComputeOpNode* self,
     const Stage& stage,
     const std::unordered_map<IterVar, Range>& dom_map) {
@@ -446,51 +358,10 @@ ComputeLoopNest MakeComputeLoopNest(
       e = likely(e);
     }
   } else {
-    ret.num_common_loop = ret.main_nest.size() - 1;
+    CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
+    ret.num_common_loop = stage->leaf_iter_vars.size();
   }
   // copy elison here.
   return ret;
 }
-
-// implement the provide utility.
-Stmt ComputeOpNode::BuildProvide(
-    const Stage& stage,
-    const std::unordered_map<IterVar, Range>& dom_map) const {
-  CHECK_EQ(stage->op.operator->(), this);
-  if (IsCrossThreadReduction(this, stage)) {
-    // specially handle cross thread reduction.
-    return MakeCrossThreadReduction(this, stage, dom_map);
-  }
-  // grab the nest structure
-  ComputeLoopNest n = MakeComputeLoopNest(this, stage, dom_map);
-  // Normal loop structure
-  n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
-  n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
-  if (this->reduce_axis.size() != 0) {
-    // make reduction.
-    Stmt init, provide;
-    Array<Tensor> source;
-    for (size_t i = 0; i < this->body.size(); ++i) {
-      source.push_back(stage->op.output(i));
-    }
-    MakeReduction(this, source, &init, &provide);
-    init = Substitute(init, n.init_vmap);
-    init = MergeNest(n.init_nest, init);
-    // common nest
-    std::vector<std::vector<Stmt> > common(
-        n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
-    std::vector<std::vector<Stmt> > reduce(
-        n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
-    provide = Substitute(provide, n.main_vmap);
-    provide = MergeNest(reduce, provide);
-    return MergeNest(common, Block::make(init, provide));
-  } else {
-    std::vector<Stmt> provides;
-    for (size_t i = 0; i < this->body.size(); ++i) {
-      provides.emplace_back(MakeProvide(this, stage->op.output(i)));
-    }
-    Stmt provide = Substitute(Block::make(provides), n.main_vmap);
-    return MergeNest(n.main_nest, provide);
-  }
-}
 }  // namespace tvm
diff --git a/src/op/compute_op.h b/src/op/compute_op.h
new file mode 100644
index 000000000..79b1954a7
--- /dev/null
+++ b/src/op/compute_op.h
@@ -0,0 +1,68 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \brief Helper utilities to implement compute_op.
+ * \file compute_op.h
+ */
+#ifndef TVM_OP_COMPUTE_OP_H_
+#define TVM_OP_COMPUTE_OP_H_
+
+#include <tvm/ir.h>
+#include <tvm/expr.h>
+#include <tvm/operation.h>
+#include <vector>
+#include <unordered_map>
+
+namespace tvm {
+// loop nest structure for general compute
+// This the the loop nest structured used in compute.
+// Does not include the loop body.
+struct ComputeLoopNest {
+  // The common number of loops between init and main
+  size_t num_common_loop;
+  // predicates for the initialize loop
+  std::vector<Expr> init_predicates;
+  // Initialization nest involved.
+  std::vector<std::vector<Stmt> > init_nest;
+  // Value map for the init code
+  std::unordered_map<IterVar, Expr> init_vmap;
+  // Predicates for the main update loop
+  std::vector<Expr> main_predicates;
+  // The general loop nest
+  std::vector<std::vector<Stmt> > main_nest;
+  // Value map for the IterVar.
+  std::unordered_map<IterVar, Expr> main_vmap;
+
+  /*!
+   * \brief constructor to build ComputeOpNest
+   * \param self The pointer to compute op.
+   * \param stage The scxhedule stage.
+   * \param dom_map The domain map.
+   * \return The constructed loop nest
+   */
+  static ComputeLoopNest make(
+      const ComputeOpNode* self,
+      const Stage& stage,
+      const std::unordered_map<IterVar, Range>& dom_map);
+};
+
+/*!
+ * \brief Whether compute op is a cross thread reduction structure.
+ * \param self The pointer to ComputeOpNode
+ * \param stage the schedule stage.
+ */
+bool IsCrossThreadReduction(const ComputeOpNode* self,
+                            const Stage& stage);
+/*!
+ * \brief Build body of compute for cross thread reduction pattern.
+ * \param self The pointer to ComputeOpNode
+ * \param stage The schedule stage.
+ * \param dom_map The domain map.
+ * \return The created statement.
+ */
+Stmt MakeCrossThreadReduction(
+    const ComputeOpNode* self,
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map);
+}  // namespace tvm
+
+#endif  // TVM_OP_COMPUTE_OP_H_
diff --git a/src/op/cross_thread_reduction.cc b/src/op/cross_thread_reduction.cc
new file mode 100644
index 000000000..2a8091414
--- /dev/null
+++ b/src/op/cross_thread_reduction.cc
@@ -0,0 +1,120 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \brief Logics related to cross thread reduction, used by ComputeOpNode.
+ * \file cross_thread_reduction.cc
+ */
+#include <tvm/ir_pass.h>
+#include "./compute_op.h"
+#include "./op_util.h"
+
+namespace tvm {
+using namespace ir;
+
+bool IsCrossThreadReduction(const ComputeOpNode* self,
+                            const Stage& stage) {
+  // Verify correctness of leaf nest.
+  int normal_red = 0, thread_red = 0;
+  for (IterVar iv : stage->leaf_iter_vars) {
+    if (iv->iter_type == kCommReduce) {
+      auto it = stage->iter_var_attrs.find(iv);
+      if (it != stage->iter_var_attrs.end() &&
+          (*it).second->bind_thread.defined()) {
+        ++thread_red;
+      } else {
+        ++normal_red;
+      }
+    } else {
+      CHECK_EQ(thread_red, 0)
+          << "Cross thread reduce cannot swap with normal data axis";
+    }
+  }
+  CHECK(normal_red == 0 || thread_red == 0)
+      << "Cannot mix normal reduction with thread reduce";
+  return thread_red != 0;
+}
+
+Stmt MakeCrossThreadReduction(
+    const ComputeOpNode* self,
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map) {
+  Array<Expr>  args;
+  for (IterVar iv : self->axis) {
+    args.push_back(iv->var);
+  }
+  std::unordered_map<IterVar, Expr> value_map;
+  auto nest = op::MakeLoopNest(
+      stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
+  auto conds = op::MakeBoundCheck(
+      stage, dom_map, false,
+      std::unordered_set<IterVar>(), value_map);
+
+  size_t size = self->body.size();
+  CHECK_GT(size, 0);
+  std::vector<const Reduce*> reduces(size);
+  for (size_t i = 0; i < size; ++i) {
+    const Reduce* reduce = self->body[i].as<Reduce>();
+    CHECK(reduce);
+    reduces[i] = reduce;
+  }
+  Expr cond = reduces[0]->condition;
+  for (Expr v : conds) {
+    cond = cond && v;
+  }
+  Array<Expr> freduce_args;
+  freduce_args.push_back(make_const(UInt(32), static_cast<uint32_t>(size)));
+  for (size_t i = 0; i < size; ++i) {
+    freduce_args.push_back(reduces[0]->source[i]);
+  }
+  freduce_args.push_back(cond);
+  std::vector<Var> res_handles(size);
+  for (size_t idx = 0; idx < size; ++idx) {
+    res_handles[idx] = Var("reduce_temp" + std::to_string(idx), Handle());
+    freduce_args.push_back(res_handles[idx]);
+  }
+
+  for (IterVar iv : stage->leaf_iter_vars) {
+    if (iv->iter_type == kCommReduce) {
+      auto it = stage->iter_var_attrs.find(iv);
+      if (it != stage->iter_var_attrs.end() &&
+          (*it).second->bind_thread.defined()) {
+        IterVar tv = (*it).second->bind_thread;
+        freduce_args.push_back(tv->var);
+      }
+    }
+  }
+  // Checks for the thread.
+  std::vector<Expr> thread_head_check;
+  if (stage->store_predicate.defined()) {
+    thread_head_check.emplace_back(stage->store_predicate);
+  }
+
+  Stmt reduce_body = Evaluate::make(Call::make(
+      Handle(),
+      ir::intrinsic::tvm_thread_allreduce,
+      freduce_args, Call::Intrinsic));
+  reduce_body = AttrStmt::make(
+      reduces[0]->combiner,
+      attr::reduce_scope,
+      make_zero(Handle()),
+      reduce_body);
+  std::vector<Stmt> assigns(size);
+  for (size_t idx = 0; idx < size; ++idx) {
+    Type t = reduces[idx]->type;
+    assigns[idx] = Provide::make(
+      stage->op, idx,
+      Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
+  }
+  Stmt assign_body = Block::make(assigns);
+  assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
+  assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
+  Stmt body = Block::make(reduce_body, assign_body);
+  for (size_t idx = size; idx != 0; --idx) {
+    body = Allocate::make(
+      res_handles[idx - 1], reduces[idx - 1]->type, {1}, const_true(), body);
+    body = AttrStmt::make(
+      res_handles[idx - 1], attr::storage_scope, StringImm::make("local"), body);
+  }
+  body = op::Substitute(body, value_map);
+  return MergeNest(nest, body);
+}
+}  // namespace tvm
diff --git a/src/op/op_util.cc b/src/op/op_util.cc
index 628c714df..fe597a0cc 100644
--- a/src/op/op_util.cc
+++ b/src/op/op_util.cc
@@ -223,7 +223,6 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
 }
 
 
-
 // replacer to replace tensors
 class TensorReplacer : public ir::IRMutator {
  public:
@@ -263,5 +262,16 @@ Expr ReplaceTensor(Expr expr,
   Expr ret = repl.Mutate(expr);
   return repl.found ? ret : expr;
 }
+
+
+Stmt Substitute(Stmt s,
+                const std::unordered_map<IterVar, Expr>& value_map) {
+  std::unordered_map<const Variable*, Expr> init;
+  for (const auto& kv : value_map) {
+    init[kv.first->var.get()] = kv.second;
+  }
+  return ir::Substitute(s, init);
+}
+
 }  // namespace op
 }  // namespace tvm
diff --git a/src/op/op_util.h b/src/op/op_util.h
index 914815f9a..419035b67 100644
--- a/src/op/op_util.h
+++ b/src/op/op_util.h
@@ -12,6 +12,7 @@
 #include <unordered_set>
 #include <vector>
 #include "../pass/ir_util.h"
+#include "../pass/arg_binder.h"
 
 namespace tvm {
 namespace op {
@@ -74,6 +75,15 @@ Stmt ReplaceTensor(Stmt stmt,
 Expr ReplaceTensor(Expr expr,
                    const std::unordered_map<Tensor, Tensor>& replace);
 
+/*!
+ * \brief Substitute the variables of stmt by value map.
+ * \param stmt the statment
+ * \param value_map The value map.
+ * \return Substituted result.
+ */
+Stmt Substitute(Stmt stmt,
+                const std::unordered_map<IterVar, Expr>& value_map);
+
 }  // namespace op
 }  // namespace tvm
 #endif  // TVM_OP_OP_UTIL_H_
diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc
index 4ac7998d6..69e376260 100644
--- a/src/pass/arg_binder.cc
+++ b/src/pass/arg_binder.cc
@@ -75,13 +75,38 @@ void ArgBinder::BindArray(const Array<Expr>& arg,
 
 void ArgBinder::BindBuffer(const Buffer& arg,
                            const Buffer& value,
-                           const std::string& arg_name) {
+                           const std::string& arg_name,
+                           bool fuzzy_match) {
   CHECK_EQ(arg->scope, value->scope)
       << "Argument " << arg_name
       << " Buffer bind scope mismatch";
   this->Bind(arg->data, value->data, arg_name + ".data");
-  this->BindArray(arg->shape, value->shape, arg_name + ".shape");
-  this->BindArray(arg->strides, value->strides, arg_name + ".strides");
+  if (arg->shape.size() > value->shape.size()) {
+    CHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch";
+    size_t diff = arg->shape.size() - value->shape.size();
+    for (size_t i = 0; i < diff; ++i) {
+      CHECK(is_one(arg->shape[i]))
+          << "Argument " << arg_name << " shape mismatch"
+          << arg->shape << " vs " << value->shape;
+    }
+    for (size_t i = 0; i < value->shape.size(); ++i) {
+      std::ostringstream os;
+      os << arg_name << ".shape[" << i << "]";
+      this->Bind(arg->shape[i + diff], value->shape[i], os.str());
+    }
+    if (arg->strides.size() != 0) {
+      CHECK_EQ(arg->strides.size(), arg->shape.size());
+      CHECK_EQ(value->strides.size(), value->shape.size());
+      for (size_t i = 0; i < value->strides.size(); ++i) {
+        std::ostringstream os;
+        os << arg_name << ".strides[" << i << "]";
+        this->Bind(arg->strides[i + diff], value->strides[i], os.str());
+      }
+    }
+  } else {
+    this->BindArray(arg->shape, value->shape, arg_name + ".shape");
+    this->BindArray(arg->strides, value->strides, arg_name + ".strides");
+  }
   this->Bind(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset");
 }
 
diff --git a/src/pass/arg_binder.h b/src/pass/arg_binder.h
index 59e4eab55..6d6e6e7ca 100644
--- a/src/pass/arg_binder.h
+++ b/src/pass/arg_binder.h
@@ -71,10 +71,12 @@ class ArgBinder {
    * \param arg The argument to be binded.
    * \param value The target expression value
    * \param arg_name argument name.
+   * \param fuzzy_match If enabled, we allow value's dimension to be smaller than arg, as long as arg's higher dimensions are of 1.
    */
   void BindBuffer(const Buffer& arg,
                   const Buffer& value,
-                  const std::string& arg_name);
+                  const std::string& arg_name,
+                  bool fuzzy_match);
   /*!
    * \brief Bind symbolic buffer to a DLTensor handle.
    * \param buffer The argument buffer to be binded.
diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc
new file mode 100644
index 000000000..48656a41f
--- /dev/null
+++ b/src/pass/ir_deep_compare.cc
@@ -0,0 +1,417 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file ir_deep_compare.cc
+ */
+#include <tvm/ir_pass.h>
+#include <tvm/ir_functor_ext.h>
+
+namespace tvm {
+namespace ir {
+
+using ExprComparator = ExprFunctor<void(const Expr& n, const Expr &other)>;
+using StmtComparator = StmtFunctor<void(const Stmt& n, const Stmt &other)>;
+
+#define DEFINE_BIOP_EXPR_CMP_(OP)                                 \
+  void VisitExpr_(const OP* op, const Expr& other) final {        \
+    const OP* rhs = other.as<OP>();                               \
+    if (CompareExpr(op->a, rhs->a) != 0) return;                      \
+    if (CompareExpr(op->b, rhs->b) != 0) return;                      \
+  }
+
+// Deep comparison to check if two IR graph are equivalent
+class IRDeepCompare :
+      public ExprComparator, public StmtComparator {
+ public:
+  // Equality comparison
+  bool Equal(const Stmt& lhs, const Stmt& rhs) {
+    tie_def_ = true;
+    VisitStmt(lhs, rhs);
+    return order_ == 0;
+  }
+
+  bool Equal(const Expr& lhs, const Expr& rhs) {
+    tie_def_ = true;
+    VisitExpr(lhs, rhs);
+    return order_ == 0;
+  }
+
+  void VisitExpr(const Expr& n, const Expr& other) override {
+    if (order_ != 0) return;
+    if (CompareValue(n->type_index(), other->type_index()) != 0) return;
+    if (CompareType(n.type(), other.type()) != 0) return;
+    ExprComparator::VisitExpr(n, other);
+  }
+
+  void VisitStmt(const Stmt& n, const Stmt& other) override {
+    if (order_ != 0) return;
+    if (CompareValue(n->type_index(), other->type_index()) != 0) return;
+    StmtComparator::VisitStmt(n, other);
+  }
+  // Stmt
+  void VisitStmt_(const LetStmt* op, const Stmt& other) final {
+    const LetStmt* rhs = other.as<LetStmt>();
+    if (CompareExpr(op->value, rhs->value) != 0) return;
+    if (tie_def_) {
+      vmap_[op->var.get()] = rhs->var.get();
+    } else {
+      if (CompareExpr(op->var, rhs->var) != 0) return;
+    }
+    if (CompareStmt(op->body, rhs->body) != 0) return;
+  }
+
+  void VisitStmt_(const AttrStmt* op, const Stmt& other) final {
+    const AttrStmt* rhs = other.as<AttrStmt>();
+    if (CompareString(op->attr_key, rhs->attr_key) != 0) return;
+    if (CompareNodeRef(op->node, rhs->node) != 0) return;
+    if (CompareExpr(op->value, rhs->value) != 0) return;
+    if (CompareStmt(op->body, rhs->body) != 0) return;
+  }
+
+  void VisitStmt_(const IfThenElse* op, const Stmt& other) final {
+    const IfThenElse* rhs = other.as<IfThenElse>();
+    if (CompareExpr(op->condition, rhs->condition) != 0) return;
+    if (CompareStmt(op->then_case, rhs->then_case) != 0) return;
+    if (CompareStmt(op->else_case, rhs->else_case) != 0) return;
+  }
+
+  void VisitStmt_(const For* op, const Stmt& other) final {
+    const For* rhs = other.as<For>();
+    if (CompareExpr(op->min, rhs->min) != 0) return;
+    if (CompareExpr(op->extent, rhs->extent) != 0) return;
+    if (tie_def_) {
+      vmap_[op->loop_var.get()] = rhs->loop_var.get();
+    } else {
+      if (CompareExpr(op->loop_var, rhs->loop_var) != 0) return;
+    }
+    if (CompareStmt(op->body, rhs->body) != 0) return;
+  }
+
+  void VisitStmt_(const Allocate* op, const Stmt& other) final {
+    const Allocate* rhs = other.as<Allocate>();
+    if (tie_def_) {
+      vmap_[op->buffer_var.get()] = rhs->buffer_var.get();
+    } else {
+      if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
+    }
+    if (CompareType(op->type, rhs->type) != 0) return;
+    if (CompareArray(op->extents, rhs->extents) != 0) return;
+    if (CompareExpr(op->condition, rhs->condition) != 0) return;
+    if (CompareStmt(op->body, rhs->body) != 0) return;
+    if (CompareExpr(op->new_expr, rhs->new_expr) != 0) return;
+    if (CompareString(op->free_function, rhs->free_function) != 0) return;
+  }
+
+  void VisitStmt_(const Store* op, const Stmt& other) final {
+    const Store* rhs = other.as<Store>();
+    if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
+    if (CompareExpr(op->value, rhs->value) != 0) return;
+    if (CompareExpr(op->index, rhs->index) != 0) return;
+    if (CompareExpr(op->predicate, rhs->predicate) != 0) return;
+  }
+
+  void VisitStmt_(const Free* op, const Stmt& other) final {
+    const Free* rhs = other.as<Free>();
+    if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
+  }
+
+  void VisitStmt_(const AssertStmt* op, const Stmt& other) final {
+    const AssertStmt* rhs = other.as<AssertStmt>();
+    if (CompareExpr(op->condition, rhs->condition) != 0) return;
+    if (CompareExpr(op->message, rhs->message) != 0) return;
+  }
+
+  void VisitStmt_(const ProducerConsumer* op, const Stmt& other) final {
+    const ProducerConsumer* rhs = other.as<ProducerConsumer>();
+    if (CompareNodeRef(op->func, rhs->func) != 0) return;
+    if (CompareValue(op->is_producer, rhs->is_producer) != 0) return;
+    if (CompareStmt(op->body, rhs->body) != 0) return;
+  }
+
+
+  void VisitStmt_(const Provide* op, const Stmt& other) final {
+    const Provide* rhs = other.as<Provide>();
+    if (CompareNodeRef(op->func, rhs->func) != 0) return;
+    if (CompareValue(op->value_index, rhs->value_index) != 0) return;
+    if (CompareExpr(op->value, rhs->value) != 0) return;
+    if (CompareArray(op->args, rhs->args) != 0) return;
+  }
+
+  void VisitStmt_(const Realize* op, const Stmt& other) final {
+    const Realize* rhs = other.as<Realize>();
+    if (CompareNodeRef(op->func, rhs->func) != 0) return;
+    if (CompareValue(op->value_index, rhs->value_index) != 0) return;
+    if (CompareType(op->type, rhs->type) != 0) return;
+    if (CompareRegion(op->bounds, rhs->bounds) != 0) return;
+    if (CompareStmt(op->body, rhs->body) != 0) return;
+  }
+
+  void VisitStmt_(const Prefetch* op, const Stmt& other) final {
+    const Prefetch* rhs = other.as<Prefetch>();
+    if (CompareNodeRef(op->func, rhs->func) != 0) return;
+    if (CompareValue(op->value_index, rhs->value_index) != 0) return;
+    if (CompareType(op->type, rhs->type) != 0) return;
+    if (CompareRegion(op->bounds, rhs->bounds) != 0) return;
+  }
+
+  void VisitStmt_(const Block* op, const Stmt& other) final {
+    const Block* rhs = other.as<Block>();
+    if (CompareStmt(op->first, rhs->first) != 0) return;
+    if (CompareStmt(op->rest, rhs->rest) != 0) return;
+  }
+
+  void VisitStmt_(const Evaluate* op, const Stmt& other) final {
+    const Evaluate* rhs = other.as<Evaluate>();
+    CompareExpr(op->value, rhs->value);
+  }
+
+  // Exprs
+  void VisitExpr_(const Variable* op, const Expr& other) final {
+    const Variable* rhs = other.as<Variable>();
+    auto it = vmap_.find(op);
+    if (it != vmap_.end()) op = it->second;
+    if (op < rhs) {
+      order_ = -1;
+    } else if (op > rhs) {
+      order_ = +1;
+    }
+  }
+  void VisitExpr_(const Load* op, const Expr& other) final {
+    const Load* rhs = other.as<Load>();
+    if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
+    if (CompareExpr(op->index, rhs->index) != 0) return;
+    if (CompareExpr(op->predicate, rhs->predicate) != 0) return;
+  }
+
+  void VisitExpr_(const Let* op, const Expr& other) final {
+    const Let* rhs = other.as<Let>();
+    if (tie_def_) {
+      vmap_[op->var.get()] = rhs->var.get();
+    } else {
+      if (CompareExpr(op->var, rhs->var) != 0) return;
+    }
+    if (CompareExpr(op->value, rhs->value) != 0) return;
+    if (CompareExpr(op->body, rhs->body) != 0) return;
+  }
+
+  void VisitExpr_(const Call* op, const Expr& other) final {
+    const Call* rhs = other.as<Call>();
+    if (CompareString(op->name, rhs->name)) return;
+    if (CompareArray(op->args, rhs->args)) return;
+    if (CompareValue(op->call_type, rhs->call_type) != 0) return;
+    if (CompareNodeRef(op->func, rhs->func) != 0) return;
+    if (CompareValue(op->value_index, rhs->value_index) != 0) return;
+  }
+
+  void VisitExpr_(const Reduce *op, const Expr& other) final {
+    const Reduce* rhs = other.as<Reduce>();
+    if (CompareCommReducer(op->combiner, rhs->combiner) != 0) return;
+    if (CompareValue(op->axis.size(), rhs->axis.size()) != 0) return;
+    if (CompareValue(op->value_index, rhs->value_index) != 0) return;
+    for (size_t i = 0; i < op->axis.size(); ++i) {
+      if (CompareExpr(op->axis[i]->dom->min, rhs->axis[i]->dom->min) != 0) return;
+      if (CompareExpr(op->axis[i]->dom->extent, rhs->axis[i]->dom->extent) != 0) return;
+      if (tie_def_) {
+        vmap_[op->axis[i]->var.get()] = rhs->axis[i]->var.get();
+      } else {
+        if (CompareExpr(op->axis[i]->var, rhs->axis[i]->var) != 0) return;
+      }
+    }
+    if (CompareExpr(op->condition, rhs->condition) != 0) return;
+    if (CompareArray(op->source, rhs->source) != 0) return;
+  }
+
+  void VisitExpr_(const IntImm *op, const Expr& other) final {
+    CompareValue(op->value, other.as<IntImm>()->value);
+  }
+
+  void VisitExpr_(const UIntImm *op, const Expr& other) final {
+    CompareValue(op->value, other.as<UIntImm>()->value);
+  }
+
+  void VisitExpr_(const FloatImm *op, const Expr& other) final {
+    CompareValue(op->value, other.as<FloatImm>()->value);
+  }
+
+  void VisitExpr_(const StringImm *op, const Expr& other) final {
+    CompareString(op->value, other.as<StringImm>()->value);
+  }
+
+  void VisitExpr_(const Cast *op, const Expr& other) final {
+    CompareExpr(op->value, other.as<Cast>()->value);
+  }
+
+  void VisitExpr_(const Not *op, const Expr& other) final {
+    CompareExpr(op->a, other.as<Not>()->a);
+  }
+
+  void VisitExpr_(const Select *op, const Expr& other) final {
+    const Select* rhs = other.as<Select>();
+    if (CompareExpr(op->condition, rhs->condition) != 0) return;
+    if (CompareExpr(op->true_value, rhs->true_value) != 0) return;
+    if (CompareExpr(op->false_value, rhs->false_value) != 0) return;
+  }
+
+  void VisitExpr_(const Ramp *op, const Expr& other) final {
+    const Ramp* rhs = other.as<Ramp>();
+    if (CompareExpr(op->base, rhs->base) != 0) return;
+    if (CompareExpr(op->stride, rhs->stride) != 0) return;
+    if (CompareValue(op->lanes, rhs->lanes) != 0) return;
+  }
+
+  void VisitExpr_(const Broadcast *op, const Expr& other) final {
+    const Broadcast* rhs = other.as<Broadcast>();
+    if (CompareExpr(op->value, rhs->value) != 0) return;
+    if (CompareValue(op->lanes, rhs->lanes) != 0) return;
+  }
+
+  void VisitExpr_(const Shuffle *op, const Expr& other) final {
+    const Shuffle* rhs = other.as<Shuffle>();
+    if (CompareArray(op->vectors, rhs->vectors) != 0) return;
+    if (CompareArray(op->indices, rhs->indices) != 0) return;
+  }
+
+  DEFINE_BIOP_EXPR_CMP_(Add)
+  DEFINE_BIOP_EXPR_CMP_(Sub)
+  DEFINE_BIOP_EXPR_CMP_(Mul)
+  DEFINE_BIOP_EXPR_CMP_(Div)
+  DEFINE_BIOP_EXPR_CMP_(Mod)
+  DEFINE_BIOP_EXPR_CMP_(Min)
+  DEFINE_BIOP_EXPR_CMP_(Max)
+  DEFINE_BIOP_EXPR_CMP_(EQ)
+  DEFINE_BIOP_EXPR_CMP_(NE)
+  DEFINE_BIOP_EXPR_CMP_(LT)
+  DEFINE_BIOP_EXPR_CMP_(LE)
+  DEFINE_BIOP_EXPR_CMP_(GT)
+  DEFINE_BIOP_EXPR_CMP_(GE)
+  DEFINE_BIOP_EXPR_CMP_(And)
+  DEFINE_BIOP_EXPR_CMP_(Or)
+
+ private:
+  int CompareExpr(const Expr& lhs, const Expr& rhs) {
+    if (order_ != 0) return order_;
+    if (!lhs.defined() && rhs.defined()) {
+      order_ = -1; return order_;
+    }
+    if (!rhs.defined() && lhs.defined()) {
+      order_ = +1; return order_;
+    }
+    VisitExpr(lhs, rhs);
+    return order_;
+  }
+
+  int CompareStmt(const Stmt& lhs, const Stmt& rhs) {
+    if (order_ != 0) return order_;
+    if (!lhs.defined() && rhs.defined()) {
+      order_ = -1; return order_;
+    }
+    if (!rhs.defined() && lhs.defined()) {
+      order_ = +1; return order_;
+    }
+    VisitStmt(lhs, rhs);
+    return order_;
+  }
+
+  int CompareArray(const Array<Expr>& lhs, const Array<Expr>& rhs) {
+    if (order_ != 0) return order_;
+    if (CompareValue(lhs.size(), rhs.size()) != 0) return order_;
+    for (size_t i = 0; i < lhs.size(); ++i) {
+      if (CompareExpr(lhs[i], rhs[i]) != 0) return order_;
+    }
+    return order_;
+  }
+
+  int CompareRegion(const Halide::Internal::Region& lhs,
+                    const Halide::Internal::Region& rhs) {
+    if (order_ != 0) return order_;
+    if (CompareValue(lhs.size(), rhs.size()) != 0) return order_;
+    for (size_t i = 0; i < lhs.size(); ++i) {
+      if (CompareExpr(lhs[i]->min, rhs[i]->min) != 0) return order_;
+      if (CompareExpr(lhs[i]->extent, rhs[i]->extent) != 0) return order_;
+    }
+    return order_;
+  }
+
+  int CompareNodeRef(const NodeRef& lhs, const NodeRef& rhs) {
+    if (order_ != 0) return order_;
+    if (lhs.get() < rhs.get()) {
+      order_ = -1; return order_;
+    }
+    if (lhs.get() > rhs.get()) {
+      order_ = +1; return order_;
+    }
+    return order_;
+  }
+
+  int CompareType(const Type& lhs, const Type& rhs) {
+    if (order_ != 0) return order_;
+    if (lhs == rhs) return order_;
+    if (CompareValue(lhs.code(), rhs.code()) != 0) return order_;
+    if (CompareValue(lhs.bits(), rhs.bits()) != 0) return order_;
+    if (CompareValue(lhs.lanes(), rhs.lanes()) != 0) return order_;
+    return order_;
+  }
+
+  int CompareString(const std::string& lhs, const std::string& rhs) {
+    if (order_ != 0) return order_;
+    order_ = lhs.compare(rhs);
+    return order_;
+  }
+
+  template<typename T>
+  int CompareValue(const T& lhs, const T& rhs) {
+    if (order_ != 0) return order_;
+    if (lhs < rhs) {
+      order_ = -1; return order_;
+    } else if (lhs > rhs) {
+      order_ = +1; return order_;
+    }
+    return order_;
+  }
+
+  int CompareCommReducer(const CommReducer& lhs, const CommReducer& rhs) {
+    if (order_ != 0) return order_;
+    if (lhs == rhs) return order_;
+    if (CompareValue(lhs->lhs.size(), rhs->lhs.size()) != 0) return order_;
+    if (CompareValue(lhs->rhs.size(), rhs->rhs.size()) != 0) return order_;
+    IRDeepCompare cmp;
+    if (tie_def_) {
+      for (size_t i = 0; i < lhs->lhs.size(); ++i) {
+        cmp.vmap_[lhs->lhs[i].get()] = rhs->lhs[i].get();
+      }
+      for (size_t i = 0; i < lhs->rhs.size(); ++i) {
+        cmp.vmap_[lhs->rhs[i].get()] = rhs->rhs[i].get();
+      }
+    } else {
+      for (size_t i = 0; i < lhs->lhs.size(); ++i) {
+        if (CompareExpr(lhs->lhs[i], rhs->lhs[i]) != 0) return order_;
+      }
+      for (size_t i = 0; i < lhs->lhs.size(); ++i) {
+        if (CompareExpr(lhs->rhs[i], rhs->rhs[i]) != 0) return order_;
+      }
+    }
+    order_ = cmp.CompareArray(lhs->result, rhs->result);
+    return order_;
+  }
+  // The order flag, smaller, -1, bigger: +1, equal: 0
+  int order_{0};
+  // Whether tie intermediate definitions.
+  // This allows use to tie definitions of two variables together.
+  // This enables us to assert equal between (let x in x + 1),  (let y in y + 1)
+  // However, the comparison is no longer in total order.
+  // Only equality/non-equality information is valid.
+  bool tie_def_{false};
+  // varaible remap if any
+  std::unordered_map<const Variable*, const Variable*> vmap_;
+};
+
+
+bool Equal(const Stmt& lhs, const Stmt& rhs) {
+  return IRDeepCompare().Equal(lhs, rhs);
+}
+
+bool Equal(const Expr& lhs, const Expr& rhs) {
+  return IRDeepCompare().Equal(lhs, rhs);
+}
+
+}  // namespace ir
+}  // namespace tvm
diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc
index a33e03496..d7ac0ac6a 100644
--- a/src/pass/storage_flatten.cc
+++ b/src/pass/storage_flatten.cc
@@ -195,7 +195,7 @@ class StorageFlattener : public IRMutator {
     }
     // start binding
     ArgBinder binder(&var_remap_);
-    binder.BindBuffer(Buffer(arr[0].node_), slice, buffer->name);
+    binder.BindBuffer(Buffer(arr[0].node_), slice, buffer->name, true);
     // Apply the remaps
     Stmt body = MergeNest(binder.asserts(), op->body);
     body = MergeNest(binder.init_nest(), body);
diff --git a/tests/python/unittest/test_pass_equal.py b/tests/python/unittest/test_pass_equal.py
new file mode 100644
index 000000000..1c13b82ea
--- /dev/null
+++ b/tests/python/unittest/test_pass_equal.py
@@ -0,0 +1,48 @@
+import tvm
+
+def test_equal_expr():
+    x = tvm.var('x')
+    y = tvm.var('y')
+
+    def func1():
+        return x + y + 1
+
+    def func2():
+        return tvm.exp((x + y + 1) * y / 4)
+
+    assert tvm.ir_pass.Equal(func1(), func1())
+    assert tvm.ir_pass.Equal(func2(), func2())
+    assert not tvm.ir_pass.Equal(func2(), func1())
+
+
+def test_equal_compute():
+    x = tvm.var('x')
+    y = tvm.var('y')
+    n = 128
+    A = tvm.placeholder((n, n), name='A')
+    B = tvm.placeholder((n, n), name='B')
+    ii = tvm.var('i')
+    jj = tvm.var('j')
+
+    def func1():
+        k = tvm.reduce_axis((0, n), name='k')
+        return tvm.sum(A[ii, k] * B[jj, k], axis=k)
+
+    Ab = tvm.decl_buffer((n,), name='A')
+    n = tvm.var("n")
+    def func2():
+        ib = tvm.ir_builder.create()
+        A = ib.buffer_ptr(Ab)
+        with ib.for_range(0, n, name="i") as i:
+            A[i] = A[i] + 1
+            with ib.for_range(0, 10, name="j") as j:
+                A[j] = A[j] + 2
+        return ib.get()
+
+    assert tvm.ir_pass.Equal(func1(), func1())
+    assert tvm.ir_pass.Equal(func2(), func2())
+
+
+if __name__ == "__main__":
+    test_equal_expr()
+    test_equal_compute()
-- 
GitLab