From 65016b6528d2c6aed7816f8f813c7f2eccc31256 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" <slyubomirsky@gmail.com> Date: Thu, 11 Oct 2018 09:24:26 -0700 Subject: [PATCH] [Relay] Alpha equality tests for Relay exprs (#1871) --- python/tvm/relay/expr.py | 2 +- src/relay/pass/alpha_eq.cc | 44 +++ tests/python/relay/test_pass_alpha_equal.py | 281 +++++++++++++++++++- 3 files changed, 320 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 05214ca09..6ed8df0d7 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -112,7 +112,7 @@ class Call(Expr): class Let(Expr): """A variable bindings in Relay, see tvm/relay/expr.h for more details.""" - def __init__(self, var, value, body, value_type): + def __init__(self, var, value, body, value_type=None): self.__init_handle_by_constructor__( _make.Let, var, value, body, value_type) diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 0e13a598c..0ed0e3df3 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -268,10 +268,27 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { return; } + if (func1->type_params.size() != func2->type_params.size()) { + equal = false; + return; + } + for (size_t i = 0U; i < func1->params.size(); i++) { this->VisitExpr(func1->params[i], func2->params[i]); } + for (size_t i = 0U; i < func1->type_params.size(); i++) { + equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]); + if (!equal) { + return; + } + } + + equal = equal && AlphaEqual(func1->ret_type, func2->ret_type); + if (!equal) { + return; + } + this->VisitExpr(func1->body, func2->body); } else { equal = false; @@ -287,10 +304,27 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { return; } + if (op->type_args.size() != call->type_args.size()) { + equal = false; + return; + } + + // checking attrs by pointer equality for now + equal = equal && (op->attrs == call->attrs); + if (!equal) { + return; + } + for (size_t i = 0U; i < op->args.size(); i++) { this->VisitExpr(op->args[i], call->args[i]); } + for (size_t i = 0U; i < op->type_args.size(); i++) { + equal = equal && AlphaEqual(op->type_args[i], call->type_args[i]); + if (!equal) { + return; + } + } } else { equal = false; } @@ -301,6 +335,16 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { eq_map.Set(op->var, let->var); this->VisitExpr(op->value, let->value); this->VisitExpr(op->body, let->body); + + // value_type should match as well (including nulls) + if (op->value_type.defined() != let->value_type.defined()) { + equal = false; + return; + } + + if (op->value_type.defined()) { + equal = equal && AlphaEqual(op->value_type, let->value_type); + } } else { equal = false; } diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 9fa1a554a..dd722399d 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -14,12 +14,6 @@ def test_tensor_type_alpha_equal(): t2 = relay.TensorType((), "float32") assert t1 == t2 -def test_constant_alpha_equal(): - x = convert(1) - y = convert(2) - assert alpha_equal(x, x) - assert not alpha_equal(x, y) - assert alpha_equal(x, convert(1)) def test_incomplete_type_alpha_equal(): t1 = relay.IncompleteType(relay.Kind.Shape) @@ -167,6 +161,79 @@ def test_type_relation_alpha_equal(): assert bigger != diff_num_inputs + +def test_constant_alpha_equal(): + x = convert(1) + y = convert(2) + assert alpha_equal(x, x) + assert not alpha_equal(x, y) + assert alpha_equal(x, convert(1)) + + +def test_var_alpha_equal(): + v1 = relay.Var("v1") + v2 = relay.Var("v2") + + # normally only pointer equality + assert alpha_equal(v1, v1) + assert not alpha_equal(v1, v2) + + # let node allows for setting the eq_map + l1 = relay.Let(v1, convert(1), v1, None) + l2 = relay.Let(v2, convert(1), v2, None) + l3 = relay.Let(v1, convert(1), v2, None) + + assert alpha_equal(l1, l2) + assert not alpha_equal(l1, l3) + + +def test_global_var_alpha_equal(): + v1 = relay.GlobalVar("v1") + v2 = relay.GlobalVar("v2") + + # only pointer equality suffices (smoke test) + assert alpha_equal(v1, v1) + assert not alpha_equal(v1, v2) + + +def test_tuple_alpha_equal(): + v1 = relay.Var("v1") + v2 = relay.Var("v2") + + # unit value is a valid tuple + assert alpha_equal(relay.Tuple([]), relay.Tuple([])) + + tup = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)])]) + same = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)])]) + + assert alpha_equal(tup, same) + + # use the eq_map + let_tup = relay.Let(v1, tup, v1, None) + let_mapped = relay.Let(v2, relay.Tuple([v2, convert(2), convert(3), + relay.Tuple([convert(4)])]), + v2, None) + assert alpha_equal(let_tup, let_mapped) + + more_fields = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)]), v2]) + assert not alpha_equal(tup, more_fields) + + fewer_fields = relay.Tuple([v1, convert(2), convert(3)]) + assert not alpha_equal(tup, fewer_fields) + + different_end = relay.Tuple([v1, convert(2), convert(3), + relay.Tuple([convert(5)])]) + assert not alpha_equal(tup, different_end) + + different_start = relay.Tuple([v2, convert(2), convert(3), + relay.Tuple([convert(4)])]) + assert not alpha_equal(tup, different_start) + + longer_at_end = relay.Tuple([v1, convert(2), convert(3), + relay.Tuple([convert(4), convert(5)])]) + assert not alpha_equal(tup, longer_at_end) + + def test_tuple_get_item_alpha_equal(): x = relay.Var('x') y = relay.Var('y') @@ -174,6 +241,198 @@ def test_tuple_get_item_alpha_equal(): assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2)) assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1)) + +def test_param_alpha_equal(): + # only checks equality of the types + v1 = relay.Var("v1") + v2 = relay.Var("v2") + + p1 = relay.Param(v1, relay.TensorType((1, 2, 3), "float32")) + p2 = relay.Param(v2, relay.TensorType((1, 2, 3), "float32")) + assert alpha_equal(p1, p2) + + p3 = relay.Param(v1, relay.TensorType((4, 5, 6), "int8")) + assert not alpha_equal(p1, p3) + + p4 = relay.Param(v1, relay.TupleType([relay.TensorType((1, 2, 3), + "float32")])) + assert not alpha_equal(p1, p4) + + +def test_function_alpha_equal(): + v1 = relay.Var("v1") + v2 = relay.Var("v2") + v3 = relay.Var("v3") + v4 = relay.Var("v4") + + tt1 = relay.TensorType((1, 2, 3), "float32") + tt2 = relay.TensorType((4, 5, 6), "int8") + tt3 = relay.TupleType([tt1, tt2]) + + tp1 = relay.TypeParam("tp1", relay.Kind.Type) + tp2 = relay.TypeParam("tp2", relay.Kind.Type) + tp3 = relay.TypeParam("tp3", relay.Kind.Shape) + tp4 = relay.TypeParam("tp4", relay.Kind.Shape) + + basic_args = [relay.Param(v3, tt1), relay.Param(v4, tt2)] + basic_tps = [tp1, tp2] + + func = relay.Function([relay.Param(v1, tt1), relay.Param(v2, tt2)], + tt2, v2, basic_tps) + mapped = relay.Function(basic_args, tt2, v4, basic_tps) + assert alpha_equal(func, mapped) + + fewer_params = relay.Function([relay.Param(v4, tt2)], tt2, v4, basic_tps) + assert not alpha_equal(func, fewer_params) + + more_params = relay.Function([relay.Param(v3, tt1), relay.Param(v4, tt2), + relay.Param(v2, tt2)], tt2, v4, basic_tps) + assert not alpha_equal(func, more_params) + + params_unordered = relay.Function([relay.Param(v3, tt2), + relay.Param(v4, tt1)], + tt1, v3, basic_tps) + assert not alpha_equal(func, params_unordered) + + params_mismatch = relay.Function([relay.Param(v3, tt3), + relay.Param(v4, tt2)], + tt2, v4, basic_tps) + assert not alpha_equal(func, params_mismatch) + + # also would not typecheck + ret_type_mismatch = relay.Function(basic_args, tt1, v4, basic_tps) + assert not alpha_equal(func, ret_type_mismatch) + + # also mis-typed + different_body = relay.Function(basic_args, tt2, v3, basic_tps) + assert not alpha_equal(func, different_body) + + fewer_type_params = relay.Function(basic_args, tt2, v4, [tp1]) + assert not alpha_equal(func, fewer_type_params) + + more_type_params = relay.Function(basic_args, tt2, v4, [tp1, tp2, tp3]) + assert not alpha_equal(func, more_type_params) + + type_params_unordered = relay.Function(basic_args, tt2, v4, [tp2, tp1]) + assert not alpha_equal(func, type_params_unordered) + + different_type_params = relay.Function(basic_args, tt2, v4, [tp3, tp4]) + assert not alpha_equal(func, different_type_params) + + # a well-typed example that also differs in body, ret type, and type params + tupled_example = relay.Function(basic_args, tt3, relay.Tuple([v3, v4])) + assert not alpha_equal(func, tupled_example) + + +def test_call_alpha_equal(): + v1 = relay.Var("v1") + v2 = relay.Var("v2") + + # attrs are compared only by pointer equality + attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) + attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) + + tt1 = relay.TensorType((1, 2, 3), "float32") + tt2 = relay.TensorType((), "int8") + + basic_args = [convert(1), convert(2), v2, relay.Tuple([])] + + # manually writing out args to ensure that args does not rely on + # pointer equality + call = relay.Call(v1, [convert(1), convert(2), v2, relay.Tuple([])], + attr1, [tt1]) + same = relay.Call(v1, basic_args, attr1, [tt1]) + assert alpha_equal(call, same) + + different_fn = relay.Call(v2, basic_args, attr1, [tt1]) + assert not alpha_equal(call, different_fn) + + fewer_args = relay.Call(v1, [convert(1), convert(2), v2], attr1, [tt1]) + assert not alpha_equal(call, fewer_args) + + reordered_args = relay.Call(v1, [convert(2), convert(1), + relay.Tuple([]), v2], attr1, [tt1]) + assert not alpha_equal(call, reordered_args) + + different_args = relay.Call(v1, [convert(1), convert(2), convert(3)], + attr1, [tt1]) + assert not alpha_equal(call, different_args) + + more_args = relay.Call(v1, [convert(1), convert(2), v2, relay.Tuple([]), + convert(3), convert(4)], attr1, [tt1]) + assert not alpha_equal(call, more_args) + + different_attrs = relay.Call(v1, basic_args, attr2, [tt1]) + assert not alpha_equal(call, different_attrs) + + no_type_args = relay.Call(v1, basic_args, attr1) + assert not alpha_equal(call, no_type_args) + + more_type_args = relay.Call(v1, basic_args, attr1, [tt1, tt2]) + assert not alpha_equal(call, more_type_args) + + different_type_arg = relay.Call(v1, basic_args, attr1, [tt2]) + assert not alpha_equal(call, different_type_arg) + + +def test_let_alpha_equal(): + v1 = relay.Var("v1") + v2 = relay.Var("v2") + v3 = relay.Var("v3") + + let = relay.Let(v1, convert(2), v1) + mapped = relay.Let(v2, convert(2), v2) + assert alpha_equal(let, mapped) + + mismatched_var = relay.Let(v2, convert(2), v3) + assert not alpha_equal(let, mismatched_var) + + different_value = relay.Let(v2, convert(3), v2) + assert not alpha_equal(let, different_value) + + different_body = relay.Let(v2, convert(3), convert(12)) + assert not alpha_equal(let, different_body) + + # specified types must match + tt1 = relay.TensorType((), "float32") + tt2 = relay.TensorType((), "int8") + let_with_type = relay.Let(v1, convert(2), v1, tt1) + same_type = relay.Let(v1, convert(2), v1, tt1) + assert alpha_equal(let_with_type, same_type) + assert not alpha_equal(let, let_with_type) + + different_type = relay.Let(v1, convert(2), v1, tt2) + assert not alpha_equal(let_with_type, different_type) + + +def test_if_alpha_equal(): + v1 = relay.Var("v1") + v2 = relay.Var("v2") + + if_sample = relay.If(v1, convert(1), relay.Tuple([convert(2), convert(3)])) + same = relay.If(v1, convert(1), relay.Tuple([convert(2), convert(3)])) + assert alpha_equal(if_sample, same) + + different_cond = relay.If(v2, convert(1), relay.Tuple([convert(2), convert(3)])) + assert not alpha_equal(if_sample, different_cond) + + different_true = relay.If(v1, convert(2), relay.Tuple([convert(2), convert(3)])) + assert not alpha_equal(if_sample, different_true) + + different_false = relay.If(v1, convert(1), relay.Tuple([])) + assert not alpha_equal(if_sample, different_false) + + +def test_op_alpha_equal(): + # only checks names + op1 = relay.op.get("add") + op2 = relay.op.get("add") + assert alpha_equal(op1, op2) + + op3 = relay.op.get("take") + assert not alpha_equal(op1, op3) + + if __name__ == "__main__": test_tensor_type_alpha_equal() test_incomplete_type_alpha_equal() @@ -182,4 +441,14 @@ if __name__ == "__main__": test_func_type_alpha_equal() test_tuple_type_alpha_equal() test_type_relation_alpha_equal() + test_constant_alpha_equal() + test_var_alpha_equal() + test_global_var_alpha_equal() + test_tuple_alpha_equal() test_tuple_get_item_alpha_equal() + test_param_alpha_equal() + test_function_alpha_equal() + test_call_alpha_equal() + test_let_alpha_equal() + test_if_alpha_equal() + test_op_alpha_equal() -- GitLab