diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 29d2f87cf04aeb740a3f819837ca874737e7a927..059504efc883341341085df93b385993a122f256 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -193,6 +193,12 @@ struct TypeAlphaEq : TypeVisitor<const Type&> { }; bool AlphaEqual(const Type& t1, const Type& t2) { + if (t1.defined() != t2.defined()) + return false; + + if (!t1.defined()) + return true; + TypeAlphaEq aeq; aeq.VisitType(t1, t2); return aeq.equal; @@ -373,15 +379,11 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { private: void MergeVarDecl(const Var& var1, const Var& var2) { - if (var1->type_annotation.defined() != var2->type_annotation.defined()) { - equal = false; - return; - } - if (var1->type_annotation.defined() && - !AlphaEqual(var1->type_annotation, var2->type_annotation)) { - equal = false; + equal = equal && AlphaEqual(var1->type_annotation, var2->type_annotation); + if (!equal) { return; } + eq_map.Set(var1, var2); } }; diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 51c1d4a2715a6ccc685622569a38d4a7ec7c230c..2bfbc7f10a402aee17b2fa0e8e05f9a9ef2dd22f 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -187,6 +187,25 @@ def test_var_alpha_equal(): assert alpha_equal(l1, l2) assert not alpha_equal(l1, l3) + # type annotations + tt1 = relay.TensorType([], "int32") + tt2 = relay.TensorType([], "int32") + tt3 = relay.TensorType([], "int64") + v3 = relay.Var("v3", tt1) + v4 = relay.Var("v4", tt2) + v5 = relay.Var("v5", tt3) + + l4 = relay.Let(v3, convert(1), v3) + l5 = relay.Let(v4, convert(1), v4) + l6 = relay.Let(v5, convert(1), v5) + + # same annotations + assert alpha_equal(l4, l5) + # different annotations + assert not alpha_equal(l4, l6) + # one null annotation + assert not alpha_equal(l1, l4) + def test_global_var_alpha_equal(): v1 = relay.GlobalVar("v1") @@ -307,6 +326,14 @@ def test_function_alpha_equal(): tupled_example = relay.Function(basic_args, relay.Tuple([v3, v4]), tt3) assert not alpha_equal(func, tupled_example) + # nullable + no_ret_type = relay.Function(basic_args, v4, None, [tp1, tp2]) + # both null + assert alpha_equal(no_ret_type, no_ret_type) + # one null + assert not alpha_equal(func, no_ret_type) + assert not alpha_equal(no_ret_type, func) + def test_call_alpha_equal(): v1 = relay.Var("v1")