From 07399e023916ce1767a0730fc2f68019baca0f36 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Tue, 30 Oct 2018 09:53:30 -0700
Subject: [PATCH] [RELAY][OP]  Maketuple to be resolved when containing
 incompleteType (#2031)

---
 src/relay/op/tensor/transform.cc     |  2 +-
 src/relay/pass/type_infer.cc         | 38 +++++++++++++++++++++++-----
 tests/python/relay/test_op_level1.py |  1 +
 3 files changed, 34 insertions(+), 7 deletions(-)

diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 635f04668..20e0e3adb 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -140,7 +140,7 @@ bool ConcatenateRel(const Array<Type>& types,
   CHECK_EQ(types.size(), 2);
   const auto* tensor_tuple = types[0].as<TupleTypeNode>();
   if (tensor_tuple == nullptr) {
-    CHECK(types[0].as<TupleTypeNode>())
+    CHECK(types[0].as<IncompleteTypeNode>())
         << "cast: expect input type to be TupleType but get "
         << types[0];
     return false;
diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc
index c0f1db97b..e3e8ad7ec 100644
--- a/src/relay/pass/type_infer.cc
+++ b/src/relay/pass/type_infer.cc
@@ -56,11 +56,31 @@ bool TupleGetItemRel(const Array<Type>& types,
   return true;
 }
 
+bool MakeTupleRel(const Array<Type>& types,
+                  int num_inputs,
+                  const Attrs& attrs,
+                  const TypeReporter& reporter) {
+  CHECK_EQ(static_cast<size_t>(num_inputs + 1), types.size());
+  for (int i = 0; i < num_inputs; ++i) {
+    if (types[i].as<IncompleteTypeNode>()) return false;
+  }
+  Array<Type> fields;
+  for (int i = 0; i < num_inputs; ++i) {
+    fields.push_back(types[i]);
+  }
+  reporter->Assign(types[num_inputs], TupleTypeNode::make(fields));
+  return true;
+}
+
 TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs);
 TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem")
 .set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
     TupleGetItemRel);
 
+TVM_REGISTER_API("tvm.relay.type_relation.MakeTuple")
+.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
+    MakeTupleRel);
+
 struct ResolvedTypeInfo {
   explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args)
       : checked_type(checked_type), type_args(type_args) {}
@@ -104,6 +124,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
   TypeSolver solver_;
   // relation function
   TypeRelationFn tuple_getitem_rel_;
+  TypeRelationFn make_tuple_rel_;
   // Unify two types
   Type Unify(const Type& t1, const Type& t2, const Span& span) {
     // TODO(tqchen, jroesch): propagate span to solver
@@ -154,14 +175,19 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
   }
 
   Type VisitExpr_(const TupleNode* op) final {
-    // TODO(tqchen, jroesch)
-    // tuple should be a constraint in the type solver
-    // to handle cases where the field type is not known.
-    Array<Type> fields;
+    if (!make_tuple_rel_.defined())  {
+      make_tuple_rel_ = TypeRelationFn(
+          EnvFunc::Get("tvm.relay.type_relation.MakeTuple").node_);
+    }
+    Array<Type> types;
     for (Expr field : op->fields) {
-      fields.push_back(GetType(field));
+      types.push_back(GetType(field));
     }
-    return TupleTypeNode::make(fields);
+    Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
+    types.push_back(rtype);
+    solver_.AddConstraint(TypeRelationNode::make(
+        make_tuple_rel_, types, op->fields.size(), Attrs()));
+    return rtype;
   }
 
   Type VisitExpr_(const TupleGetItemNode* op) final {
diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py
index fd01dbdde..a622dfc2c 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -87,6 +87,7 @@ def test_concatenate_infer_type():
     zz = relay.ir_pass.infer_type(z)
     assert zz.checked_type == relay.TensorType((n, t, 200))
 
+    x = relay.exp(x)
     z = relay.concatenate((x, y), axis=2)
     zz = relay.ir_pass.infer_type(z)
     assert zz.checked_type == relay.TensorType((n, t, 200))
-- 
GitLab