diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 18a53be92815154db0139d179275ed7f7226ba6f..7007028af6c7e6871985f39ff0f0258451ffb54c 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -15,6 +15,7 @@ Span = base.Span # Type Type = ty.Type +TupleType = ty.TupleType TensorType = ty.TensorType Kind = ty.Kind TypeParam = ty.TypeParam diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 84189c840d71a5dccdc08b92a22eb837b2fb5405..8a96124203271dcf9fc08786ce0effe812e240dd 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -12,3 +12,5 @@ from . import _ir_pass check_expr = _ir_pass.check_expr well_formed = _ir_pass.well_formed + +check_kind = _ir_pass.check_kind diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index 10e267a539779a4acb43d64eaa7039c57e799ac6..d2a256e77f5bc15b87cc558ff5ce123f1fa1da64 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -21,7 +21,6 @@ class Type(NodeBase): """Compares two Relay types by referential equality.""" return super().__eq__(other) - @register_relay_node class TensorType(Type): """A concrete TensorType in Relay, see tvm/relay/type.h for more details. @@ -94,6 +93,27 @@ class TypeConstraint(Type): pass +@register_relay_node +class TupleType(Type): + """A tuple type in Relay, see tvm/relay/type.h for more details. + + Lists the type of each field in the tuple. + """ + + def __init__(self, fields): + """Constructs a tuple type + + Parameters + ---------- + fields: list of tvm.Type + + Returns + ------- + tuple_type: the tuple type + """ + self.__init_handle_by_constructor__(_make.TupleType, fields) + + @register_relay_node class FuncType(Type): """A function type in Relay, see tvm/relay/type.h for more details. diff --git a/python/tvm/relay/ty.pyi b/python/tvm/relay/ty.pyi index 0581847598d470aa7d29ceac15185f104c4df8d1..1aba99e42a27c2563e1d9b8ccef1cf0cb4c84b9f 100644 --- a/python/tvm/relay/ty.pyi +++ b/python/tvm/relay/ty.pyi @@ -94,6 +94,27 @@ class TypeConstraint(Type): pass +@register_relay_node +class TupleType(Type): + """A tuple type in Relay, see tvm/relay/type.h for more details. + + Lists the type of each field in the tuple. + """ + + def __init__(self, fields): + """Constructs a tuple type + + Parameters + ---------- + fields: list of tvm.Type + + Returns + ------- + tuple_type: the tuple type + """ + self.__init_handle_by_constructor__(_make.TupleType, fields) + + @register_relay_node class FuncType(Type): """A function type in Relay, see tvm/relay/type.h for more details. diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 91d2d582211084541e1cf8538f3b7360c96c9ef6..83f52d8873e37dd6e18093ee4a636942adcb72cb 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -20,12 +20,72 @@ namespace tvm { namespace relay { using namespace tvm::runtime; +using Kind = TypeParamNode::Kind; struct KindChecker : TypeVisitor<> { bool valid; KindChecker() : valid(true) {} + // checks if t is an incomplete node of kind k or a type param of kind k + bool MatchKind(const Type& t, Kind k) { + if (const IncompleteTypeNode *tv = t.as<IncompleteTypeNode>()) { + return tv->kind == k; + } + + if (const TypeParamNode *tp = t.as<TypeParamNode>()) { + return tp->kind == k; + } + + return false; + } + + bool IsTypeKind(const Type& t) { + if (MatchKind(t, Kind::kType)) { + return true; + } + + return t.as<TensorTypeNode>() || t.as<BaseTensorTypeNode>() + || t.as<TupleTypeNode>() || t.as<FuncTypeNode>(); + } + + void VisitType_(const TupleTypeNode* op) override { + // tuples should only contain normal types + for (const Type& t : op->fields) { + this->VisitType(t); + valid = valid && IsTypeKind(t); + if (!valid) { + return; + } + } + } + + void VisitType_(const FuncTypeNode* op) override { + // func types should only take normal types for arguments + // and only return a normal type + for (const Type& t : op->arg_types) { + this->VisitType(t); + valid = valid && IsTypeKind(t); + if (!valid) { + return; + } + } + + this->VisitType(op->ret_type); + valid = valid && IsTypeKind(op->ret_type); + } + + void VisitType_(const TypeRelationNode* op) override { + // arguments to type relation should be normal types + for (const Type& t : op->args) { + this->VisitType(t); + valid = valid && IsTypeKind(t); + if (!valid) { + return; + } + } + } + bool Check(const Type &t) { this->VisitType(t); return valid; @@ -37,5 +97,14 @@ bool KindCheck(const Environment& env, const Type &t) { return kc.Check(t); } +TVM_REGISTER_API("relay._ir_pass.check_kind") + .set_body([](TVMArgs args, TVMRetValue* ret) { + if (args.size() == 1) { + *ret = KindCheck(EnvironmentNode::make({}), args[0]); + } else { + *ret = KindCheck(args[0], args[1]); + } + }); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_check_kind.py b/tests/python/relay/test_check_kind.py new file mode 100644 index 0000000000000000000000000000000000000000..413e6d7051d6dd90617dedae46a0eacebe8d2e77 --- /dev/null +++ b/tests/python/relay/test_check_kind.py @@ -0,0 +1,79 @@ +import tvm +from tvm import relay +from tvm.relay.ir_pass import check_kind + +def test_tuple_kinds(): + # only contain type kinds + tp = relay.TypeParam('tp', relay.Kind.Type) + tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') + tf = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([])) + fields = tvm.convert([tp, tf, tt]) + + tup_ty = relay.TupleType(fields) + assert check_kind(tup_ty) + +def test_func_kind(): + # only contain type kinds + tp1 = relay.TypeParam('tp1', relay.Kind.Type) + tp2 = relay.TypeParam('tp2', relay.Kind.Type) + + shape = tvm.convert([1, 2, 3]) + dtype = 'float32' + tensor_type = relay.TensorType(shape, dtype) + + type_params = tvm.convert([tp1, tp2]) + type_constraints = tvm.convert([]) + arg_types = tvm.convert([tp1, tensor_type]) + ret_type = relay.TupleType(tvm.convert([tp2, tensor_type])) + + tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) + assert check_kind(tf) + +def test_invalid_tuple_kinds(): + tp1 = relay.TypeParam('tp1', relay.Kind.Shape) + tp2 = relay.TypeParam('tp2', relay.Kind.BaseType) + tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar) + fields = tvm.convert([tp1, tp2, tp3]) + + tup_ty = relay.TupleType(fields) + assert not check_kind(tup_ty) + +def test_invalid_func_kind(): + tp1 = relay.TypeParam('tp1', relay.Kind.Shape) + tp2 = relay.TypeParam('tp2', relay.Kind.BaseType) + tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar) + + type_params = tvm.convert([tp1, tp2, tp3]) + type_constraints = tvm.convert([]) + arg_types = tvm.convert([tp1, tp2]) + ret_type = tp3 + + tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) + assert not check_kind(tf) + +def test_func_with_invalid_ret_type(): + tp1 = relay.TypeParam('tp1', relay.Kind.Type) + tp2 = relay.TypeParam('tp2', relay.Kind.Shape) + tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) + +def test_func_with_invalid_arg_types(): + tp1 = relay.TypeParam('tp1', relay.Kind.Shape) + tp2 = relay.TypeParam('tp2', relay.Kind.Type) + tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) + +def test_func_with_invalid_tuple(): + tp1 = relay.TypeParam('tp1', relay.Kind.Shape) + + ret_type = relay.TupleType(tvm.convert([tp1, tp1, tp1])) + + tf = relay.FuncType(tvm.convert([]), ret_type, tvm.convert([tp1]), tvm.convert([])) + assert not check_kind(tf) + +def test_tuple_with_invalid_func(): + tensor_type = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') + + tp1 = relay.TypeParam('tp1', relay.Kind.Shape) + tf = relay.FuncType(tvm.convert([]), tp1, tvm.convert([tp1]), tvm.convert([])) + + tup_ty = relay.TupleType(tvm.convert([tensor_type, tf])) + assert not check_kind(tup_ty) diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index 803b3d0faa0c04a750dfb7e41eedd601d6cd9dd4..fc5d8ee0777d464f4784acd2986e9812c3ecc5d6 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -28,8 +28,8 @@ def test_tensor_type(): def test_type_param(): tp = relay.TypeParam('name', relay.Kind.Shape) - tp.kind == relay.Kind.Shape - tp.span # TODO allow us to set span + assert tp.kind == relay.Kind.Shape + # assert tp.span # TODO allow us to set span str(tp) @@ -48,6 +48,16 @@ def test_func_type(): str(tf) +def test_tuple_type(): + tp = relay.TypeParam('tp', relay.Kind.Type) + tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([])) + tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') + fields = tvm.convert([tp, tf, tt]) + + tup_ty = relay.TupleType(fields) + assert tup_ty.fields == fields + + def test_constant(): arr = tvm.nd.array(10) const = relay.Constant(arr)