From 65016b6528d2c6aed7816f8f813c7f2eccc31256 Mon Sep 17 00:00:00 2001
From: "Steven S. Lyubomirsky" <slyubomirsky@gmail.com>
Date: Thu, 11 Oct 2018 09:24:26 -0700
Subject: [PATCH] [Relay] Alpha equality tests for Relay exprs (#1871)

---
 python/tvm/relay/expr.py                    |   2 +-
 src/relay/pass/alpha_eq.cc                  |  44 +++
 tests/python/relay/test_pass_alpha_equal.py | 281 +++++++++++++++++++-
 3 files changed, 320 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index 05214ca09..6ed8df0d7 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -112,7 +112,7 @@ class Call(Expr):
 class Let(Expr):
     """A variable bindings in Relay, see tvm/relay/expr.h for more details."""
 
-    def __init__(self, var, value, body, value_type):
+    def __init__(self, var, value, body, value_type=None):
         self.__init_handle_by_constructor__(
             _make.Let, var, value, body, value_type)
 
diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc
index 0e13a598c..0ed0e3df3 100644
--- a/src/relay/pass/alpha_eq.cc
+++ b/src/relay/pass/alpha_eq.cc
@@ -268,10 +268,27 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
         return;
       }
 
+      if (func1->type_params.size() != func2->type_params.size()) {
+        equal = false;
+        return;
+      }
+
       for (size_t i = 0U; i < func1->params.size(); i++) {
         this->VisitExpr(func1->params[i], func2->params[i]);
       }
 
+      for (size_t i = 0U; i < func1->type_params.size(); i++) {
+        equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]);
+        if (!equal) {
+          return;
+        }
+      }
+
+      equal = equal && AlphaEqual(func1->ret_type, func2->ret_type);
+      if (!equal) {
+        return;
+      }
+
       this->VisitExpr(func1->body, func2->body);
     } else {
       equal = false;
@@ -287,10 +304,27 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
         return;
       }
 
+      if (op->type_args.size() != call->type_args.size()) {
+        equal = false;
+        return;
+      }
+
+      // checking attrs by pointer equality for now
+      equal = equal && (op->attrs == call->attrs);
+      if (!equal) {
+        return;
+      }
+
       for (size_t i = 0U; i < op->args.size(); i++) {
         this->VisitExpr(op->args[i], call->args[i]);
       }
 
+      for (size_t i = 0U; i < op->type_args.size(); i++) {
+        equal = equal && AlphaEqual(op->type_args[i], call->type_args[i]);
+        if (!equal) {
+          return;
+        }
+      }
     } else {
       equal = false;
     }
@@ -301,6 +335,16 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
       eq_map.Set(op->var, let->var);
       this->VisitExpr(op->value, let->value);
       this->VisitExpr(op->body, let->body);
+
+      // value_type should match as well (including nulls)
+      if (op->value_type.defined() != let->value_type.defined()) {
+        equal = false;
+        return;
+      }
+
+      if (op->value_type.defined()) {
+        equal = equal && AlphaEqual(op->value_type, let->value_type);
+      }
     } else {
       equal = false;
     }
diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py
index 9fa1a554a..dd722399d 100644
--- a/tests/python/relay/test_pass_alpha_equal.py
+++ b/tests/python/relay/test_pass_alpha_equal.py
@@ -14,12 +14,6 @@ def test_tensor_type_alpha_equal():
     t2 = relay.TensorType((), "float32")
     assert t1 == t2
 
-def test_constant_alpha_equal():
-    x = convert(1)
-    y = convert(2)
-    assert alpha_equal(x, x)
-    assert not alpha_equal(x, y)
-    assert alpha_equal(x, convert(1))
 
 def test_incomplete_type_alpha_equal():
     t1 = relay.IncompleteType(relay.Kind.Shape)
@@ -167,6 +161,79 @@ def test_type_relation_alpha_equal():
 
     assert bigger != diff_num_inputs
 
+
+def test_constant_alpha_equal():
+    x = convert(1)
+    y = convert(2)
+    assert alpha_equal(x, x)
+    assert not alpha_equal(x, y)
+    assert alpha_equal(x, convert(1))
+
+
+def test_var_alpha_equal():
+    v1 = relay.Var("v1")
+    v2 = relay.Var("v2")
+
+    # normally only pointer equality
+    assert alpha_equal(v1, v1)
+    assert not alpha_equal(v1, v2)
+
+    # let node allows for setting the eq_map
+    l1 = relay.Let(v1, convert(1), v1, None)
+    l2 = relay.Let(v2, convert(1), v2, None)
+    l3 = relay.Let(v1, convert(1), v2, None)
+
+    assert alpha_equal(l1, l2)
+    assert not alpha_equal(l1, l3)
+
+
+def test_global_var_alpha_equal():
+    v1 = relay.GlobalVar("v1")
+    v2 = relay.GlobalVar("v2")
+
+    # only pointer equality suffices (smoke test)
+    assert alpha_equal(v1, v1)
+    assert not alpha_equal(v1, v2)
+
+
+def test_tuple_alpha_equal():
+    v1 = relay.Var("v1")
+    v2 = relay.Var("v2")
+
+    # unit value is a valid tuple
+    assert alpha_equal(relay.Tuple([]), relay.Tuple([]))
+
+    tup = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)])])
+    same = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)])])
+
+    assert alpha_equal(tup, same)
+
+    # use the eq_map
+    let_tup = relay.Let(v1, tup, v1, None)
+    let_mapped = relay.Let(v2, relay.Tuple([v2, convert(2), convert(3),
+                                            relay.Tuple([convert(4)])]),
+                           v2, None)
+    assert alpha_equal(let_tup, let_mapped)
+
+    more_fields = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)]), v2])
+    assert not alpha_equal(tup, more_fields)
+
+    fewer_fields = relay.Tuple([v1, convert(2), convert(3)])
+    assert not alpha_equal(tup, fewer_fields)
+
+    different_end = relay.Tuple([v1, convert(2), convert(3),
+                           relay.Tuple([convert(5)])])
+    assert not alpha_equal(tup, different_end)
+
+    different_start = relay.Tuple([v2, convert(2), convert(3),
+                                 relay.Tuple([convert(4)])])
+    assert not alpha_equal(tup, different_start)
+
+    longer_at_end = relay.Tuple([v1, convert(2), convert(3),
+                                 relay.Tuple([convert(4), convert(5)])])
+    assert not alpha_equal(tup, longer_at_end)
+
+
 def test_tuple_get_item_alpha_equal():
     x = relay.Var('x')
     y = relay.Var('y')
@@ -174,6 +241,198 @@ def test_tuple_get_item_alpha_equal():
     assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2))
     assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
 
+
+def test_param_alpha_equal():
+    # only checks equality of the types
+    v1 = relay.Var("v1")
+    v2 = relay.Var("v2")
+
+    p1 = relay.Param(v1, relay.TensorType((1, 2, 3), "float32"))
+    p2 = relay.Param(v2, relay.TensorType((1, 2, 3), "float32"))
+    assert alpha_equal(p1, p2)
+
+    p3 = relay.Param(v1, relay.TensorType((4, 5, 6), "int8"))
+    assert not alpha_equal(p1, p3)
+
+    p4 = relay.Param(v1, relay.TupleType([relay.TensorType((1, 2, 3),
+                                                           "float32")]))
+    assert not alpha_equal(p1, p4)
+
+
+def test_function_alpha_equal():
+    v1 = relay.Var("v1")
+    v2 = relay.Var("v2")
+    v3 = relay.Var("v3")
+    v4 = relay.Var("v4")
+
+    tt1 = relay.TensorType((1, 2, 3), "float32")
+    tt2 = relay.TensorType((4, 5, 6), "int8")
+    tt3 = relay.TupleType([tt1, tt2])
+
+    tp1 = relay.TypeParam("tp1", relay.Kind.Type)
+    tp2 = relay.TypeParam("tp2", relay.Kind.Type)
+    tp3 = relay.TypeParam("tp3", relay.Kind.Shape)
+    tp4 = relay.TypeParam("tp4", relay.Kind.Shape)
+
+    basic_args = [relay.Param(v3, tt1), relay.Param(v4, tt2)]
+    basic_tps = [tp1, tp2]
+
+    func = relay.Function([relay.Param(v1, tt1), relay.Param(v2, tt2)],
+                          tt2, v2, basic_tps)
+    mapped = relay.Function(basic_args, tt2, v4, basic_tps)
+    assert alpha_equal(func, mapped)
+
+    fewer_params = relay.Function([relay.Param(v4, tt2)], tt2, v4, basic_tps)
+    assert not alpha_equal(func, fewer_params)
+
+    more_params = relay.Function([relay.Param(v3, tt1), relay.Param(v4, tt2),
+                                  relay.Param(v2, tt2)], tt2, v4, basic_tps)
+    assert not alpha_equal(func, more_params)
+
+    params_unordered = relay.Function([relay.Param(v3, tt2),
+                                       relay.Param(v4, tt1)],
+                                      tt1, v3, basic_tps)
+    assert not alpha_equal(func, params_unordered)
+
+    params_mismatch = relay.Function([relay.Param(v3, tt3),
+                                      relay.Param(v4, tt2)],
+                                     tt2, v4, basic_tps)
+    assert not alpha_equal(func, params_mismatch)
+
+    # also would not typecheck
+    ret_type_mismatch = relay.Function(basic_args, tt1, v4, basic_tps)
+    assert not alpha_equal(func, ret_type_mismatch)
+
+    # also mis-typed
+    different_body = relay.Function(basic_args, tt2, v3, basic_tps)
+    assert not alpha_equal(func, different_body)
+
+    fewer_type_params = relay.Function(basic_args, tt2, v4, [tp1])
+    assert not alpha_equal(func, fewer_type_params)
+
+    more_type_params = relay.Function(basic_args, tt2, v4, [tp1, tp2, tp3])
+    assert not alpha_equal(func, more_type_params)
+
+    type_params_unordered = relay.Function(basic_args, tt2, v4, [tp2, tp1])
+    assert not alpha_equal(func, type_params_unordered)
+
+    different_type_params = relay.Function(basic_args, tt2, v4, [tp3, tp4])
+    assert not alpha_equal(func, different_type_params)
+
+    # a well-typed example that also differs in body, ret type, and type params
+    tupled_example = relay.Function(basic_args, tt3, relay.Tuple([v3, v4]))
+    assert not alpha_equal(func, tupled_example)
+
+
+def test_call_alpha_equal():
+    v1 = relay.Var("v1")
+    v2 = relay.Var("v2")
+
+    # attrs are compared only by pointer equality
+    attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
+    attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
+
+    tt1 = relay.TensorType((1, 2, 3), "float32")
+    tt2 = relay.TensorType((), "int8")
+
+    basic_args = [convert(1), convert(2), v2, relay.Tuple([])]
+
+    # manually writing out args to ensure that args does not rely on
+    # pointer equality
+    call = relay.Call(v1, [convert(1), convert(2), v2, relay.Tuple([])],
+                      attr1, [tt1])
+    same = relay.Call(v1, basic_args, attr1, [tt1])
+    assert alpha_equal(call, same)
+
+    different_fn = relay.Call(v2, basic_args, attr1, [tt1])
+    assert not alpha_equal(call, different_fn)
+
+    fewer_args = relay.Call(v1, [convert(1), convert(2), v2], attr1, [tt1])
+    assert not alpha_equal(call, fewer_args)
+
+    reordered_args = relay.Call(v1, [convert(2), convert(1),
+                                     relay.Tuple([]), v2], attr1, [tt1])
+    assert not alpha_equal(call, reordered_args)
+
+    different_args = relay.Call(v1, [convert(1), convert(2), convert(3)],
+                                attr1, [tt1])
+    assert not alpha_equal(call, different_args)
+
+    more_args = relay.Call(v1, [convert(1), convert(2), v2, relay.Tuple([]),
+                                convert(3), convert(4)], attr1, [tt1])
+    assert not alpha_equal(call, more_args)
+
+    different_attrs = relay.Call(v1, basic_args, attr2, [tt1])
+    assert not alpha_equal(call, different_attrs)
+
+    no_type_args = relay.Call(v1, basic_args, attr1)
+    assert not alpha_equal(call, no_type_args)
+
+    more_type_args = relay.Call(v1, basic_args, attr1, [tt1, tt2])
+    assert not alpha_equal(call, more_type_args)
+
+    different_type_arg = relay.Call(v1, basic_args, attr1, [tt2])
+    assert not alpha_equal(call, different_type_arg)
+
+
+def test_let_alpha_equal():
+    v1 = relay.Var("v1")
+    v2 = relay.Var("v2")
+    v3 = relay.Var("v3")
+
+    let = relay.Let(v1, convert(2), v1)
+    mapped = relay.Let(v2, convert(2), v2)
+    assert alpha_equal(let, mapped)
+
+    mismatched_var = relay.Let(v2, convert(2), v3)
+    assert not alpha_equal(let, mismatched_var)
+
+    different_value = relay.Let(v2, convert(3), v2)
+    assert not alpha_equal(let, different_value)
+
+    different_body = relay.Let(v2, convert(3), convert(12))
+    assert not alpha_equal(let, different_body)
+
+    # specified types must match
+    tt1 = relay.TensorType((), "float32")
+    tt2 = relay.TensorType((), "int8")
+    let_with_type = relay.Let(v1, convert(2), v1, tt1)
+    same_type = relay.Let(v1, convert(2), v1, tt1)
+    assert alpha_equal(let_with_type, same_type)
+    assert not alpha_equal(let, let_with_type)
+
+    different_type = relay.Let(v1, convert(2), v1, tt2)
+    assert not alpha_equal(let_with_type, different_type)
+
+
+def test_if_alpha_equal():
+    v1 = relay.Var("v1")
+    v2 = relay.Var("v2")
+
+    if_sample = relay.If(v1, convert(1), relay.Tuple([convert(2), convert(3)]))
+    same = relay.If(v1, convert(1), relay.Tuple([convert(2), convert(3)]))
+    assert alpha_equal(if_sample, same)
+
+    different_cond = relay.If(v2, convert(1), relay.Tuple([convert(2), convert(3)]))
+    assert not alpha_equal(if_sample, different_cond)
+
+    different_true = relay.If(v1, convert(2), relay.Tuple([convert(2), convert(3)]))
+    assert not alpha_equal(if_sample, different_true)
+
+    different_false = relay.If(v1, convert(1), relay.Tuple([]))
+    assert not alpha_equal(if_sample, different_false)
+
+
+def test_op_alpha_equal():
+    # only checks names
+    op1 = relay.op.get("add")
+    op2 = relay.op.get("add")
+    assert alpha_equal(op1, op2)
+
+    op3 = relay.op.get("take")
+    assert not alpha_equal(op1, op3)
+
+
 if __name__ == "__main__":
     test_tensor_type_alpha_equal()
     test_incomplete_type_alpha_equal()
@@ -182,4 +441,14 @@ if __name__ == "__main__":
     test_func_type_alpha_equal()
     test_tuple_type_alpha_equal()
     test_type_relation_alpha_equal()
+    test_constant_alpha_equal()
+    test_var_alpha_equal()
+    test_global_var_alpha_equal()
+    test_tuple_alpha_equal()
     test_tuple_get_item_alpha_equal()
+    test_param_alpha_equal()
+    test_function_alpha_equal()
+    test_call_alpha_equal()
+    test_let_alpha_equal()
+    test_if_alpha_equal()
+    test_op_alpha_equal()
-- 
GitLab