diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc
index 29d2f87cf04aeb740a3f819837ca874737e7a927..059504efc883341341085df93b385993a122f256 100644
--- a/src/relay/pass/alpha_eq.cc
+++ b/src/relay/pass/alpha_eq.cc
@@ -193,6 +193,12 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
 };
 
 bool AlphaEqual(const Type& t1, const Type& t2) {
+  if (t1.defined() != t2.defined())
+    return false;
+
+  if (!t1.defined())
+    return true;
+
   TypeAlphaEq aeq;
   aeq.VisitType(t1, t2);
   return aeq.equal;
@@ -373,15 +379,11 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
 
  private:
   void MergeVarDecl(const Var& var1, const Var& var2) {
-    if (var1->type_annotation.defined() != var2->type_annotation.defined()) {
-      equal = false;
-      return;
-    }
-    if (var1->type_annotation.defined() &&
-        !AlphaEqual(var1->type_annotation, var2->type_annotation)) {
-      equal = false;
+    equal = equal && AlphaEqual(var1->type_annotation, var2->type_annotation);
+    if (!equal) {
       return;
     }
+
     eq_map.Set(var1, var2);
   }
 };
diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py
index 51c1d4a2715a6ccc685622569a38d4a7ec7c230c..2bfbc7f10a402aee17b2fa0e8e05f9a9ef2dd22f 100644
--- a/tests/python/relay/test_pass_alpha_equal.py
+++ b/tests/python/relay/test_pass_alpha_equal.py
@@ -187,6 +187,25 @@ def test_var_alpha_equal():
     assert alpha_equal(l1, l2)
     assert not alpha_equal(l1, l3)
 
+    # type annotations
+    tt1 = relay.TensorType([], "int32")
+    tt2 = relay.TensorType([], "int32")
+    tt3 = relay.TensorType([], "int64")
+    v3 = relay.Var("v3", tt1)
+    v4 = relay.Var("v4", tt2)
+    v5 = relay.Var("v5", tt3)
+
+    l4 = relay.Let(v3, convert(1), v3)
+    l5 = relay.Let(v4, convert(1), v4)
+    l6 = relay.Let(v5, convert(1), v5)
+
+    # same annotations
+    assert alpha_equal(l4, l5)
+    # different annotations
+    assert not alpha_equal(l4, l6)
+    # one null annotation
+    assert not alpha_equal(l1, l4)
+
 
 def test_global_var_alpha_equal():
     v1 = relay.GlobalVar("v1")
@@ -307,6 +326,14 @@ def test_function_alpha_equal():
     tupled_example = relay.Function(basic_args, relay.Tuple([v3, v4]), tt3)
     assert not alpha_equal(func, tupled_example)
 
+    # nullable
+    no_ret_type = relay.Function(basic_args, v4, None, [tp1, tp2])
+    # both null
+    assert alpha_equal(no_ret_type, no_ret_type)
+    # one null
+    assert not alpha_equal(func, no_ret_type)
+    assert not alpha_equal(no_ret_type, func)
+
 
 def test_call_alpha_equal():
     v1 = relay.Var("v1")