From 52e55baa47c08ceec52c67b8cd21b26b092860e2 Mon Sep 17 00:00:00 2001 From: Josh Pollock <joshpollock1997@gmail.com> Date: Sun, 2 Dec 2018 18:58:40 -0800 Subject: [PATCH] [Relay] Parser Tests (#2209) --- src/relay/ir/alpha_equal.cc | 12 +- src/relay/ir/text_printer.cc | 4 +- tests/python/relay/test_ir_parser.py | 562 +++++++++++++++++++++++++++ 3 files changed, 570 insertions(+), 8 deletions(-) create mode 100644 tests/python/relay/test_ir_parser.py diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 873210321..16af572a9 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -26,7 +26,7 @@ class AlphaEqualHandler: * Check equality of two nodes. * \param lhs The left hand operand. * \param rhs The right hand operand. - * \return the compare result. + * \return The compare result. */ bool Equal(const NodeRef& lhs, const NodeRef& rhs) { if (lhs.same_as(rhs)) return true; @@ -46,7 +46,7 @@ class AlphaEqualHandler: * Check equality of two attributes. * \param lhs The left hand operand. * \param rhs The right hand operand. - * \return the compare result. + * \return The compare result. */ bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) { return AttrsEqualHandler::Equal(lhs, rhs); @@ -55,7 +55,7 @@ class AlphaEqualHandler: * Check equality of two types. * \param lhs The left hand operand. * \param rhs The right hand operand. - * \return the compare result. + * \return The compare result. */ bool TypeEqual(const Type& lhs, const Type& rhs) { if (lhs.same_as(rhs)) return true; @@ -72,7 +72,7 @@ class AlphaEqualHandler: * * \param lhs The left hand operand. * \param rhs The right hand operand. - * \return the compare result. + * \return The compare result. */ bool ExprEqual(const Expr& lhs, const Expr& rhs) { if (lhs.same_as(rhs)) return true; @@ -94,7 +94,7 @@ class AlphaEqualHandler: * \brief Check if data type equals each other. * \param lhs The left hand operand. * \param rhs The right hand operand. - * \return the compare result. + * \return The compare result. */ bool DataTypeEqual(const DataType& lhs, const DataType& rhs) { return lhs == rhs; @@ -104,7 +104,7 @@ class AlphaEqualHandler: * if map_free_var_ is set to true, try to map via equal node. * \param lhs The left hand operand. * \param rhs The right hand operand. - * \return the compare result. + * \return The compare result. */ bool LeafNodeEqual(const NodeRef& lhs, const NodeRef& rhs) { if (lhs.same_as(rhs)) return true; diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 2664c4756..46b0d25b3 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -38,7 +38,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO * It can be hard to design a text format for all the possible nodes * as the set of nodes can grow when we do more extensions. * - * Instead of trying to design readable text format for every nodes, + * Instead of trying to design readable text format for every node, * we support a meta-data section in the text format. * We allow the text format to refer to a node in the meta-data section. * @@ -73,7 +73,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO * \endcode * * Note that we store tvm.var("n") in the meta data section. - * Since it is stored in the index-0 in the meta-data seciton, + * Since it is stored in the index-0 in the meta-data section, * we print it as meta.Variable(0). * * The text parser can recover this object by loading from the corresponding diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py new file mode 100644 index 000000000..c2c83df7e --- /dev/null +++ b/tests/python/relay/test_ir_parser.py @@ -0,0 +1,562 @@ +import tvm +from tvm import relay +from tvm.relay.parser import enabled +from tvm.relay.ir_pass import alpha_equal +from nose.tools import nottest, raises +from numpy import isclose +from typing import Union +from functools import wraps +if enabled(): + from tvm.relay._parser import ParseError + raises_parse_error = raises(ParseError) +else: + raises_parse_error = lambda x: x + +BINARY_OPS = { + "*": relay.multiply, + "/": relay.divide, + "+": relay.add, + "-": relay.subtract, + "<": relay.less, + ">": relay.greater, + "<=": relay.less_equal, + ">=": relay.greater_equal, + "==": relay.equal, + "!=": relay.not_equal, +} + +TYPES = { + "int8", + "int16", + "int32", + "int64", + + "uint8", + "uint16", + "uint32", + "uint64", + + "float16", + "float32", + "float64", + + "bool", + + "int8x4", + "uint1x4", + "float16x4", +} + +def get_scalar(x): + # type: (relay.Constant) -> (Union[float, int, bool]) + return x.data.asnumpy().item() + +int32 = relay.scalar_type("int32") + +_ = relay.Var("_") +X = relay.Var("x") +Y = relay.Var("y") +X_ANNO = relay.Var("x", int32) +Y_ANNO = relay.Var("y", int32) + +UNIT = relay.Tuple([]) + +# decorator to determine if parser is enabled +def if_parser_enabled(func): + # https://stackoverflow.com/q/7727678 + @wraps(func) + def wrapper(): + if not enabled(): + return + func() + return wrapper + +@if_parser_enabled +def test_comments(): + assert alpha_equal( + relay.fromtext(""" + // This is a line comment! + () + """), + UNIT + ) + + assert alpha_equal( + relay.fromtext(""" + /* This is a block comment! + This is still a block comment! + */ + () + """), + UNIT + ) + +@if_parser_enabled +def test_int_literal(): + assert isinstance(relay.fromtext("1"), relay.Constant) + assert isinstance(relay.fromtext("1").data, tvm.ndarray.NDArray) + + assert get_scalar(relay.fromtext("1")) == 1 + assert get_scalar(relay.fromtext("10")) == 10 + assert get_scalar(relay.fromtext("0")) == 0 + assert get_scalar(relay.fromtext("-100")) == -100 + assert get_scalar(relay.fromtext("-05")) == -5 + +@if_parser_enabled +def test_float_literal(): + assert get_scalar(relay.fromtext("1.0")) == 1.0 + assert isclose(get_scalar(relay.fromtext("1.56667")), 1.56667) + assert get_scalar(relay.fromtext("0.0")) == 0.0 + assert get_scalar(relay.fromtext("-10.0")) == -10.0 + + # scientific notation + assert isclose(get_scalar(relay.fromtext("1e-1")), 1e-1) + assert get_scalar(relay.fromtext("1e+1")) == 1e+1 + assert isclose(get_scalar(relay.fromtext("1E-1")), 1E-1) + assert get_scalar(relay.fromtext("1E+1")) == 1E+1 + assert isclose(get_scalar(relay.fromtext("1.0e-1")), 1.0e-1) + assert get_scalar(relay.fromtext("1.0e+1")) == 1.0e+1 + assert isclose(get_scalar(relay.fromtext("1.0E-1")), 1.0E-1) + assert get_scalar(relay.fromtext("1.0E+1")) == 1.0E+1 + +@if_parser_enabled +def test_bool_literal(): + assert get_scalar(relay.fromtext("True")) == True + assert get_scalar(relay.fromtext("False")) == False + +@if_parser_enabled +def test_negative(): + assert isinstance(relay.fromtext("let %x = 1; -%x").body, relay.Call) + assert get_scalar(relay.fromtext("--10")) == 10 + assert get_scalar(relay.fromtext("---10")) == -10 + +@if_parser_enabled +def test_bin_op(): + for bin_op in BINARY_OPS.keys(): + assert alpha_equal( + relay.fromtext("1 {} 1".format(bin_op)), + BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1)) + ) + +@if_parser_enabled +def test_parens(): + assert alpha_equal(relay.fromtext("1 * 1 + 1"), relay.fromtext("(1 * 1) + 1")) + assert not alpha_equal(relay.fromtext("1 * 1 + 1"), relay.fromtext("1 * (1 + 1)")) + +@if_parser_enabled +def test_op_assoc(): + assert alpha_equal(relay.fromtext("1 * 1 + 1 < 1 == 1"), relay.fromtext("(((1 * 1) + 1) < 1) == 1")) + assert alpha_equal(relay.fromtext("1 == 1 < 1 + 1 * 1"), relay.fromtext("1 == (1 < (1 + (1 * 1)))")) + +@nottest +@if_parser_enabled +def test_vars(): + # temp vars won't work b/c they start with a digit + # # temp var + # temp_var = relay.fromtext("%1") + # assert isinstance(temp_var, relay.Var) + # assert temp_var.name == "1" + + # var + var = relay.fromtext("let %foo = (); %foo") + assert isinstance(var.body, relay.Var) + assert var.body.name_hint == "foo" + + # global var + global_var = relay.fromtext("@foo") + assert isinstance(global_var, relay.GlobalVar) + assert global_var.name_hint == "foo" + + # operator id + op = relay.fromtext("foo") + assert isinstance(op, relay.Op) + assert op.name == "foo" + +@if_parser_enabled +def test_let(): + assert alpha_equal( + relay.fromtext("let %x = 1; ()"), + relay.Let( + X, + relay.const(1), + UNIT + ) + ) + +@if_parser_enabled +def test_seq(): + assert alpha_equal( + relay.fromtext("(); ()"), + relay.Let( + _, + UNIT, + UNIT) + ) + + assert alpha_equal( + relay.fromtext("let %_ = { 1 }; ()"), + relay.Let( + X, + relay.const(1), + UNIT + ) + ) + +@raises_parse_error +@if_parser_enabled +def test_let_global_var(): + relay.fromtext("let @x = 1; ()") + +@raises_parse_error +@if_parser_enabled +def test_let_op(): + relay.fromtext("let x = 1; ()") + +@if_parser_enabled +def test_tuple(): + assert alpha_equal(relay.fromtext("()"), relay.Tuple([])) + + assert alpha_equal(relay.fromtext("(0,)"), relay.Tuple([relay.const(0)])) + + assert alpha_equal(relay.fromtext("(0, 1)"), relay.Tuple([relay.const(0), relay.const(1)])) + + assert alpha_equal(relay.fromtext("(0, 1, 2)"), relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) + +@if_parser_enabled +def test_func(): + # 0 args + assert alpha_equal( + relay.fromtext("fn () { 0 }"), + relay.Function( + [], + relay.const(0), + None, + [] + ) + ) + + # 1 arg + assert alpha_equal( + relay.fromtext("fn (%x) { %x }"), + relay.Function( + [X], + X, + None, + [] + ) + ) + + # 2 args + assert alpha_equal( + relay.fromtext("fn (%x, %y) { %x + %y }"), + relay.Function( + [X, Y], + relay.add(X, Y), + None, + [] + ) + ) + + # annotations + assert alpha_equal( + relay.fromtext("fn (%x: int32) -> int32 { %x }"), + relay.Function( + [X_ANNO], + X_ANNO, + int32, + [] + ) + ) + +# TODO(@jmp): Crashes if %x isn't annnotated. +# @nottest +@if_parser_enabled +def test_defn(): + id_defn = relay.fromtext( + """ + def @id(%x: int32) -> int32 { + %x + } + """) + assert isinstance(id_defn, relay.Module) + +@if_parser_enabled +def test_ifelse(): + assert alpha_equal( + relay.fromtext( + """ + if (True) { + 0 + } else { + 1 + } + """ + ), + relay.If( + relay.const(True), + relay.const(0), + relay.const(1) + ) + ) + +@raises_parse_error +@if_parser_enabled +def test_ifelse_scope(): + relay.fromtext( + """ + if (True) { + let %x = (); + () + } else { + %x + } + """ + ) + +@if_parser_enabled +def test_call(): + # 0 args + constant = relay.Var("constant") + assert alpha_equal( + relay.fromtext( + """ + let %constant = fn () { 0 }; + %constant() + """ + ), + relay.Let( + constant, + relay.Function([], relay.const(0), None, []), + relay.Call(constant, [], None, None) + ) + ) + + # 1 arg + id_var = relay.Var("id") + assert alpha_equal( + relay.fromtext( + """ + let %id = fn (%x) { %x }; + %id(1) + """ + ), + relay.Let( + id_var, + relay.Function([X], X, None, []), + relay.Call(id_var, [relay.const(1)], None, None) + ) + ) + + # 2 args + multiply = relay.Var("multiply") + assert alpha_equal( + relay.fromtext( + """ + let %multiply = fn (%x, %y) { %x * %y }; + %multiply(0, 0) + """ + ), + relay.Let( + multiply, + relay.Function( + [X, Y], + relay.multiply(X, Y), + None, + [] + ), + relay.Call(multiply, [relay.const(0), relay.const(0)], None, None) + ) + ) + + # anonymous function + assert alpha_equal( + relay.fromtext( + """ + (fn (%x) { %x })(0) + """ + ), + relay.Call( + relay.Function( + [X], + X, + None, + [] + ), + [relay.const(0)], + None, + None + ) + ) + + # curried function + curried_mult = relay.Var("curried_mult") + alpha_equal( + relay.fromtext( + """ + let %curried_mult = + fn (%x) { + fn (%y) { + %x * %y + } + }; + %curried_mult(0); + %curried_mult(0)(0) + """ + ), + relay.Let( + curried_mult, + relay.Function( + [X], + relay.Function( + [Y], + relay.multiply(X, Y), + None, + [] + ), + None, + [] + ), + relay.Let( + _, + relay.Call(curried_mult, [relay.const(0)], None, None), + relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None) + ) + ) + ) + + # op + alpha_equal( + relay.fromtext("abs(1)"), + relay.Call(relay.op.get("abs"), [relay.const(1)], None, None) + ) + +# Types + +@if_parser_enabled +def test_incomplete_type(): + assert alpha_equal( + relay.fromtext("let %_ : _ = (); ()"), + relay.Let( + _, + UNIT, + UNIT + ) + ) + +@if_parser_enabled +def test_builtin_types(): + for builtin_type in TYPES: + relay.fromtext("let %_ : {} = (); ()".format(builtin_type)) + +@nottest +@if_parser_enabled +def test_call_type(): + assert False + +@if_parser_enabled +def test_tensor_type(): + assert alpha_equal( + relay.fromtext("let %_ : Tensor[(), float32] = (); ()"), + relay.Let( + relay.Var("_", relay.TensorType((), "float32")), + UNIT, + UNIT + ) + ) + + assert alpha_equal( + relay.fromtext("let %_ : Tensor[(1,), float32] = (); ()"), + relay.Let( + relay.Var("_", relay.TensorType((1,), "float32")), + UNIT, + UNIT + ) + ) + + assert alpha_equal( + relay.fromtext("let %_ : Tensor[(1, 1), float32] = (); ()"), + relay.Let( + relay.Var("_", relay.TensorType((1, 1), "float32")), + UNIT, + UNIT + ) + ) + +@if_parser_enabled +def test_function_type(): + assert alpha_equal( + relay.fromtext( + """ + let %_: fn () -> int32 = fn () -> int32 { 0 }; () + """ + ), + relay.Let( + relay.Var("_", relay.FuncType([], int32, [], [])), + relay.Function([], relay.const(0), int32, []), + UNIT + ) + ) + + assert alpha_equal( + relay.fromtext( + """ + let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; () + """ + ), + relay.Let( + relay.Var("_", relay.FuncType([int32], int32, [], [])), + relay.Function([relay.Var("x", int32)], relay.const(0), int32, []), + UNIT + ) + ) + + assert alpha_equal( + relay.fromtext( + """ + let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; () + """ + ), + relay.Let( + relay.Var("_", relay.FuncType([int32, int32], int32, [], [])), + relay.Function([relay.Var("x", int32), relay.Var("y", int32)], relay.const(0), int32, []), + UNIT + ) + ) + +@if_parser_enabled +def test_tuple_type(): + assert alpha_equal( + relay.fromtext( + """ + let %_: () = (); () + """), + relay.Let( + relay.Var("_", relay.TupleType([])), + UNIT, + UNIT + ) + ) + + assert alpha_equal( + relay.fromtext( + """ + let %_: (int32,) = (0,); () + """), + relay.Let( + relay.Var("_", relay.TupleType([int32])), + relay.Tuple([relay.const(0)]), + UNIT + ) + ) + + assert alpha_equal( + relay.fromtext( + """ + let %_: (int32, int32) = (0, 1); () + """), + relay.Let( + relay.Var("_", relay.TupleType([int32, int32])), + relay.Tuple([relay.const(0), relay.const(1)]), + UNIT + ) + ) -- GitLab