diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index c6e5573d941357b82cf080d5ab6893a34936c72e..5e50cfc05e67c82b65094a5bc756a65081b307d9 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -118,17 +118,27 @@ class Var; /*! \brief Container for Var */ class VarNode : public ExprNode { public: - /*! \brief The name of the variable, this only acts as a hint to the user, - * and is not used for equality. + /*! + * \brief The name of the variable, + * this only acts as a hint to the user, + * and is not used for equality. */ std::string name_hint; + /*! + * \brief type annotaion of the variable. + * This field records user provided type annotation of the Var. + * This field is optional and can be None. + */ + Type type_annotation; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("name_hint", &name_hint); + v->Visit("type_annotation", &type_annotation); v->Visit("_checked_type_", &checked_type_); } - TVM_DLL static Var make(std::string name_hint); + TVM_DLL static Var make(std::string name_hint, + Type type_annotation); static constexpr const char* _type_key = "relay.Var"; TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode); @@ -162,32 +172,6 @@ class GlobalVarNode : public ExprNode { RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr); -/*! - * \brief Function parameter declaration. - */ -class Param; -/*! \brief A parameter. */ -class ParamNode : public ExprNode { - public: - /*! \brief The variable */ - Var var; - /*! \brief The type of the parameter */ - Type type; - - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("var", &var); - v->Visit("type", &type); - v->Visit("span", &span); - } - - TVM_DLL static Param make(Var var, Type type); - - static constexpr const char* _type_key = "relay.Param"; - TVM_DECLARE_NODE_TYPE_INFO(ParamNode, ExprNode); -}; - -RELAY_DEFINE_NODE_REF(Param, ParamNode, Expr); - /*! * \brief Function (subgraph in computational graph) */ @@ -196,7 +180,7 @@ class Function; class FunctionNode : public ExprNode { public: /*! \brief Function parameters */ - tvm::Array<Param> params; + tvm::Array<Var> params; /*! \brief User annotated return type of the function. */ Type ret_type; /*! @@ -224,10 +208,18 @@ class FunctionNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } - Type fn_type() const; + /*! + * \brief Return the derived function annotation of this expression. + * + * \return The function type annotation. + * \note The function type annotation can contain IncompleteType. + */ + TVM_DLL FuncType func_type_annotation() const; - TVM_DLL static Function make(tvm::Array<Param> params, Type ret_type, - Expr body, tvm::Array<TypeParam> ty_params); + TVM_DLL static Function make(tvm::Array<Var> params, + Type ret_type, + Expr body, + tvm::Array<TypeParam> ty_params); static constexpr const char* _type_key = "relay.Function"; TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode); @@ -289,7 +281,7 @@ class CallNode : public ExprNode { TVM_DLL static Call make(Expr op, Array<Expr> args, Attrs attrs = Attrs(), - Array<Type> ty_args = Array<Type>()); + Array<Type> type_args = Array<Type>()); static constexpr const char* _type_key = "relay.Call"; TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode); @@ -318,19 +310,16 @@ class LetNode : public ExprNode { Expr value; /*! \brief The body of the let binding */ Expr body; - /*! \brief Type annotation of value, this can be null */ - Type value_type; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("var", &var); v->Visit("value", &value); v->Visit("body", &body); - v->Visit("value_type", &value_type); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } - TVM_DLL static Let make(Var var, Expr value, Expr body, Type value_type); + TVM_DLL static Let make(Var var, Expr value, Expr body); static constexpr const char* _type_key = "relay.Let"; TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode); @@ -376,11 +365,11 @@ class IfNode : public ExprNode { RELAY_DEFINE_NODE_REF(If, IfNode, Expr); -/*! \brief Get a field out of a tuple. */ +/*! \brief Get index-th field out of a tuple. */ class TupleGetItem; class TupleGetItemNode : public ExprNode { public: - /*! \brief The tuple */ + /*! \brief The tuple Expression */ Expr tuple; /*! \brief which value to get */ int index; diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index be174d33b4c8dc62754d4a9e0033893e14c9f5d4..c10933590f99a81460b07cd2cfcb4a9616d88637 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -80,7 +80,6 @@ class ExprFunctor<R(const Expr& n, Args...)> { Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const ParamNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -103,7 +102,6 @@ class ExprFunctor<R(const Expr& n, Args...)> { RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode); RELAY_EXPR_FUNCTOR_DISPATCH(VarNode); RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); - RELAY_EXPR_FUNCTOR_DISPATCH(ParamNode); RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode); RELAY_EXPR_FUNCTOR_DISPATCH(CallNode); RELAY_EXPR_FUNCTOR_DISPATCH(LetNode); @@ -127,7 +125,6 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> { void VisitExpr_(const GlobalVarNode* op) override; void VisitExpr_(const ConstantNode* op) override; void VisitExpr_(const TupleNode* op) override; - void VisitExpr_(const ParamNode* op) override; void VisitExpr_(const FunctionNode* op) override; void VisitExpr_(const CallNode* op) override; void VisitExpr_(const LetNode* op) override; @@ -151,7 +148,6 @@ class ExprMutator Expr VisitExpr_(const GlobalVarNode* op) override; Expr VisitExpr_(const OpNode* op) override; Expr VisitExpr_(const TupleNode* op) override; - Expr VisitExpr_(const ParamNode* op) override; Expr VisitExpr_(const FunctionNode* op) override; Expr VisitExpr_(const CallNode* call_node) override; Expr VisitExpr_(const LetNode* op) override; diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 18c02a416d6b5d5bc5a236d1b06c9e55602a21e3..b1085be2e1e298cf67952e9883ab8b5b6d707025 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -34,7 +34,6 @@ Constant = expr.Constant Tuple = expr.Tuple Var = expr.Var GlobalVar = expr.GlobalVar -Param = expr.Param Function = expr.Function Call = expr.Call Let = expr.Let diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 6ed8df0d736b8a68bac4d668f2834ea429cb6d73..a71fd329ed5bc436d66585148b1523d15c64fa55 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -11,11 +11,11 @@ class Expr(NodeBase): """The base type for all Relay expressions.""" @property def checked_type(self): - """Get the checked type of relay. + """Get the checked type of tvm.relay.Expr. Returns ------- - checked_type : relay.Type + checked_type : tvm.relay.Type The checked type. """ ret = self._checked_type_ @@ -25,70 +25,97 @@ class Expr(NodeBase): return ret def __call__(self, *args): - converted_args = [] - for arg in args: - if isinstance(arg, Param): - converted_args.append(arg.var) - else: - converted_args.append(arg) - return Call(self, args, None, None) @register_relay_node class Constant(Expr): - """A constant tensor in Relay, see tvm/relay/type.h for more details. - """ + """A constant expression in Relay. + Parameters + ---------- + data : tvm.nd.NDArray + The data content of the constant expression. + """ def __init__(self, data): self.__init_handle_by_constructor__(_make.Constant, data) @register_relay_node class Tuple(Expr): - """A hetereogenous sequence of values. - see tvm/relay/type.h for more details. - """ + """Tuple expression that groups several fields together. + Parameters + ---------- + fields : List[tvm.relay.Expr] + The fields in the tuple. + """ def __init__(self, fields): self.__init_handle_by_constructor__(_make.Tuple, fields) @register_relay_node class Var(Expr): - """A local variable in Relay.""" + """A local variable in Tvm.Relay. - def __init__(self, name_hint): - self.__init_handle_by_constructor__(_make.Var, name_hint) + Local variable can be used to declare input + arguments to a function, or intermediate variables. + + Parameters + ---------- + name_hint: str + The name of the variable. + This name only acts as a hint, and is not used + for equality. + + type_annotation: tvm.relay.Type, optional + The type annotation on the variable. + """ + def __init__(self, name_hint, type_annotation=None): + self.__init_handle_by_constructor__( + _make.Var, name_hint, type_annotation) @register_relay_node class GlobalVar(Expr): - """A global variable in Relay.""" + """A global variable in Tvm.Relay. + GlobalVar is used to refer to the global functions + stored in the environment. + + Parameters + ---------- + name_hint: str + The name of the variable. + """ def __init__(self, name_hint): self.__init_handle_by_constructor__(_make.GlobalVar, name_hint) @register_relay_node -class Param(Expr): - """A function type in Relay, see tvm/relay/type.h for more details. - """ +class Function(Expr): + """A function declaration expression. - def __init__(self, var, ty): - self.__init_handle_by_constructor__(_make.Param, var, ty) + Parameters + ---------- + params: List[tvm.relay.Var] + List of input parameters to the function. + ret_type: tvm.relay.Type + The return type annotation of the function. -@register_relay_node -class Function(Expr): - """A function in Relay, see tvm/relay/expr.h for more details.""" + body: tvm.relay.Expr + The body of the function. + type_params: Optional[List[tvm.relay.TypeParam]] + The additional type parameters, this is only + used in advanced usecase of template functions. + """ def __init__(self, params, ret_type, body, - type_params=None - ): + type_params=None): if type_params is None: type_params = convert([]) @@ -98,39 +125,87 @@ class Function(Expr): @register_relay_node class Call(Expr): - """A function call in Relay, see tvm/relay/expr.h for more details.""" + """Function call node in Relay. + + Call node corresponds the operator application node + in computational graph terminology. + + Parameters + ---------- + op: tvm.relay.Op or any tvm.relay.Expr with function type. + The operation to be called. - def __init__(self, op, args, attrs, ty_args=None): - if not ty_args: - ty_args = [] + args: List[tvm.relay.Expr] + The arguments to the call. + attrs: Optional[tvm.Attrs] + Attributes to the call, can be None + + type_args: Optional[List[tvm.relay.Type]] + The additional type arguments, this is only + used in advanced usecase of template functions. + """ + def __init__(self, op, args, attrs=None, type_args=None): + if not type_args: + type_args = [] self.__init_handle_by_constructor__( - _make.Call, op, args, attrs, ty_args) + _make.Call, op, args, attrs, type_args) @register_relay_node class Let(Expr): - """A variable bindings in Relay, see tvm/relay/expr.h for more details.""" + """Let variable binding expression. + + Parameters + ---------- + var: tvm.relay.Var + The local variable to be bound. + + value: tvm.relay.Expr + The value to be bound. - def __init__(self, var, value, body, value_type=None): + body: tvm.relay.Expr + The body of the let binding. + """ + def __init__(self, var, value, body): self.__init_handle_by_constructor__( - _make.Let, var, value, body, value_type) + _make.Let, var, value, body) @register_relay_node class If(Expr): - """A conditional expression in Relay, see tvm/relay/expr.h for more details.""" + """A conditional expression in Relay. + + Parameters + ---------- + cond: tvm.relay.Expr + The condition. - def __init__(self, cond, true_value, false_value): + true_branch: tvm.relay.Expr + The expression evaluated when condition is true. + + false_branch: tvm.relay.Expr + The expression evaluated when condition is false. + """ + def __init__(self, cond, true_branch, false_branch): self.__init_handle_by_constructor__( - _make.If, cond, true_value, false_value) + _make.If, cond, true_branch, false_branch) + @register_relay_node class TupleGetItem(Expr): - """An expression that get field from tuple in Relay, see tvm/relay/expr.h for more details.""" + """Get index-th item from a tuple. + + Parameters + ---------- + tuple_value: tvm.relay.Expr + The input tuple expression. - def __init__(self, tuple_, index): + index: int + The index. + """ + def __init__(self, tuple_value, index): self.__init_handle_by_constructor__( - _make.TupleGetItem, tuple_, index) + _make.TupleGetItem, tuple_value, index) debug_print = _expr._debug_print diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index accb782659dffecc58b40cabc57e966678157f04..a429aea7d5ea6b1ab71865e23835fa99abf44c8a 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -7,7 +7,7 @@ from collections import OrderedDict import numpy as np import tvm from .ty import Type, FuncType, TensorType -from .expr import Expr, Constant, Let, Var, Param, Function, If +from .expr import Expr, Constant, Let, Var, Function, If from .env import Environment @@ -98,7 +98,7 @@ class PartialFunc(object): self.type_params = type_params def param_ids(self): - return [p.var for p in self.params] + return [p for p in self.params] def to_func(self): """Converts a PartialFunc into a :py:class:`~relay.Function`.""" @@ -113,9 +113,8 @@ class PartialFunc(object): def _mk_let(bindings, ret_value): let_expr = ret_value - for var, (value, ty) in reversed(list(bindings.items())): - let_expr = Let(var, value, let_expr, ty) - + for var, value in reversed(list(bindings.items())): + let_expr = Let(var, value, let_expr) return let_expr @@ -168,15 +167,12 @@ class IRBuilder(object): #pylint: disable=invalid-name def bind(self, name, value, ty): - lv = Var(name) + lv = Var(name, ty) self.scopes[-1][name] = lv - self.bindings[-1][lv] = (value, ty) + self.bindings[-1][lv] = value return lv def let(self, name, value, value_type=None): - if isinstance(value, Param): - value = value.var - if not isinstance(value, Expr): value = convert(value) @@ -185,23 +181,18 @@ class IRBuilder(object): def _convert_params(self, raw_params): relay_params = [] for raw_param in raw_params: - if isinstance(raw_param, Param): - var = raw_param.var + if isinstance(raw_param, Var): param = raw_param elif isinstance(raw_param, tuple): var, ty = raw_param - if isinstance(var, str): - var = Var(var) ty = _convert_type(ty) - param = Param(var, ty) - elif isinstance(param, str): - var = Var(raw_param) - ty = None - param = Param(var, ty) + param = Var(var, ty) + elif isinstance(raw_param, str): + param = Var(raw_param, None) else: raise Exception("unknown parameter type") - self.scopes[-1][var.name_hint] = var + self.scopes[-1][param.name_hint] = param relay_params.append(param) return relay_params @@ -265,7 +256,7 @@ class IRBuilder(object): else: ty = _convert_type(ty) - return Param(Var(name), ty) + return Var(name, ty) def global_var(self, name): # type: (str) -> GlobalVar diff --git a/src/relay/ir/debug_printer.cc b/src/relay/ir/debug_printer.cc index 90e82d3b2dd7182e6d7a943a10781cf55382bbc6..cb463ef6975a738f4217440243c42293e6da3035 100644 --- a/src/relay/ir/debug_printer.cc +++ b/src/relay/ir/debug_printer.cc @@ -96,7 +96,9 @@ class TypeDocifier : private TypeFunctor<Doc(const Type& n)> { } std::vector<Doc> DocifyTypeParam(const tvm::Array<TypeParam>& arr) { - return MapDocify<TypeParam>(arr, [=](const TypeParam& tp) { return Docify(tp); }); + return MapDocify<TypeParam>(arr, [=](const TypeParam& tp) { + return Docify(tp); + }); } std::vector<Doc> DocifyTypeConstraint(const tvm::Array<TypeConstraint>& arr) { @@ -188,10 +190,11 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> { return vec; } - std::vector<Doc> DocifyParamArray(const tvm::Array<Param>& arr) { + std::vector<Doc> DocifyParamArray(const tvm::Array<Var>& arr) { std::vector<Doc> vec; - for (size_t i = 0; i < arr.size(); ++i) { - vec.push_back(Docify(arr[i])); + for (Var param : arr) { + vec.emplace_back(TypeAnnotation(DocOfStr(VarName(param)), + param->type_annotation)); } return vec; } @@ -212,10 +215,6 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> { return DocOfStr(g->name_hint); } - Doc VisitExpr_(const ParamNode* p) final { - return TypeAnnotation(Docify(p->var), p->type); - } - Doc VisitExpr_(const FunctionNode* f) final { return Group(TypeAnnotation(Seq("(", DocifyParamArray(f->params), ")"), f->ret_type) + Sep() + DocOfStr("=>") + Sep() + @@ -227,7 +226,8 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> { } Doc VisitExpr_(const LetNode* l) final { - return Group(DocOfStr("let") + Sep() + TypeAnnotation(Docify(l->var), l->value_type) + Sep() + + return Group(DocOfStr("let") + Sep() + + TypeAnnotation(Docify(l->var), l->var->type_annotation) + Sep() + DocOfStr("=") + Sep() + Docify(l->value) + DocOfStr(";") + Endl() + Docify(l->body)); } diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 6b56cb4e844fdbe0b215223e99bcd3ef742c4e83..c248ad0de6f718bd16b974fc2d94b1e252fae403 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -54,20 +54,26 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "Tuple(" << node->fields << ")"; }); -Var VarNode::make(std::string name_hint) { +Var VarNode::make(std::string name_hint, Type type_annotation) { NodePtr<VarNode> n = make_node<VarNode>(); n->name_hint = std::move(name_hint); + n->type_annotation = std::move(type_annotation); return Var(n); } TVM_REGISTER_API("relay._make.Var") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = VarNode::make(args[0]); + *ret = VarNode::make(args[0], args[1]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch<VarNode>([](const VarNode *node, tvm::IRPrinter *p) { - p->stream << "Var(" << node->name_hint << ")"; + p->stream << "Var(" << node->name_hint; + if (node->type_annotation.defined()) { + p->stream << ", ty="; + p->print(node->type_annotation); + } + p->stream << ")"; }); GlobalVar GlobalVarNode::make(std::string name_hint) { @@ -86,24 +92,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "GlobalVar(" << node->name_hint << ")"; }); -Param ParamNode::make(Var var, Type type) { - NodePtr<ParamNode> n = make_node<ParamNode>(); - n->var = std::move(var); - n->type = std::move(type); - return Param(n); -} - -TVM_REGISTER_API("relay._make.Param") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = ParamNode::make(args[0], args[1]); -}); -TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch<ParamNode>([](const ParamNode *node, tvm::IRPrinter *p) { - p->stream << "Param(" << node->var << ", " << node->type << ")"; -}); - -Function FunctionNode::make(tvm::Array<Param> params, Type ret_type, Expr body, +Function FunctionNode::make(tvm::Array<Var> params, + Type ret_type, + Expr body, tvm::Array<TypeParam> type_params) { NodePtr<FunctionNode> n = make_node<FunctionNode>(); n->params = std::move(params); @@ -113,12 +105,11 @@ Function FunctionNode::make(tvm::Array<Param> params, Type ret_type, Expr body, return Function(n); } -Type FunctionNode::fn_type() const { +FuncType FunctionNode::func_type_annotation() const { Array<Type> param_types; for (auto param : this->params) { - param_types.push_back(param->type); + param_types.push_back(param->type_annotation); } - return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {}); } @@ -155,24 +146,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << node->attrs << ", " << node->type_args << ")"; }); -Let LetNode::make(Var var, Expr value, Expr body, Type value_type) { +Let LetNode::make(Var var, Expr value, Expr body) { NodePtr<LetNode> n = make_node<LetNode>(); n->var = std::move(var); n->value = std::move(value); n->body = std::move(body); - n->value_type = std::move(value_type); return Let(n); } TVM_REGISTER_API("relay._make.Let") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = LetNode::make(args[0], args[1], args[2], args[3]); -}); + *ret = LetNode::make(args[0], args[1], args[2]); + }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch<LetNode>([](const LetNode *node, tvm::IRPrinter *p) { p->stream << "LetNode(" << node->var << ", " << node->value - << ", " << node->body << ", " << node->value_type << ")"; + << ", " << node->body << ")"; }); If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 792f99d699ddc652723cbc5ff34503e0873ccbf8..c55e4d672b6c0e02b67f231311bbfd7a9d5409f1 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -24,6 +24,16 @@ Expr ExprMutator::Mutate(const Expr& expr) { } Expr ExprMutator::VisitExpr_(const VarNode* op) { + // NOTE: var will only be mutated once + // Thanks to the memo and reused during rewriting if necessary. + // It is safe to assume that the + if (op->type_annotation.defined()) { + auto type = this->VisitType(op->type_annotation); + if (!op->type_annotation.same_as(type)) { + return VarNode::make(op->name_hint, type); + } + } + // default case return self. return GetRef<Expr>(op); } @@ -55,16 +65,6 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) { } } -Expr ExprMutator::VisitExpr_(const ParamNode* op) { - Var var = Downcast<Var>(this->Mutate(op->var)); - auto type = this->VisitType(op->type); - if (op->var.same_as(var) && op->type.same_as(type)) { - return GetRef<Expr>(op); - } else { - return ParamNode::make(var, type); - } -} - Expr ExprMutator::VisitExpr_(const FunctionNode* op) { tvm::Array<TypeParam> ty_params; bool all_ty_params_changed = true; @@ -75,10 +75,10 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { all_ty_params_changed &= new_ty_param.same_as(ty_param); } - tvm::Array<Param> params; + tvm::Array<Var> params; bool all_params_changed = true; for (auto param : op->params) { - Param new_param = Downcast<Param>(this->Mutate(param)); + Var new_param = Downcast<Var>(this->Mutate(param)); params.push_back(new_param); all_params_changed &= param.same_as(new_param); } @@ -123,17 +123,15 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) { Expr ExprMutator::VisitExpr_(const LetNode* op) { Var var = Downcast<Var>(this->Mutate(op->var)); - auto type = this->VisitType(op->value_type); auto value = this->Mutate(op->value); auto body = this->Mutate(op->body); if (var.same_as(op->var) && - type.same_as(op->value_type) && value.same_as(op->value) && body.same_as(op->body)) { return GetRef<Expr>(op); } else { - return LetNode::make(var, value, body, type); + return LetNode::make(var, value, body); } } @@ -162,6 +160,9 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) { Type ExprMutator::VisitType(const Type& t) { return t; } void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { + if (op->type_annotation.defined()) { + this->VisitType(op->type_annotation); + } } void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) { @@ -176,10 +177,6 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) { } } -void ExprVisitor::ExprVisitor::VisitExpr_(const ParamNode* op) { - this->VisitExpr(op->var); -} - void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) { for (auto param : op->params) { this->VisitExpr(param); diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 0ed0e3df3056f3951e02675804b3eedcce1c61db..29d2f87cf04aeb740a3f819837ca874737e7a927 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -252,15 +252,6 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { } } - void VisitExpr_(const ParamNode* p1, const Expr& e2) final { - if (const ParamNode* p2 = e2.as<ParamNode>()) { - eq_map.Set(p1->var, p2->var); - equal = equal && AlphaEqual(p1->type, p2->type); - } else { - equal = false; - } - } - void VisitExpr_(const FunctionNode* func1, const Expr& e2) final { if (const FunctionNode* func2 = e2.as<FunctionNode>()) { if (func1->params.size() != func2->params.size()) { @@ -273,9 +264,10 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { return; } - for (size_t i = 0U; i < func1->params.size(); i++) { - this->VisitExpr(func1->params[i], func2->params[i]); + for (size_t i = 0; i < func1->params.size(); ++i) { + MergeVarDecl(func1->params[i], func2->params[i]); } + if (!equal) return; for (size_t i = 0U; i < func1->type_params.size(); i++) { equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]); @@ -332,19 +324,9 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { void VisitExpr_(const LetNode* op, const Expr& e2) final { if (const LetNode* let = e2.as<LetNode>()) { - eq_map.Set(op->var, let->var); + MergeVarDecl(op->var, let->var); this->VisitExpr(op->value, let->value); this->VisitExpr(op->body, let->body); - - // value_type should match as well (including nulls) - if (op->value_type.defined() != let->value_type.defined()) { - equal = false; - return; - } - - if (op->value_type.defined()) { - equal = equal && AlphaEqual(op->value_type, let->value_type); - } } else { equal = false; } @@ -388,6 +370,20 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { equal = false; } } + + private: + void MergeVarDecl(const Var& var1, const Var& var2) { + if (var1->type_annotation.defined() != var2->type_annotation.defined()) { + equal = false; + return; + } + if (var1->type_annotation.defined() && + !AlphaEqual(var1->type_annotation, var2->type_annotation)) { + equal = false; + return; + } + eq_map.Set(var1, var2); + } }; bool AlphaEqual(const Expr& e1, const Expr& e2) { diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index 05036042a6354a2caf9335a6fa77238da385555d..2e2eca1f2739ee44bec82b642b739f1743551c3d 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -54,12 +54,7 @@ class CalcDep : private ExprMutator { } private: - struct Binder { - Type t; - Expr e; - Binder(const Type& t, const Expr& e) : t(t), e(e) { } - }; - using VarMap = std::unordered_map<Var, Binder, NodeHash, NodeEqual>; + using VarMap = std::unordered_map<Var, Expr, NodeHash, NodeEqual>; VarMap var_map_; Expr VisitExpr_(const IfNode* i) final { @@ -74,9 +69,7 @@ class CalcDep : private ExprMutator { } Expr VisitExpr_(const LetNode* l) final { - var_map_.insert(std::pair<Var, Binder>(l->var, - Binder(l->value_type, - Eliminate(l->value)))); + var_map_[l->var] = Eliminate(l->value); return VisitExpr(l->body); } @@ -92,15 +85,16 @@ class CalcDep : private ExprMutator { explicit GenLet(const VarMap& var_map) : var_map_(var_map) { } friend CalcDep; - void VisitExpr_(const VarNode* vn) final { - Var v = GetRef<Var>(vn); - if (var_map_.count(v) != 0) { - auto val = var_map_.at(v); - var_map_.erase(v); + void VisitExpr_(const VarNode* vnode) final { + Var v = GetRef<Var>(vnode); + auto it = var_map_.find(v); + if (it != var_map_.end()) { + Expr expr = it->second; + var_map_.erase(it); // erase before visit to handle letrec - VisitExpr(val.e); + VisitExpr(expr); // visit before push back so the dependency of dependency is before the dependency - lets_.Push(v, val.t, val.e); + lets_.Push(v, expr); } } }; diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h index d13358fe0e302f073627a0e93d329d36f17db036..43b8bb8bba1d5aaea44c73d4af1479c23570a2c2 100644 --- a/src/relay/pass/let_list.h +++ b/src/relay/pass/let_list.h @@ -26,57 +26,46 @@ namespace relay { */ class LetList { public: - /*! \brief insert a binding. + /*! + * \brief insert a binding. * - * \param pv the var of the binding. + * \param pv the var of the binding. * - * \param ty the type of the binding. + * \param expr the value of the binding. * - * \param expr the value of the binding. - * - * \return a Var that hold the inserted expr. + * \return a Var that hold the inserted expr. */ - Var Push(const Var& pv, const Type& ty, const Expr& expr) { - std::tuple<Var, Type, Expr> tuple(pv, ty, expr); - lets_.push_back(tuple); + Var Push(Var pv, Expr expr) { + lets_.emplace_back(std::make_pair(pv, expr)); return pv; } - /*! \brief insert a binding. + /*! + * \brief insert a binding. * - * \param ty the type of the binding. + * \param ty the type of the binding. * - * \param expr the value of the binding. + * \param expr the value of the binding. * - * \return a Var that hold the inserted expr. - */ - Var Push(const Type& ty, const Expr& expr) { - return Push(VarNode::make("x"), ty, expr); - } - - /*! \brief insert a binding. - * - * \param pv the var of the binding. - * - * \param expr the value of the binding. - * - * \return a Var that hold the inserted expr. + * \return a Var that hold the inserted expr. */ - Var Push(const Var& pv, const Expr& expr) { - return Push(pv, IncompleteTypeNode::make(TypeParamNode::kType), expr); + Var Push(Type ty, Expr expr) { + return Push(VarNode::make("x", ty), expr); } - /*! \brief insert a binding. + /*! + * \brief insert a binding. * * \param expr the value of the binding. * * \return a Var that hold the inserted expr. */ - Var Push(const Expr& expr) { + Var Push(Expr expr) { return Push(IncompleteTypeNode::make(TypeParamNode::kType), expr); } - /*! \brief wrap an expr around the LetList. + /*! + * \brief wrap an expr around the LetList. * * \param body the Expression to be wrapped around. * @@ -85,7 +74,7 @@ class LetList { Expr Get(const Expr& body) const { Expr ret = body; for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) { - ret = LetNode::make(std::get<0>(*rit), std::get<2>(*rit), ret, std::get<1>(*rit)); + ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret); } return ret; } @@ -118,7 +107,7 @@ class LetList { } private: - std::vector<std::tuple<Var, Type, Expr> > lets_; + std::vector<std::pair<Var, Expr> > lets_; }; } // namespace relay diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 72bdaf69f061e00a2c4add010343ea0c6bb995ec..1b30865eacb1ad6e1d428e3d9a586d059a5ced69 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -87,15 +87,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { // Visitor logics Type VisitExpr_(const VarNode* op) final { - // The type of Var can already been lookedup in type_map_; - LOG(FATAL) << "Cannot find binding for var " << GetRef<Var>(op); - return Type(); - } - - Type VisitExpr_(const ParamNode* op) final { - // directly handled by Funtion - LOG(FATAL) << "not reached"; - return Type(); + if (op->type_annotation.defined()) { + return op->type_annotation; + } else { + return IncompleteTypeNode::make(TypeParamNode::kType); + } } Type VisitExpr_(const GlobalVarNode* op) final { @@ -139,11 +135,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { Type VisitExpr_(const LetNode* op) final { Type vtype = GetType(op->value); - if (op->value_type.defined()) { - vtype = Unify(vtype, op->value_type, op->span); + if (op->var->type_annotation.defined()) { + vtype = Unify(vtype, op->var->type_annotation, op->span); } CHECK(!type_map_.count(op->var)); - // NOTE: no scoping is necessary becase var are unique in program + // NOTE: no scoping is necessary because var are unique in program type_map_[op->var] = vtype; return GetType(op->body); } @@ -256,8 +252,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { Type VisitExpr_(const FunctionNode* f) final { for (auto param : f->params) { - type_map_[param->var] = param->type; - type_map_[param] = param->type; + GetType(param); } Type rtype = GetType(f->body); // Run solver using the currently known information @@ -265,8 +260,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { // Trying to resolve Array<Type> arg_types; for (size_t i = 0; i < f->params.size(); ++i) { - Param param = f->params[i]; - Type atype = solver_.Resolve(param->type); + Type atype = solver_.Resolve(GetType(f->params[i])); CHECK(atype.as<IncompleteTypeNode>() == nullptr) << "Cannot resolve type of " << i << "-th parameter of function at" << f->span; @@ -311,9 +305,6 @@ class TypeInferencer::Resolver : public ExprMutator { return AttachCheckedType(op); } - Expr VisitExpr_(const ParamNode* op) final { - return ExprMutator::VisitExpr_(op); - } Expr VisitExpr_(const FunctionNode* op) final { return AttachCheckedType(op); @@ -380,7 +371,7 @@ Expr InferType(const Environment& env, const GlobalVar& var, const Function& func) { Function func_copy = Function(make_node<FunctionNode>(*func.operator->())); - func_copy->checked_type_ = func_copy->fn_type(); + func_copy->checked_type_ = func_copy->func_type_annotation(); env->functions.Set(var, func_copy); Expr func_ret = TypeInferencer(env).Infer(func_copy); auto map_node = env->functions.CopyOnWrite(); diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 5f87c3d4cb89dc0c9e6b0c13bb65f5a05f937337..c845995b20030f9ee7c9dbd52d1ea28fd4274a2a 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -50,14 +50,17 @@ class FreeVar : public ExprVisitor { if (bound_vars.count(var) == 0) { free_vars.insert(var); } + if (v->type_annotation.defined()) { + VisitType(v->type_annotation); + } } void VisitExpr_(const FunctionNode *f) final { for (const auto& tp : f->type_params) { bound_types.insert(tp); } - for (const auto& p : f->params) { - bound_vars.insert(p->var); + for (const auto& param : f->params) { + bound_vars.insert(param); } VisitExpr(f->body); VisitType(f->ret_type); @@ -67,7 +70,6 @@ class FreeVar : public ExprVisitor { bound_vars.insert(l->var); VisitExpr(l->value); VisitExpr(l->body); - VisitType(l->value_type); } public: diff --git a/src/relay/pass/well_formed.cc b/src/relay/pass/well_formed.cc index a9bce74926bfa59d495e7b9d6d68b6981c88bbf7..e008a72e5d9007fbca0126787d3638dc362b14f5 100644 --- a/src/relay/pass/well_formed.cc +++ b/src/relay/pass/well_formed.cc @@ -34,8 +34,8 @@ class WellFormedChecker : private ExprVisitor { } void VisitExpr_(const FunctionNode * f) final { - for (const Param & p : f->params) { - Check(p->var); + for (const Var & param : f->params) { + Check(param); } CheckWellFormed(f->body); } diff --git a/tests/python/relay/test_ir_builder.py b/tests/python/relay/test_ir_builder.py index c98f920ca49151515bd51c81af21d958193ef3e3..165c66f17ac354c543733b0af2dfaff676db9255 100644 --- a/tests/python/relay/test_ir_builder.py +++ b/tests/python/relay/test_ir_builder.py @@ -14,7 +14,6 @@ def test_let(): assert var == prog.body assert isinstance(value, Constant) assert value.data.asnumpy() == np.array(1) - assert prog.value_type == None if __name__ == "__main__": test_let() diff --git a/tests/python/relay/test_ir_debug_printer.py b/tests/python/relay/test_ir_debug_printer.py index e5f9ad2e69cd88201479e19f4e490a032024c466..b8aa86a8763876be68f798c8723f06034d21fa0d 100644 --- a/tests/python/relay/test_ir_debug_printer.py +++ b/tests/python/relay/test_ir_debug_printer.py @@ -49,18 +49,11 @@ def test_global_var(): show(gv) -def test_param(): - lv = relay.Var('x') - ty = None - param = relay.Param(lv, ty) - show(lv) - - def test_function(): param_names = ['a', 'b', 'c', 'd'] - params = tvm.convert([relay.Param(relay.Var(n), None) for n in param_names]) + params = tvm.convert([relay.Var(n) for n in param_names]) ret_type = None - body = params[0].var + body = params[0] type_params = tvm.convert([]) fn = relay.Function(params, ret_type, body, type_params) show(fn) @@ -76,11 +69,11 @@ def test_call(): def test_let(): - lv = relay.Var('x') ty = relay.ty.TensorType((10, 20), 'float32') + lv = relay.Var('x', ty) arr = tvm.nd.array(10) value = relay.Constant(arr) - let = relay.Let(lv, value, lv, ty) + let = relay.Let(lv, value, lv) show(let) diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index 79883ed225e08b8422290636a2b993d11abdb8f7..e571f2a9c99a0ca66ef9ff04afcfb9d26bd3e2ff 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -99,10 +99,16 @@ def test_tuple(): def test_local_var(): name_hint = 's' lv = relay.Var(name_hint) - lv.name_hint == name_hint + assert lv.name_hint == name_hint + assert lv.type_annotation is None # assert lv.span == None todo(@jroesch): what do we do about spans str(lv) + t1 = relay.ty.TensorType((), "float") + lv = relay.Var(name_hint, t1) + assert lv.name_hint == name_hint + assert lv.type_annotation == t1 + def test_global_var(): name_hint = 'g' @@ -112,19 +118,9 @@ def test_global_var(): str(gv) -def test_param(): - lv = relay.Var('x') - ty = None - param = relay.Param(lv, ty) - assert param.var == lv - assert param.type == ty - assert param.span == None - str(param) - - def test_function(): param_names = ['a', 'b', 'c', 'd'] - params = tvm.convert([relay.Param(relay.Var(n), None) for n in param_names]) + params = tvm.convert([relay.Var(n) for n in param_names]) ret_type = None body = None type_params = tvm.convert([]) @@ -154,10 +150,9 @@ def test_let(): value = relay.Constant(arr) # I would prefer that the order of arguments # matches syntax let x: t = v in b - let = relay.Let(lv, value, lv, ty) + let = relay.Let(lv, value, lv) assert let.var == lv assert let.value == value - assert let.value_type == ty assert let.body == lv assert let.span == None str(let) @@ -194,7 +189,6 @@ if __name__ == "__main__": test_tuple() test_local_var() test_global_var() - test_param() test_function() test_call() test_let() diff --git a/tests/python/relay/test_ir_well_formed.py b/tests/python/relay/test_ir_well_formed.py index c6cb99662bb5e7e2fc7cf2b40fd9a25d56b0ed78..d555c2beb6272a0cda1d200d148b35b650067247 100644 --- a/tests/python/relay/test_ir_well_formed.py +++ b/tests/python/relay/test_ir_well_formed.py @@ -7,23 +7,22 @@ def test_well_formed(): assert well_formed(x) v = relay.Constant(tvm.nd.array(10)) ty = None - let = relay.Let(x, v, x, ty) + let = relay.Let(x, v, x) assert well_formed(let) - assert not well_formed(relay.Let(x, v, let, ty)) - f = relay.Function([relay.Param(x, ty)], ty, x) + assert not well_formed(relay.Let(x, v, let)) + f = relay.Function([x], ty, x) assert well_formed(f) # this test should pass in case of weak uniqueness (only test for shadowing) # but we want all binder to be distinct from each other. assert not well_formed(relay.Let(relay.Var("y"), f, - relay.Let(relay.Var("z"), f, v, ty), ty)) + relay.Let(relay.Var("z"), f, v))) def test_tuple(): x = relay.Var('x') assert well_formed(x) v = relay.Constant(tvm.nd.array(10)) - ty = None - let = relay.Let(x, v, x, ty) + let = relay.Let(x, v, x) assert well_formed(let) assert well_formed(relay.Tuple([v, v])) assert not well_formed(relay.Tuple([let, let])) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index a90f6eb55ae1d5b125543907db7334c4a92abae0..05c02ab5d197a957555db998f0e9ecf92896bac6 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -27,6 +27,8 @@ def test_single_op(): tvm.relay.sigmoid, tvm.relay.tanh]: check_single_op(opfunc) + + def test_expand_dims_infer_type(): ib = relay.ir_builder.IRBuilder() n, t, d = tvm.var("n"), tvm.var("t"), 100 @@ -75,12 +77,13 @@ def test_unary_op(): ib = relay.ir_builder.IRBuilder() x = ib.param("x", relay.TensorType((10, 4), "int32")) with ib.function(x) as func: - ib.ret(op(x.var)) + ib.ret(op(x)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type assert ftype.ret_type == relay.TensorType((10, 4), "int32") + def test_binary_op(): def check_binary_op(opfunc): """ @@ -94,7 +97,7 @@ def test_binary_op(): x = b.param('x', tensor_type(5, 5, 5)) y = b.param('y', tensor_type(5, 5, 5)) with b.function(x, y) as func: - b.ret(opfunc(x.var, y.var)) + b.ret(opfunc(x, y)) b.ret(func) prog, env = b.get() ttype = tensor_type(5, 5, 5) @@ -118,7 +121,7 @@ def test_binary_broadcast_op(): x = b.param('x', tensor_type(10, 4)) y = b.param('y', tensor_type(5, 10, 1)) with b.function(x, y) as func: - b.ret(opfunc(x.var, y.var)) + b.ret(opfunc(x, y)) b.ret(func) prog, env = b.get() diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index f67faea19be11bb93597e378e74273407272fe72..d0d02aece06d0b4846a211f13fd71ab6ab9baf77 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -11,7 +11,7 @@ def test_conv2d_infer_type(): w = ib.param("w", relay.ty.IncompleteType()) with ib.function(x, w) as func: - ib.ret(relay.nn.conv2d(x.var, w.var, + ib.ret(relay.nn.conv2d(x, w, kernel_size=(3, 3), padding=(1, 1), channels=2)) @@ -29,7 +29,7 @@ def test_conv2d_infer_type(): x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8")) w = ib.param("w", relay.ty.TensorType((2, 10, 3, 3), "int8")) with ib.function(x, w) as func: - ib.ret(relay.nn.conv2d(x.var, w.var, out_dtype="int32")) + ib.ret(relay.nn.conv2d(x, w, out_dtype="int32")) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -42,7 +42,7 @@ def test_conv2d_infer_type(): x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8")) w = ib.param("w", relay.ty.IncompleteType()) with ib.function(x, w) as func: - ib.ret(relay.nn.conv2d(x.var, w.var, + ib.ret(relay.nn.conv2d(x, w, kernel_size=(3, 3), padding=(1, 1), channels=16, @@ -65,7 +65,7 @@ def test_conv2d_transpose_infer_type(): w = ib.param("w", relay.ty.IncompleteType()) with ib.function(x, w) as func: - ib.ret(relay.nn.conv2d_transpose(x.var, w.var, + ib.ret(relay.nn.conv2d_transpose(x, w, kernel_size=(3, 3), padding=(1, 1), channels=15)) @@ -83,7 +83,7 @@ def test_conv2d_transpose_infer_type(): x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) w = ib.param("w", relay.ty.TensorType((12, 11, 5, 5), "float32")) with ib.function(x, w) as func: - ib.ret(relay.nn.conv2d_transpose(x.var, w.var, + ib.ret(relay.nn.conv2d_transpose(x, w, output_padding=(1, 1), channels=11, data_layout="NHWC")) @@ -98,7 +98,7 @@ def test_upsampling_infer_type(): n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) with ib.function(x) as func: - ib.ret(relay.nn.upsampling(x.var, scale=2, layout="NCHW", method="BILINEAR")) + ib.ret(relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR")) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -108,7 +108,7 @@ def test_upsampling_infer_type(): n, c = tvm.var("n"), tvm.var("c") x = ib.param("x", relay.ty.TensorType((n, c, 100, 200), "float32")) with ib.function(x) as func: - ib.ret(relay.nn.upsampling(x.var, scale=2, layout="NCHW", method="BILINEAR")) + ib.ret(relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR")) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -119,7 +119,7 @@ def _test_pool2d_infer_type(opfunc): n, c, h, w = tvm.var("n"), 10, 224, 224 x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) with ib.function(x) as func: - ib.ret(opfunc(x.var, pool_size=(1, 1))) + ib.ret(opfunc(x, pool_size=(1, 1))) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -132,7 +132,7 @@ def _test_pool2d_infer_type(opfunc): n, c, h, w = tvm.var("n"), 10, 224, 224 x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) with ib.function(x) as func: - ib.ret(opfunc(x.var, pool_size=(ph, pw), strides=(sh, sw))) + ib.ret(opfunc(x, pool_size=(ph, pw), strides=(sh, sw))) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -144,7 +144,7 @@ def _test_global_pool2d_infer_type(opfunc): n, c, h, w = tvm.var("n"), tvm.var("c"), 224, 224 x = ib.param("x", relay.ty.TensorType((n, h, w, c), "float32")) with ib.function(x) as func: - ib.ret(opfunc(x.var, layout="NHWC")) + ib.ret(opfunc(x, layout="NHWC")) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -154,7 +154,7 @@ def _test_global_pool2d_infer_type(opfunc): n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) with ib.function(x) as func: - ib.ret(opfunc(x.var)) + ib.ret(opfunc(x)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -172,7 +172,7 @@ def test_flatten_infer_type(): x = ib.param("x", relay.ty.TensorType((d1, d2, d3, d4), "float32")) with ib.function(x) as func: - ib.ret(relay.nn.batch_flatten(x.var)) + ib.ret(relay.nn.batch_flatten(x)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -181,7 +181,7 @@ def test_flatten_infer_type(): ib = relay.ir_builder.IRBuilder() x = ib.param("x", relay.ty.TensorType((3, 2, 4, 3), "float32")) with ib.function(x) as func: - ib.ret(relay.nn.batch_flatten(x.var)) + ib.ret(relay.nn.batch_flatten(x)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -190,7 +190,7 @@ def test_flatten_infer_type(): ib = relay.ir_builder.IRBuilder() x = ib.param("x", relay.ty.TensorType((d1, 2, d3, 3), "float32")) with ib.function(x) as func: - ib.ret(relay.nn.batch_flatten(x.var)) + ib.ret(relay.nn.batch_flatten(x)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -202,7 +202,7 @@ def test_pad_infer_type(): n, c, h, w = 1, 2, 3, 4 t = ib.param("t", relay.TensorType((n, c, h, w), "float32")) with ib.function(t) as func: - ib.ret(relay.nn.pad(t.var, ((1, 1), (2, 2), (3, 3), (4, 4)))) + ib.ret(relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4)))) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -213,7 +213,7 @@ def test_pad_infer_type(): n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w") t = ib.param("t", relay.TensorType((n, c, h, w), "float32")) with ib.function(t) as func: - ib.ret(relay.nn.pad(t.var, ((1, 1), (2, 2), (3, 3), (4, 4)))) + ib.ret(relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4)))) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -227,4 +227,3 @@ if __name__ == "__main__": test_flatten_infer_type() test_pad_infer_type() test_conv2d_transpose_infer_type() - diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 9515db87e64aae07e60972033dbb71ec943508f2..7d949b21026b4ee5a39e4fc832b0a73091c635cd 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -17,12 +17,13 @@ def test_zeros_ones(): ftype = func.checked_type assert ftype.ret_type == relay.TensorType((124, 50), "float64") + def test_unary_identity(): for op in [relay.zeros_like, relay.ones_like]: ib = relay.ir_builder.IRBuilder() x = ib.param("x", relay.TensorType((8, 9, 4), "int32")) with ib.function(x) as func: - ib.ret(op(x.var)) + ib.ret(op(x)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -33,7 +34,7 @@ def test_clip_type(): ib = relay.ir_builder.IRBuilder() a = ib.param("a", relay.TensorType((10, 4), "float32")) with ib.function(a) as func: - ib.ret(relay.clip(a.var, 1., 4.)) + ib.ret(relay.clip(a, 1., 4.)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -106,7 +107,7 @@ def test_take_infer_type(): x = ib.param("x", relay.ty.TensorType(dshape, "float32")) indices = ib.param("indices", relay.ty.TensorType(indices_shape, "int32")) with ib.function(x, indices) as func: - ib.ret(relay.take(x.var, indices.var, axis=axis)) + ib.ret(relay.take(x, indices, axis=axis)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -127,7 +128,7 @@ def test_full(): ib = relay.ir_builder.IRBuilder() x = ib.param("x", relay.TensorType((), "int8")) with ib.function(x) as func: - ib.ret(relay.full(x.var, ())) + ib.ret(relay.full(x, ())) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -137,7 +138,7 @@ def test_full(): ib = relay.ir_builder.IRBuilder() x = ib.param("x", relay.TensorType((), "float32")) with ib.function(x) as func: - ib.ret(relay.full(x.var, (1, 2), "int8")) + ib.ret(relay.full(x, (1, 2), "int8")) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -150,7 +151,7 @@ def test_full_like(): base = ib.param("base", relay.TensorType((1, 2, 3), "float32")) fill = ib.param("fill", relay.TensorType((), "float32")) with ib.function(base, fill) as func: - ib.ret(relay.full_like(base.var, fill.var)) + ib.ret(relay.full_like(base, fill)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -162,7 +163,7 @@ def test_full_like(): base = ib.param("base", relay.TensorType((n, c, h, w), "float32")) fill = ib.param("fill", relay.TensorType((), "float32")) with ib.function(base, fill) as func: - ib.ret(relay.full_like(base.var, fill.var)) + ib.ret(relay.full_like(base, fill)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 807d3a3a964e2875faeabb7631c0e80f16b94820..995e15fb9760d161bfafaa893f3a1a6a9aa575e1 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -24,7 +24,7 @@ def test_cmp_type(): x = ib.param("x", relay.TensorType((10, 4), "float32")) y = ib.param("y", relay.TensorType((5, 10, 1), "float32")) with ib.function(x, y) as func: - ib.ret(op(x.var, y.var)) + ib.ret(op(x, y)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -39,7 +39,7 @@ def test_binary_broadcast(): x = ib.param("x", relay.TensorType((10, 4), "int32")) y = ib.param("y", relay.TensorType((5, 10, 1), "int32")) with ib.function(x, y) as func: - ib.ret(op(x.var, y.var)) + ib.ret(op(x, y)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -58,7 +58,7 @@ def test_binary_op(): x = b.param('x', tensor_type(5, 5, 5)) y = b.param('y', tensor_type(5, 5, 5)) with b.function(x, y) as func: - b.ret(opfunc(x.var, y.var)) + b.ret(opfunc(x, y)) b.ret(func) prog, env = b.get() ttype = tensor_type(5, 5, 5) @@ -81,7 +81,7 @@ def test_binary_broadcast_op(): x = b.param('x', tensor_type(10, 4)) y = b.param('y', tensor_type(5, 10, 1)) with b.function(x, y) as func: - b.ret(opfunc(x.var, y.var)) + b.ret(opfunc(x, y)) b.ret(func) prog, env = b.get() @@ -103,7 +103,7 @@ def test_cmp_type(): x = ib.param("x", relay.TensorType((10, 4), "float32")) y = ib.param("y", relay.TensorType((5, 10, 1), "float32")) with ib.function(x, y) as func: - ib.ret(op(x.var, y.var)) + ib.ret(op(x, y)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -118,7 +118,7 @@ def test_binary_broadcast(): x = ib.param("x", relay.TensorType((10, 4), "int32")) y = ib.param("y", relay.TensorType((5, 10, 1), "int32")) with ib.function(x, y) as func: - ib.ret(op(x.var, y.var)) + ib.ret(op(x, y)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -131,7 +131,7 @@ def test_where(): x = ib.param("x", relay.TensorType((3, 4), "float32")) y = ib.param("y", relay.TensorType((3, 4), "float32")) with ib.function(cond, x, y) as func: - ib.ret(relay.where(cond.var, x.var, y.var)) + ib.ret(relay.where(cond, x, y)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 62da592e8249379cca4665c774cf738b9ad416f6..8d871e9ef4f5510dffa516d9ff84894075aca747 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -10,7 +10,7 @@ def test_resize_infer_type(): th, tw = tvm.var("th"), tvm.var("tw") with ib.function(x) as func: - ib.ret(relay.image.resize(x.var, (th, tw))) + ib.ret(relay.image.resize(x, (th, tw))) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type @@ -19,7 +19,7 @@ def test_resize_infer_type(): ib = relay.ir_builder.IRBuilder() x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8")) with ib.function(x) as func: - ib.ret(relay.image.resize(x.var, (100, 200), "NCHW", "BILINEAR", False)) + ib.ret(relay.image.resize(x, (100, 200), "NCHW", "BILINEAR", False)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) ftype = func.checked_type diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index dd722399dac43acc0cc94a68468ee7d7f00eccb9..04ef3cf3da8f4a65cbf529b5afcd25ff4eb14f24 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -1,4 +1,5 @@ import tvm +import numpy as np from tvm import relay from tvm.relay.ir_pass import alpha_equal from tvm.relay.ir_builder import convert @@ -179,9 +180,9 @@ def test_var_alpha_equal(): assert not alpha_equal(v1, v2) # let node allows for setting the eq_map - l1 = relay.Let(v1, convert(1), v1, None) - l2 = relay.Let(v2, convert(1), v2, None) - l3 = relay.Let(v1, convert(1), v2, None) + l1 = relay.Let(v1, convert(1), v1) + l2 = relay.Let(v2, convert(1), v2) + l3 = relay.Let(v1, convert(1), v2) assert alpha_equal(l1, l2) assert not alpha_equal(l1, l3) @@ -209,10 +210,10 @@ def test_tuple_alpha_equal(): assert alpha_equal(tup, same) # use the eq_map - let_tup = relay.Let(v1, tup, v1, None) + let_tup = relay.Let(v1, tup, v1) let_mapped = relay.Let(v2, relay.Tuple([v2, convert(2), convert(3), relay.Tuple([convert(4)])]), - v2, None) + v2) assert alpha_equal(let_tup, let_mapped) more_fields = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)]), v2]) @@ -242,61 +243,44 @@ def test_tuple_get_item_alpha_equal(): assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1)) -def test_param_alpha_equal(): - # only checks equality of the types - v1 = relay.Var("v1") - v2 = relay.Var("v2") - - p1 = relay.Param(v1, relay.TensorType((1, 2, 3), "float32")) - p2 = relay.Param(v2, relay.TensorType((1, 2, 3), "float32")) - assert alpha_equal(p1, p2) - - p3 = relay.Param(v1, relay.TensorType((4, 5, 6), "int8")) - assert not alpha_equal(p1, p3) - - p4 = relay.Param(v1, relay.TupleType([relay.TensorType((1, 2, 3), - "float32")])) - assert not alpha_equal(p1, p4) - - def test_function_alpha_equal(): - v1 = relay.Var("v1") - v2 = relay.Var("v2") - v3 = relay.Var("v3") - v4 = relay.Var("v4") - tt1 = relay.TensorType((1, 2, 3), "float32") tt2 = relay.TensorType((4, 5, 6), "int8") tt3 = relay.TupleType([tt1, tt2]) + v1 = relay.Var("v1", tt1) + v2 = relay.Var("v2", tt2) + v3 = relay.Var("v3", tt3) + v4 = relay.Var("v4", tt2) + vret = relay.Constant(tvm.nd.array(np.ones(1))) + tp1 = relay.TypeParam("tp1", relay.Kind.Type) tp2 = relay.TypeParam("tp2", relay.Kind.Type) tp3 = relay.TypeParam("tp3", relay.Kind.Shape) tp4 = relay.TypeParam("tp4", relay.Kind.Shape) - basic_args = [relay.Param(v3, tt1), relay.Param(v4, tt2)] + basic_args = [relay.Var("v3", tt1), relay.Var("v4", tt2)] basic_tps = [tp1, tp2] - func = relay.Function([relay.Param(v1, tt1), relay.Param(v2, tt2)], - tt2, v2, basic_tps) - mapped = relay.Function(basic_args, tt2, v4, basic_tps) + func = relay.Function([v1, v2], + tt2, v1, basic_tps) + mapped = relay.Function(basic_args, tt2, basic_args[0], basic_tps) assert alpha_equal(func, mapped) - fewer_params = relay.Function([relay.Param(v4, tt2)], tt2, v4, basic_tps) + fewer_params = relay.Function([relay.Var("v4", tt2)], tt2, v4, basic_tps) assert not alpha_equal(func, fewer_params) - more_params = relay.Function([relay.Param(v3, tt1), relay.Param(v4, tt2), - relay.Param(v2, tt2)], tt2, v4, basic_tps) + more_params = relay.Function([relay.Var("v3", tt1), + relay.Var("v4", tt2), + relay.Var("v2", tt2)], tt2, v4, basic_tps) assert not alpha_equal(func, more_params) - params_unordered = relay.Function([relay.Param(v3, tt2), - relay.Param(v4, tt1)], - tt1, v3, basic_tps) + params_unordered = relay.Function([v2, v1], + tt2, v1, basic_tps) assert not alpha_equal(func, params_unordered) - params_mismatch = relay.Function([relay.Param(v3, tt3), - relay.Param(v4, tt2)], - tt2, v4, basic_tps) + params_mismatch = relay.Function([v1, v3], + tt2, v1, basic_tps) assert not alpha_equal(func, params_mismatch) # also would not typecheck @@ -376,7 +360,10 @@ def test_call_alpha_equal(): def test_let_alpha_equal(): + tt1 = relay.TensorType((), "float32") + tt2 = relay.TensorType((), "int8") v1 = relay.Var("v1") + v1_wtype = relay.Var("v1", tt1) v2 = relay.Var("v2") v3 = relay.Var("v3") @@ -394,14 +381,13 @@ def test_let_alpha_equal(): assert not alpha_equal(let, different_body) # specified types must match - tt1 = relay.TensorType((), "float32") - tt2 = relay.TensorType((), "int8") - let_with_type = relay.Let(v1, convert(2), v1, tt1) - same_type = relay.Let(v1, convert(2), v1, tt1) + + let_with_type = relay.Let(v1_wtype, convert(2), v1_wtype) + same_type = relay.Let(v1_wtype, convert(2), v1_wtype) assert alpha_equal(let_with_type, same_type) assert not alpha_equal(let, let_with_type) - - different_type = relay.Let(v1, convert(2), v1, tt2) + v2 = relay.Var("v1", tt2) + different_type = relay.Let(v2, convert(2), v2) assert not alpha_equal(let_with_type, different_type) @@ -437,16 +423,13 @@ if __name__ == "__main__": test_tensor_type_alpha_equal() test_incomplete_type_alpha_equal() test_constant_alpha_equal() - test_type_param_alpha_equal() test_func_type_alpha_equal() test_tuple_type_alpha_equal() test_type_relation_alpha_equal() test_constant_alpha_equal() - test_var_alpha_equal() test_global_var_alpha_equal() test_tuple_alpha_equal() test_tuple_get_item_alpha_equal() - test_param_alpha_equal() test_function_alpha_equal() test_call_alpha_equal() test_let_alpha_equal() diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index ce9bda3d254f4d253603cf4234b73e050e46c58e..121cea0081bdd4fd86ee96eb7fd5eb082c54f727 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -28,17 +28,17 @@ e = env() def test_let(): - orig = relay.Let(e.x, e.y, e.z, e.tt) + orig = relay.Let(e.x, e.y, e.z) assert alpha_equal(dead_code_elimination(orig), e.z) def test_used_let(): - orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c, e.tt), e.tt) - assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c, e.tt)) + orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c)) + assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c)) def test_chain_unused_let(): - orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e, e.tt), e.tt) + orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e)) assert alpha_equal(dead_code_elimination(orig), e.e) @@ -56,19 +56,17 @@ def test_recursion(): f(2, 10000); """ f = relay.Var("f") - n = relay.Var("n") - np = relay.Param(n, e.int32) - data = relay.Var("data") - datap = relay.Param(data, e.float32) + n = relay.Var("n", e.int32) + data = relay.Var("data", e.float32) funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data))) - value = relay.Function([np, datap], e.float32, funcbody, []) - orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)), e.float32) + value = relay.Function([n, data], e.float32, funcbody, []) + orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0))) assert alpha_equal(dead_code_elimination(orig), orig) - assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three, e.float32)), e.three) + assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three) def test_op_let(): - assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three, e.float32), e.two)), add(e.three, e.two)) + assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two)) def test_if(): @@ -80,7 +78,7 @@ def test_tuple_get_item(): t = relay.Var('t') g = relay.TupleGetItem(t, 0) assert alpha_equal(dead_code_elimination(g), g) - assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t, e.float32), 0)), g) + assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t), 0)), g) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_free_vars.py b/tests/python/relay/test_pass_free_vars.py index 989c9f8d25dbcd79b32b8484e99dd58bbfcca612..a4c745de10e0737a14b65f4678c3eac381f1e4c7 100644 --- a/tests/python/relay/test_pass_free_vars.py +++ b/tests/python/relay/test_pass_free_vars.py @@ -3,16 +3,17 @@ from tvm import relay from tvm.relay.ir_pass import free_vars, free_type_vars def test_free_vars(): - x = relay.Var("x") + ty = relay.TensorType([], "int32") + x = relay.Var("x", ty) fvx = free_vars(x) assert len(fvx) == 1 assert fvx[0] == x v = relay.Constant(tvm.nd.array(10)) - ty = relay.TensorType([], "int32") - let = relay.Let(x, v, x, ty) + + let = relay.Let(x, v, x) fvx = free_vars(let) assert len(free_vars(let)) == 0 - f = relay.Function([relay.Param(x, ty)], ty, x) + f = relay.Function([x], ty, x) assert len(free_vars(f)) == 0 @@ -29,9 +30,9 @@ def test_tuple(): def test_free_type_vars(): tp = relay.TypeParam("") ty = relay.TupleType([tp, relay.TensorType([], "int32")]) - x = relay.Var("x") + x = relay.Var("x", ty) y = relay.Var("y") - let = relay.Let(x, y, x, ty) + let = relay.Let(x, y, x) fvl = free_vars(let) assert len(fvl) == 1 assert fvl[0] == y