From 81db03abb91e510de70f91d6a0f69d1355f0f247 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Tue, 23 Oct 2018 10:54:20 -0700 Subject: [PATCH] [RELAY] Refactor AlphaEqual to support deep comparison of Attrs. (#1958) --- include/tvm/attrs.h | 197 +++++---- src/api/api_pass.cc | 8 +- src/lang/attr_functor.h | 124 +++++- src/lang/attrs.cc | 316 ++++++++----- src/relay/ir/alpha_equal.cc | 384 ++++++++++++++++ src/relay/ir/text_printer.cc | 7 +- .../type_visitor.h => ir/type_functor.h} | 94 +++- src/relay/pass/alpha_eq.cc | 418 ------------------ src/relay/pass/kind_check.cc | 16 +- src/relay/pass/type_functor.h | 94 ---- src/relay/pass/type_subst.cc | 2 +- src/relay/pass/util.cc | 2 +- tests/python/relay/test_pass_alpha_equal.py | 43 +- .../unittest/test_pass_attrs_hash_equal.py | 6 + 14 files changed, 980 insertions(+), 731 deletions(-) create mode 100644 src/relay/ir/alpha_equal.cc rename src/relay/{pass/type_visitor.h => ir/type_functor.h} (52%) delete mode 100644 src/relay/pass/alpha_eq.cc delete mode 100644 src/relay/pass/type_functor.h diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index 3b7beaa37..33d84cece 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -108,6 +108,90 @@ class AttrFieldInfoNode : public Node { /*! \brief AttrFieldInfo */ TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode); +class AttrsHashHandler; +class AttrsEqualHandler; +/*! + * \brief Content-aware Equality comparator for attrs. + * + * This comparator will recursively deep compare the following Attributes. + * + * - IntImm, UIntImm, FloatImm, StringImm + * - Any subclass of BaseAttrsNode + * - Array of Attributes. + * - Map from string to Attributes. + */ +class AttrsEqual { + public: + bool operator()(const double& lhs, const double& rhs) const { + return lhs == rhs; + } + bool operator()(const int64_t& lhs, const int64_t& rhs) const { + return lhs == rhs; + } + bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { + return lhs == rhs; + } + bool operator()(const int& lhs, const int& rhs) const { + return lhs == rhs; + } + bool operator()(const bool& lhs, const bool& rhs) const { + return lhs == rhs; + } + bool operator()(const std::string& lhs, const std::string& rhs) const { + return lhs == rhs; + } + bool operator()(const Type& lhs, const Type& rhs) const { + return lhs == rhs; + } + // node comparator + TVM_DLL bool operator()(const NodeRef& lhs, const NodeRef& rhs) const; + + protected: + friend class AttrsEqualHandler; + /*! \brief internal handle. */ + AttrsEqualHandler* handler_{nullptr}; +}; + +/*! + * \brief Content-aware hash function. + * + * This hash functor will recursively hash the content of the Attributes. + * It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b); + */ +class AttrsHash { + public: + size_t operator()(const double& value) const { + return std::hash<double>()(value); + } + size_t operator()(const int64_t& value) const { + return std::hash<int64_t>()(value); + } + size_t operator()(const uint64_t& value) const { + return std::hash<uint64_t>()(value); + } + size_t operator()(const int& value) const { + return std::hash<int>()(value); + } + size_t operator()(const bool& value) const { + return std::hash<bool>()(value); + } + size_t operator()(const std::string& value) const { + return std::hash<std::string>()(value); + } + size_t operator()(const Type& value) const { + return std::hash<int>()( + static_cast<int>(value.code()) | + (static_cast<int>(value.bits()) << 8) | + (static_cast<int>(value.lanes()) << 16)); + } + TVM_DLL size_t operator()(const NodeRef& value) const; + + private: + friend class AttrsHashHandler; + /*! \brief internal handle. */ + AttrsHashHandler* handler_{nullptr}; +}; + /*! * \brief Base class of all attribute class * \note Do not subclass AttrBaseNode directly, @@ -153,14 +237,17 @@ class BaseAttrsNode : public Node { /*! * \brief Whether this attribute's content equals to another node. * \param other The pointer to another node. + * \param equal The equal comparator * \return The comparison result. */ - TVM_DLL virtual bool ContentEqual(const Node* other) const = 0; + TVM_DLL virtual bool ContentEqual( + const Node* other, AttrsEqual equal) const = 0; /*! * \brief Content aware hash. + * \param hasher The hasher to run the hash. * \return the hash result. */ - TVM_DLL virtual size_t ContentHash() const = 0; + TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0; static constexpr const char* _type_key = "Attrs"; TVM_DECLARE_BASE_NODE_INFO(BaseAttrsNode, Node); @@ -209,92 +296,13 @@ class DictAttrsNode : public BaseAttrsNode { void VisitNonDefaultAttrs(AttrVisitor* v) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; Array<AttrFieldInfo> ListFieldInfo() const final; - bool ContentEqual(const Node* other) const final; - size_t ContentHash() const final; + bool ContentEqual(const Node* other, AttrsEqual equal) const final; + size_t ContentHash(AttrsHash hasher) const final; // type info static constexpr const char* _type_key = "DictAttrs"; TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode); }; -/*! - * \brief Content-aware Equality comparator for attrs. - * - * This comparator will recursively deep compare the following Attributes. - * - * - IntImm, UIntImm, FloatImm, StringImm - * - Any subclass of BaseAttrsNode - * - Array of Attributes. - * - Map from string to Attributes. - */ -class AttrsEqual { - public: - bool operator()(const double& lhs, const double& rhs) const { - return lhs == rhs; - } - bool operator()(const int64_t& lhs, const int64_t& rhs) const { - return lhs == rhs; - } - bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { - return lhs == rhs; - } - bool operator()(const int& lhs, const int& rhs) const { - return lhs == rhs; - } - bool operator()(const bool& lhs, const bool& rhs) const { - return lhs == rhs; - } - bool operator()(const std::string& lhs, const std::string& rhs) const { - return lhs == rhs; - } - bool operator()(const Type& lhs, const Type& rhs) const { - return lhs == rhs; - } - bool operator()(const NodeRef& lhs, const NodeRef& rhs) const { - return AttrsEqual::Equal(lhs, rhs); - } - - // comparator of NodeRef types. - static TVM_DLL bool Equal(const NodeRef& lhs, const NodeRef& rhs); -}; - -/*! - * \brief Content-aware hash function. - * - * This hash functor will recursively hash the content of the Attributes. - * It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b); - */ -class AttrsHash { - public: - size_t operator()(const double& value) const { - return std::hash<double>()(value); - } - size_t operator()(const int64_t& value) const { - return std::hash<int64_t>()(value); - } - size_t operator()(const uint64_t& value) const { - return std::hash<uint64_t>()(value); - } - size_t operator()(const int& value) const { - return std::hash<int>()(value); - } - size_t operator()(const bool& value) const { - return std::hash<bool>()(value); - } - size_t operator()(const std::string& value) const { - return std::hash<std::string>()(value); - } - size_t operator()(const Type& value) const { - return std::hash<int>()( - static_cast<int>(value.code()) | - (static_cast<int>(value.bits()) << 8) | - (static_cast<int>(value.lanes()) << 16)); - } - size_t operator()(const NodeRef& value) const { - return AttrsHash::Hash(value); - } - // hash function of the attribute and attribute fields. - static TVM_DLL size_t Hash(const NodeRef& lhs); -}; // Namespace containing detail implementations namespace detail { @@ -342,8 +350,8 @@ class AttrsEqualVisitor { public: bool result_{true}; // constructor - AttrsEqualVisitor(const Node* lhs, const Node* rhs) - : lhs_(lhs), rhs_(rhs) { + AttrsEqualVisitor(const Node* lhs, const Node* rhs, const AttrsEqual& equal) + : lhs_(lhs), rhs_(rhs), equal_(equal) { } template<typename T> AttrNopEntry operator()(const char* key, T* lhs_value) { @@ -353,7 +361,7 @@ class AttrsEqualVisitor { reinterpret_cast<const char*>(rhs_) + (reinterpret_cast<const char*>(lhs_value) - reinterpret_cast<const char*>(lhs_))); - if (!AttrsEqual()(*lhs_value, *rhs_value)) { + if (!equal_(*lhs_value, *rhs_value)) { result_ = false; } return AttrNopEntry(); @@ -362,17 +370,24 @@ class AttrsEqualVisitor { private: const Node* lhs_; const Node* rhs_; + const AttrsEqual& equal_; }; class AttrsHashVisitor { public: + explicit AttrsHashVisitor(const AttrsHash& hasher) + : hasher_(hasher) {} + size_t result_{0}; template<typename T> AttrNopEntry operator()(const char* key, T* value) { - result_ = dmlc::HashCombine(result_, AttrsHash()(*value)); + result_ = dmlc::HashCombine(result_, hasher_(*value)); return AttrNopEntry(); } + + private: + const AttrsHash& hasher_; }; // helper entry that does initialization, set default. @@ -793,18 +808,18 @@ class AttrsNode : public BaseAttrsNode { return visitor.fields_; } - bool ContentEqual(const Node* other) const final { + bool ContentEqual(const Node* other, AttrsEqual equal) const final { DerivedType* pself = self(); if (pself == other) return true; if (other == nullptr) return false; if (pself->type_index() != other->type_index()) return false; - detail::AttrsEqualVisitor visitor(pself, other); + detail::AttrsEqualVisitor visitor(pself, other, equal); self()->__VisitAttrs__(visitor); return visitor.result_; } - size_t ContentHash() const final { - detail::AttrsHashVisitor visitor; + size_t ContentHash(AttrsHash hasher) const final { + detail::AttrsHashVisitor visitor(hasher); visitor.result_ = std::hash<std::string>()(this->type_key()); self()->__VisitAttrs__(visitor); return visitor.result_; diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 66e4529ac..1e571ca0d 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -68,10 +68,14 @@ TVM_REGISTER_API("ir_pass.Equal") TVM_REGISTER_API("ir_pass.AttrsEqual") -.set_body_typed<bool(const NodeRef&, const NodeRef&)>(AttrsEqual::Equal); +.set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) { + return AttrsEqual()(lhs, rhs); + }); TVM_REGISTER_API("ir_pass.AttrsHash") -.set_body_typed<int64_t(const NodeRef&)>(AttrsHash::Hash); +.set_body_typed<int64_t(const NodeRef&)>([](const NodeRef &node) { + return AttrsHash()(node); + }); TVM_REGISTER_API("ir_pass.ExprUseVar") diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index 8aa39a774..ef1d06101 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -52,13 +52,33 @@ class AttrFunctor<R(const NodeRef& n, Args...)> { return VisitAttrDefault_(n.get(), std::forward<Args>(args)...); } } + virtual R VisitAttrDefault_(const Node* node, Args... args) = 0; virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::UIntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::FloatImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttrDefault_(const Node* node, Args... args) = 0; + // deep comparison of symbolic integer expressions. + virtual R VisitAttr_(const Variable* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::Add* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::Sub* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::GT* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::LT* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::LE* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::EQ* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::NE* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::And* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::Or* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::Not* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::Cast* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::Call* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::Select* op, Args... args) ATTR_FUNCTOR_DEFAULT; private: // initialize the vtable. @@ -72,9 +92,111 @@ class AttrFunctor<R(const NodeRef& n, Args...)> { ATTR_FUNCTOR_DISPATCH(UIntImm); ATTR_FUNCTOR_DISPATCH(FloatImm); ATTR_FUNCTOR_DISPATCH(StringImm); + ATTR_FUNCTOR_DISPATCH(Variable); + ATTR_FUNCTOR_DISPATCH(Add); + ATTR_FUNCTOR_DISPATCH(Sub); + ATTR_FUNCTOR_DISPATCH(Mul); + ATTR_FUNCTOR_DISPATCH(Min); + ATTR_FUNCTOR_DISPATCH(Max); + ATTR_FUNCTOR_DISPATCH(GE); + ATTR_FUNCTOR_DISPATCH(GT); + ATTR_FUNCTOR_DISPATCH(LE); + ATTR_FUNCTOR_DISPATCH(LT); + ATTR_FUNCTOR_DISPATCH(EQ); + ATTR_FUNCTOR_DISPATCH(NE); + ATTR_FUNCTOR_DISPATCH(And); + ATTR_FUNCTOR_DISPATCH(Or); + ATTR_FUNCTOR_DISPATCH(Not); + ATTR_FUNCTOR_DISPATCH(Cast); + ATTR_FUNCTOR_DISPATCH(Call); + ATTR_FUNCTOR_DISPATCH(Select); return vtable; } }; +class AttrsEqualHandler : + protected AttrFunctor<bool(const NodeRef&, const NodeRef&)> { + public: + /*! + * \brief Check if lhs equals rhs + * \param lhs The left operand. + * \param rhs The right operand. + */ + bool Equal(const NodeRef& lhs, const NodeRef& rhs); + + protected: + bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final; + bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final; + bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::IntImm* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::UIntImm* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::FloatImm* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::StringImm* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::Add* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::Sub* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::GT* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::LT* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::LE* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::EQ* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::NE* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::And* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::Or* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::Not* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::Cast* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::Call* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::Select* lhs, const NodeRef& other) final; +}; + +class AttrsHashHandler : + protected AttrFunctor<size_t(const NodeRef&)> { + public: + /*! + * \brief Get hash value of node + * \param node The node to be hashed. + */ + size_t Hash(const NodeRef& node) { + return this->VisitAttr(node); + } + + protected: + size_t VisitAttrDefault_(const Node* lhs) final; + size_t VisitAttr_(const ir::IntImm* lhs) final; + size_t VisitAttr_(const ir::UIntImm* lhs) final; + size_t VisitAttr_(const ir::FloatImm* lhs) final; + size_t VisitAttr_(const ir::StringImm* lhs) final; + size_t VisitAttr_(const ArrayNode* lhs) final; + size_t VisitAttr_(const StrMapNode* lhs) final; + size_t VisitAttr_(const ir::Add* op) final; + size_t VisitAttr_(const ir::Sub* op) final; + size_t VisitAttr_(const ir::Mul* op) final; + size_t VisitAttr_(const ir::Mod* op) final; + size_t VisitAttr_(const ir::Min* op) final; + size_t VisitAttr_(const ir::Max* op) final; + size_t VisitAttr_(const ir::GE* op) final; + size_t VisitAttr_(const ir::GT* op) final; + size_t VisitAttr_(const ir::LE* op) final; + size_t VisitAttr_(const ir::LT* op) final; + size_t VisitAttr_(const ir::EQ* op) final; + size_t VisitAttr_(const ir::NE* op) final; + size_t VisitAttr_(const ir::And* op) final; + size_t VisitAttr_(const ir::Or* op) final; + size_t VisitAttr_(const ir::Not* op) final; + size_t VisitAttr_(const ir::Cast* op) final; + size_t VisitAttr_(const ir::Call* op) final; + size_t VisitAttr_(const ir::Select* op) final; + /*! + * \brief alias of dmlc::HashCombine + * \param lhs The first hash value. + * \param rhs The second hash value. + */ + static size_t Combine(size_t lhs, size_t rhs) { + return dmlc::HashCombine(lhs, rhs); + } +}; } // namespace tvm #endif // TVM_LANG_ATTR_FUNCTOR_H_ diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index e467018ad..9aa067c09 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -51,156 +51,272 @@ TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); using namespace ir; +// Equal handler. +bool AttrsEqualHandler::Equal(const NodeRef& lhs, const NodeRef& rhs) { + if (lhs.same_as(rhs)) return true; + if (!lhs.defined() || !rhs.defined()) return false; + return this->VisitAttr(lhs, rhs); +} -class AttrsEqualChecker : - public AttrFunctor<bool(const NodeRef&, const NodeRef&)> { - public: - bool Check(const NodeRef& lhs, const NodeRef& rhs) { - if (!equal_) return false; - if (lhs.same_as(rhs)) return true; - if (!lhs.defined() || !rhs.defined()) return false; - if (!this->VisitAttr(lhs, rhs)) { - equal_ = false; - } - return equal_; +bool AttrsEqualHandler::VisitAttrDefault_(const Node* lhs, const NodeRef& other) { + if (lhs->derived_from<BaseAttrsNode>()) { + AttrsEqual equal; + equal.handler_ = this; + return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual( + other.get(), equal); } + return lhs == other.get(); +} - bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final { - if (lhs->derived_from<BaseAttrsNode>()) { - return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(other.get()); - } - return lhs == other.get(); +bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const NodeRef& other) { + if (const auto* rhs = other.as<IntImm>()) { + return lhs->value == rhs->value; } + return false; +} - bool VisitAttr_(const IntImm* lhs, const NodeRef& other) final { - if (const auto* rhs = other.as<IntImm>()) { - return lhs->value == rhs->value; - } - return false; +bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const NodeRef& other) { + if (const auto* rhs = other.as<UIntImm>()) { + return lhs->value == rhs->value; } + return false; +} - bool VisitAttr_(const UIntImm* lhs, const NodeRef& other) final { - if (const auto* rhs = other.as<UIntImm>()) { - return lhs->value == rhs->value; - } - return false; +bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const NodeRef& other) { + if (const auto* rhs = other.as<FloatImm>()) { + return lhs->value == rhs->value; } + return false; +} - bool VisitAttr_(const FloatImm* lhs, const NodeRef& other) final { - if (const auto* rhs = other.as<FloatImm>()) { - return lhs->value == rhs->value; - } - return false; +bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const NodeRef& other) { + if (const auto* rhs = other.as<StringImm>()) { + return lhs->value == rhs->value; } + return false; +} - bool VisitAttr_(const StringImm* lhs, const NodeRef& other) final { - if (const auto* rhs = other.as<StringImm>()) { - return lhs->value == rhs->value; +bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const NodeRef& other) { + if (const auto* rhs = other.as<ArrayNode>()) { + if (rhs->data.size() != lhs->data.size()) return false; + for (size_t i = 0; i < lhs->data.size(); ++i) { + if (!Equal(NodeRef(lhs->data[i]), NodeRef(rhs->data[i]))) return false; } - return false; } + return true; +} - bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final { - if (const auto* rhs = other.as<ArrayNode>()) { - if (rhs->data.size() != lhs->data.size()) return false; - for (size_t i = 0; i < lhs->data.size(); ++i) { - if (!Check(NodeRef(lhs->data[i]), NodeRef(rhs->data[i]))) return false; - } +bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other) { + if (const auto* rhs = other.as<StrMapNode>()) { + if (rhs->data.size() != lhs->data.size()) return false; + for (const auto& kv : lhs->data) { + auto it = rhs->data.find(kv.first); + if (it == rhs->data.end()) return false; + if (!Equal(NodeRef(kv.second), NodeRef(it->second))) return false; } - return true; } + return true; +} - bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final { - if (const auto* rhs = other.as<StrMapNode>()) { - if (rhs->data.size() != lhs->data.size()) return false; - for (const auto& kv : lhs->data) { - auto it = rhs->data.find(kv.first); - if (it == rhs->data.end()) return false; - if (!Check(NodeRef(kv.second), NodeRef(it->second))) return false; - } - } - return true; +#define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName) \ + bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const NodeRef& other) { \ + if (const auto* rhs = other.as<NodeName>()) { \ + if (!Equal(lhs->a, rhs->a)) return false; \ + if (!Equal(lhs->b, rhs->b)) return false; \ + return true; \ + } else { \ + return false; \ + } \ + } \ + +TVM_DEFINE_ATTRS_BINOP_EQUAL(Add); +TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub); +TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul); +TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod); +TVM_DEFINE_ATTRS_BINOP_EQUAL(Max); +TVM_DEFINE_ATTRS_BINOP_EQUAL(Min); +TVM_DEFINE_ATTRS_BINOP_EQUAL(GE); +TVM_DEFINE_ATTRS_BINOP_EQUAL(GT); +TVM_DEFINE_ATTRS_BINOP_EQUAL(LE); +TVM_DEFINE_ATTRS_BINOP_EQUAL(LT); +TVM_DEFINE_ATTRS_BINOP_EQUAL(EQ); +TVM_DEFINE_ATTRS_BINOP_EQUAL(NE); +TVM_DEFINE_ATTRS_BINOP_EQUAL(And); +TVM_DEFINE_ATTRS_BINOP_EQUAL(Or); + +bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const NodeRef& other) { + if (const auto* rhs = other.as<Not>()) { + return Equal(lhs->a, rhs->a); + } else { + return false; } +} - private: - bool equal_{true}; -}; - -class AttrContentHasher : - public AttrFunctor<void(const NodeRef&)> { - public: - size_t result_{0}; - - void VisitAttrDefault_(const Node* value) final { - if (value->derived_from<BaseAttrsNode>()) { - Update(static_cast<const BaseAttrsNode*>(value)->ContentHash()); - } else { - Update(NodeHash()(GetRef<NodeRef>(value))); - } +bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const NodeRef& other) { + if (const auto* rhs = other.as<Cast>()) { + if (lhs->type != rhs->type) return false; + return Equal(lhs->value, rhs->value); + } else { + return false; } +} - void VisitAttr_(const IntImm* op) final { - Update(std::hash<int64_t>()(op->value)); +bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const NodeRef& other) { + if (const auto* rhs = other.as<Call>()) { + return + lhs->name == rhs->name && + lhs->type == rhs->type && + lhs->call_type == rhs->call_type && + Equal(lhs->args, rhs->args); + } else { + return false; } +} - void VisitAttr_(const UIntImm* op) final { - Update(std::hash<uint64_t>()(op->value)); +bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const NodeRef& other) { + if (const auto* rhs = other.as<Select>()) { + return + Equal(lhs->condition, rhs->condition) && + Equal(lhs->true_value, rhs->true_value) && + Equal(lhs->false_value, rhs->false_value); + } else { + return false; } +} - void VisitAttr_(const FloatImm* op) final { - Update(std::hash<double>()(op->value)); +// Hash Handler. +size_t AttrsHashHandler::VisitAttrDefault_(const Node* value) { + if (value->derived_from<BaseAttrsNode>()) { + AttrsHash hasher; + hasher.handler_ = this; + return static_cast<const BaseAttrsNode*>(value)->ContentHash(hasher); + } else { + return NodeHash()(GetRef<NodeRef>(value)); } +} - void VisitAttr_(const StringImm* op) final { - Update(std::hash<std::string>()(op->value)); - } +size_t AttrsHashHandler::VisitAttr_(const IntImm* op) { + return std::hash<int64_t>()(op->value); +} - void VisitAttr_(const ArrayNode* op) final { - Update(op->data.size()); - for (size_t i = 0; i < op->data.size(); ++i) { - this->VisitAttr(NodeRef(op->data[i])); - } +size_t AttrsHashHandler::VisitAttr_(const UIntImm* op) { + return std::hash<uint64_t>()(op->value); +} + +size_t AttrsHashHandler::VisitAttr_(const FloatImm* op) { + return std::hash<double>()(op->value); +} + +size_t AttrsHashHandler::VisitAttr_(const StringImm* op) { + return std::hash<std::string>()(op->value); +} + +size_t AttrsHashHandler::VisitAttr_(const ArrayNode* op) { + size_t result = op->data.size(); + for (size_t i = 0; i < op->data.size(); ++i) { + result = Combine(result, this->Hash(NodeRef(op->data[i]))); } + return result; +} - void VisitAttr_(const StrMapNode* lhs) final { +size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) { using Entry = std::pair<std::string, NodePtr<Node> >; std::vector<Entry> data(lhs->data.begin(), lhs->data.end()); std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) { return a.first < b.first; }); + size_t result = 0; for (const Entry& kv : data) { - Update(std::hash<std::string>()(kv.first)); - this->VisitAttr(NodeRef(kv.second)); + result = Combine(result, std::hash<std::string>()(kv.first)); + result = Combine(result, this->Hash(NodeRef(kv.second))); } - } + return result; +} - void Update(size_t value) { - result_ = dmlc::HashCombine(result_, value); - } -}; -bool AttrsEqual::Equal(const NodeRef& lhs, const NodeRef& rhs) { +#define TVM_DEFINE_ATTRS_BINOP_HASH(NodeName) \ + size_t AttrsHashHandler::VisitAttr_(const NodeName* op) { \ + static size_t key = std::hash<std::string>()(NodeName::_type_key); \ + return Combine(key, Combine(Hash(op->a), Hash(op->b))); \ + } \ + +TVM_DEFINE_ATTRS_BINOP_HASH(Add); +TVM_DEFINE_ATTRS_BINOP_HASH(Sub); +TVM_DEFINE_ATTRS_BINOP_HASH(Mul); +TVM_DEFINE_ATTRS_BINOP_HASH(Mod); +TVM_DEFINE_ATTRS_BINOP_HASH(Max); +TVM_DEFINE_ATTRS_BINOP_HASH(Min); +TVM_DEFINE_ATTRS_BINOP_HASH(GE); +TVM_DEFINE_ATTRS_BINOP_HASH(GT); +TVM_DEFINE_ATTRS_BINOP_HASH(LE); +TVM_DEFINE_ATTRS_BINOP_HASH(LT); +TVM_DEFINE_ATTRS_BINOP_HASH(EQ); +TVM_DEFINE_ATTRS_BINOP_HASH(NE); +TVM_DEFINE_ATTRS_BINOP_HASH(And); +TVM_DEFINE_ATTRS_BINOP_HASH(Or); + +size_t AttrsHashHandler::VisitAttr_(const Not* op) { + static size_t key = std::hash<std::string>()(Not::_type_key); + return Combine(key, Hash(op->a)); +} + +size_t AttrsHashHandler::VisitAttr_(const Cast* op) { + static size_t key = std::hash<std::string>()(Cast::_type_key); + AttrsHash hasher; + size_t res = key; + res = Combine(res, hasher(op->type)); + res = Combine(res, Hash(op->value)); + return res; +} + +size_t AttrsHashHandler::VisitAttr_(const Call* op) { + static size_t key = std::hash<std::string>()(Call::_type_key); + AttrsHash hasher; + size_t res = key; + res = Combine(res, hasher(op->name)); + res = Combine(res, hasher(op->type)); + res = Combine(res, Hash(op->args)); + return res; +} + +size_t AttrsHashHandler::VisitAttr_(const Select* op) { + static size_t key = std::hash<std::string>()(Select::_type_key); + size_t res = key; + res = Combine(res, Hash(op->condition)); + res = Combine(res, Hash(op->true_value)); + res = Combine(res, Hash(op->false_value)); + return res; +} + + +// Default case +bool AttrsEqual::operator()(const NodeRef& lhs, const NodeRef& rhs) const { if (lhs.same_as(rhs)) return true; - AttrsEqualChecker checker; - return checker.Check(lhs, rhs); + if (handler_ == nullptr) { + return AttrsEqualHandler().Equal(lhs, rhs); + } else { + return handler_->Equal(lhs, rhs); + } } -size_t AttrsHash::Hash(const NodeRef& node) { +size_t AttrsHash::operator()(const NodeRef& node) const { if (!node.defined()) return 0; - AttrContentHasher hasher; - hasher.VisitAttr(node); - return hasher.result_; + if (handler_ == nullptr) { + return AttrsHashHandler().Hash(node); + } else { + return handler_->Hash(node); + } } -size_t DictAttrsNode::ContentHash() const { - return AttrsHash()(this->dict); +size_t DictAttrsNode::ContentHash(AttrsHash hasher) const { + return hasher(this->dict); } -bool DictAttrsNode::ContentEqual(const Node* other) const { +bool DictAttrsNode::ContentEqual(const Node* other, AttrsEqual equal) const { if (this == other) return true; if (other == nullptr) return false; if (this->type_index() != other->type_index()) return false; - return AttrsEqual()(this->dict, static_cast<const DictAttrsNode*>(other)->dict); + return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict); } } // namespace tvm diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc new file mode 100644 index 000000000..f227970fc --- /dev/null +++ b/src/relay/ir/alpha_equal.cc @@ -0,0 +1,384 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/ir/alpha_equal.cc + * \brief Alpha equality check by deep comparing two nodes. + */ +#include <tvm/ir_pass.h> +#include <tvm/relay/expr_functor.h> +#include <tvm/runtime/ndarray.h> +#include <tvm/relay/pass.h> +#include "type_functor.h" +#include "../../lang/attr_functor.h" + +namespace tvm { +namespace relay { + +// Alpha equal handler for relay. +class AlphaEqualHandler: + public AttrsEqualHandler, + public TypeFunctor<bool(const Type&, const Type&)>, + public ExprFunctor<bool(const Expr&, const Expr&)> { + public: + explicit AlphaEqualHandler(bool map_free_var) + : map_free_var_(map_free_var) {} + + /*! + * Check equality of two nodes. + * \param lhs The left hand operand. + * \param rhs The right hand operand. + * \return the compare result. + */ + bool Equal(const NodeRef& lhs, const NodeRef& rhs) { + if (lhs.same_as(rhs)) return true; + if (!lhs.defined() || !rhs.defined()) return false; + if (lhs->derived_from<TypeNode>()) { + if (!rhs->derived_from<TypeNode>()) return false; + return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs)); + } + if (lhs->derived_from<ExprNode>()) { + if (!rhs->derived_from<ExprNode>()) return false; + return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs)); + } + return AttrEqual(lhs, rhs); + } + + /*! + * Check equality of two attributes. + * \param lhs The left hand operand. + * \param rhs The right hand operand. + * \return the compare result. + */ + bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) { + return AttrsEqualHandler::Equal(lhs, rhs); + } + /*! + * Check equality of two types. + * \param lhs The left hand operand. + * \param rhs The right hand operand. + * \return the compare result. + */ + bool TypeEqual(const Type& lhs, const Type& rhs) { + if (lhs.same_as(rhs)) return true; + if (!lhs.defined() || !rhs.defined()) return false; + return this->VisitType(lhs, rhs); + } + /*! + * Check equality of two expressions. + * + * \note We run graph structural equality checking when comparing two Exprs. + * This means that AlphaEqualHandler can only be used once for each pair. + * The equality checker checks data-flow equvalence of the Expr DAG. + * This function also runs faster as it memomizes equal_map. + * + * \param lhs The left hand operand. + * \param rhs The right hand operand. + * \return the compare result. + */ + bool ExprEqual(const Expr& lhs, const Expr& rhs) { + if (lhs.same_as(rhs)) return true; + if (!lhs.defined() || !rhs.defined()) return false; + auto it = equal_map_.find(lhs); + if (it != equal_map_.end()) { + return it->second.same_as(rhs); + } + if (this->VisitExpr(lhs, rhs)) { + equal_map_[lhs] = rhs; + return true; + } else { + return false; + } + } + + protected: + /*! + * \brief Check if data type equals each other. + * \param lhs The left hand operand. + * \param rhs The right hand operand. + * \return the compare result. + */ + bool DataTypeEqual(const DataType& lhs, const DataType& rhs) { + return lhs == rhs; + } + /*! + * \brief Check Equality of leaf node of the graph. + * if map_free_var_ is set to true, try to map via equal node. + * \param lhs The left hand operand. + * \param rhs The right hand operand. + * \return the compare result. + */ + bool LeafNodeEqual(const NodeRef& lhs, const NodeRef& rhs) { + if (lhs.same_as(rhs)) return true; + auto it = equal_map_.find(lhs); + if (it != equal_map_.end()) { + return it->second.same_as(rhs); + } else { + if (map_free_var_) { + if (lhs->type_index() != rhs->type_index()) return false; + equal_map_[lhs] = rhs; + return true; + } else { + return false; + } + } + } + using AttrsEqualHandler::VisitAttr_; + bool VisitAttr_(const Variable* lhs, const NodeRef& other) final { + return LeafNodeEqual(GetRef<NodeRef>(lhs), other); + } + + // Type equality + bool VisitType_(const TensorTypeNode* lhs, const Type& other) final { + if (const TensorTypeNode* rhs = other.as<TensorTypeNode>()) { + return (lhs->dtype == rhs->dtype && + AttrEqual(lhs->shape, rhs->shape)); + } else { + return false; + } + } + + bool VisitType_(const IncompleteTypeNode* lhs, const Type& other) final { + return LeafNodeEqual(GetRef<NodeRef>(lhs), other); + } + + bool VisitType_(const TypeVarNode* lhs, const Type& other) final { + if (const TypeVarNode* rhs = other.as<TypeVarNode>()) { + if (lhs->kind != rhs->kind) return false; + return LeafNodeEqual(GetRef<NodeRef>(lhs), other); + } else { + return false; + } + } + + bool VisitType_(const FuncTypeNode* lhs, const Type& other) final { + if (const FuncTypeNode* rhs = other.as<FuncTypeNode>()) { + if (lhs->arg_types.size() != rhs->arg_types.size()) return false; + if (lhs->type_params.size() != rhs->type_params.size()) return false; + if (lhs->type_constraints.size() != rhs->type_constraints.size()) return false; + for (size_t i = 0; i < lhs->type_params.size(); ++i) { + if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) { + return false; + } + equal_map_[lhs->type_params[i]] = rhs->type_params[i]; + // set up type parameter equal + if (lhs->type_params[i]->kind == TypeVarNode::Kind::kShapeVar) { + // map variable + equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var; + } + } + for (size_t i = 0; i < lhs->arg_types.size(); i++) { + if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false; + } + if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false; + for (size_t i = 0; i < lhs->type_constraints.size(); i++) { + if (!TypeEqual(lhs->type_constraints[i], + rhs->type_constraints[i])) { + return false; + } + } + return true; + } else { + return false; + } + } + + bool VisitType_(const TypeRelationNode* lhs, const Type& other) final { + if (const TypeRelationNode* rhs = other.as<TypeRelationNode>()) { + if (!lhs->func.same_as(rhs->func)) return false; + if (lhs->num_inputs != rhs->num_inputs) return false; + if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false; + if (lhs->args.size() != rhs->args.size()) return false; + for (size_t i = 0; i < lhs->args.size(); ++i) { + if (!TypeEqual(lhs->args[i], rhs->args[i])) return false; + } + return true; + } else { + return false; + } + } + + bool VisitType_(const TupleTypeNode* lhs, const Type& other) final { + if (const TupleTypeNode* rhs = other.as<TupleTypeNode>()) { + if (lhs->fields.size() != rhs->fields.size()) return false; + for (size_t i = 0; i < lhs->fields.size(); ++i) { + if (!TypeEqual(lhs->fields[i], rhs->fields[i])) return false; + } + return true; + } else { + return false; + } + } + // Expr equal checking. + bool NDArrayEqual(const runtime::NDArray& lhs, + const runtime::NDArray& rhs) { + if (lhs.defined() != rhs.defined()) { + return false; + } else if (lhs.same_as(rhs)) { + return true; + } else { + auto ldt = lhs->dtype; + auto rdt = rhs->dtype; + CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; + CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; + if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { + size_t data_size = runtime::GetDataSize(*lhs.operator->()); + return std::memcmp(lhs->data, rhs->data, data_size) == 0; + } else { + return false; + } + } + } + // merge declaration of two variables together. + bool MergeVarDecl(const Var& lhs, const Var& rhs) { + if (lhs.same_as(rhs)) return true; + if (!lhs.defined() || !rhs.defined()) return false; + if (!TypeEqual(lhs->type_annotation, + rhs->type_annotation)) return false; + CHECK(!equal_map_.count(lhs)) + << "Duplicated declaration of variable " << lhs; + equal_map_[lhs] = rhs; + return true; + } + + bool VisitExpr_(const VarNode* lhs, const Expr& other) final { + if (const VarNode* rhs = other.as<VarNode>()) { + if (lhs->name_hint != rhs->name_hint) return false; + if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false; + return LeafNodeEqual(GetRef<NodeRef>(lhs), other); + } else { + return false; + } + } + + bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final { + if (const GlobalVarNode* rhs = other.as<GlobalVarNode>()) { + // use name equality for global var for now. + if (lhs->name_hint != rhs->name_hint) return false; + return true; + } else { + return false; + } + } + + bool VisitExpr_(const TupleNode* lhs, const Expr& other) final { + if (const TupleNode* rhs = other.as<TupleNode>()) { + if (lhs->fields.size() != rhs->fields.size()) return false; + for (size_t i = 0; i < lhs->fields.size(); ++i) { + if (!ExprEqual(lhs->fields[i], rhs->fields[i])) return false; + } + return true; + } else { + return false; + } + } + + bool VisitExpr_(const FunctionNode* lhs, const Expr& other) final { + if (const FunctionNode* rhs = other.as<FunctionNode>()) { + if (lhs->params.size() != rhs->params.size()) return false; + if (lhs->type_params.size() != rhs->type_params.size()) return false; + // map type parameter to be the same + for (size_t i = 0; i < lhs->type_params.size(); ++i) { + if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) return false; + equal_map_[lhs->type_params[i]] = rhs->type_params[i]; + } + // check parameter type annotations + for (size_t i = 0; i < lhs->params.size(); ++i) { + if (!MergeVarDecl(lhs->params[i], rhs->params[i])) return false; + } + // check return types. + if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false; + return ExprEqual(lhs->body, rhs->body); + } else { + return false; + } + } + + bool VisitExpr_(const CallNode* lhs, const Expr& other) final { + if (const CallNode* rhs = other.as<CallNode>()) { + if (!ExprEqual(lhs->op, rhs->op)) return false; + if (lhs->args.size() != rhs->args.size()) return false; + if (lhs->type_args.size() != rhs->type_args.size()) return false; + + for (size_t i = 0; i < lhs->args.size(); ++i) { + if (!ExprEqual(lhs->args[i], rhs->args[i])) return false; + } + for (size_t i = 0; i < lhs->type_args.size(); ++i) { + if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false; + } + return AttrEqual(lhs->attrs, rhs->attrs); + } else { + return false; + } + } + + bool VisitExpr_(const LetNode* lhs, const Expr& other) final { + if (const LetNode* rhs = other.as<LetNode>()) { + if (!ExprEqual(lhs->value, rhs->value)) return false; + if (!MergeVarDecl(lhs->var, rhs->var)) return false; + return ExprEqual(lhs->body, rhs->body); + } else { + return false; + } + } + + bool VisitExpr_(const IfNode* lhs, const Expr& other) final { + if (const IfNode* rhs = other.as<IfNode>()) { + return ExprEqual(lhs->cond, rhs->cond) && + ExprEqual(lhs->true_branch, rhs->true_branch) && + ExprEqual(lhs->false_branch, rhs->false_branch); + } else { + return false; + } + } + + bool VisitExpr_(const OpNode* op, const Expr& other) final { + return op == other.get(); + } + + bool VisitExpr_(const ConstantNode* lhs, const Expr& other) final { + if (const ConstantNode* rhs = other.as<ConstantNode>()) { + return NDArrayEqual(lhs->data, rhs->data); + } else { + return false; + } + } + + bool VisitExpr_(const TupleGetItemNode* lhs, const Expr& other) final { + if (const TupleGetItemNode* rhs = other.as<TupleGetItemNode>()) { + return ExprEqual(lhs->tuple, rhs->tuple) && lhs->index == rhs->index; + } else { + return false; + } + } + + private: + // whether to map open terms. + bool map_free_var_{false}; + // renaming of NodeRef to indicate two nodes equals to each other + std::unordered_map<NodeRef, NodeRef, NodeHash, NodeEqual> equal_map_; +}; + +bool AlphaEqual(const Type& lhs, const Type& rhs) { + return AlphaEqualHandler(false).TypeEqual(lhs, rhs); +} + +bool AlphaEqual(const Expr& lhs, const Expr& rhs) { + return AlphaEqualHandler(false).ExprEqual(lhs, rhs); +} + +// TODO(@jroesch): move to correct namespace? +TVM_REGISTER_API("relay._make._alpha_equal") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = AlphaEqualHandler(false).Equal(args[0], args[1]); + }); + +TVM_REGISTER_API("relay._make._type_alpha_equal") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = AlphaEqualHandler(false).TypeEqual(args[0], args[1]); + }); + +TVM_REGISTER_API("relay._make._graph_equal") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = AlphaEqualHandler(true).Equal(args[0], args[1]); + }); +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 66ef86641..0ebe111ab 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -6,7 +6,7 @@ #include <tvm/relay/environment.h> #include <tvm/relay/expr_functor.h> #include <sstream> -#include "../pass/type_functor.h" +#include "type_functor.h" #include "../../lang/attr_functor.h" namespace tvm { @@ -245,6 +245,9 @@ class TextPrinter : stream_ << ", "; } } + if (fields.size() == 1) { + stream_ << ','; + } stream_ << ')'; this->PrintEndInst("\n"); return id; @@ -648,7 +651,7 @@ class TextPrinter : name = "%" + name; } TextValue val(GetUniqueName(name)); - CHECK(!memo_.count(var)); + CHECK(!memo_.count(var)) << "Duplicated variable " << var; memo_[var] = val; return val; } diff --git a/src/relay/pass/type_visitor.h b/src/relay/ir/type_functor.h similarity index 52% rename from src/relay/pass/type_visitor.h rename to src/relay/ir/type_functor.h index c1b2c3e1a..03bb4db1f 100644 --- a/src/relay/pass/type_visitor.h +++ b/src/relay/ir/type_functor.h @@ -1,18 +1,97 @@ /*! * Copyright (c) 2018 by Contributors - * \file type_visitor.h - * \brief A wrapper around TypeFunctor for common use cases. + * \file type_functor.h + * \brief A way to defined arbitrary function signature with dispatch on types. */ -#ifndef TVM_RELAY_PASS_TYPE_VISITOR_H_ -#define TVM_RELAY_PASS_TYPE_VISITOR_H_ +#ifndef TVM_RELAY_IR_TYPE_FUNCTOR_H_ +#define TVM_RELAY_IR_TYPE_FUNCTOR_H_ +#include <tvm/node/ir_functor.h> +#include <tvm/relay/expr.h> +#include <string> #include <vector> -#include "./type_functor.h" namespace tvm { namespace relay { -/*! \brief A type visitor for vistiors which make use of internal +template <typename FType> +class TypeFunctor; + +// functions to be overriden. +#define TYPE_FUNCTOR_DEFAULT \ + { return VisitTypeDefault_(op, std::forward<Args>(args)...); } + + +#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch<OP>( \ + [](const NodeRef& n, TSelf* self, Args... args) { \ + return self->VisitType_(static_cast<const OP*>(n.node_.get()), \ + std::forward<Args>(args)...); \ + }); + +template <typename R, typename... Args> +class TypeFunctor<R(const Type& n, Args...)> { + private: + using TSelf = TypeFunctor<R(const Type& n, Args...)>; + using FType = tvm::IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~TypeFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Type& n, Args... args) { + return VisitType(n, std::forward<Args>(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitType(const Type& n, Args... args) { + static FType vtable = InitVTable(); + return vtable(n, this, std::forward<Args>(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitType_(const TensorTypeNode* op, + Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + + virtual R VisitTypeDefault_(const Node* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->type_key(); + throw; // unreachable, written to stop compiler warning + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeVarNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode); + RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); + return vtable; + } +}; + +/*! + * \brief A type visitor for vistiors which make use of internal * mutable state. * * We recursively visit each type contained inside the visitor. @@ -118,7 +197,6 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> { return GetRef<Type>(op); } }; - } // namespace relay } // namespace tvm -#endif // TVM_RELAY_PASS_TYPE_VISITOR_H_ +#endif // TVM_RELAY_IR_TYPE_FUNCTOR_H_ diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc deleted file mode 100644 index 41ec3f1e0..000000000 --- a/src/relay/pass/alpha_eq.cc +++ /dev/null @@ -1,418 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file src/tvm/relay/pass/alpha_eq.cc - * \brief Check that two type are syntactically equal up to alpha equivalence. - */ -#include <tvm/ir_pass.h> -#include <tvm/relay/expr_functor.h> -#include <tvm/runtime/ndarray.h> -#include "./type_visitor.h" -#include "tvm/relay/pass.h" - -namespace tvm { -namespace relay { - -using namespace tvm::runtime; - -bool SameNDArray(const NDArray& lhs, const NDArray& rhs) { - if (lhs.defined() != rhs.defined()) { - return false; - } else if (lhs.same_as(rhs)) { - return true; - } else { - auto ldt = lhs->dtype; - auto rdt = rhs->dtype; - CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { - size_t s = GetDataSize(*lhs.operator->()); - return memcmp(lhs->data, rhs->data, s) == 0; - } else { - return false; - } - } -} - -struct TypeAlphaEq : TypeVisitor<const Type&> { - tvm::Map<TypeVar, TypeVar> eq_map; - bool equal; - - TypeAlphaEq() : eq_map(), equal(true) {} - - void DataTypeEqual(const DataType& dt1, const DataType& dt2) { - if (dt1 != dt2) { - equal = false; - } - } - - void ShapeEqual(const Array<IndexExpr>& s1, const Array<IndexExpr>& s2) { - if (s1.size() != s2.size()) { - equal = false; - return; - } - for (size_t i = 0; i < s1.size(); ++i) { - if (!tvm::ir::Equal(s1[i], s2[i])) { - equal = false; - return; - } - } - } - - void VisitType_(const TensorTypeNode* tt1, const Type& t2) final { - if (const TensorTypeNode* tt2 = t2.as<TensorTypeNode>()) { - DataTypeEqual(tt1->dtype, tt2->dtype); - ShapeEqual(tt1->shape, tt2->shape); - } else { - equal = false; - } - } - - void VisitType_(const IncompleteTypeNode* bt1, const Type& t2) final { - if (const IncompleteTypeNode* bt2 = t2.as<IncompleteTypeNode>()) { - equal = equal && bt1 == bt2; - return; - } else { - equal = false; - } - } - - void VisitType_(const TypeVarNode* ti1, const Type& t2) final { - if (const TypeVarNode* ti2 = t2.as<TypeVarNode>()) { - auto tid1 = GetRef<TypeVar>(ti1); - auto tid2 = GetRef<TypeVar>(ti2); - - // We handle open terms with this rule assuming variables are identical. - // - // Not sure if we should do this. - if (tid1 == tid2) { - return; - } - - // Check that they are same kind - if (tid1->kind != tid2->kind) { - equal = false; - return; - } - - // Next we see if there is mapping for local1 into the rhs term. - // If there is we check to see if those are equal. - if (eq_map.find(tid1) != eq_map.end()) { - equal = equal && eq_map[tid1] == tid2; - } else { - equal = false; - } - } else { - equal = false; - } - } - - void VisitType_(const FuncTypeNode* op, const Type& t2) final { - if (const FuncTypeNode* ta2 = t2.as<FuncTypeNode>()) { - if (op->arg_types.size() != ta2->arg_types.size() - || op->type_params.size() != ta2->type_params.size() - || op->type_constraints.size() != ta2->type_constraints.size()) { - equal = false; - return; - } - - // must visit params first so they are appropriate entered - // into equality map - for (size_t i = 0; i < op->type_params.size(); i++) { - eq_map.Set(op->type_params[i], ta2->type_params[i]); - this->VisitType(op->type_params[i], ta2->type_params[i]); - if (!equal) { - return; - } - } - - for (size_t i = 0; i < op->arg_types.size(); i++) { - this->VisitType(op->arg_types[i], ta2->arg_types[i]); - if (!equal) { - return; - } - } - - this->VisitType(op->ret_type, ta2->ret_type); - if (!equal) { - return; - } - - for (size_t i = 0; i < op->type_constraints.size(); i++) { - this->VisitType(op->type_constraints[i], ta2->type_constraints[i]); - if (!equal) { - return; - } - } - } else { - equal = false; - } - } - - void VisitType_(const TypeRelationNode* tr1, const Type& t2) final { - if (const TypeRelationNode* tr2 = t2.as<TypeRelationNode>()) { - if (tr1->func != tr2->func - || tr1->num_inputs != tr2->num_inputs - || tr1->attrs != tr2->attrs) { - equal = false; - return; - } - - if (tr1->args.size() != tr2->args.size()) { - equal = false; - return; - } - - for (size_t i = 0; i < tr1->args.size(); i++) { - this->VisitType(tr1->args[i], tr2->args[i]); - if (!equal) { - return; - } - } - } else { - equal = false; - } - } - - void VisitType_(const TupleTypeNode* op, const Type& t2) final { - if (const TupleTypeNode* pt = t2.as<TupleTypeNode>()) { - if (op->fields.size() != pt->fields.size()) { - equal = false; - return; - } - - for (size_t i = 0U; i < op->fields.size(); i++) { - if (!equal) { - return; - } - this->VisitType(op->fields[i], pt->fields[i]); - } - } else { - equal = false; - } - } -}; - -bool AlphaEqual(const Type& t1, const Type& t2) { - if (t1.defined() != t2.defined()) { - return false; - } - - if (!t1.defined()) { - return true; - } - - TypeAlphaEq aeq; - aeq.VisitType(t1, t2); - return aeq.equal; -} - -struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { - public: - tvm::Map<Var, Var> eq_map; - - bool equal; - AlphaEq() : eq_map(), equal(true) {} - - void VisitExpr_(const VarNode* e1, const Expr& e2) final { - if (const VarNode* id2 = e2.as<VarNode>()) { - auto local1 = GetRef<Var>(e1); - auto local2 = GetRef<Var>(id2); - // We handle open terms with this rule assuming variables are identical. - if (local1 == local2) { - equal = true; - return; - } - - // Next we see if there is mapping for local1 into the rhs term. - // If there is we check to see if those are equal. - if (eq_map.find(local1) != eq_map.end()) { - equal = equal && eq_map[local1] == local2; - } else { - equal = false; - } - } else { - equal = false; - } - } - - void VisitExpr_(const GlobalVarNode* g1, const Expr& e2) final { - if (const GlobalVarNode* g2 = e2.as<GlobalVarNode>()) { - equal = equal && g1 == g2; - } else { - equal = false; - } - } - - void VisitExpr_(const TupleNode* pl1, const Expr& e2) final { - Tuple prod1 = GetRef<Tuple>(pl1); - if (const TupleNode* pl2 = e2.as<TupleNode>()) { - Tuple prod2 = GetRef<Tuple>(pl2); - if (prod1->fields.size() != prod2->fields.size()) { - equal = false; - return; - } - - for (size_t i = 0U; i < prod1->fields.size(); i++) { - this->VisitExpr(prod1->fields[i], prod2->fields[i]); - } - } else { - equal = false; - } - } - - void VisitExpr_(const FunctionNode* func1, const Expr& e2) final { - if (const FunctionNode* func2 = e2.as<FunctionNode>()) { - if (func1->params.size() != func2->params.size()) { - equal = false; - return; - } - - if (func1->type_params.size() != func2->type_params.size()) { - equal = false; - return; - } - - for (size_t i = 0; i < func1->params.size(); ++i) { - MergeVarDecl(func1->params[i], func2->params[i]); - } - - if (!equal) { - return; - } - - for (size_t i = 0U; i < func1->type_params.size(); i++) { - equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]); - if (!equal) { - return; - } - } - - equal = equal && AlphaEqual(func1->ret_type, func2->ret_type); - if (!equal) { - return; - } - - this->VisitExpr(func1->body, func2->body); - } else { - equal = false; - } - } - - void VisitExpr_(const CallNode* op, const Expr& e2) final { - if (const CallNode* call = e2.as<CallNode>()) { - this->VisitExpr(op->op, call->op); - - if (op->args.size() != call->args.size()) { - equal = false; - return; - } - - if (op->type_args.size() != call->type_args.size()) { - equal = false; - return; - } - - // checking attrs by pointer equality for now - equal = equal && (op->attrs == call->attrs); - if (!equal) { - return; - } - - for (size_t i = 0U; i < op->args.size(); i++) { - this->VisitExpr(op->args[i], call->args[i]); - } - - for (size_t i = 0U; i < op->type_args.size(); i++) { - equal = equal && AlphaEqual(op->type_args[i], call->type_args[i]); - if (!equal) { - return; - } - } - } else { - equal = false; - } - } - - void VisitExpr_(const LetNode* op, const Expr& e2) final { - if (const LetNode* let = e2.as<LetNode>()) { - MergeVarDecl(op->var, let->var); - this->VisitExpr(op->value, let->value); - this->VisitExpr(op->body, let->body); - } else { - equal = false; - } - } - - void VisitExpr_(const IfNode* op, const Expr& e2) final { - if (const IfNode* i = e2.as<IfNode>()) { - VisitExpr(op->cond, i->cond); - VisitExpr(op->true_branch, i->true_branch); - VisitExpr(op->false_branch, i->false_branch); - } else { - equal = false; - } - } - - void VisitExpr_(const OpNode* op, const Expr& e2) final { - if (const OpNode* o = e2.as<OpNode>()) { - equal = equal && op->name == o->name; - } else { - equal = false; - } - } - - void VisitExpr_(const ConstantNode* op, const Expr& e2) final { - if (const ConstantNode* c = e2.as<ConstantNode>()) { - if (AlphaEqual(op->tensor_type(), c->tensor_type())) { - equal = equal && SameNDArray(op->data, c->data); - } else { - equal = false; - } - } else { - equal = false; - } - } - - void VisitExpr_(const TupleGetItemNode* op, const Expr& e2) final { - if (const TupleGetItemNode* proj = e2.as<TupleGetItemNode>()) { - this->VisitExpr(op->tuple, proj->tuple); - equal = equal && (op->index == proj->index); - } else { - equal = false; - } - } - - private: - void MergeVarDecl(const Var& var1, const Var& var2) { - equal = equal && AlphaEqual(var1->type_annotation, var2->type_annotation); - if (!equal) { - return; - } - - eq_map.Set(var1, var2); - } -}; - -bool AlphaEqual(const Expr& e1, const Expr& e2) { - AlphaEq eq; - eq.VisitExpr(e1, e2); - return eq.equal; -} - -// TODO(@jroesch): move to correct namespace? -TVM_REGISTER_API("relay._make._alpha_equal") - .set_body([](TVMArgs args, TVMRetValue* ret) { - Expr e1 = args[0]; - Expr e2 = args[1]; - *ret = AlphaEqual(e1, e2); - }); - -TVM_REGISTER_API("relay._make._type_alpha_equal") - .set_body([](TVMArgs args, TVMRetValue* ret) { - Type t1 = args[0]; - Type t2 = args[1]; - *ret = AlphaEqual(t1, t2); - }); - -} // namespace relay -} // namespace tvm diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 8fd77a71e..c3d16c297 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -14,7 +14,7 @@ * contains a data type such as `int`, `float`, `uint`. */ #include <tvm/relay/pass.h> -#include "./type_visitor.h" +#include "../ir/type_functor.h" namespace tvm { namespace relay { @@ -105,13 +105,13 @@ bool KindCheck(const Type& t, const Environment& env) { } TVM_REGISTER_API("relay._ir_pass.check_kind") - .set_body([](TVMArgs args, TVMRetValue* ret) { - if (args.size() == 1) { - *ret = KindCheck(args[0], EnvironmentNode::make({})); - } else { - *ret = KindCheck(args[0], args[1]); - } - }); +.set_body([](TVMArgs args, TVMRetValue* ret) { + if (args.size() == 1) { + *ret = KindCheck(args[0], EnvironmentNode::make({})); + } else { + *ret = KindCheck(args[0], args[1]); + } + }); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/type_functor.h b/src/relay/pass/type_functor.h deleted file mode 100644 index b8eaa85a7..000000000 --- a/src/relay/pass/type_functor.h +++ /dev/null @@ -1,94 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file type_functor.h - * \brief A way to defined arbitrary function signature with dispatch on types. - */ -#ifndef TVM_RELAY_PASS_TYPE_FUNCTOR_H_ -#define TVM_RELAY_PASS_TYPE_FUNCTOR_H_ - -#include <tvm/node/ir_functor.h> -#include <tvm/relay/expr.h> -#include <string> - -namespace tvm { -namespace relay { - -template <typename FType> -class TypeFunctor; - -// functions to be overriden. -#define TYPE_FUNCTOR_DEFAULT \ - { return VisitTypeDefault_(op, std::forward<Args>(args)...); } - - -#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch<OP>( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitType_(static_cast<const OP*>(n.node_.get()), \ - std::forward<Args>(args)...); \ - }); - -template <typename R, typename... Args> -class TypeFunctor<R(const Type& n, Args...)> { - private: - using TSelf = TypeFunctor<R(const Type& n, Args...)>; - using FType = tvm::IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>; - - public: - /*! \brief the result type of this functor */ - using result_type = R; - /*! \brief virtual destructor */ - virtual ~TypeFunctor() {} - /*! - * \brief Same as call. - * \param n The expression node. - * \param args Additional arguments. - * \return The result of the call - */ - R operator()(const Type& n, Args... args) { - return VisitType(n, std::forward<Args>(args)...); - } - /*! - * \brief The functor call. - * \param n The expression node. - * \param args Additional arguments. - * \return The result of the call - */ - virtual R VisitType(const Type& n, Args... args) { - static FType vtable = InitVTable(); - return vtable(n, this, std::forward<Args>(args)...); - } - // Functions that can be overriden by subclass - virtual R VisitType_(const TensorTypeNode* op, - Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const TypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - - virtual R VisitTypeDefault_(const Node* op, Args...) { - LOG(FATAL) << "Do not have a default for " << op->type_key(); - throw; // unreachable, written to stop compiler warning - } - - private: - // initialize the vtable. - static FType InitVTable() { - FType vtable; - // Set dispatch - RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode); - RELAY_TYPE_FUNCTOR_DISPATCH(TypeVarNode); - RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode); - RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); - RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); - RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); - RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); - return vtable; - } -}; - -} // namespace relay -} // namespace tvm -#endif // TVM_RELAY_PASS_TYPE_FUNCTOR_H_ diff --git a/src/relay/pass/type_subst.cc b/src/relay/pass/type_subst.cc index bffd779d1..76507058f 100644 --- a/src/relay/pass/type_subst.cc +++ b/src/relay/pass/type_subst.cc @@ -4,7 +4,7 @@ * \brief Function for substituting a concrete type in place of a type ID */ #include "./type_subst.h" -#include "./type_visitor.h" +#include "../ir/type_functor.h" namespace tvm { namespace relay { diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index ff4bb55b7..d69f1bce7 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -7,7 +7,7 @@ */ #include <tvm/relay/pass.h> #include <tvm/relay/expr_functor.h> -#include "./type_visitor.h" +#include "../ir/type_functor.h" namespace tvm { namespace relay { diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index de4df7c84..d16c2df53 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -139,7 +139,8 @@ def test_type_relation_alpha_equal(): # attrs are also compared only by pointer equality attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) - attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) + attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) + attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4)) tr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1) same = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1) @@ -147,6 +148,7 @@ def test_type_relation_alpha_equal(): diff_order = relay.TypeRelation(broadcast, tvm.convert([t2, t1]), 1, attr1) diff_args = relay.TypeRelation(broadcast, tvm.convert([t2, t3]), 1, attr1) diff_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr2) + same_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1_same) bigger = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 2, attr1) diff_num_inputs = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 1, attr2) @@ -157,6 +159,7 @@ def test_type_relation_alpha_equal(): assert tr != diff_order assert tr != diff_args assert tr != diff_attr + assert tr == same_attr assert tr != bigger assert bigger != diff_num_inputs @@ -216,22 +219,26 @@ def test_global_var_alpha_equal(): def test_tuple_alpha_equal(): + v0 = relay.Var("v0") v1 = relay.Var("v1") v2 = relay.Var("v2") # unit value is a valid tuple assert alpha_equal(relay.Tuple([]), relay.Tuple([])) - tup = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]) - same = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]) + tup = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]) + same = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]) assert alpha_equal(tup, same) # use the eq_map + + let_tup = relay.Let(v1, tup, v1) - let_mapped = relay.Let(v2, relay.Tuple([v2, relay.const(2), relay.const(3), + let_mapped = relay.Let(v2, relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]), v2) + assert alpha_equal(let_tup, let_mapped) more_fields = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)]), v2]) @@ -340,7 +347,8 @@ def test_call_alpha_equal(): # attrs are compared only by pointer equality attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) - attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) + attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) + attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4)) tt1 = relay.TensorType((1, 2, 3), "float32") tt2 = relay.TensorType((), "int8") @@ -375,6 +383,9 @@ def test_call_alpha_equal(): different_attrs = relay.Call(v1, basic_args, attr2, [tt1]) assert not alpha_equal(call, different_attrs) + same_attrs = relay.Call(v1, basic_args, attr1_same, [tt1]) + assert alpha_equal(call, same_attrs) + no_type_args = relay.Call(v1, basic_args, attr1) assert not alpha_equal(call, no_type_args) @@ -445,6 +456,27 @@ def test_op_alpha_equal(): assert not alpha_equal(op1, op3) +def test_graph_equal(): + x = relay.var("x") + + y0 = relay.add(x, x) + z0 = relay.add(y0, y0) + + y1 = relay.add(x, x) + z1 = relay.add(y1, y1) + + z3 = relay.add(relay.add(x, x), relay.add(x, x)) + + assert alpha_equal(z0, z1) + + # z3's dataflow format is different from z0 + # z0 is computed from a common y0 node + # Relay view them as different programs + # Check the difference in the text format. + assert not alpha_equal(z0, z3) + + + if __name__ == "__main__": test_tensor_type_alpha_equal() test_incomplete_type_alpha_equal() @@ -462,3 +494,4 @@ if __name__ == "__main__": test_if_alpha_equal() test_op_alpha_equal() test_var_alpha_equal() + test_graph_equal() diff --git a/tests/python/unittest/test_pass_attrs_hash_equal.py b/tests/python/unittest/test_pass_attrs_hash_equal.py index 23f0e6374..2d6987aeb 100644 --- a/tests/python/unittest/test_pass_attrs_hash_equal.py +++ b/tests/python/unittest/test_pass_attrs_hash_equal.py @@ -17,6 +17,12 @@ def test_attrs_equal(): assert tvm.ir_pass.AttrsEqual({"x": [x, x]}, {"x": [y, x]}) assert not tvm.ir_pass.AttrsEqual({"x": [x, 1]}, {"x": [y, 2]}) + n = tvm.var("n") + assert tvm.ir_pass.AttrsEqual({"x": n+1}, {"x": n+1}) + + + + def test_attrs_hash(): fhash = tvm.ir_pass.AttrsHash -- GitLab