From 6616355d37820df4767cd301576db267e890da5d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?=
 <lolisa@marisa.moe>
Date: Wed, 10 Oct 2018 14:25:38 -0700
Subject: [PATCH] [Relay] GetItem (#1861)

---
 include/tvm/relay/expr.h                      | 24 +++++++++++++++++--
 include/tvm/relay/expr_functor.h              |  4 ++++
 python/tvm/relay/__init__.py                  |  1 +
 python/tvm/relay/expr.py                      |  8 +++++++
 src/relay/ir/debug_printer.cc                 |  6 +++--
 src/relay/ir/expr.cc                          | 16 +++++++++++++
 src/relay/ir/expr_functor.cc                  | 15 ++++++++++--
 src/relay/pass/alpha_eq.cc                    |  9 +++++++
 src/relay/pass/type_functor.h                 |  9 ++++---
 src/relay/pass/type_infer.cc                  | 18 ++++++++++++++
 tests/python/relay/test_ir_debug_printer.py   |  7 +++++-
 tests/python/relay/test_ir_nodes.py           |  8 +++++++
 tests/python/relay/test_ir_well_formed.py     | 18 +++++++++++++-
 tests/python/relay/test_pass_alpha_equal.py   |  8 +++++++
 .../relay/test_pass_dead_code_elimination.py  | 16 +++++++++++++
 tests/python/relay/test_pass_free_vars.py     | 11 +++++++++
 tests/python/relay/test_type_infer.py         | 12 ++++++++++
 17 files changed, 177 insertions(+), 13 deletions(-)

diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h
index 909b702bc..c6e5573d9 100644
--- a/include/tvm/relay/expr.h
+++ b/include/tvm/relay/expr.h
@@ -360,8 +360,6 @@ class IfNode : public ExprNode {
   /*! \brief The expression evaluated when condition is false */
   Expr false_branch;
 
-  IfNode() {}
-
   void VisitAttrs(tvm::AttrVisitor* v) final {
     v->Visit("cond", &cond);
     v->Visit("true_branch", &true_branch);
@@ -378,6 +376,28 @@ class IfNode : public ExprNode {
 
 RELAY_DEFINE_NODE_REF(If, IfNode, Expr);
 
+/*! \brief Get a field out of a tuple. */
+class TupleGetItem;
+class TupleGetItemNode : public ExprNode {
+ public:
+  /*! \brief The tuple */
+  Expr tuple;
+  /*! \brief which value to get */
+  int index;
+
+  void VisitAttrs(tvm::AttrVisitor* v) final {
+    v->Visit("tuple", &tuple);
+    v->Visit("index", &index);
+  }
+
+  TVM_DLL static TupleGetItem make(Expr tuple, int index);
+
+  static constexpr const char * _type_key = "relay.GetItem";
+  TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
+};
+
+RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);
+
 /*! \brief Print a debug representation of the expression to the stream.
  *  \param env The environment.
  *  \param e The expression
diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h
index 1da66bc95..be174d33b 100644
--- a/include/tvm/relay/expr_functor.h
+++ b/include/tvm/relay/expr_functor.h
@@ -89,6 +89,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
                        Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const OpNode* op,
                        Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExprDefault_(const Node* op, Args...) {
     throw Error(std::string("Do not have a default for ") + op->type_key());
   }
@@ -108,6 +109,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
     RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
     RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
     RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
+    RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
     return vtable;
   }
 };
@@ -131,6 +133,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
   void VisitExpr_(const LetNode* op) override;
   void VisitExpr_(const IfNode* op) override;
   void VisitExpr_(const OpNode* op) override;
+  void VisitExpr_(const TupleGetItemNode* op) override;
   virtual void VisitType(const Type& t);
 };
 
@@ -153,6 +156,7 @@ class ExprMutator
   Expr VisitExpr_(const CallNode* call_node) override;
   Expr VisitExpr_(const LetNode* op) override;
   Expr VisitExpr_(const IfNode* op) override;
+  Expr VisitExpr_(const TupleGetItemNode* op) override;
   /*! \brief Used to visit the types inside of expressions.
    *
    * Can be overloaded to transform the types in arbitrary
diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index dd48d213f..18c02a416 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -39,3 +39,4 @@ Function = expr.Function
 Call = expr.Call
 Let = expr.Let
 If = expr.If
+TupleGetItem = expr.TupleGetItem
diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index 9b292a74e..05214ca09 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -125,4 +125,12 @@ class If(Expr):
         self.__init_handle_by_constructor__(
             _make.If, cond, true_value, false_value)
 
+@register_relay_node
+class TupleGetItem(Expr):
+    """An expression that get field from tuple in Relay, see tvm/relay/expr.h for more details."""
+
+    def __init__(self, tuple_, index):
+        self.__init_handle_by_constructor__(
+            _make.TupleGetItem, tuple_, index)
+
 debug_print = _expr._debug_print
diff --git a/src/relay/ir/debug_printer.cc b/src/relay/ir/debug_printer.cc
index e216faa0f..90e82d3b2 100644
--- a/src/relay/ir/debug_printer.cc
+++ b/src/relay/ir/debug_printer.cc
@@ -223,7 +223,6 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
   }
 
   Doc VisitExpr_(const CallNode* c) final {
-    auto args = DocifyExprArray(c->args);
     return Docify(c->op) + Seq("<", DocifyExprArray(c->args), ">");
   }
 
@@ -244,6 +243,10 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
     return DocOfStr(o->name);
   }
 
+  Doc VisitExpr_(const TupleGetItemNode* g) final {
+    return Docify(g->tuple) + DocOfStr(std::string(".") + std::to_string(g->index));
+  }
+
  public:
   ExprDocifier(const Environment& env) : env(env), td(env) { }
 
@@ -291,7 +294,6 @@ std::string PrintType(const Environment& env, const Type& t) {
 TVM_REGISTER_API("relay._expr._debug_print")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
     NodeRef x = args[1];
-    std::cout << x << std::endl;
     if (x.as<TypeNode>()) {
       *ret = PrintType(args[0], Downcast<Type>(x));
     } else {
diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc
index dbbb5b84f..6b56cb4e8 100644
--- a/src/relay/ir/expr.cc
+++ b/src/relay/ir/expr.cc
@@ -193,5 +193,21 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
             << ", " << node->false_branch << ")";
 });
 
+TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
+  NodePtr<TupleGetItemNode> n = make_node<TupleGetItemNode>();
+  n->tuple = std::move(tuple);
+  n->index = index;
+  return TupleGetItem(n);
+}
+
+TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
+  *ret = TupleGetItemNode::make(args[0], args[1]);
+});
+
+TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
+.set_dispatch<TupleGetItemNode>([](const TupleGetItemNode* node, tvm::IRPrinter* p) {
+  p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
+});
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc
index e3393bdb0..792f99d69 100644
--- a/src/relay/ir/expr_functor.cc
+++ b/src/relay/ir/expr_functor.cc
@@ -150,10 +150,17 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) {
   }
 }
 
-Type ExprMutator::VisitType(const Type& t) {
-  return t;
+Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
+  auto t = this->Mutate(g->tuple);
+  if (g->tuple == t) {
+    return GetRef<Expr>(g);
+  } else {
+    return TupleGetItemNode::make(t, g->index);
+  }
 }
 
+Type ExprMutator::VisitType(const Type& t) { return t; }
+
 void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
 }
 
@@ -206,6 +213,10 @@ void ExprVisitor::VisitExpr_(const IfNode* op) {
 
 void ExprVisitor::VisitExpr_(const OpNode* op) { return; }
 
+void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
+  this->VisitExpr(op->tuple);
+}
+
 void ExprVisitor::VisitType(const Type& t) { return; }
 
 }  // namespace relay
diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc
index 3c4c3d780..0e13a598c 100644
--- a/src/relay/pass/alpha_eq.cc
+++ b/src/relay/pass/alpha_eq.cc
@@ -335,6 +335,15 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
       equal = false;
     }
   }
+
+  void VisitExpr_(const TupleGetItemNode* op, const Expr& e2) final {
+    if (const TupleGetItemNode* proj = e2.as<TupleGetItemNode>()) {
+      this->VisitExpr(op->tuple, proj->tuple);
+      equal = equal && (op->index == proj->index);
+    } else {
+      equal = false;
+    }
+  }
 };
 
 bool AlphaEqual(const Expr& e1, const Expr& e2) {
diff --git a/src/relay/pass/type_functor.h b/src/relay/pass/type_functor.h
index a451fbe16..70a2d9347 100644
--- a/src/relay/pass/type_functor.h
+++ b/src/relay/pass/type_functor.h
@@ -8,7 +8,6 @@
 
 #include <tvm/node/ir_functor.h>
 #include <tvm/relay/expr.h>
-#include <tvm/relay/error.h>
 #include <string>
 
 namespace tvm {
@@ -21,11 +20,11 @@ class TypeFunctor;
 #define TYPE_FUNCTOR_DEFAULT \
   { return VisitTypeDefault_(op, std::forward<Args>(args)...); }
 
-#define RELAY_TYPE_FUNCTOR_DISPATCH(OP)                       \
-  vtable.template set_dispatch<OP>(                           \
-      [](const NodeRef& n, TSelf* self, Args... args) {       \
+#define RELAY_TYPE_FUNCTOR_DISPATCH(OP)                                   \
+  vtable.template set_dispatch<OP>(                                       \
+      [](const NodeRef& n, TSelf* self, Args... args) {                   \
         return self->VisitType_(static_cast<const OP*>(n.node_.get()),    \
-                                std::forward<Args>(args)...); \
+                                std::forward<Args>(args)...);             \
       });
 
 template <typename R, typename... Args>
diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc
index 1e2100fa9..72bdaf69f 100644
--- a/src/relay/pass/type_infer.cc
+++ b/src/relay/pass/type_infer.cc
@@ -119,6 +119,20 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
     return TupleTypeNode::make(fields);
   }
 
+  Type VisitExpr_(const TupleGetItemNode* op) final {
+    // TODO(M.K.)
+    // handle case where field type is not known
+    Type tuple_type = GetType(op->tuple);
+    auto tuple_ty_node = tuple_type.as<TupleTypeNode>();
+    if (!tuple_ty_node) {
+      LOG(FATAL) << "only expressions with tuple types is accepted" << GetRef<TupleGetItem>(op);
+    }
+    if (static_cast<int>(tuple_ty_node->fields.size()) <= op->index) {
+      LOG(FATAL) << "tuple not big enough" << GetRef<TupleGetItem>(op);
+    }
+    return tuple_ty_node->fields[op->index];
+  }
+
   Type VisitExpr_(const OpNode* op) final {
     return op->op_type;
   }
@@ -293,6 +307,10 @@ class TypeInferencer::Resolver : public ExprMutator {
     return AttachCheckedType(op);
   }
 
+  Expr VisitExpr_(const TupleGetItemNode* op) final {
+    return AttachCheckedType(op);
+  }
+
   Expr VisitExpr_(const ParamNode* op) final {
     return ExprMutator::VisitExpr_(op);
   }
diff --git a/tests/python/relay/test_ir_debug_printer.py b/tests/python/relay/test_ir_debug_printer.py
index 2ea0b7575..e5f9ad2e6 100644
--- a/tests/python/relay/test_ir_debug_printer.py
+++ b/tests/python/relay/test_ir_debug_printer.py
@@ -77,7 +77,7 @@ def test_call():
 
 def test_let():
     lv = relay.Var('x')
-    ty = relay.ty.TensorType((10, 20), "float32")
+    ty = relay.ty.TensorType((10, 20), 'float32')
     arr = tvm.nd.array(10)
     value = relay.Constant(arr)
     let = relay.Let(lv, value, lv, ty)
@@ -90,3 +90,8 @@ def test_if():
     right = relay.Var('right')
     ife = relay.If(cond, left, right)
     show(ife)
+
+def test_tuple_get_item():
+    t = relay.Var('t')
+    g = relay.TupleGetItem(t, 0)
+    show(g)
diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py
index d3dae9b2c..79883ed22 100644
--- a/tests/python/relay/test_ir_nodes.py
+++ b/tests/python/relay/test_ir_nodes.py
@@ -175,6 +175,13 @@ def test_if():
     str(ife)
 
 
+def test_tuple_get_item():
+    tup = relay.Var("tuple")
+    get = relay.TupleGetItem(tup, 1)
+    assert get.tuple == tup
+    assert get.index == 1
+    str(get)
+
 if __name__ == "__main__":
     test_bad_constructor()
     test_span()
@@ -192,3 +199,4 @@ if __name__ == "__main__":
     test_call()
     test_let()
     test_if()
+    test_tuple_get_item()
diff --git a/tests/python/relay/test_ir_well_formed.py b/tests/python/relay/test_ir_well_formed.py
index 8bdef4d0e..c6cb99662 100644
--- a/tests/python/relay/test_ir_well_formed.py
+++ b/tests/python/relay/test_ir_well_formed.py
@@ -3,7 +3,7 @@ from tvm import relay
 from tvm.relay.ir_pass import well_formed
 
 def test_well_formed():
-    x = relay.Var("x")
+    x = relay.Var('x')
     assert well_formed(x)
     v = relay.Constant(tvm.nd.array(10))
     ty = None
@@ -16,3 +16,19 @@ def test_well_formed():
     # but we want all binder to be distinct from each other.
     assert not well_formed(relay.Let(relay.Var("y"), f,
                                      relay.Let(relay.Var("z"), f, v, ty), ty))
+
+
+def test_tuple():
+    x = relay.Var('x')
+    assert well_formed(x)
+    v = relay.Constant(tvm.nd.array(10))
+    ty = None
+    let = relay.Let(x, v, x, ty)
+    assert well_formed(let)
+    assert well_formed(relay.Tuple([v, v]))
+    assert not well_formed(relay.Tuple([let, let]))
+
+
+def test_tuple_get_item():
+    t = relay.Var('t')
+    assert well_formed(relay.TupleGetItem(t, 2))
diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py
index 93f8a8fbc..9fa1a554a 100644
--- a/tests/python/relay/test_pass_alpha_equal.py
+++ b/tests/python/relay/test_pass_alpha_equal.py
@@ -167,11 +167,19 @@ def test_type_relation_alpha_equal():
 
     assert bigger != diff_num_inputs
 
+def test_tuple_get_item_alpha_equal():
+    x = relay.Var('x')
+    y = relay.Var('y')
+    assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(y, 1))
+    assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2))
+    assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
 
 if __name__ == "__main__":
     test_tensor_type_alpha_equal()
     test_incomplete_type_alpha_equal()
+    test_constant_alpha_equal()
     test_type_param_alpha_equal()
     test_func_type_alpha_equal()
     test_tuple_type_alpha_equal()
     test_type_relation_alpha_equal()
+    test_tuple_get_item_alpha_equal()
diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py
index db73fb5c5..ce9bda3d2 100644
--- a/tests/python/relay/test_pass_dead_code_elimination.py
+++ b/tests/python/relay/test_pass_dead_code_elimination.py
@@ -4,6 +4,7 @@ from tvm.relay.ir_pass import dead_code_elimination, alpha_equal
 from tvm.relay.ir_builder import convert, IRBuilder
 from tvm.relay.op import log, add, equal, subtract
 
+
 class env:
     def __init__(self):
         self.a = relay.Var("a")
@@ -22,20 +23,25 @@ class env:
         self.two = convert(2.0)
         self.three = convert(3.0)
 
+
 e = env()
 
+
 def test_let():
     orig = relay.Let(e.x, e.y, e.z, e.tt)
     assert alpha_equal(dead_code_elimination(orig), e.z)
 
+
 def test_used_let():
     orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c, e.tt), e.tt)
     assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c, e.tt))
 
+
 def test_chain_unused_let():
     orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e, e.tt), e.tt)
     assert alpha_equal(dead_code_elimination(orig), e.e)
 
+
 # make sure we dont infinite loop
 def test_recursion():
     """
@@ -60,14 +66,23 @@ def test_recursion():
     assert alpha_equal(dead_code_elimination(orig), orig)
     assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three, e.float32)), e.three)
 
+
 def test_op_let():
     assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three, e.float32), e.two)), add(e.three, e.two))
 
+
 def test_if():
     orig = relay.If(convert(True), e.a, e.b)
     assert alpha_equal(dead_code_elimination(orig), e.a)
 
 
+def test_tuple_get_item():
+    t = relay.Var('t')
+    g = relay.TupleGetItem(t, 0)
+    assert alpha_equal(dead_code_elimination(g), g)
+    assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t, e.float32), 0)), g)
+
+
 if __name__ == "__main__":
     test_let()
     test_used_let()
@@ -75,3 +90,4 @@ if __name__ == "__main__":
     test_recursion()
     test_op_let()
     test_if()
+    test_tuple_get_item()
diff --git a/tests/python/relay/test_pass_free_vars.py b/tests/python/relay/test_pass_free_vars.py
index 002646ada..989c9f8d2 100644
--- a/tests/python/relay/test_pass_free_vars.py
+++ b/tests/python/relay/test_pass_free_vars.py
@@ -15,6 +15,17 @@ def test_free_vars():
     f = relay.Function([relay.Param(x, ty)], ty, x)
     assert len(free_vars(f)) == 0
 
+
+def test_tuple():
+    t = relay.Var('t')
+    fv = free_vars(relay.Tuple([t, t]))
+    assert len(fv) == 1
+    assert fv[0] == t
+    fv = free_vars(relay.TupleGetItem(t, 123))
+    assert len(fv) == 1
+    assert fv[0] == t
+
+
 def test_free_type_vars():
     tp = relay.TypeParam("")
     ty = relay.TupleType([tp, relay.TensorType([], "int32")])
diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py
index 662993292..77b04590d 100644
--- a/tests/python/relay/test_type_infer.py
+++ b/tests/python/relay/test_type_infer.py
@@ -9,6 +9,7 @@ from tvm.relay.ir_builder import scalar_type, convert, tensor_type
 from tvm.relay.env import Environment
 from tvm.relay.op import log, add, equal, subtract, concatenate
 from tvm.relay.expr import Function
+from tvm import relay
 
 def assert_has_type(expr, typ, env=Environment({})):
     checked_expr = infer_type(env, expr)
@@ -110,6 +111,16 @@ def test_concat():
     fn_ty = func_type([tensor_type(3, 2), tensor_type(2, 2)], tensor_type(5, 2))
     assert_decl_has_type(ib.env, try_concat2, fn_ty)
 
+def test_tuple():
+    ib = IRBuilder()
+    dup = ib.global_var('dup')
+    x = ib.param('x')
+    with ib.decl(dup, x):
+        ib.ret(relay.Tuple([x, x]))
+    # todo: why is this not generalized?
+    fn_ty = func_type([tensor_type()], relay.TupleType([tensor_type(), tensor_type()]))
+    assert_decl_has_type(ib.env, dup, fn_ty)
+
 if __name__ == "__main__":
     test_dual_op()
     test_recursion()
@@ -117,3 +128,4 @@ if __name__ == "__main__":
     test_decl()
     test_recursion()
     test_concat()
+    test_tuple()
-- 
GitLab