From 07399e023916ce1767a0730fc2f68019baca0f36 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Tue, 30 Oct 2018 09:53:30 -0700 Subject: [PATCH] [RELAY][OP] Maketuple to be resolved when containing incompleteType (#2031) --- src/relay/op/tensor/transform.cc | 2 +- src/relay/pass/type_infer.cc | 38 +++++++++++++++++++++++----- tests/python/relay/test_op_level1.py | 1 + 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 635f04668..20e0e3adb 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -140,7 +140,7 @@ bool ConcatenateRel(const Array<Type>& types, CHECK_EQ(types.size(), 2); const auto* tensor_tuple = types[0].as<TupleTypeNode>(); if (tensor_tuple == nullptr) { - CHECK(types[0].as<TupleTypeNode>()) + CHECK(types[0].as<IncompleteTypeNode>()) << "cast: expect input type to be TupleType but get " << types[0]; return false; diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index c0f1db97b..e3e8ad7ec 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -56,11 +56,31 @@ bool TupleGetItemRel(const Array<Type>& types, return true; } +bool MakeTupleRel(const Array<Type>& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(static_cast<size_t>(num_inputs + 1), types.size()); + for (int i = 0; i < num_inputs; ++i) { + if (types[i].as<IncompleteTypeNode>()) return false; + } + Array<Type> fields; + for (int i = 0; i < num_inputs; ++i) { + fields.push_back(types[i]); + } + reporter->Assign(types[num_inputs], TupleTypeNode::make(fields)); + return true; +} + TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs); TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem") .set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>( TupleGetItemRel); +TVM_REGISTER_API("tvm.relay.type_relation.MakeTuple") +.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>( + MakeTupleRel); + struct ResolvedTypeInfo { explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args) : checked_type(checked_type), type_args(type_args) {} @@ -104,6 +124,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { TypeSolver solver_; // relation function TypeRelationFn tuple_getitem_rel_; + TypeRelationFn make_tuple_rel_; // Unify two types Type Unify(const Type& t1, const Type& t2, const Span& span) { // TODO(tqchen, jroesch): propagate span to solver @@ -154,14 +175,19 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { } Type VisitExpr_(const TupleNode* op) final { - // TODO(tqchen, jroesch) - // tuple should be a constraint in the type solver - // to handle cases where the field type is not known. - Array<Type> fields; + if (!make_tuple_rel_.defined()) { + make_tuple_rel_ = TypeRelationFn( + EnvFunc::Get("tvm.relay.type_relation.MakeTuple").node_); + } + Array<Type> types; for (Expr field : op->fields) { - fields.push_back(GetType(field)); + types.push_back(GetType(field)); } - return TupleTypeNode::make(fields); + Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + types.push_back(rtype); + solver_.AddConstraint(TypeRelationNode::make( + make_tuple_rel_, types, op->fields.size(), Attrs())); + return rtype; } Type VisitExpr_(const TupleGetItemNode* op) final { diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index fd01dbdde..a622dfc2c 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -87,6 +87,7 @@ def test_concatenate_infer_type(): zz = relay.ir_pass.infer_type(z) assert zz.checked_type == relay.TensorType((n, t, 200)) + x = relay.exp(x) z = relay.concatenate((x, y), axis=2) zz = relay.ir_pass.infer_type(z) assert zz.checked_type == relay.TensorType((n, t, 200)) -- GitLab