From 81db03abb91e510de70f91d6a0f69d1355f0f247 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Tue, 23 Oct 2018 10:54:20 -0700
Subject: [PATCH] [RELAY] Refactor AlphaEqual to support deep comparison of
 Attrs. (#1958)

---
 include/tvm/attrs.h                           | 197 +++++----
 src/api/api_pass.cc                           |   8 +-
 src/lang/attr_functor.h                       | 124 +++++-
 src/lang/attrs.cc                             | 316 ++++++++-----
 src/relay/ir/alpha_equal.cc                   | 384 ++++++++++++++++
 src/relay/ir/text_printer.cc                  |   7 +-
 .../type_visitor.h => ir/type_functor.h}      |  94 +++-
 src/relay/pass/alpha_eq.cc                    | 418 ------------------
 src/relay/pass/kind_check.cc                  |  16 +-
 src/relay/pass/type_functor.h                 |  94 ----
 src/relay/pass/type_subst.cc                  |   2 +-
 src/relay/pass/util.cc                        |   2 +-
 tests/python/relay/test_pass_alpha_equal.py   |  43 +-
 .../unittest/test_pass_attrs_hash_equal.py    |   6 +
 14 files changed, 980 insertions(+), 731 deletions(-)
 create mode 100644 src/relay/ir/alpha_equal.cc
 rename src/relay/{pass/type_visitor.h => ir/type_functor.h} (52%)
 delete mode 100644 src/relay/pass/alpha_eq.cc
 delete mode 100644 src/relay/pass/type_functor.h

diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h
index 3b7beaa37..33d84cece 100644
--- a/include/tvm/attrs.h
+++ b/include/tvm/attrs.h
@@ -108,6 +108,90 @@ class AttrFieldInfoNode : public Node {
 /*! \brief AttrFieldInfo */
 TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode);
 
+class AttrsHashHandler;
+class AttrsEqualHandler;
+/*!
+ * \brief Content-aware Equality comparator for attrs.
+ *
+ * This comparator will recursively deep compare the following Attributes.
+ *
+ * - IntImm, UIntImm, FloatImm, StringImm
+ * - Any subclass of BaseAttrsNode
+ * - Array of Attributes.
+ * - Map from string to Attributes.
+ */
+class AttrsEqual {
+ public:
+  bool operator()(const double& lhs, const double& rhs) const {
+    return lhs == rhs;
+  }
+  bool operator()(const int64_t& lhs, const int64_t& rhs) const {
+    return lhs == rhs;
+  }
+  bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
+    return lhs == rhs;
+  }
+  bool operator()(const int& lhs, const int& rhs) const {
+    return lhs == rhs;
+  }
+  bool operator()(const bool& lhs, const bool& rhs) const {
+    return lhs == rhs;
+  }
+  bool operator()(const std::string& lhs, const std::string& rhs) const {
+    return lhs == rhs;
+  }
+  bool operator()(const Type& lhs, const Type& rhs) const {
+    return lhs == rhs;
+  }
+  // node comparator
+  TVM_DLL bool operator()(const NodeRef& lhs, const NodeRef& rhs) const;
+
+ protected:
+  friend class AttrsEqualHandler;
+  /*! \brief internal handle. */
+  AttrsEqualHandler* handler_{nullptr};
+};
+
+/*!
+ * \brief Content-aware hash function.
+ *
+ * This hash functor will recursively hash the content of the Attributes.
+ * It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b);
+ */
+class AttrsHash {
+ public:
+  size_t operator()(const double& value) const {
+    return std::hash<double>()(value);
+  }
+  size_t operator()(const int64_t& value) const {
+    return std::hash<int64_t>()(value);
+  }
+  size_t operator()(const uint64_t& value) const {
+    return std::hash<uint64_t>()(value);
+  }
+  size_t operator()(const int& value) const {
+    return std::hash<int>()(value);
+  }
+  size_t operator()(const bool& value) const {
+    return std::hash<bool>()(value);
+  }
+  size_t operator()(const std::string& value) const {
+    return std::hash<std::string>()(value);
+  }
+  size_t operator()(const Type& value) const {
+    return std::hash<int>()(
+        static_cast<int>(value.code()) |
+        (static_cast<int>(value.bits()) << 8) |
+        (static_cast<int>(value.lanes()) << 16));
+  }
+  TVM_DLL size_t operator()(const NodeRef& value) const;
+
+ private:
+  friend class AttrsHashHandler;
+  /*! \brief internal handle. */
+  AttrsHashHandler* handler_{nullptr};
+};
+
 /*!
  * \brief Base class of all attribute class
  * \note Do not subclass AttrBaseNode directly,
@@ -153,14 +237,17 @@ class BaseAttrsNode : public Node {
   /*!
    * \brief Whether this attribute's content equals to another node.
    * \param other The pointer to another node.
+   * \param equal The equal comparator
    * \return The comparison result.
    */
-  TVM_DLL virtual bool ContentEqual(const Node* other) const = 0;
+  TVM_DLL virtual bool ContentEqual(
+      const Node* other, AttrsEqual equal) const = 0;
   /*!
    * \brief Content aware hash.
+   * \param hasher The hasher to run the hash.
    * \return the hash result.
    */
-  TVM_DLL virtual size_t ContentHash() const = 0;
+  TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;
 
   static constexpr const char* _type_key = "Attrs";
   TVM_DECLARE_BASE_NODE_INFO(BaseAttrsNode, Node);
@@ -209,92 +296,13 @@ class DictAttrsNode : public BaseAttrsNode {
   void VisitNonDefaultAttrs(AttrVisitor* v) final;
   void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
   Array<AttrFieldInfo> ListFieldInfo() const final;
-  bool ContentEqual(const Node* other) const final;
-  size_t ContentHash() const final;
+  bool ContentEqual(const Node* other, AttrsEqual equal) const final;
+  size_t ContentHash(AttrsHash hasher) const final;
   // type info
   static constexpr const char* _type_key = "DictAttrs";
   TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode);
 };
 
-/*!
- * \brief Content-aware Equality comparator for attrs.
- *
- * This comparator will recursively deep compare the following Attributes.
- *
- * - IntImm, UIntImm, FloatImm, StringImm
- * - Any subclass of BaseAttrsNode
- * - Array of Attributes.
- * - Map from string to Attributes.
- */
-class AttrsEqual {
- public:
-  bool operator()(const double& lhs, const double& rhs) const {
-    return lhs == rhs;
-  }
-  bool operator()(const int64_t& lhs, const int64_t& rhs) const {
-    return lhs == rhs;
-  }
-  bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
-    return lhs == rhs;
-  }
-  bool operator()(const int& lhs, const int& rhs) const {
-    return lhs == rhs;
-  }
-  bool operator()(const bool& lhs, const bool& rhs) const {
-    return lhs == rhs;
-  }
-  bool operator()(const std::string& lhs, const std::string& rhs) const {
-    return lhs == rhs;
-  }
-  bool operator()(const Type& lhs, const Type& rhs) const {
-    return lhs == rhs;
-  }
-  bool operator()(const NodeRef& lhs, const NodeRef& rhs) const {
-    return AttrsEqual::Equal(lhs, rhs);
-  }
-
-  // comparator of NodeRef types.
-  static TVM_DLL bool Equal(const NodeRef& lhs, const NodeRef& rhs);
-};
-
-/*!
- * \brief Content-aware hash function.
- *
- * This hash functor will recursively hash the content of the Attributes.
- * It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b);
- */
-class AttrsHash {
- public:
-  size_t operator()(const double& value) const {
-    return std::hash<double>()(value);
-  }
-  size_t operator()(const int64_t& value) const {
-    return std::hash<int64_t>()(value);
-  }
-  size_t operator()(const uint64_t& value) const {
-    return std::hash<uint64_t>()(value);
-  }
-  size_t operator()(const int& value) const {
-    return std::hash<int>()(value);
-  }
-  size_t operator()(const bool& value) const {
-    return std::hash<bool>()(value);
-  }
-  size_t operator()(const std::string& value) const {
-    return std::hash<std::string>()(value);
-  }
-  size_t operator()(const Type& value) const {
-    return std::hash<int>()(
-        static_cast<int>(value.code()) |
-        (static_cast<int>(value.bits()) << 8) |
-        (static_cast<int>(value.lanes()) << 16));
-  }
-  size_t operator()(const NodeRef& value) const {
-    return AttrsHash::Hash(value);
-  }
-  // hash function of the attribute and attribute fields.
-  static TVM_DLL size_t Hash(const NodeRef& lhs);
-};
 
 // Namespace containing detail implementations
 namespace detail {
@@ -342,8 +350,8 @@ class AttrsEqualVisitor {
  public:
   bool result_{true};
   // constructor
-  AttrsEqualVisitor(const Node* lhs, const Node* rhs)
-      : lhs_(lhs), rhs_(rhs) {
+  AttrsEqualVisitor(const Node* lhs, const Node* rhs, const AttrsEqual& equal)
+      : lhs_(lhs), rhs_(rhs), equal_(equal) {
   }
   template<typename T>
   AttrNopEntry operator()(const char* key, T* lhs_value) {
@@ -353,7 +361,7 @@ class AttrsEqualVisitor {
             reinterpret_cast<const char*>(rhs_) +
             (reinterpret_cast<const char*>(lhs_value) -
              reinterpret_cast<const char*>(lhs_)));
-    if (!AttrsEqual()(*lhs_value, *rhs_value)) {
+    if (!equal_(*lhs_value, *rhs_value)) {
       result_ = false;
     }
     return AttrNopEntry();
@@ -362,17 +370,24 @@ class AttrsEqualVisitor {
  private:
   const Node* lhs_;
   const Node* rhs_;
+  const AttrsEqual& equal_;
 };
 
 class AttrsHashVisitor {
  public:
+  explicit AttrsHashVisitor(const AttrsHash& hasher)
+      : hasher_(hasher) {}
+
   size_t result_{0};
 
   template<typename T>
   AttrNopEntry operator()(const char* key, T* value) {
-    result_ = dmlc::HashCombine(result_, AttrsHash()(*value));
+    result_ = dmlc::HashCombine(result_, hasher_(*value));
     return AttrNopEntry();
   }
+
+ private:
+  const AttrsHash& hasher_;
 };
 
 // helper entry that does initialization, set default.
@@ -793,18 +808,18 @@ class AttrsNode : public BaseAttrsNode {
     return visitor.fields_;
   }
 
-  bool ContentEqual(const Node* other) const final {
+  bool ContentEqual(const Node* other, AttrsEqual equal) const final {
     DerivedType* pself = self();
     if (pself == other) return true;
     if (other == nullptr) return false;
     if (pself->type_index() != other->type_index()) return false;
-    detail::AttrsEqualVisitor visitor(pself, other);
+    detail::AttrsEqualVisitor visitor(pself, other, equal);
     self()->__VisitAttrs__(visitor);
     return visitor.result_;
   }
 
-  size_t ContentHash() const final {
-    detail::AttrsHashVisitor visitor;
+  size_t ContentHash(AttrsHash hasher) const final {
+    detail::AttrsHashVisitor visitor(hasher);
     visitor.result_ = std::hash<std::string>()(this->type_key());
     self()->__VisitAttrs__(visitor);
     return visitor.result_;
diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc
index 66e4529ac..1e571ca0d 100644
--- a/src/api/api_pass.cc
+++ b/src/api/api_pass.cc
@@ -68,10 +68,14 @@ TVM_REGISTER_API("ir_pass.Equal")
 
 
 TVM_REGISTER_API("ir_pass.AttrsEqual")
-.set_body_typed<bool(const NodeRef&, const NodeRef&)>(AttrsEqual::Equal);
+.set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) {
+    return AttrsEqual()(lhs, rhs);
+  });
 
 TVM_REGISTER_API("ir_pass.AttrsHash")
-.set_body_typed<int64_t(const NodeRef&)>(AttrsHash::Hash);
+.set_body_typed<int64_t(const NodeRef&)>([](const NodeRef &node) {
+    return AttrsHash()(node);
+  });
 
 
 TVM_REGISTER_API("ir_pass.ExprUseVar")
diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h
index 8aa39a774..ef1d06101 100644
--- a/src/lang/attr_functor.h
+++ b/src/lang/attr_functor.h
@@ -52,13 +52,33 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
       return VisitAttrDefault_(n.get(), std::forward<Args>(args)...);
     }
   }
+  virtual R VisitAttrDefault_(const Node* node, Args... args) = 0;
   virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
   virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
   virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
   virtual R VisitAttr_(const ir::UIntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
   virtual R VisitAttr_(const ir::FloatImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
   virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttrDefault_(const Node* node, Args... args) = 0;
+  // deep comparison of symbolic integer expressions.
+  virtual R VisitAttr_(const Variable* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::Add* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::Sub* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::GT* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::LT* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::LE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::EQ* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::NE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::And* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::Or* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::Not* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::Cast* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::Call* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::Select* op, Args... args) ATTR_FUNCTOR_DEFAULT;
 
  private:
   // initialize the vtable.
@@ -72,9 +92,111 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
     ATTR_FUNCTOR_DISPATCH(UIntImm);
     ATTR_FUNCTOR_DISPATCH(FloatImm);
     ATTR_FUNCTOR_DISPATCH(StringImm);
+    ATTR_FUNCTOR_DISPATCH(Variable);
+    ATTR_FUNCTOR_DISPATCH(Add);
+    ATTR_FUNCTOR_DISPATCH(Sub);
+    ATTR_FUNCTOR_DISPATCH(Mul);
+    ATTR_FUNCTOR_DISPATCH(Min);
+    ATTR_FUNCTOR_DISPATCH(Max);
+    ATTR_FUNCTOR_DISPATCH(GE);
+    ATTR_FUNCTOR_DISPATCH(GT);
+    ATTR_FUNCTOR_DISPATCH(LE);
+    ATTR_FUNCTOR_DISPATCH(LT);
+    ATTR_FUNCTOR_DISPATCH(EQ);
+    ATTR_FUNCTOR_DISPATCH(NE);
+    ATTR_FUNCTOR_DISPATCH(And);
+    ATTR_FUNCTOR_DISPATCH(Or);
+    ATTR_FUNCTOR_DISPATCH(Not);
+    ATTR_FUNCTOR_DISPATCH(Cast);
+    ATTR_FUNCTOR_DISPATCH(Call);
+    ATTR_FUNCTOR_DISPATCH(Select);
     return vtable;
   }
 };
 
+class AttrsEqualHandler :
+      protected AttrFunctor<bool(const NodeRef&, const NodeRef&)> {
+ public:
+  /*!
+   * \brief Check if lhs equals rhs
+   * \param lhs The left operand.
+   * \param rhs The right operand.
+   */
+  bool Equal(const NodeRef& lhs, const NodeRef& rhs);
+
+ protected:
+  bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::IntImm* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::UIntImm* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::FloatImm* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::StringImm* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::Add* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::Sub* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::GT* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::LT* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::LE* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::EQ* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::NE* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::And* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::Or* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::Not* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::Cast* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::Call* lhs, const NodeRef& other) final;
+  bool VisitAttr_(const ir::Select* lhs, const NodeRef& other) final;
+};
+
+class AttrsHashHandler :
+      protected AttrFunctor<size_t(const NodeRef&)> {
+ public:
+  /*!
+   * \brief Get hash value of node
+   * \param node The node to be hashed.
+   */
+  size_t Hash(const NodeRef& node) {
+    return this->VisitAttr(node);
+  }
+
+ protected:
+  size_t VisitAttrDefault_(const Node* lhs) final;
+  size_t VisitAttr_(const ir::IntImm* lhs) final;
+  size_t VisitAttr_(const ir::UIntImm* lhs) final;
+  size_t VisitAttr_(const ir::FloatImm* lhs) final;
+  size_t VisitAttr_(const ir::StringImm* lhs) final;
+  size_t VisitAttr_(const ArrayNode* lhs) final;
+  size_t VisitAttr_(const StrMapNode* lhs) final;
+  size_t VisitAttr_(const ir::Add* op) final;
+  size_t VisitAttr_(const ir::Sub* op) final;
+  size_t VisitAttr_(const ir::Mul* op) final;
+  size_t VisitAttr_(const ir::Mod* op) final;
+  size_t VisitAttr_(const ir::Min* op) final;
+  size_t VisitAttr_(const ir::Max* op) final;
+  size_t VisitAttr_(const ir::GE* op) final;
+  size_t VisitAttr_(const ir::GT* op) final;
+  size_t VisitAttr_(const ir::LE* op) final;
+  size_t VisitAttr_(const ir::LT* op) final;
+  size_t VisitAttr_(const ir::EQ* op) final;
+  size_t VisitAttr_(const ir::NE* op) final;
+  size_t VisitAttr_(const ir::And* op) final;
+  size_t VisitAttr_(const ir::Or* op) final;
+  size_t VisitAttr_(const ir::Not* op) final;
+  size_t VisitAttr_(const ir::Cast* op) final;
+  size_t VisitAttr_(const ir::Call* op) final;
+  size_t VisitAttr_(const ir::Select* op) final;
+  /*!
+   * \brief alias of dmlc::HashCombine
+   * \param lhs The first hash value.
+   * \param rhs The second hash value.
+   */
+  static size_t Combine(size_t lhs, size_t rhs) {
+    return dmlc::HashCombine(lhs, rhs);
+  }
+};
 }  // namespace tvm
 #endif  // TVM_LANG_ATTR_FUNCTOR_H_
diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc
index e467018ad..9aa067c09 100644
--- a/src/lang/attrs.cc
+++ b/src/lang/attrs.cc
@@ -51,156 +51,272 @@ TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode);
 
 
 using namespace ir;
+// Equal handler.
+bool AttrsEqualHandler::Equal(const NodeRef& lhs, const NodeRef& rhs) {
+  if (lhs.same_as(rhs)) return true;
+  if (!lhs.defined() || !rhs.defined()) return false;
+  return this->VisitAttr(lhs, rhs);
+}
 
-class AttrsEqualChecker :
-      public AttrFunctor<bool(const NodeRef&, const NodeRef&)> {
- public:
-  bool Check(const NodeRef& lhs, const NodeRef& rhs) {
-    if (!equal_) return false;
-    if (lhs.same_as(rhs)) return true;
-    if (!lhs.defined() || !rhs.defined()) return false;
-    if (!this->VisitAttr(lhs, rhs)) {
-      equal_ = false;
-    }
-    return equal_;
+bool AttrsEqualHandler::VisitAttrDefault_(const Node* lhs, const NodeRef& other) {
+  if (lhs->derived_from<BaseAttrsNode>()) {
+    AttrsEqual equal;
+    equal.handler_ = this;
+    return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(
+        other.get(), equal);
   }
+  return lhs == other.get();
+}
 
-  bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final {
-    if (lhs->derived_from<BaseAttrsNode>()) {
-      return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(other.get());
-    }
-    return lhs == other.get();
+bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const NodeRef& other) {
+  if (const auto* rhs = other.as<IntImm>()) {
+    return lhs->value == rhs->value;
   }
+  return false;
+}
 
-  bool VisitAttr_(const IntImm* lhs, const NodeRef& other) final {
-    if (const auto* rhs = other.as<IntImm>()) {
-      return lhs->value == rhs->value;
-    }
-    return false;
+bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const NodeRef& other) {
+  if (const auto* rhs = other.as<UIntImm>()) {
+    return lhs->value == rhs->value;
   }
+  return false;
+}
 
-  bool VisitAttr_(const UIntImm* lhs, const NodeRef& other) final {
-    if (const auto* rhs = other.as<UIntImm>()) {
-      return lhs->value == rhs->value;
-    }
-    return false;
+bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const NodeRef& other) {
+  if (const auto* rhs = other.as<FloatImm>()) {
+    return lhs->value == rhs->value;
   }
+  return false;
+}
 
-  bool VisitAttr_(const FloatImm* lhs, const NodeRef& other) final {
-    if (const auto* rhs = other.as<FloatImm>()) {
-      return lhs->value == rhs->value;
-    }
-    return false;
+bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const NodeRef& other) {
+  if (const auto* rhs = other.as<StringImm>()) {
+    return lhs->value == rhs->value;
   }
+  return false;
+}
 
-  bool VisitAttr_(const StringImm* lhs, const NodeRef& other) final {
-    if (const auto* rhs = other.as<StringImm>()) {
-      return lhs->value == rhs->value;
+bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const NodeRef& other) {
+  if (const auto* rhs = other.as<ArrayNode>()) {
+    if (rhs->data.size() != lhs->data.size()) return false;
+    for (size_t  i = 0; i < lhs->data.size(); ++i) {
+      if (!Equal(NodeRef(lhs->data[i]), NodeRef(rhs->data[i]))) return false;
     }
-    return false;
   }
+  return true;
+}
 
-  bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final {
-    if (const auto* rhs = other.as<ArrayNode>()) {
-      if (rhs->data.size() != lhs->data.size()) return false;
-      for (size_t  i = 0; i < lhs->data.size(); ++i) {
-        if (!Check(NodeRef(lhs->data[i]), NodeRef(rhs->data[i]))) return false;
-      }
+bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other) {
+  if (const auto* rhs = other.as<StrMapNode>()) {
+    if (rhs->data.size() != lhs->data.size()) return false;
+    for (const auto& kv : lhs->data) {
+      auto it = rhs->data.find(kv.first);
+      if (it == rhs->data.end()) return false;
+      if (!Equal(NodeRef(kv.second), NodeRef(it->second))) return false;
     }
-    return true;
   }
+  return true;
+}
 
-  bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final {
-    if (const auto* rhs = other.as<StrMapNode>()) {
-      if (rhs->data.size() != lhs->data.size()) return false;
-      for (const auto& kv : lhs->data) {
-        auto it = rhs->data.find(kv.first);
-        if (it == rhs->data.end()) return false;
-        if (!Check(NodeRef(kv.second), NodeRef(it->second))) return false;
-      }
-    }
-    return true;
+#define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName)                          \
+  bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const NodeRef& other) { \
+    if (const auto* rhs = other.as<NodeName>()) {                       \
+      if (!Equal(lhs->a, rhs->a)) return false;                         \
+      if (!Equal(lhs->b, rhs->b)) return false;                         \
+      return true;                                                      \
+    } else {                                                            \
+      return false;                                                     \
+    }                                                                   \
+  }                                                                     \
+
+TVM_DEFINE_ATTRS_BINOP_EQUAL(Add);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(Max);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(Min);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(GE);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(GT);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(LE);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(LT);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(EQ);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(NE);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(And);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(Or);
+
+bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const NodeRef& other) {
+  if (const auto* rhs = other.as<Not>()) {
+    return Equal(lhs->a, rhs->a);
+  } else {
+    return false;
   }
+}
 
- private:
-  bool equal_{true};
-};
-
-class AttrContentHasher :
-      public AttrFunctor<void(const NodeRef&)> {
- public:
-  size_t result_{0};
-
-  void VisitAttrDefault_(const Node* value) final {
-    if (value->derived_from<BaseAttrsNode>()) {
-      Update(static_cast<const BaseAttrsNode*>(value)->ContentHash());
-    } else {
-      Update(NodeHash()(GetRef<NodeRef>(value)));
-    }
+bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const NodeRef& other) {
+  if (const auto* rhs = other.as<Cast>()) {
+    if (lhs->type != rhs->type) return false;
+    return Equal(lhs->value, rhs->value);
+  } else {
+    return false;
   }
+}
 
-  void VisitAttr_(const IntImm* op) final {
-    Update(std::hash<int64_t>()(op->value));
+bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const NodeRef& other) {
+  if (const auto* rhs = other.as<Call>()) {
+    return
+        lhs->name == rhs->name &&
+        lhs->type == rhs->type &&
+        lhs->call_type == rhs->call_type &&
+        Equal(lhs->args, rhs->args);
+  } else {
+    return false;
   }
+}
 
-  void VisitAttr_(const UIntImm* op) final {
-    Update(std::hash<uint64_t>()(op->value));
+bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const NodeRef& other) {
+  if (const auto* rhs = other.as<Select>()) {
+    return
+        Equal(lhs->condition, rhs->condition) &&
+        Equal(lhs->true_value, rhs->true_value) &&
+        Equal(lhs->false_value, rhs->false_value);
+  } else {
+    return false;
   }
+}
 
-  void VisitAttr_(const FloatImm* op) final {
-    Update(std::hash<double>()(op->value));
+// Hash Handler.
+size_t AttrsHashHandler::VisitAttrDefault_(const Node* value) {
+  if (value->derived_from<BaseAttrsNode>()) {
+    AttrsHash hasher;
+    hasher.handler_ = this;
+    return static_cast<const BaseAttrsNode*>(value)->ContentHash(hasher);
+  } else {
+    return NodeHash()(GetRef<NodeRef>(value));
   }
+}
 
-  void VisitAttr_(const StringImm* op) final {
-    Update(std::hash<std::string>()(op->value));
-  }
+size_t AttrsHashHandler::VisitAttr_(const IntImm* op) {
+  return std::hash<int64_t>()(op->value);
+}
 
-  void VisitAttr_(const ArrayNode* op) final {
-    Update(op->data.size());
-    for (size_t  i = 0; i < op->data.size(); ++i) {
-      this->VisitAttr(NodeRef(op->data[i]));
-    }
+size_t AttrsHashHandler::VisitAttr_(const UIntImm* op) {
+  return std::hash<uint64_t>()(op->value);
+}
+
+size_t AttrsHashHandler::VisitAttr_(const FloatImm* op) {
+  return std::hash<double>()(op->value);
+}
+
+size_t AttrsHashHandler::VisitAttr_(const StringImm* op) {
+  return std::hash<std::string>()(op->value);
+}
+
+size_t AttrsHashHandler::VisitAttr_(const ArrayNode* op) {
+  size_t result = op->data.size();
+  for (size_t  i = 0; i < op->data.size(); ++i) {
+    result = Combine(result, this->Hash(NodeRef(op->data[i])));
   }
+  return result;
+}
 
-  void VisitAttr_(const StrMapNode* lhs) final {
+size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
     using Entry = std::pair<std::string, NodePtr<Node> >;
     std::vector<Entry> data(lhs->data.begin(), lhs->data.end());
     std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) {
         return a.first < b.first;
       });
+    size_t result = 0;
     for (const Entry& kv : data) {
-      Update(std::hash<std::string>()(kv.first));
-      this->VisitAttr(NodeRef(kv.second));
+      result = Combine(result, std::hash<std::string>()(kv.first));
+      result = Combine(result, this->Hash(NodeRef(kv.second)));
     }
-  }
+    return result;
+}
 
-  void Update(size_t value) {
-    result_ = dmlc::HashCombine(result_, value);
-  }
-};
 
-bool AttrsEqual::Equal(const NodeRef& lhs, const NodeRef& rhs) {
+#define TVM_DEFINE_ATTRS_BINOP_HASH(NodeName)                           \
+  size_t AttrsHashHandler::VisitAttr_(const NodeName* op) {             \
+    static size_t key = std::hash<std::string>()(NodeName::_type_key);  \
+    return Combine(key, Combine(Hash(op->a), Hash(op->b)));             \
+  }                                                                     \
+
+TVM_DEFINE_ATTRS_BINOP_HASH(Add);
+TVM_DEFINE_ATTRS_BINOP_HASH(Sub);
+TVM_DEFINE_ATTRS_BINOP_HASH(Mul);
+TVM_DEFINE_ATTRS_BINOP_HASH(Mod);
+TVM_DEFINE_ATTRS_BINOP_HASH(Max);
+TVM_DEFINE_ATTRS_BINOP_HASH(Min);
+TVM_DEFINE_ATTRS_BINOP_HASH(GE);
+TVM_DEFINE_ATTRS_BINOP_HASH(GT);
+TVM_DEFINE_ATTRS_BINOP_HASH(LE);
+TVM_DEFINE_ATTRS_BINOP_HASH(LT);
+TVM_DEFINE_ATTRS_BINOP_HASH(EQ);
+TVM_DEFINE_ATTRS_BINOP_HASH(NE);
+TVM_DEFINE_ATTRS_BINOP_HASH(And);
+TVM_DEFINE_ATTRS_BINOP_HASH(Or);
+
+size_t AttrsHashHandler::VisitAttr_(const Not* op) {
+  static size_t key = std::hash<std::string>()(Not::_type_key);
+  return Combine(key, Hash(op->a));
+}
+
+size_t AttrsHashHandler::VisitAttr_(const Cast* op) {
+  static size_t key = std::hash<std::string>()(Cast::_type_key);
+  AttrsHash hasher;
+  size_t res = key;
+  res = Combine(res, hasher(op->type));
+  res = Combine(res, Hash(op->value));
+  return res;
+}
+
+size_t AttrsHashHandler::VisitAttr_(const Call* op) {
+  static size_t key = std::hash<std::string>()(Call::_type_key);
+  AttrsHash hasher;
+  size_t res = key;
+  res = Combine(res, hasher(op->name));
+  res = Combine(res, hasher(op->type));
+  res = Combine(res, Hash(op->args));
+  return res;
+}
+
+size_t AttrsHashHandler::VisitAttr_(const Select* op) {
+  static size_t key = std::hash<std::string>()(Select::_type_key);
+  size_t res = key;
+  res = Combine(res, Hash(op->condition));
+  res = Combine(res, Hash(op->true_value));
+  res = Combine(res, Hash(op->false_value));
+  return res;
+}
+
+
+// Default case
+bool AttrsEqual::operator()(const NodeRef& lhs, const NodeRef& rhs) const {
   if (lhs.same_as(rhs)) return true;
-  AttrsEqualChecker checker;
-  return checker.Check(lhs, rhs);
+  if (handler_ == nullptr) {
+    return AttrsEqualHandler().Equal(lhs, rhs);
+  } else {
+    return handler_->Equal(lhs, rhs);
+  }
 }
 
-size_t AttrsHash::Hash(const NodeRef& node) {
+size_t AttrsHash::operator()(const NodeRef& node) const {
   if (!node.defined()) return 0;
-  AttrContentHasher hasher;
-  hasher.VisitAttr(node);
-  return hasher.result_;
+  if (handler_ == nullptr) {
+    return AttrsHashHandler().Hash(node);
+  } else {
+    return handler_->Hash(node);
+  }
 }
 
-size_t DictAttrsNode::ContentHash() const {
-  return AttrsHash()(this->dict);
+size_t DictAttrsNode::ContentHash(AttrsHash hasher) const {
+  return hasher(this->dict);
 }
 
-bool DictAttrsNode::ContentEqual(const Node* other) const {
+bool DictAttrsNode::ContentEqual(const Node* other, AttrsEqual equal) const {
   if (this == other) return true;
   if (other == nullptr) return false;
   if (this->type_index() != other->type_index()) return false;
-  return AttrsEqual()(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
+  return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
 }
 
 }  // namespace tvm
diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc
new file mode 100644
index 000000000..f227970fc
--- /dev/null
+++ b/src/relay/ir/alpha_equal.cc
@@ -0,0 +1,384 @@
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file src/tvm/relay/ir/alpha_equal.cc
+ * \brief Alpha equality check by deep comparing two nodes.
+ */
+#include <tvm/ir_pass.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/relay/pass.h>
+#include "type_functor.h"
+#include "../../lang/attr_functor.h"
+
+namespace tvm {
+namespace relay {
+
+// Alpha equal handler for relay.
+class AlphaEqualHandler:
+      public AttrsEqualHandler,
+      public TypeFunctor<bool(const Type&, const Type&)>,
+      public ExprFunctor<bool(const Expr&, const Expr&)> {
+ public:
+  explicit AlphaEqualHandler(bool map_free_var)
+      : map_free_var_(map_free_var) {}
+
+  /*!
+   * Check equality of two nodes.
+   * \param lhs The left hand operand.
+   * \param rhs The right hand operand.
+   * \return the compare result.
+   */
+  bool Equal(const NodeRef& lhs, const NodeRef& rhs) {
+    if (lhs.same_as(rhs)) return true;
+    if (!lhs.defined() || !rhs.defined()) return false;
+    if (lhs->derived_from<TypeNode>()) {
+      if (!rhs->derived_from<TypeNode>()) return false;
+      return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
+    }
+    if (lhs->derived_from<ExprNode>()) {
+      if (!rhs->derived_from<ExprNode>()) return false;
+      return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
+    }
+    return AttrEqual(lhs, rhs);
+  }
+
+  /*!
+   * Check equality of two attributes.
+   * \param lhs The left hand operand.
+   * \param rhs The right hand operand.
+   * \return the compare result.
+   */
+  bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) {
+    return AttrsEqualHandler::Equal(lhs, rhs);
+  }
+  /*!
+   * Check equality of two types.
+   * \param lhs The left hand operand.
+   * \param rhs The right hand operand.
+   * \return the compare result.
+   */
+  bool TypeEqual(const Type& lhs, const Type& rhs) {
+    if (lhs.same_as(rhs)) return true;
+    if (!lhs.defined() || !rhs.defined()) return false;
+    return this->VisitType(lhs, rhs);
+  }
+  /*!
+   * Check equality of two expressions.
+   *
+   * \note We run graph structural equality checking when comparing two Exprs.
+   *   This means that AlphaEqualHandler can only be used once for each pair.
+   *   The equality checker checks data-flow equvalence of the Expr DAG.
+   *   This function also runs faster as it memomizes equal_map.
+   *
+   * \param lhs The left hand operand.
+   * \param rhs The right hand operand.
+   * \return the compare result.
+   */
+  bool ExprEqual(const Expr& lhs, const Expr& rhs) {
+    if (lhs.same_as(rhs)) return true;
+    if (!lhs.defined() || !rhs.defined()) return false;
+    auto it = equal_map_.find(lhs);
+    if (it != equal_map_.end()) {
+      return it->second.same_as(rhs);
+    }
+    if (this->VisitExpr(lhs, rhs)) {
+      equal_map_[lhs] = rhs;
+      return true;
+    } else {
+      return false;
+    }
+  }
+
+ protected:
+  /*!
+   * \brief Check if data type equals each other.
+   * \param lhs The left hand operand.
+   * \param rhs The right hand operand.
+   * \return the compare result.
+   */
+  bool DataTypeEqual(const DataType& lhs, const DataType& rhs) {
+    return lhs == rhs;
+  }
+  /*!
+   * \brief Check Equality of leaf node of the graph.
+   *  if map_free_var_ is set to true, try to map via equal node.
+   * \param lhs The left hand operand.
+   * \param rhs The right hand operand.
+   * \return the compare result.
+   */
+  bool LeafNodeEqual(const NodeRef& lhs, const NodeRef& rhs) {
+    if (lhs.same_as(rhs)) return true;
+    auto it = equal_map_.find(lhs);
+    if (it != equal_map_.end()) {
+      return it->second.same_as(rhs);
+    } else {
+      if (map_free_var_) {
+        if (lhs->type_index() != rhs->type_index()) return false;
+        equal_map_[lhs] = rhs;
+        return true;
+      } else {
+        return false;
+      }
+    }
+  }
+  using AttrsEqualHandler::VisitAttr_;
+  bool VisitAttr_(const Variable* lhs, const NodeRef& other) final {
+    return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
+  }
+
+  // Type equality
+  bool VisitType_(const TensorTypeNode* lhs, const Type& other) final {
+    if (const TensorTypeNode* rhs = other.as<TensorTypeNode>()) {
+      return (lhs->dtype == rhs->dtype &&
+              AttrEqual(lhs->shape, rhs->shape));
+    } else {
+      return false;
+    }
+  }
+
+  bool VisitType_(const IncompleteTypeNode* lhs, const Type& other) final {
+    return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
+  }
+
+  bool VisitType_(const TypeVarNode* lhs, const Type& other) final {
+    if (const TypeVarNode* rhs = other.as<TypeVarNode>()) {
+      if (lhs->kind != rhs->kind) return false;
+      return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
+    } else {
+      return false;
+    }
+  }
+
+  bool VisitType_(const FuncTypeNode* lhs, const Type& other) final {
+    if (const FuncTypeNode* rhs = other.as<FuncTypeNode>()) {
+      if (lhs->arg_types.size() != rhs->arg_types.size()) return false;
+      if (lhs->type_params.size() != rhs->type_params.size()) return false;
+      if (lhs->type_constraints.size() != rhs->type_constraints.size()) return false;
+      for (size_t i = 0; i < lhs->type_params.size(); ++i) {
+        if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) {
+          return false;
+        }
+        equal_map_[lhs->type_params[i]] = rhs->type_params[i];
+        // set up type parameter equal
+        if (lhs->type_params[i]->kind == TypeVarNode::Kind::kShapeVar) {
+          // map variable
+          equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var;
+        }
+      }
+      for (size_t i = 0; i < lhs->arg_types.size(); i++) {
+        if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false;
+      }
+      if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
+      for (size_t i = 0; i < lhs->type_constraints.size(); i++) {
+        if (!TypeEqual(lhs->type_constraints[i],
+                       rhs->type_constraints[i])) {
+          return false;
+        }
+      }
+      return true;
+    } else {
+      return false;
+    }
+  }
+
+  bool VisitType_(const TypeRelationNode* lhs, const Type& other) final {
+    if (const TypeRelationNode* rhs = other.as<TypeRelationNode>()) {
+      if (!lhs->func.same_as(rhs->func)) return false;
+      if (lhs->num_inputs != rhs->num_inputs) return false;
+      if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false;
+      if (lhs->args.size() != rhs->args.size()) return false;
+      for (size_t i = 0; i < lhs->args.size(); ++i) {
+        if (!TypeEqual(lhs->args[i], rhs->args[i])) return false;
+      }
+      return true;
+    } else {
+      return false;
+    }
+  }
+
+  bool VisitType_(const TupleTypeNode* lhs, const Type& other) final {
+    if (const TupleTypeNode* rhs = other.as<TupleTypeNode>()) {
+      if (lhs->fields.size() != rhs->fields.size()) return false;
+      for (size_t i = 0; i < lhs->fields.size(); ++i) {
+        if (!TypeEqual(lhs->fields[i], rhs->fields[i])) return false;
+      }
+      return true;
+    } else {
+      return false;
+    }
+  }
+  // Expr equal checking.
+  bool NDArrayEqual(const runtime::NDArray& lhs,
+                    const runtime::NDArray& rhs) {
+    if (lhs.defined() != rhs.defined()) {
+      return false;
+    } else if (lhs.same_as(rhs)) {
+      return true;
+    } else {
+      auto ldt = lhs->dtype;
+      auto rdt = rhs->dtype;
+      CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
+      CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
+      if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
+        size_t data_size = runtime::GetDataSize(*lhs.operator->());
+        return std::memcmp(lhs->data, rhs->data, data_size) == 0;
+      } else {
+        return false;
+      }
+    }
+  }
+  // merge declaration of two variables together.
+  bool MergeVarDecl(const Var& lhs, const Var& rhs) {
+    if (lhs.same_as(rhs)) return true;
+    if (!lhs.defined() || !rhs.defined()) return false;
+    if (!TypeEqual(lhs->type_annotation,
+                   rhs->type_annotation)) return false;
+    CHECK(!equal_map_.count(lhs))
+        << "Duplicated declaration of variable " <<  lhs;
+    equal_map_[lhs] = rhs;
+    return true;
+  }
+
+  bool VisitExpr_(const VarNode* lhs, const Expr& other) final {
+    if (const VarNode* rhs = other.as<VarNode>()) {
+      if (lhs->name_hint != rhs->name_hint) return false;
+      if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false;
+      return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
+    } else {
+      return false;
+    }
+  }
+
+  bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final {
+    if (const GlobalVarNode* rhs = other.as<GlobalVarNode>()) {
+      // use name equality for global var for now.
+      if (lhs->name_hint != rhs->name_hint) return false;
+      return true;
+    } else {
+      return false;
+    }
+  }
+
+  bool VisitExpr_(const TupleNode* lhs, const Expr& other) final {
+    if (const TupleNode* rhs = other.as<TupleNode>()) {
+      if (lhs->fields.size() != rhs->fields.size()) return false;
+      for (size_t i = 0; i < lhs->fields.size(); ++i) {
+        if (!ExprEqual(lhs->fields[i], rhs->fields[i])) return false;
+      }
+      return true;
+    } else {
+      return false;
+    }
+  }
+
+  bool VisitExpr_(const FunctionNode* lhs, const Expr& other) final {
+    if (const FunctionNode* rhs = other.as<FunctionNode>()) {
+      if (lhs->params.size() != rhs->params.size()) return false;
+      if (lhs->type_params.size() != rhs->type_params.size()) return false;
+      // map type parameter to be the same
+      for (size_t i = 0; i < lhs->type_params.size(); ++i) {
+        if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) return false;
+        equal_map_[lhs->type_params[i]] = rhs->type_params[i];
+      }
+      // check parameter type annotations
+      for (size_t i = 0; i < lhs->params.size(); ++i) {
+        if (!MergeVarDecl(lhs->params[i], rhs->params[i])) return false;
+      }
+      // check return types.
+      if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
+      return ExprEqual(lhs->body, rhs->body);
+    } else {
+      return false;
+    }
+  }
+
+  bool VisitExpr_(const CallNode* lhs, const Expr& other) final {
+    if (const CallNode* rhs = other.as<CallNode>()) {
+      if (!ExprEqual(lhs->op, rhs->op)) return false;
+      if (lhs->args.size() != rhs->args.size()) return false;
+      if (lhs->type_args.size() != rhs->type_args.size()) return false;
+
+      for (size_t i = 0; i < lhs->args.size(); ++i) {
+        if (!ExprEqual(lhs->args[i], rhs->args[i])) return false;
+      }
+      for (size_t i = 0; i < lhs->type_args.size(); ++i) {
+        if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
+      }
+      return AttrEqual(lhs->attrs, rhs->attrs);
+    } else {
+      return false;
+    }
+  }
+
+  bool VisitExpr_(const LetNode* lhs, const Expr& other) final {
+    if (const LetNode* rhs = other.as<LetNode>()) {
+      if (!ExprEqual(lhs->value, rhs->value)) return false;
+      if (!MergeVarDecl(lhs->var, rhs->var)) return false;
+      return ExprEqual(lhs->body, rhs->body);
+    } else {
+      return false;
+    }
+  }
+
+  bool VisitExpr_(const IfNode* lhs, const Expr& other) final {
+    if (const IfNode* rhs = other.as<IfNode>()) {
+      return ExprEqual(lhs->cond, rhs->cond) &&
+          ExprEqual(lhs->true_branch, rhs->true_branch) &&
+          ExprEqual(lhs->false_branch, rhs->false_branch);
+    } else {
+      return false;
+    }
+  }
+
+  bool VisitExpr_(const OpNode* op, const Expr& other) final {
+    return op == other.get();
+  }
+
+  bool VisitExpr_(const ConstantNode* lhs, const Expr& other) final {
+    if (const ConstantNode* rhs = other.as<ConstantNode>()) {
+      return NDArrayEqual(lhs->data, rhs->data);
+    } else {
+      return false;
+    }
+  }
+
+  bool VisitExpr_(const TupleGetItemNode* lhs, const Expr& other) final {
+    if (const TupleGetItemNode* rhs = other.as<TupleGetItemNode>()) {
+      return ExprEqual(lhs->tuple, rhs->tuple) && lhs->index == rhs->index;
+    } else {
+      return false;
+    }
+  }
+
+ private:
+  // whether to map open terms.
+  bool map_free_var_{false};
+  // renaming of NodeRef to indicate two nodes equals to each other
+  std::unordered_map<NodeRef, NodeRef, NodeHash, NodeEqual> equal_map_;
+};
+
+bool AlphaEqual(const Type& lhs, const Type& rhs) {
+  return AlphaEqualHandler(false).TypeEqual(lhs, rhs);
+}
+
+bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
+  return AlphaEqualHandler(false).ExprEqual(lhs, rhs);
+}
+
+// TODO(@jroesch): move to correct namespace?
+TVM_REGISTER_API("relay._make._alpha_equal")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    *ret = AlphaEqualHandler(false).Equal(args[0], args[1]);
+  });
+
+TVM_REGISTER_API("relay._make._type_alpha_equal")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    *ret = AlphaEqualHandler(false).TypeEqual(args[0], args[1]);
+  });
+
+TVM_REGISTER_API("relay._make._graph_equal")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    *ret = AlphaEqualHandler(true).Equal(args[0], args[1]);
+  });
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc
index 66ef86641..0ebe111ab 100644
--- a/src/relay/ir/text_printer.cc
+++ b/src/relay/ir/text_printer.cc
@@ -6,7 +6,7 @@
 #include <tvm/relay/environment.h>
 #include <tvm/relay/expr_functor.h>
 #include <sstream>
-#include "../pass/type_functor.h"
+#include "type_functor.h"
 #include "../../lang/attr_functor.h"
 
 namespace tvm {
@@ -245,6 +245,9 @@ class TextPrinter :
         stream_ << ", ";
       }
     }
+    if (fields.size() == 1) {
+      stream_ << ',';
+    }
     stream_ << ')';
     this->PrintEndInst("\n");
     return id;
@@ -648,7 +651,7 @@ class TextPrinter :
       name = "%" + name;
     }
     TextValue val(GetUniqueName(name));
-    CHECK(!memo_.count(var));
+    CHECK(!memo_.count(var)) << "Duplicated variable " << var;
     memo_[var] = val;
     return val;
   }
diff --git a/src/relay/pass/type_visitor.h b/src/relay/ir/type_functor.h
similarity index 52%
rename from src/relay/pass/type_visitor.h
rename to src/relay/ir/type_functor.h
index c1b2c3e1a..03bb4db1f 100644
--- a/src/relay/pass/type_visitor.h
+++ b/src/relay/ir/type_functor.h
@@ -1,18 +1,97 @@
 /*!
  *  Copyright (c) 2018 by Contributors
- * \file type_visitor.h
- * \brief A wrapper around TypeFunctor for common use cases.
+ * \file type_functor.h
+ * \brief A way to defined arbitrary function signature with dispatch on types.
  */
-#ifndef TVM_RELAY_PASS_TYPE_VISITOR_H_
-#define TVM_RELAY_PASS_TYPE_VISITOR_H_
+#ifndef TVM_RELAY_IR_TYPE_FUNCTOR_H_
+#define TVM_RELAY_IR_TYPE_FUNCTOR_H_
 
+#include <tvm/node/ir_functor.h>
+#include <tvm/relay/expr.h>
+#include <string>
 #include <vector>
-#include "./type_functor.h"
 
 namespace tvm {
 namespace relay {
 
-/*! \brief A type visitor for vistiors which make use of internal
+template <typename FType>
+class TypeFunctor;
+
+// functions to be overriden.
+#define TYPE_FUNCTOR_DEFAULT \
+  { return VisitTypeDefault_(op, std::forward<Args>(args)...); }
+
+
+#define RELAY_TYPE_FUNCTOR_DISPATCH(OP)                                   \
+  vtable.template set_dispatch<OP>(                                       \
+      [](const NodeRef& n, TSelf* self, Args... args) {                   \
+        return self->VisitType_(static_cast<const OP*>(n.node_.get()),    \
+                                std::forward<Args>(args)...);             \
+      });
+
+template <typename R, typename... Args>
+class TypeFunctor<R(const Type& n, Args...)> {
+ private:
+  using TSelf = TypeFunctor<R(const Type& n, Args...)>;
+  using FType = tvm::IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;
+
+ public:
+  /*! \brief the result type of this functor */
+  using result_type = R;
+  /*! \brief virtual destructor */
+  virtual ~TypeFunctor() {}
+  /*!
+   * \brief Same as call.
+   * \param n The expression node.
+   * \param args Additional arguments.
+   * \return The result of the call
+   */
+  R operator()(const Type& n, Args... args) {
+    return VisitType(n, std::forward<Args>(args)...);
+  }
+  /*!
+   * \brief The functor call.
+   * \param n The expression node.
+   * \param args Additional arguments.
+   * \return The result of the call
+   */
+  virtual R VisitType(const Type& n, Args... args) {
+    static FType vtable = InitVTable();
+    return vtable(n, this, std::forward<Args>(args)...);
+  }
+  // Functions that can be overriden by subclass
+  virtual R VisitType_(const TensorTypeNode* op,
+                       Args... args) TYPE_FUNCTOR_DEFAULT;
+  virtual R VisitType_(const TypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
+  virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
+  virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
+  virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
+  virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
+  virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
+
+  virtual R VisitTypeDefault_(const Node* op, Args...) {
+    LOG(FATAL) << "Do not have a default for " << op->type_key();
+    throw;  // unreachable, written to stop compiler warning
+  }
+
+ private:
+  // initialize the vtable.
+  static FType InitVTable() {
+    FType vtable;
+    // Set dispatch
+    RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode);
+    RELAY_TYPE_FUNCTOR_DISPATCH(TypeVarNode);
+    RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode);
+    RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode);
+    RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
+    RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
+    RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
+    return vtable;
+  }
+};
+
+/*!
+ * \brief A type visitor for vistiors which make use of internal
  * mutable state.
  *
  * We recursively visit each type contained inside the visitor.
@@ -118,7 +197,6 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> {
     return GetRef<Type>(op);
   }
 };
-
 }  // namespace relay
 }  // namespace tvm
-#endif  // TVM_RELAY_PASS_TYPE_VISITOR_H_
+#endif  // TVM_RELAY_IR_TYPE_FUNCTOR_H_
diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc
deleted file mode 100644
index 41ec3f1e0..000000000
--- a/src/relay/pass/alpha_eq.cc
+++ /dev/null
@@ -1,418 +0,0 @@
-/*!
- *  Copyright (c) 2018 by Contributors
- * \file src/tvm/relay/pass/alpha_eq.cc
- * \brief Check that two type are syntactically equal up to alpha equivalence.
- */
-#include <tvm/ir_pass.h>
-#include <tvm/relay/expr_functor.h>
-#include <tvm/runtime/ndarray.h>
-#include "./type_visitor.h"
-#include "tvm/relay/pass.h"
-
-namespace tvm {
-namespace relay {
-
-using namespace tvm::runtime;
-
-bool SameNDArray(const NDArray& lhs, const NDArray& rhs) {
-  if (lhs.defined() != rhs.defined()) {
-    return false;
-  } else if (lhs.same_as(rhs)) {
-    return true;
-  } else {
-    auto ldt = lhs->dtype;
-    auto rdt = rhs->dtype;
-    CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
-    CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
-    if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
-      size_t s = GetDataSize(*lhs.operator->());
-      return memcmp(lhs->data, rhs->data, s) == 0;
-    } else {
-      return false;
-    }
-  }
-}
-
-struct TypeAlphaEq : TypeVisitor<const Type&> {
-  tvm::Map<TypeVar, TypeVar> eq_map;
-  bool equal;
-
-  TypeAlphaEq() : eq_map(), equal(true) {}
-
-  void DataTypeEqual(const DataType& dt1, const DataType& dt2) {
-    if (dt1 != dt2) {
-      equal = false;
-    }
-  }
-
-  void ShapeEqual(const Array<IndexExpr>& s1, const Array<IndexExpr>& s2) {
-    if (s1.size() != s2.size()) {
-      equal = false;
-      return;
-    }
-    for (size_t i = 0; i < s1.size(); ++i) {
-      if (!tvm::ir::Equal(s1[i], s2[i])) {
-        equal = false;
-        return;
-      }
-    }
-  }
-
-  void VisitType_(const TensorTypeNode* tt1, const Type& t2) final {
-    if (const TensorTypeNode* tt2 = t2.as<TensorTypeNode>()) {
-      DataTypeEqual(tt1->dtype, tt2->dtype);
-      ShapeEqual(tt1->shape, tt2->shape);
-    } else {
-      equal = false;
-    }
-  }
-
-  void VisitType_(const IncompleteTypeNode* bt1, const Type& t2) final {
-    if (const IncompleteTypeNode* bt2 = t2.as<IncompleteTypeNode>()) {
-      equal = equal && bt1 == bt2;
-      return;
-    } else {
-      equal = false;
-    }
-  }
-
-  void VisitType_(const TypeVarNode* ti1, const Type& t2) final {
-    if (const TypeVarNode* ti2 = t2.as<TypeVarNode>()) {
-      auto tid1 = GetRef<TypeVar>(ti1);
-      auto tid2 = GetRef<TypeVar>(ti2);
-
-      // We handle open terms with this rule assuming variables are identical.
-      //
-      // Not sure if we should do this.
-      if (tid1 == tid2) {
-        return;
-      }
-
-      // Check that they are same kind
-      if (tid1->kind != tid2->kind) {
-        equal = false;
-        return;
-      }
-
-      // Next we see if there is mapping for local1 into the rhs term.
-      // If there is we check to see if those are equal.
-      if (eq_map.find(tid1) != eq_map.end()) {
-        equal = equal && eq_map[tid1] == tid2;
-      } else {
-        equal = false;
-      }
-    } else {
-      equal = false;
-    }
-  }
-
-  void VisitType_(const FuncTypeNode* op, const Type& t2) final {
-    if (const FuncTypeNode* ta2 = t2.as<FuncTypeNode>()) {
-      if (op->arg_types.size() != ta2->arg_types.size()
-          || op->type_params.size() != ta2->type_params.size()
-          || op->type_constraints.size() != ta2->type_constraints.size()) {
-        equal = false;
-        return;
-      }
-
-      // must visit params first so they are appropriate entered
-      // into equality map
-      for (size_t i = 0; i < op->type_params.size(); i++) {
-        eq_map.Set(op->type_params[i], ta2->type_params[i]);
-        this->VisitType(op->type_params[i], ta2->type_params[i]);
-        if (!equal) {
-          return;
-        }
-      }
-
-      for (size_t i = 0; i < op->arg_types.size(); i++) {
-        this->VisitType(op->arg_types[i], ta2->arg_types[i]);
-        if (!equal) {
-          return;
-        }
-      }
-
-      this->VisitType(op->ret_type, ta2->ret_type);
-      if (!equal) {
-        return;
-      }
-
-      for (size_t i = 0; i < op->type_constraints.size(); i++) {
-        this->VisitType(op->type_constraints[i], ta2->type_constraints[i]);
-        if (!equal) {
-          return;
-        }
-      }
-    } else {
-      equal = false;
-    }
-  }
-
-  void VisitType_(const TypeRelationNode* tr1, const Type& t2) final {
-    if (const TypeRelationNode* tr2 = t2.as<TypeRelationNode>()) {
-      if (tr1->func != tr2->func
-          || tr1->num_inputs != tr2->num_inputs
-          || tr1->attrs != tr2->attrs) {
-        equal = false;
-        return;
-      }
-
-      if (tr1->args.size() != tr2->args.size()) {
-        equal = false;
-        return;
-      }
-
-      for (size_t i = 0; i < tr1->args.size(); i++) {
-        this->VisitType(tr1->args[i], tr2->args[i]);
-        if (!equal) {
-          return;
-        }
-      }
-    } else {
-      equal = false;
-    }
-  }
-
-  void VisitType_(const TupleTypeNode* op, const Type& t2) final {
-    if (const TupleTypeNode* pt = t2.as<TupleTypeNode>()) {
-      if (op->fields.size() != pt->fields.size()) {
-        equal = false;
-        return;
-      }
-
-      for (size_t i = 0U; i < op->fields.size(); i++) {
-        if (!equal) {
-          return;
-        }
-        this->VisitType(op->fields[i], pt->fields[i]);
-      }
-    } else {
-      equal = false;
-    }
-  }
-};
-
-bool AlphaEqual(const Type& t1, const Type& t2) {
-  if (t1.defined() != t2.defined()) {
-    return false;
-  }
-
-  if (!t1.defined()) {
-    return true;
-  }
-
-  TypeAlphaEq aeq;
-  aeq.VisitType(t1, t2);
-  return aeq.equal;
-}
-
-struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
- public:
-  tvm::Map<Var, Var> eq_map;
-
-  bool equal;
-  AlphaEq() : eq_map(), equal(true) {}
-
-  void VisitExpr_(const VarNode* e1, const Expr& e2) final {
-    if (const VarNode* id2 = e2.as<VarNode>()) {
-      auto local1 = GetRef<Var>(e1);
-      auto local2 = GetRef<Var>(id2);
-      // We handle open terms with this rule assuming variables are identical.
-      if (local1 == local2) {
-        equal = true;
-        return;
-      }
-
-      // Next we see if there is mapping for local1 into the rhs term.
-      // If there is we check to see if those are equal.
-      if (eq_map.find(local1) != eq_map.end()) {
-        equal = equal && eq_map[local1] == local2;
-      } else {
-        equal = false;
-      }
-    } else {
-      equal = false;
-    }
-  }
-
-  void VisitExpr_(const GlobalVarNode* g1, const Expr& e2) final {
-    if (const GlobalVarNode* g2 = e2.as<GlobalVarNode>()) {
-      equal = equal && g1 == g2;
-    } else {
-      equal = false;
-    }
-  }
-
-  void VisitExpr_(const TupleNode* pl1, const Expr& e2) final {
-    Tuple prod1 = GetRef<Tuple>(pl1);
-    if (const TupleNode* pl2 = e2.as<TupleNode>()) {
-      Tuple prod2 = GetRef<Tuple>(pl2);
-      if (prod1->fields.size() != prod2->fields.size()) {
-        equal = false;
-        return;
-      }
-
-      for (size_t i = 0U; i < prod1->fields.size(); i++) {
-        this->VisitExpr(prod1->fields[i], prod2->fields[i]);
-      }
-    } else {
-      equal = false;
-    }
-  }
-
-  void VisitExpr_(const FunctionNode* func1, const Expr& e2) final {
-    if (const FunctionNode* func2 = e2.as<FunctionNode>()) {
-      if (func1->params.size() != func2->params.size()) {
-        equal = false;
-        return;
-      }
-
-      if (func1->type_params.size() != func2->type_params.size()) {
-        equal = false;
-        return;
-      }
-
-      for (size_t i = 0; i < func1->params.size(); ++i) {
-        MergeVarDecl(func1->params[i], func2->params[i]);
-      }
-
-      if (!equal) {
-        return;
-      }
-
-      for (size_t i = 0U; i < func1->type_params.size(); i++) {
-        equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]);
-        if (!equal) {
-          return;
-        }
-      }
-
-      equal = equal && AlphaEqual(func1->ret_type, func2->ret_type);
-      if (!equal) {
-        return;
-      }
-
-      this->VisitExpr(func1->body, func2->body);
-    } else {
-      equal = false;
-    }
-  }
-
-  void VisitExpr_(const CallNode* op, const Expr& e2) final {
-    if (const CallNode* call = e2.as<CallNode>()) {
-      this->VisitExpr(op->op, call->op);
-
-      if (op->args.size() != call->args.size()) {
-        equal = false;
-        return;
-      }
-
-      if (op->type_args.size() != call->type_args.size()) {
-        equal = false;
-        return;
-      }
-
-      // checking attrs by pointer equality for now
-      equal = equal && (op->attrs == call->attrs);
-      if (!equal) {
-        return;
-      }
-
-      for (size_t i = 0U; i < op->args.size(); i++) {
-        this->VisitExpr(op->args[i], call->args[i]);
-      }
-
-      for (size_t i = 0U; i < op->type_args.size(); i++) {
-        equal = equal && AlphaEqual(op->type_args[i], call->type_args[i]);
-        if (!equal) {
-          return;
-        }
-      }
-    } else {
-      equal = false;
-    }
-  }
-
-  void VisitExpr_(const LetNode* op, const Expr& e2) final {
-    if (const LetNode* let = e2.as<LetNode>()) {
-      MergeVarDecl(op->var, let->var);
-      this->VisitExpr(op->value, let->value);
-      this->VisitExpr(op->body, let->body);
-    } else {
-      equal = false;
-    }
-  }
-
-  void VisitExpr_(const IfNode* op, const Expr& e2) final {
-    if (const IfNode* i = e2.as<IfNode>()) {
-      VisitExpr(op->cond, i->cond);
-      VisitExpr(op->true_branch, i->true_branch);
-      VisitExpr(op->false_branch, i->false_branch);
-    } else {
-      equal = false;
-    }
-  }
-
-  void VisitExpr_(const OpNode* op, const Expr& e2) final {
-    if (const OpNode* o = e2.as<OpNode>()) {
-      equal = equal && op->name == o->name;
-    } else {
-      equal = false;
-    }
-  }
-
-  void VisitExpr_(const ConstantNode* op, const Expr& e2) final {
-    if (const ConstantNode* c = e2.as<ConstantNode>()) {
-      if (AlphaEqual(op->tensor_type(), c->tensor_type())) {
-        equal = equal && SameNDArray(op->data, c->data);
-      } else {
-        equal = false;
-      }
-    } else {
-      equal = false;
-    }
-  }
-
-  void VisitExpr_(const TupleGetItemNode* op, const Expr& e2) final {
-    if (const TupleGetItemNode* proj = e2.as<TupleGetItemNode>()) {
-      this->VisitExpr(op->tuple, proj->tuple);
-      equal = equal && (op->index == proj->index);
-    } else {
-      equal = false;
-    }
-  }
-
- private:
-  void MergeVarDecl(const Var& var1, const Var& var2) {
-    equal = equal && AlphaEqual(var1->type_annotation, var2->type_annotation);
-    if (!equal) {
-      return;
-    }
-
-    eq_map.Set(var1, var2);
-  }
-};
-
-bool AlphaEqual(const Expr& e1, const Expr& e2) {
-  AlphaEq eq;
-  eq.VisitExpr(e1, e2);
-  return eq.equal;
-}
-
-// TODO(@jroesch): move to correct namespace?
-TVM_REGISTER_API("relay._make._alpha_equal")
-    .set_body([](TVMArgs args, TVMRetValue* ret) {
-      Expr e1 = args[0];
-      Expr e2 = args[1];
-      *ret = AlphaEqual(e1, e2);
-    });
-
-TVM_REGISTER_API("relay._make._type_alpha_equal")
-    .set_body([](TVMArgs args, TVMRetValue* ret) {
-      Type t1 = args[0];
-      Type t2 = args[1];
-      *ret = AlphaEqual(t1, t2);
-    });
-
-}  // namespace relay
-}  // namespace tvm
diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc
index 8fd77a71e..c3d16c297 100644
--- a/src/relay/pass/kind_check.cc
+++ b/src/relay/pass/kind_check.cc
@@ -14,7 +14,7 @@
  * contains a data type such as `int`, `float`, `uint`.
  */
 #include <tvm/relay/pass.h>
-#include "./type_visitor.h"
+#include "../ir/type_functor.h"
 
 namespace tvm {
 namespace relay {
@@ -105,13 +105,13 @@ bool KindCheck(const Type& t, const Environment& env) {
 }
 
 TVM_REGISTER_API("relay._ir_pass.check_kind")
-    .set_body([](TVMArgs args, TVMRetValue* ret) {
-      if (args.size() == 1) {
-        *ret = KindCheck(args[0], EnvironmentNode::make({}));
-      } else {
-        *ret = KindCheck(args[0], args[1]);
-      }
-    });
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    if (args.size() == 1) {
+      *ret = KindCheck(args[0], EnvironmentNode::make({}));
+    } else {
+      *ret = KindCheck(args[0], args[1]);
+    }
+  });
 
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/pass/type_functor.h b/src/relay/pass/type_functor.h
deleted file mode 100644
index b8eaa85a7..000000000
--- a/src/relay/pass/type_functor.h
+++ /dev/null
@@ -1,94 +0,0 @@
-/*!
- *  Copyright (c) 2018 by Contributors
- * \file type_functor.h
- * \brief A way to defined arbitrary function signature with dispatch on types.
- */
-#ifndef TVM_RELAY_PASS_TYPE_FUNCTOR_H_
-#define TVM_RELAY_PASS_TYPE_FUNCTOR_H_
-
-#include <tvm/node/ir_functor.h>
-#include <tvm/relay/expr.h>
-#include <string>
-
-namespace tvm {
-namespace relay {
-
-template <typename FType>
-class TypeFunctor;
-
-// functions to be overriden.
-#define TYPE_FUNCTOR_DEFAULT \
-  { return VisitTypeDefault_(op, std::forward<Args>(args)...); }
-
-
-#define RELAY_TYPE_FUNCTOR_DISPATCH(OP)                                   \
-  vtable.template set_dispatch<OP>(                                       \
-      [](const NodeRef& n, TSelf* self, Args... args) {                   \
-        return self->VisitType_(static_cast<const OP*>(n.node_.get()),    \
-                                std::forward<Args>(args)...);             \
-      });
-
-template <typename R, typename... Args>
-class TypeFunctor<R(const Type& n, Args...)> {
- private:
-  using TSelf = TypeFunctor<R(const Type& n, Args...)>;
-  using FType = tvm::IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;
-
- public:
-  /*! \brief the result type of this functor */
-  using result_type = R;
-  /*! \brief virtual destructor */
-  virtual ~TypeFunctor() {}
-  /*!
-   * \brief Same as call.
-   * \param n The expression node.
-   * \param args Additional arguments.
-   * \return The result of the call
-   */
-  R operator()(const Type& n, Args... args) {
-    return VisitType(n, std::forward<Args>(args)...);
-  }
-  /*!
-   * \brief The functor call.
-   * \param n The expression node.
-   * \param args Additional arguments.
-   * \return The result of the call
-   */
-  virtual R VisitType(const Type& n, Args... args) {
-    static FType vtable = InitVTable();
-    return vtable(n, this, std::forward<Args>(args)...);
-  }
-  // Functions that can be overriden by subclass
-  virtual R VisitType_(const TensorTypeNode* op,
-                       Args... args) TYPE_FUNCTOR_DEFAULT;
-  virtual R VisitType_(const TypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
-  virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
-  virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
-  virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
-  virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
-  virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
-
-  virtual R VisitTypeDefault_(const Node* op, Args...) {
-    LOG(FATAL) << "Do not have a default for " << op->type_key();
-    throw;  // unreachable, written to stop compiler warning
-  }
-
- private:
-  // initialize the vtable.
-  static FType InitVTable() {
-    FType vtable;
-    // Set dispatch
-    RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode);
-    RELAY_TYPE_FUNCTOR_DISPATCH(TypeVarNode);
-    RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode);
-    RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode);
-    RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
-    RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
-    RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
-    return vtable;
-  }
-};
-
-}  // namespace relay
-}  // namespace tvm
-#endif  // TVM_RELAY_PASS_TYPE_FUNCTOR_H_
diff --git a/src/relay/pass/type_subst.cc b/src/relay/pass/type_subst.cc
index bffd779d1..76507058f 100644
--- a/src/relay/pass/type_subst.cc
+++ b/src/relay/pass/type_subst.cc
@@ -4,7 +4,7 @@
  * \brief Function for substituting a concrete type in place of a type ID
  */
 #include "./type_subst.h"
-#include "./type_visitor.h"
+#include "../ir/type_functor.h"
 
 namespace tvm {
 namespace relay {
diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc
index ff4bb55b7..d69f1bce7 100644
--- a/src/relay/pass/util.cc
+++ b/src/relay/pass/util.cc
@@ -7,7 +7,7 @@
  */
 #include <tvm/relay/pass.h>
 #include <tvm/relay/expr_functor.h>
-#include "./type_visitor.h"
+#include "../ir/type_functor.h"
 
 namespace tvm {
 namespace relay {
diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py
index de4df7c84..d16c2df53 100644
--- a/tests/python/relay/test_pass_alpha_equal.py
+++ b/tests/python/relay/test_pass_alpha_equal.py
@@ -139,7 +139,8 @@ def test_type_relation_alpha_equal():
 
     # attrs are also compared only by pointer equality
     attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
-    attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
+    attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
+    attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
 
     tr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
     same = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
@@ -147,6 +148,7 @@ def test_type_relation_alpha_equal():
     diff_order = relay.TypeRelation(broadcast, tvm.convert([t2, t1]), 1, attr1)
     diff_args = relay.TypeRelation(broadcast, tvm.convert([t2, t3]), 1, attr1)
     diff_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr2)
+    same_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1_same)
 
     bigger = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 2, attr1)
     diff_num_inputs = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 1, attr2)
@@ -157,6 +159,7 @@ def test_type_relation_alpha_equal():
     assert tr != diff_order
     assert tr != diff_args
     assert tr != diff_attr
+    assert tr == same_attr
     assert tr != bigger
 
     assert bigger != diff_num_inputs
@@ -216,22 +219,26 @@ def test_global_var_alpha_equal():
 
 
 def test_tuple_alpha_equal():
+    v0 = relay.Var("v0")
     v1 = relay.Var("v1")
     v2 = relay.Var("v2")
 
     # unit value is a valid tuple
     assert alpha_equal(relay.Tuple([]), relay.Tuple([]))
 
-    tup = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
-    same = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
+    tup = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
+    same = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
 
     assert alpha_equal(tup, same)
 
     # use the eq_map
+
+
     let_tup = relay.Let(v1, tup, v1)
-    let_mapped = relay.Let(v2, relay.Tuple([v2, relay.const(2), relay.const(3),
+    let_mapped = relay.Let(v2, relay.Tuple([v0, relay.const(2), relay.const(3),
                                             relay.Tuple([relay.const(4)])]),
                            v2)
+
     assert alpha_equal(let_tup, let_mapped)
 
     more_fields = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)]), v2])
@@ -340,7 +347,8 @@ def test_call_alpha_equal():
 
     # attrs are compared only by pointer equality
     attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
-    attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
+    attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
+    attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
 
     tt1 = relay.TensorType((1, 2, 3), "float32")
     tt2 = relay.TensorType((), "int8")
@@ -375,6 +383,9 @@ def test_call_alpha_equal():
     different_attrs = relay.Call(v1, basic_args, attr2, [tt1])
     assert not alpha_equal(call, different_attrs)
 
+    same_attrs = relay.Call(v1, basic_args, attr1_same, [tt1])
+    assert alpha_equal(call, same_attrs)
+
     no_type_args = relay.Call(v1, basic_args, attr1)
     assert not alpha_equal(call, no_type_args)
 
@@ -445,6 +456,27 @@ def test_op_alpha_equal():
     assert not alpha_equal(op1, op3)
 
 
+def test_graph_equal():
+    x = relay.var("x")
+
+    y0 = relay.add(x, x)
+    z0 = relay.add(y0, y0)
+
+    y1 = relay.add(x, x)
+    z1 = relay.add(y1, y1)
+
+    z3 = relay.add(relay.add(x, x), relay.add(x, x))
+
+    assert alpha_equal(z0, z1)
+
+    # z3's dataflow format is different from z0
+    # z0 is computed from a common y0 node
+    # Relay view them as different programs
+    # Check the difference in the text format.
+    assert not alpha_equal(z0, z3)
+
+
+
 if __name__ == "__main__":
     test_tensor_type_alpha_equal()
     test_incomplete_type_alpha_equal()
@@ -462,3 +494,4 @@ if __name__ == "__main__":
     test_if_alpha_equal()
     test_op_alpha_equal()
     test_var_alpha_equal()
+    test_graph_equal()
diff --git a/tests/python/unittest/test_pass_attrs_hash_equal.py b/tests/python/unittest/test_pass_attrs_hash_equal.py
index 23f0e6374..2d6987aeb 100644
--- a/tests/python/unittest/test_pass_attrs_hash_equal.py
+++ b/tests/python/unittest/test_pass_attrs_hash_equal.py
@@ -17,6 +17,12 @@ def test_attrs_equal():
     assert tvm.ir_pass.AttrsEqual({"x": [x, x]}, {"x": [y, x]})
     assert not tvm.ir_pass.AttrsEqual({"x": [x, 1]}, {"x": [y, 2]})
 
+    n = tvm.var("n")
+    assert tvm.ir_pass.AttrsEqual({"x": n+1}, {"x": n+1})
+
+
+
+
 
 def test_attrs_hash():
     fhash = tvm.ir_pass.AttrsHash
-- 
GitLab