From 52e55baa47c08ceec52c67b8cd21b26b092860e2 Mon Sep 17 00:00:00 2001
From: Josh Pollock <joshpollock1997@gmail.com>
Date: Sun, 2 Dec 2018 18:58:40 -0800
Subject: [PATCH] [Relay] Parser Tests (#2209)

---
 src/relay/ir/alpha_equal.cc          |  12 +-
 src/relay/ir/text_printer.cc         |   4 +-
 tests/python/relay/test_ir_parser.py | 562 +++++++++++++++++++++++++++
 3 files changed, 570 insertions(+), 8 deletions(-)
 create mode 100644 tests/python/relay/test_ir_parser.py

diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc
index 873210321..16af572a9 100644
--- a/src/relay/ir/alpha_equal.cc
+++ b/src/relay/ir/alpha_equal.cc
@@ -26,7 +26,7 @@ class AlphaEqualHandler:
    * Check equality of two nodes.
    * \param lhs The left hand operand.
    * \param rhs The right hand operand.
-   * \return the compare result.
+   * \return The compare result.
    */
   bool Equal(const NodeRef& lhs, const NodeRef& rhs) {
     if (lhs.same_as(rhs)) return true;
@@ -46,7 +46,7 @@ class AlphaEqualHandler:
    * Check equality of two attributes.
    * \param lhs The left hand operand.
    * \param rhs The right hand operand.
-   * \return the compare result.
+   * \return The compare result.
    */
   bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) {
     return AttrsEqualHandler::Equal(lhs, rhs);
@@ -55,7 +55,7 @@ class AlphaEqualHandler:
    * Check equality of two types.
    * \param lhs The left hand operand.
    * \param rhs The right hand operand.
-   * \return the compare result.
+   * \return The compare result.
    */
   bool TypeEqual(const Type& lhs, const Type& rhs) {
     if (lhs.same_as(rhs)) return true;
@@ -72,7 +72,7 @@ class AlphaEqualHandler:
    *
    * \param lhs The left hand operand.
    * \param rhs The right hand operand.
-   * \return the compare result.
+   * \return The compare result.
    */
   bool ExprEqual(const Expr& lhs, const Expr& rhs) {
     if (lhs.same_as(rhs)) return true;
@@ -94,7 +94,7 @@ class AlphaEqualHandler:
    * \brief Check if data type equals each other.
    * \param lhs The left hand operand.
    * \param rhs The right hand operand.
-   * \return the compare result.
+   * \return The compare result.
    */
   bool DataTypeEqual(const DataType& lhs, const DataType& rhs) {
     return lhs == rhs;
@@ -104,7 +104,7 @@ class AlphaEqualHandler:
    *  if map_free_var_ is set to true, try to map via equal node.
    * \param lhs The left hand operand.
    * \param rhs The right hand operand.
-   * \return the compare result.
+   * \return The compare result.
    */
   bool LeafNodeEqual(const NodeRef& lhs, const NodeRef& rhs) {
     if (lhs.same_as(rhs)) return true;
diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc
index 2664c4756..46b0d25b3 100644
--- a/src/relay/ir/text_printer.cc
+++ b/src/relay/ir/text_printer.cc
@@ -38,7 +38,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) {  // NO
  * It can be hard to design a text format for all the possible nodes
  * as the set of nodes can grow when we do more extensions.
  *
- * Instead of trying to design readable text format for every nodes,
+ * Instead of trying to design readable text format for every node,
  * we support a meta-data section in the text format.
  * We allow the text format to refer to a node in the meta-data section.
  *
@@ -73,7 +73,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) {  // NO
  * \endcode
  *
  * Note that we store tvm.var("n") in the meta data section.
- * Since it is stored in the index-0 in the meta-data seciton,
+ * Since it is stored in the index-0 in the meta-data section,
  * we print it as meta.Variable(0).
  *
  * The text parser can recover this object by loading from the corresponding
diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py
new file mode 100644
index 000000000..c2c83df7e
--- /dev/null
+++ b/tests/python/relay/test_ir_parser.py
@@ -0,0 +1,562 @@
+import tvm
+from tvm import relay
+from tvm.relay.parser import enabled
+from tvm.relay.ir_pass import alpha_equal
+from nose.tools import nottest, raises
+from numpy import isclose
+from typing import Union
+from functools import wraps
+if enabled():
+    from tvm.relay._parser import ParseError
+    raises_parse_error = raises(ParseError)
+else:
+    raises_parse_error = lambda x: x
+
+BINARY_OPS = {
+    "*": relay.multiply,
+    "/": relay.divide,
+    "+": relay.add,
+    "-": relay.subtract,
+    "<": relay.less,
+    ">": relay.greater,
+    "<=": relay.less_equal,
+    ">=": relay.greater_equal,
+    "==": relay.equal,
+    "!=": relay.not_equal,
+}
+
+TYPES = {
+    "int8",
+    "int16",
+    "int32",
+    "int64",
+
+    "uint8",
+    "uint16",
+    "uint32",
+    "uint64",
+
+    "float16",
+    "float32",
+    "float64",
+
+    "bool",
+
+    "int8x4",
+    "uint1x4",
+    "float16x4",
+}
+
+def get_scalar(x):
+    # type: (relay.Constant) -> (Union[float, int, bool])
+    return x.data.asnumpy().item()
+
+int32 = relay.scalar_type("int32")
+
+_ = relay.Var("_")
+X = relay.Var("x")
+Y = relay.Var("y")
+X_ANNO = relay.Var("x", int32)
+Y_ANNO = relay.Var("y", int32)
+
+UNIT = relay.Tuple([])
+
+# decorator to determine if parser is enabled
+def if_parser_enabled(func):
+    # https://stackoverflow.com/q/7727678
+    @wraps(func)
+    def wrapper():
+        if not enabled():
+            return
+        func()
+    return wrapper
+
+@if_parser_enabled
+def test_comments():
+    assert alpha_equal(
+        relay.fromtext("""
+            // This is a line comment!
+            ()
+        """),
+        UNIT
+    )
+
+    assert alpha_equal(
+        relay.fromtext("""
+            /* This is a block comment!
+               This is still a block comment!
+            */
+            ()
+        """),
+        UNIT
+    )
+
+@if_parser_enabled
+def test_int_literal():
+    assert isinstance(relay.fromtext("1"), relay.Constant)
+    assert isinstance(relay.fromtext("1").data, tvm.ndarray.NDArray)
+    
+    assert get_scalar(relay.fromtext("1")) == 1
+    assert get_scalar(relay.fromtext("10")) == 10
+    assert get_scalar(relay.fromtext("0")) == 0
+    assert get_scalar(relay.fromtext("-100")) == -100
+    assert get_scalar(relay.fromtext("-05")) == -5
+
+@if_parser_enabled
+def test_float_literal():
+    assert get_scalar(relay.fromtext("1.0")) == 1.0
+    assert isclose(get_scalar(relay.fromtext("1.56667")), 1.56667)
+    assert get_scalar(relay.fromtext("0.0")) == 0.0
+    assert get_scalar(relay.fromtext("-10.0")) == -10.0
+
+    # scientific notation
+    assert isclose(get_scalar(relay.fromtext("1e-1")), 1e-1)
+    assert get_scalar(relay.fromtext("1e+1")) == 1e+1
+    assert isclose(get_scalar(relay.fromtext("1E-1")), 1E-1)
+    assert get_scalar(relay.fromtext("1E+1")) == 1E+1
+    assert isclose(get_scalar(relay.fromtext("1.0e-1")), 1.0e-1)
+    assert get_scalar(relay.fromtext("1.0e+1")) == 1.0e+1
+    assert isclose(get_scalar(relay.fromtext("1.0E-1")), 1.0E-1)
+    assert get_scalar(relay.fromtext("1.0E+1")) == 1.0E+1
+
+@if_parser_enabled
+def test_bool_literal():
+    assert get_scalar(relay.fromtext("True")) == True
+    assert get_scalar(relay.fromtext("False")) == False
+
+@if_parser_enabled
+def test_negative():
+    assert isinstance(relay.fromtext("let %x = 1; -%x").body, relay.Call)
+    assert get_scalar(relay.fromtext("--10")) == 10
+    assert get_scalar(relay.fromtext("---10")) == -10
+
+@if_parser_enabled
+def test_bin_op():
+    for bin_op in BINARY_OPS.keys():
+        assert alpha_equal(
+            relay.fromtext("1 {} 1".format(bin_op)),
+            BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1))
+        )
+
+@if_parser_enabled
+def test_parens():
+    assert alpha_equal(relay.fromtext("1 * 1 + 1"), relay.fromtext("(1 * 1) + 1"))
+    assert not alpha_equal(relay.fromtext("1 * 1 + 1"), relay.fromtext("1 * (1 + 1)"))
+
+@if_parser_enabled
+def test_op_assoc():
+    assert alpha_equal(relay.fromtext("1 * 1 + 1 < 1 == 1"), relay.fromtext("(((1 * 1) + 1) < 1) == 1"))
+    assert alpha_equal(relay.fromtext("1 == 1 < 1 + 1 * 1"), relay.fromtext("1 == (1 < (1 + (1 * 1)))"))
+
+@nottest
+@if_parser_enabled
+def test_vars():
+    # temp vars won't work b/c they start with a digit
+    # # temp var
+    # temp_var = relay.fromtext("%1")
+    # assert isinstance(temp_var, relay.Var)
+    # assert temp_var.name == "1"
+
+    # var
+    var = relay.fromtext("let %foo = (); %foo")
+    assert isinstance(var.body, relay.Var)
+    assert var.body.name_hint == "foo"
+
+    # global var
+    global_var = relay.fromtext("@foo")
+    assert isinstance(global_var, relay.GlobalVar)
+    assert global_var.name_hint == "foo"
+
+    # operator id
+    op = relay.fromtext("foo")
+    assert isinstance(op, relay.Op)
+    assert op.name == "foo"
+
+@if_parser_enabled
+def test_let():
+    assert alpha_equal(
+        relay.fromtext("let %x = 1; ()"),
+        relay.Let(
+            X,
+            relay.const(1),
+            UNIT
+        )
+    )
+
+@if_parser_enabled
+def test_seq():
+    assert alpha_equal(
+        relay.fromtext("(); ()"),
+        relay.Let(
+            _,
+            UNIT,
+            UNIT)
+    )
+
+    assert alpha_equal(
+        relay.fromtext("let %_ = { 1 }; ()"),
+        relay.Let(
+            X,
+            relay.const(1),
+            UNIT
+        )
+    )
+
+@raises_parse_error
+@if_parser_enabled
+def test_let_global_var():
+    relay.fromtext("let @x = 1; ()")
+
+@raises_parse_error
+@if_parser_enabled
+def test_let_op():
+    relay.fromtext("let x = 1; ()")
+
+@if_parser_enabled
+def test_tuple():
+    assert alpha_equal(relay.fromtext("()"), relay.Tuple([]))
+
+    assert alpha_equal(relay.fromtext("(0,)"), relay.Tuple([relay.const(0)]))
+
+    assert alpha_equal(relay.fromtext("(0, 1)"), relay.Tuple([relay.const(0), relay.const(1)]))
+
+    assert alpha_equal(relay.fromtext("(0, 1, 2)"), relay.Tuple([relay.const(0), relay.const(1), relay.const(2)]))
+
+@if_parser_enabled
+def test_func():
+    # 0 args
+    assert alpha_equal(
+        relay.fromtext("fn () { 0 }"),
+        relay.Function(
+            [],
+            relay.const(0),
+            None,
+            []
+        )
+    )
+
+    # 1 arg
+    assert alpha_equal(
+        relay.fromtext("fn (%x) { %x }"),
+        relay.Function(
+            [X],
+            X,
+            None,
+            []
+        )
+    )
+
+    # 2 args
+    assert alpha_equal(
+        relay.fromtext("fn (%x, %y) { %x + %y }"),
+        relay.Function(
+            [X, Y],
+            relay.add(X, Y),
+            None,
+            []
+        )
+    )
+
+    # annotations
+    assert alpha_equal(
+        relay.fromtext("fn (%x: int32) -> int32 { %x }"),
+        relay.Function(
+            [X_ANNO],
+            X_ANNO,
+            int32,
+            []
+        )
+    )
+
+# TODO(@jmp): Crashes if %x isn't annnotated.
+# @nottest
+@if_parser_enabled
+def test_defn():
+    id_defn = relay.fromtext(
+        """
+        def @id(%x: int32) -> int32 {
+            %x
+        }
+        """)
+    assert isinstance(id_defn, relay.Module)
+
+@if_parser_enabled
+def test_ifelse():
+    assert alpha_equal(
+        relay.fromtext(
+        """
+        if (True) {
+            0
+        } else {
+            1
+        }
+        """
+        ),
+        relay.If(
+            relay.const(True),
+            relay.const(0),
+            relay.const(1)
+        )
+    )
+
+@raises_parse_error
+@if_parser_enabled
+def test_ifelse_scope():
+    relay.fromtext(
+        """
+        if (True) {
+            let %x = ();
+            ()
+        } else {
+            %x
+        }
+        """
+    )
+
+@if_parser_enabled
+def test_call():
+    # 0 args
+    constant = relay.Var("constant")
+    assert alpha_equal(
+        relay.fromtext(
+        """
+        let %constant = fn () { 0 };
+        %constant()
+        """
+        ),
+        relay.Let(
+            constant,
+            relay.Function([], relay.const(0), None, []),
+            relay.Call(constant, [], None, None)
+        )
+    )
+
+    # 1 arg
+    id_var = relay.Var("id")
+    assert alpha_equal(
+        relay.fromtext(
+            """
+            let %id = fn (%x) { %x };
+            %id(1)
+            """
+        ),
+        relay.Let(
+            id_var,
+            relay.Function([X], X, None, []),
+            relay.Call(id_var, [relay.const(1)], None, None)
+        )
+    )
+
+    # 2 args
+    multiply = relay.Var("multiply")
+    assert alpha_equal(
+        relay.fromtext(
+        """
+        let %multiply = fn (%x, %y) { %x * %y };
+        %multiply(0, 0)
+        """
+        ),
+        relay.Let(
+            multiply,
+            relay.Function(
+                [X, Y],
+                relay.multiply(X, Y),
+                None,
+                []
+            ),
+            relay.Call(multiply, [relay.const(0), relay.const(0)], None, None)
+        )
+    )
+
+    # anonymous function
+    assert alpha_equal(
+        relay.fromtext(
+        """
+        (fn (%x) { %x })(0)
+        """
+        ),
+        relay.Call(
+            relay.Function(
+                [X],
+                X,
+                None,
+                []
+            ),
+            [relay.const(0)],
+            None,
+            None
+        )
+    )
+
+    # curried function
+    curried_mult = relay.Var("curried_mult")
+    alpha_equal(
+        relay.fromtext(
+            """
+            let %curried_mult =
+                fn (%x) {
+                fn (%y) {
+                    %x * %y
+                }
+                };
+            %curried_mult(0);
+            %curried_mult(0)(0)
+            """
+        ),
+        relay.Let(
+            curried_mult,
+            relay.Function(
+                [X],
+                relay.Function(
+                    [Y],
+                    relay.multiply(X, Y),
+                    None,
+                    []
+                ),
+                None,
+                []
+            ),
+            relay.Let(
+                _,
+                relay.Call(curried_mult, [relay.const(0)], None, None),
+                relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None)
+            )
+        )
+    )
+
+    # op
+    alpha_equal(
+        relay.fromtext("abs(1)"),
+        relay.Call(relay.op.get("abs"), [relay.const(1)], None, None)
+    )
+
+# Types
+
+@if_parser_enabled
+def test_incomplete_type():
+    assert alpha_equal(
+        relay.fromtext("let %_ : _ = (); ()"),
+        relay.Let(
+            _,
+            UNIT,
+            UNIT
+        )
+    )
+
+@if_parser_enabled
+def test_builtin_types():
+    for builtin_type in TYPES:
+        relay.fromtext("let %_ : {} = (); ()".format(builtin_type))
+
+@nottest
+@if_parser_enabled
+def test_call_type():
+    assert False
+
+@if_parser_enabled
+def test_tensor_type():
+    assert alpha_equal(
+        relay.fromtext("let %_ : Tensor[(), float32] = (); ()"),
+        relay.Let(
+            relay.Var("_", relay.TensorType((), "float32")),
+            UNIT,
+            UNIT
+        )
+    )
+
+    assert alpha_equal(
+        relay.fromtext("let %_ : Tensor[(1,), float32] = (); ()"),
+        relay.Let(
+            relay.Var("_", relay.TensorType((1,), "float32")),
+            UNIT,
+            UNIT
+        )
+    )
+
+    assert alpha_equal(
+        relay.fromtext("let %_ : Tensor[(1, 1), float32] = (); ()"),
+        relay.Let(
+            relay.Var("_", relay.TensorType((1, 1), "float32")),
+            UNIT,
+            UNIT
+        )
+    )
+
+@if_parser_enabled
+def test_function_type():
+    assert alpha_equal(
+        relay.fromtext(
+            """
+            let %_: fn () -> int32 = fn () -> int32 { 0 }; ()
+            """
+        ),
+        relay.Let(
+            relay.Var("_", relay.FuncType([], int32, [], [])),
+            relay.Function([], relay.const(0), int32, []),
+            UNIT
+        )
+    )
+
+    assert alpha_equal(
+        relay.fromtext(
+            """
+            let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; ()
+            """
+        ),
+        relay.Let(
+            relay.Var("_", relay.FuncType([int32], int32, [], [])),
+            relay.Function([relay.Var("x", int32)], relay.const(0), int32, []),
+            UNIT
+        )
+    )
+
+    assert alpha_equal(
+        relay.fromtext(
+            """
+            let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; ()
+            """
+        ),
+        relay.Let(
+            relay.Var("_", relay.FuncType([int32, int32], int32, [], [])),
+            relay.Function([relay.Var("x", int32), relay.Var("y", int32)], relay.const(0), int32, []),
+            UNIT
+        )
+    )
+
+@if_parser_enabled
+def test_tuple_type():
+    assert alpha_equal(
+        relay.fromtext(
+        """
+        let %_: () = (); ()
+        """),
+        relay.Let(
+            relay.Var("_", relay.TupleType([])),
+            UNIT,
+            UNIT
+        )
+    )
+
+    assert alpha_equal(
+        relay.fromtext(
+        """
+        let %_: (int32,) = (0,); ()
+        """),
+        relay.Let(
+            relay.Var("_", relay.TupleType([int32])),
+            relay.Tuple([relay.const(0)]),
+            UNIT
+        )
+    )
+
+    assert alpha_equal(
+        relay.fromtext(
+        """
+        let %_: (int32, int32) = (0, 1); ()
+        """),
+        relay.Let(
+            relay.Var("_", relay.TupleType([int32, int32])),
+            relay.Tuple([relay.const(0), relay.const(1)]),
+            UNIT
+        )
+    )
-- 
GitLab