diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 6eba6b25d9fd88a65ee86bf2334d829809e6683e..42883f5f77da2b32730052c0209448eb5cab2659 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -40,6 +40,8 @@ This level enables fully connected multi-layer perceptron. tvm.relay.nn.relu tvm.relay.nn.dropout tvm.relay.nn.batch_norm + tvm.relay.nn.bias_add + **Level 2: Convolutions** @@ -85,8 +87,13 @@ This level enables additional math and transform operators. tvm.relay.abs tvm.relay.negative tvm.relay.take + tvm.relay.zeros + tvm.relay.zeros_like + tvm.relay.ones + tvm.relay.ones_like tvm.relay.full tvm.relay.full_like + tvm.relay.cast **Level 4: Broadcast and Reductions** @@ -151,6 +158,9 @@ Level 1 Definitions .. autofunction:: tvm.relay.nn.softmax .. autofunction:: tvm.relay.nn.log_softmax .. autofunction:: tvm.relay.nn.relu +.. autofunction:: tvm.relay.nn.dropout +.. autofunction:: tvm.relay.nn.batch_norm +.. autofunction:: tvm.relay.nn.bias_add Level 2 Definitions @@ -185,6 +195,9 @@ Level 3 Definitions .. autofunction:: tvm.relay.zeros_like .. autofunction:: tvm.relay.ones .. autofunction:: tvm.relay.ones_like +.. autofunction:: tvm.relay.full +.. autofunction:: tvm.relay.full_like +.. autofunction:: tvm.relay.cast Level 4 Definitions diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 6b522ef3bfd08950ec52dce1a851ca2f39e93df6..eb044ccb29fd7cbbb452a398c435b8e75108c25f 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -12,6 +12,23 @@ namespace tvm { namespace relay { +/*! + * \brief Add a 1D Tensor to an axis of a data. + * + * \note bias_add is a special add operator that is in nn + * and enables automatic derivation of bias's shape. + * You can directly use add for more generalized case. + */ +struct BiasAddAttrs : public tvm::AttrsNode<BiasAddAttrs> { + int axis; + + TVM_DECLARE_ATTRS(BiasAddAttrs, "relay.attrs.BiasAddAttrs") { + TVM_ATTR_FIELD(axis) + .describe("The axis to add the bias") + .set_default(1); + } +}; + /*! \brief Attributes used in convolution operators */ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> { Array<IndexExpr> strides; diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 8e2b741091b3bfd1705e63a832f4cfe2aa055e85..1941e045ed8d46d43b6c8107411485d03c3c6635 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -12,6 +12,16 @@ namespace tvm { namespace relay { +/*! \brief data type cast */ +struct CastAttrs : public tvm::AttrsNode<CastAttrs> { + DataType dtype; + + TVM_DECLARE_ATTRS(CastAttrs, "relay.attrs.CastAttrs") { + TVM_ATTR_FIELD(dtype) + .describe("Target data type"); + } +}; // struct CastAttrs. + /*! \brief Attributes used in expand_dims operators */ struct ExpandDimsAttrs : public tvm::AttrsNode<ExpandDimsAttrs> { int axis; diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index c10933590f99a81460b07cd2cfcb4a9616d88637..c0256cf3a1c37bd81a1a5a2c09c3d9c4c289f125 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -112,15 +112,17 @@ class ExprFunctor<R(const Expr& n, Args...)> { } }; -/*! \brief A simple visitor wrapper around ExprFunctor. +/*! + * \brief A simple visitor wrapper around ExprFunctor. + * Recursively visit the content. * - * Exposes two visitors with default traversal strategies, one - * which doesn't compute a result but can mutate internal state, - * and another which functionally builds a new Expr. + * ExprVisitor treats Expr as dataflow graph, + * and only visit each Expr node once. */ - -class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> { +class ExprVisitor + : public ::tvm::relay::ExprFunctor<void(const Expr& n)> { public: + void VisitExpr(const Expr& expr) override; void VisitExpr_(const VarNode* op) override; void VisitExpr_(const GlobalVarNode* op) override; void VisitExpr_(const ConstantNode* op) override; @@ -132,13 +134,19 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> { void VisitExpr_(const OpNode* op) override; void VisitExpr_(const TupleGetItemNode* op) override; virtual void VisitType(const Type& t); + + private: + // internal visited flag. + std::unordered_set<const Node*> visited_; }; -/*! \brief A wrapper around ExprFunctor which functionally updates the AST. -* -* ExprMutator uses memoization and self return in order to amortize -* the cost of using functional updates. -*/ +/*! + * \brief A wrapper around ExprFunctor which functionally updates the AST. + * + * ExprMutator treats Expr as dataflow graph, and only Mutate each Expr once. + * The mutated results are memoized in a map and reused so that + * local transformation on the dataflow preserves the graph structure. + */ class ExprMutator : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> { public: diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 9a3b75364167279988396b7ac7712bc5848b0db8..1b3462659e18a015191136f6c2a9c8e05f746393 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -102,35 +102,26 @@ bool AlphaEqual(const Type& t1, const Type& t2); */ bool WellFormed(const Expr& e); -/*! \brief Get free variables from expression e. +/*! \brief Get free Vars from expr in PostDFS order. * - * Free variables are variables that are not bound by a let or a function parameter in the context. + * Free variables are variables that are not bound by a + * let or a function parameter in the context. * - * \param e the expression. + * \param expr the expression. * - * \return the set of free variable. + * \return List of free vars, in the PostDFS order visited by expr. */ -tvm::Array<Var> FreeVariables(const Expr& e); +tvm::Array<Var> FreeVars(const Expr& expr); -/*! \brief Get free type parameters from expression e. +/*! \brief Get free TypeVars from expression expr. * * Free type parameters are type parameters that are not bound by a function type in the context. * - * \param e the expression. + * \param expr the expression. * - * \return the set of free type variables. + * \return List of free vars, in the PostDFS order visited by expr. */ -tvm::Array<TypeVar> FreeTypeVariables(const Expr& e); - -/*! \brief Get free type parameters from type t. - * - * Free type parameters are type parameters that are not bound by a function type in the context. - * - * \param t the type. - * - * \return the set of free type variables. - */ -tvm::Array<TypeVar> FreeTypeVariables(const Type& t); +tvm::Array<TypeVar> FreeTypeVars(const Expr& expr); /*! \brief Remove expressions which does not effect the program result. * diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 00a523416c855ccba554cfe6e9c925518564c461..bdb253d21582487efcc2421f6b028f0933878d8a 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -299,6 +299,9 @@ class IntImm(ConstExpr): self.__init_handle_by_constructor__( _make.IntImm, dtype, value) + def __int__(self): + return self.value + @register_node class UIntImm(ConstExpr): diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 36116d07d6011597431337afd02ecfbe4a0d51a9..655379066c74aeb427a3a624e5654eda0400c6be 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -6,7 +6,7 @@ import numpy as _np from .base import RelayNode, register_relay_node from . import _make from . import ty as _ty -from .._ffi import base as _base, node as _node +from .._ffi import base as _base from .. import nd as _nd from .. import convert @@ -28,6 +28,25 @@ class Expr(RelayNode): " the checked_type for this node") return ret + def astype(self, dtype): + """Cast the content type of the current data to dtype. + + Parameters + ---------- + dtype : str + The target data type. + + Note + ---- + This function only works for TensorType Exprs. + + Returns + ------- + result : tvm.relay.Expr + The result expression. + """ + return _make.dtype_cast(self, dtype) + @register_relay_node class Constant(Expr): @@ -62,6 +81,9 @@ class Tuple(Expr): def __len__(self): return len(self.fields) + def astype(self, _): + raise TypeError("astype cannot be used on tuple") + @register_relay_node class Var(Expr): @@ -238,7 +260,7 @@ class TupleGetItem(Expr): _make.TupleGetItem, tuple_value, index) -class TupleWrapper(_node.NodeGeneric): +class TupleWrapper(object): """TupleWrapper. This class is a Python wrapper for a Relay tuple of known size. @@ -257,10 +279,9 @@ class TupleWrapper(_node.NodeGeneric): self.tuple_value = tuple_value self.size = size - def asnode(self): + def astuple(self): """Returns the underlying Relay tuple if this wrapper is passed as an argument to an FFI function.""" - return self.tuple_value def __getitem__(self, index): @@ -275,6 +296,9 @@ class TupleWrapper(_node.NodeGeneric): return ("TupleWrapper(" + self.tuple_value.__repr__() + ", " + self.size + ")") + def astype(self, _): + raise TypeError("astype cannot be used on tuple") + def var(name_hint, type_annotation=None, diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 22ee918039b581f615c8204b87e31842f7f67cdf..c6d5aa7515bccaec74d5bc0785fa86f8edc88501 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -15,16 +15,16 @@ def infer_type(expr, env=None): Parameters ---------- expr: tvm.relay.Expr - The input expression. + The input expression. env: Optional[tvm.relay.Environment] - The global environment. + The global environment. Returns ------- checked_expr : tvm.relay.Expr - The checked expression. + The checked expression. """ return _ir_pass.infer_type(expr, env) @@ -35,12 +35,12 @@ def well_formed(expr): Parameters ---------- expr: tvm.relay.Expr - The input expression + The input expression Returns ------- well_form : bool - whether the input expression is well formed + Whether the input expression is well formed """ return _ir_pass.well_formed(expr) @@ -52,15 +52,15 @@ def check_kind(t, env=None): Parameters ---------- t: tvm.relay.Type - The type to check + The type to check env: tvm.relay.Environment, optional - The global environment + The global environment Returns ------- well_kinded : bool - whether the input type is well kinded. + whether the input type is well kinded. Examples -------- @@ -75,20 +75,26 @@ def check_kind(t, env=None): return _ir_pass.check_kind(t) -def free_vars(e): - """Get free variables from expression e. +def free_vars(expr): + """Get free Vars from expression expr in Post DFS order. Parameters ---------- - e: tvm.relay.Expr - The input expression + expr: tvm.relay.Expr + The input expression Returns ------- free : List[tvm.relay.Var] - The list of free variables + The list of free variables in post DFS order. + + Note + ---- + The fact that Vars are post-DFS ordred are useful in + neural networks: usually this means weights of previous + are ordered first. """ - return _ir_pass.free_vars(e) + return _ir_pass.free_vars(expr) def free_type_vars(expr): @@ -130,15 +136,15 @@ def alpha_equal(lhs, rhs): Parameters ---------- lhs: tvm.relay.Expr - One of the input Expression. + One of the input Expression. rhs: tvm.relay.Expr - One of the input Expression. + One of the input Expression. Returns ------- result: bool - True iff lhs is alpha equal to rhs. + True iff lhs is alpha equal to rhs. """ return bool(_make._alpha_equal(lhs, rhs)) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 8a5357e4a2df1a804c5732dd953d19d18008ea22..d0ccfcb44899c1e0bde081549a64addd68c524da 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -43,10 +43,10 @@ def conv2d(data, Parameters ---------- - data : relay.Expr + data : tvm.relay.Expr The input data to the operator. - weight : relay.Expr + weight : tvm.relay.Expr The weight expressions. strides : tuple of int, optional @@ -81,7 +81,7 @@ def conv2d(data, Returns ------- - result : relay.Expr + result : tvm.relay.Expr The computed result. """ return _make.conv2d(data, weight, strides, padding, dilation, @@ -105,10 +105,10 @@ def conv2d_transpose(data, Parameters ---------- - data : relay.Expr + data : tvm.relay.Expr The input data to the operator. - weight : relay.Expr + weight : tvm.relay.Expr The weight expressions. strides : Tuple[int], optional @@ -137,7 +137,7 @@ def conv2d_transpose(data, Returns ------- - result : relay.Expr + result : tvm.relay.Expr The computed result. """ return _make.conv2d_transpose(data, weight, strides, padding, dilation, @@ -155,7 +155,7 @@ def softmax(data, axis=1): Parameters ---------- - data: relay.Expr + data: tvm.relay.Expr The input data to the operator. axis: int, optional @@ -163,7 +163,7 @@ def softmax(data, axis=1): Returns ------- - result : relay.Expr + result : tvm.relay.Expr The computed result. """ return _make.softmax(data, axis) @@ -181,7 +181,7 @@ def log_softmax(data, axis): Parameters ---------- - data: relay.Expr + data: tvm.relay.Expr The input data to the operator. axis: int @@ -189,7 +189,7 @@ def log_softmax(data, axis): Returns ------- - result : relay.Expr + result : tvm.relay.Expr The computed result. """ return _make.log_softmax(data, axis) @@ -224,7 +224,7 @@ def max_pool2d(data, Parameters ---------- - data : relay.Expr + data : tvm.relay.Expr The input data to the operator. strides : tuple of int, optional @@ -241,7 +241,7 @@ def max_pool2d(data, Returns ------- - result : relay.Expr + result : tvm.relay.Expr The computed result. """ return _make.max_pool2d(data, pool_size, strides, padding, @@ -278,7 +278,7 @@ def avg_pool2d(data, Parameters ---------- - data : relay.Expr + data : tvm.relay.Expr The input data to the operator. strides : tuple of int, optional @@ -298,7 +298,7 @@ def avg_pool2d(data, Returns ------- - result : relay.Expr + result : tvm.relay.Expr The computed result. """ return _make.avg_pool2d(data, pool_size, strides, padding, @@ -325,7 +325,7 @@ def global_max_pool2d(data, Parameters ---------- - data : relay.Expr + data : tvm.relay.Expr The input data to the operator. layout : str, optional @@ -333,7 +333,7 @@ def global_max_pool2d(data, Returns ------- - result : relay.Expr + result : tvm.relay.Expr The computed result. """ return _make.global_max_pool2d(data, layout) @@ -359,7 +359,7 @@ def global_avg_pool2d(data, Parameters ---------- - data : relay.Expr + data : tvm.relay.Expr The input data to the operator. layout : str, optional @@ -367,7 +367,7 @@ def global_avg_pool2d(data, Returns ------- - result : relay.Expr + result : tvm.relay.Expr The computed result. """ return _make.global_avg_pool2d(data, layout) @@ -389,10 +389,10 @@ def upsampling(data, Parameters ---------- - data : relay.Expr + data : tvm.relay.Expr The input data to the operator. - scale : relay.Expr + scale : tvm.relay.Expr The scale factor for upsampling. layout : str, optional @@ -403,11 +403,12 @@ def upsampling(data, Returns ------- - result : relay.Expr + result : tvm.relay.Expr The computed result. """ return _make.upsampling(data, scale, layout, method) + def batch_flatten(data): """BatchFlatten. @@ -420,17 +421,43 @@ def batch_flatten(data): Parameters ---------- - data : relay.Expr + data : tvm.relay.Expr The input data to the operator. Returns ------- - result: relay.Expr + result : tvm.relay.Expr The Flattened result. """ return _make.batch_flatten(data) +def bias_add(data, bias, axis=1): + """add_bias operator. + + Add 1D bias to the axis of data. + This function is a special case of add which allows + inference of shape of the bias from data. + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + bias : tvm.relay.Expr + The bias to be added. + + axis : int, optional + The axis to add the bias. + + Returns + ------- + result : tvm.relay.Expr + The final result. + """ + return _make.bias_add(data, bias, axis) + + def dense(data, weight, units=None): """Dense operator. Applies a linear transformation @@ -441,10 +468,10 @@ def dense(data, weight, units=None): Parameters ---------- - data : relay.Expr + data : tvm.relay.Expr The input data to the operator. - weight : relay.Expr + weight : tvm.relay.Expr The weight expressions. units : int, optional @@ -452,7 +479,7 @@ def dense(data, weight, units=None): Returns ------- - result : relay.Expr + result : tvm.relay.Expr The computed result. """ return _make.dense(data, weight, units) @@ -466,12 +493,12 @@ def relu(data): Parameters ---------- - data : relay.Expr + data : tvm.relay.Expr The input data Returns ------- - result : relay.Expr + result : tvm.relay.Expr The computed result. """ return _make.relu(data) @@ -487,7 +514,7 @@ def leaky_relu(data, alpha): Parameters ---------- - data : relay.Expr + data : tvm.relay.Expr The input data to the operator. alpha : float @@ -495,7 +522,7 @@ def leaky_relu(data, alpha): Returns ------- - result : relay.Expr + result : tvm.relay.Expr The computed result. """ return _make.leaky_relu(data, alpha) @@ -511,7 +538,7 @@ def pad(data, Parameters ---------- - data: relay.Expr + data: tvm.relay.Expr The input data to the operator pad_width: tuple of <tuple of <int>>, required Number of values padded to the edges of each axis, in the format @@ -521,7 +548,7 @@ def pad(data, Returns ------- - result : relay.Expr + result : tvm.relay.Expr The computed result. """ return _make.pad(data, pad_width, pad_value) @@ -540,7 +567,7 @@ def lrn(data, size=5, axis=1, bias=2, alpha=.00001, beta=0.75): Parameters ---------- - data : relay.Expr + data : tvm.relay.Expr The input data to the operator. size : int, optional @@ -560,7 +587,7 @@ def lrn(data, size=5, axis=1, bias=2, alpha=.00001, beta=0.75): Returns ------- - result : relay.Expr + result : tvm.relay.Expr The computed result. """ return _make.lrn(data, size, axis, alpha, beta, bias) @@ -574,7 +601,7 @@ def l2_normalize(data, eps, axis=None): Parameters ---------- - data : relay.Expr + data : tvm.relay.Expr The input data to the operator. eps : float @@ -585,11 +612,12 @@ def l2_normalize(data, eps, axis=None): Returns ------- - result : relay.Expr + result : tvm.relay.Expr The computed result. """ return _make.l2_normalize(data, eps, axis) + def dropout(data, rate=0.5): """Applies the dropout operation to the input array. @@ -599,7 +627,7 @@ def dropout(data, rate=0.5): Parameters ---------- - data : relay.Expr + data : tvm.relay.Expr The input data to the operator. rate : float, optional (default=0.5) @@ -607,17 +635,22 @@ def dropout(data, rate=0.5): Returns ------- - result : relay.Tuple([relay.Expr, relay.Expr]) - The first member of the tuple is the result of dropping elements from ``data`` - and rescaling. The second member is a "mask" tensor, which is of the same - shape and data type as ``data`` and, for each element in ``data``, is 1.0 - if the element was not dropped and 0.0 if it was. + result : tvm.relay.Expr + The result of dropout """ result = _make.dropout(data, rate) - return TupleWrapper(result, 2) - -def batch_norm(data, gamma, beta, moving_mean, moving_var, - axis=1, epsilon=1e-5, center=True, scale=True): + return TupleWrapper(result, 2)[0] + + +def batch_norm(data, + gamma, + beta, + moving_mean, + moving_var, + axis=1, + epsilon=1e-5, + center=True, + scale=True): r""" Batch normalization layer (Ioffe and Szegedy, 2014). Normalizes the input at each batch, i.e. applies a transformation @@ -658,34 +691,50 @@ def batch_norm(data, gamma, beta, moving_mean, moving_var, Parameters ---------- - data : relay.Expr + data : tvm.relay.Expr Input to which batch_norm will be applied. - gamma : relay.Expr + + gamma : tvm.relay.Expr The gamma scale factor. - beta : relay.Expr + + beta : tvm.relay.Expr The beta offset factor. - moving_mean : relay.Expr + + moving_mean : tvm.relay.Expr Running mean of input, - moving_var : relay.Expr + + moving_var : tvm.relay.Expr Running variance of input. + axis : int, optional, default=1 Specify along which shape axis the channel is specified. + epsilon : double, optional, default=1e-5 Small float added to variance to avoid diving by zero. + center : boolean, optional, default=True If True, add offset of beta to normalized tensor, If False, beta is ignored. + scale : boolean, optional, default=True If true, multiply by gamma. If False, gamma is not used. When the next layer is piecewise linear (also e.g. nn.relu), - this can be disabled since the scalingwill be done by the next layer. + this can be disabled since the scaling will be done by the next layer. Returns ------- - result : relay.Tuple([relay.Expr, relay.Expr, relay.Expr]) - Tuple of normed data (same shape as input), new running mean (k-length vector), + result : relay.Tuple([tvm.relay.Expr, tvm.relay.Expr, tvm.relay.Expr]) + Tuple of normed data (same shape as input), + new running mean (k-length vector), and new running variance (k-length vector) """ - result = _make.batch_norm(data, gamma, beta, moving_mean, moving_var, - axis, epsilon, center, scale) + result = _make.batch_norm(data, + gamma, + beta, + moving_mean, + moving_var, + axis, + epsilon, + center, + scale) return TupleWrapper(result, 3) diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9bfda3c7abc78424f26653e18b6af856d10da845 --- /dev/null +++ b/python/tvm/relay/testing/__init__.py @@ -0,0 +1,5 @@ +"""Utilities for testing and benchmarks""" +from __future__ import absolute_import as _abs + +from . import mlp +from . import resnet diff --git a/python/tvm/relay/testing/init.py b/python/tvm/relay/testing/init.py new file mode 100644 index 0000000000000000000000000000000000000000..fdbde9d289d6650c9ba72d8777866e7d0769c025 --- /dev/null +++ b/python/tvm/relay/testing/init.py @@ -0,0 +1,149 @@ +"""Initializer of parameters.""" +import tvm +from tvm import relay +import numpy as np + +class Initializer(object): + """The base class of an initializer.""" + def __init__(self, **kwargs): + self._kwargs = kwargs + + def __call__(self, desc, arr): + """Initialize an array + + Parameters + ---------- + desc : str + Initialization pattern descriptor. + + arr : NDArray + The array to be initialized. + """ + if desc.endswith('weight'): + self._init_weight(desc, arr) + elif desc.endswith('bias'): + self._init_bias(desc, arr) + elif desc.endswith('gamma'): + self._init_gamma(desc, arr) + elif desc.endswith('beta'): + self._init_beta(desc, arr) + elif desc.endswith('mean'): + self._init_mean(desc, arr) + elif desc.endswith('var'): + self._init_var(desc, arr) + else: + self._init_default(desc, arr) + + def _init_bias(self, _, arr): + arr[:] = 0.0 + + def _init_gamma(self, _, arr): + arr[:] = 1.0 + + def _init_beta(self, _, arr): + arr[:] = 0.0 + + def _init_mean(self, _, arr): + arr[:] = 0.0 + + def _init_var(self, _, arr): + arr[:] = 1.0 + + def _init_weight(self, name, arr): + """Abstract method to Initialize weight.""" + raise NotImplementedError("Must override it") + + def _init_default(self, name, _): + raise ValueError( + 'Unknown initialization pattern for %s. ' \ + 'Default initialization is now limited to '\ + '"weight", "bias", "gamma" (1.0), and "beta" (0.0).' \ + 'Please use mx.sym.Variable(init=mx.init.*) to set initialization pattern' % name) + + +class Xavier(Initializer): + """ "Xavier" initialization for weights + + Parameters + ---------- + rnd_type: str, optional + Random generator type, can be ``'gaussian'`` or ``'uniform'``. + + factor_type: str, optional + Can be ``'avg'``, ``'in'``, or ``'out'``. + + magnitude: float, optional + Scale of random number. + """ + def __init__(self, rnd_type="uniform", factor_type="avg", magnitude=3): + super(Xavier, self).__init__(rnd_type=rnd_type, + factor_type=factor_type, + magnitude=magnitude) + self.rnd_type = rnd_type + self.factor_type = factor_type + self.magnitude = float(magnitude) + + def _init_weight(self, name, arr): + shape = arr.shape + hw_scale = 1. + if len(shape) < 2: + raise ValueError('Xavier initializer cannot be applied to vector {0}. It requires at' + ' least 2D.'.format(name)) + if len(shape) > 2: + hw_scale = np.prod(shape[2:]) + fan_in, fan_out = shape[1] * hw_scale, shape[0] * hw_scale + factor = 1. + if self.factor_type == "avg": + factor = (fan_in + fan_out) / 2.0 + elif self.factor_type == "in": + factor = fan_in + elif self.factor_type == "out": + factor = fan_out + else: + raise ValueError("Incorrect factor type") + # Hack for mobilenet, because there is less connectivity + if "depthwise" in name: + factor = 3 * 3 + scale = np.sqrt(self.magnitude / factor) + if self.rnd_type == "uniform": + arr[:] = np.random.uniform(-scale, scale, size=arr.shape) + else: + raise ValueError("Unknown random type") + + +def create_workload(net, initializer=None, seed=0): + """Helper function to create benchmark image classification workload. + + Parameters + ---------- + net : tvm.relay.Function + The selected function of the network. + + initializer : Initializer + The initializer used + + seed : int + The seed used in initialization. + + Returns + ------- + net : tvm.relay.Function + The updated dataflow + + params : dict of str to NDArray + The parameters. + """ + net = relay.ir_pass.infer_type(net) + shape_dict = { + v.name_hint : v.checked_type for v in net.params} + net.astext() + np.random.seed(seed) + initializer = initializer if initializer else Xavier() + params = {} + for k, v in shape_dict.items(): + if k == "data": + continue + init_value = np.zeros(v.concrete_shape).astype(v.dtype) + initializer(k, init_value) + params[k] = tvm.nd.array(init_value, ctx=tvm.cpu(0)) + return net, params diff --git a/python/tvm/relay/testing/layers.py b/python/tvm/relay/testing/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..fc06ca229f771b8b04839d6cb52c04b7de44887a --- /dev/null +++ b/python/tvm/relay/testing/layers.py @@ -0,0 +1,114 @@ +"""Simple Layer DSL wrapper to ease creation of neural nets.""" +from tvm import relay + +def batch_norm_infer(data, + gamma=None, + beta=None, + moving_mean=None, + moving_var=None, + **kwargs): + """Wrapper of batch_norm. + + This function automatically creates weights and return + the first output(normalized result). + + Parameters + ---------- + data : relay.Expr + The input expression. + + gamma : relay.Expr + The gamma scale factor. + + beta : relay.Expr + The beta offset factor. + + moving_mean : relay.Expr + Running mean of input, + + moving_var : relay.Expr + Running variance of input. + + kwargs : dict + Additional arguments. + + Returns + ------- + result : relay.Expr + The result. + """ + name = kwargs.get("name") + kwargs.pop("name") + if not gamma: + gamma = relay.var(name + "_gamma") + if not beta: + beta = relay.var(name + "_beta") + if not moving_mean: + moving_mean = relay.var(name + "_moving_mean") + if not moving_var: + moving_var = relay.var(name + "_moving_var") + return relay.nn.batch_norm(data, + gamma=gamma, + beta=beta, + moving_mean=moving_mean, + moving_var=moving_var, + **kwargs)[0] + + +def conv2d(data, weight=None, **kwargs): + """Wrapper of conv2d which automatically creates weights if not given. + + Parameters + ---------- + data : relay.Expr + The input expression. + + weight : relay.Expr + The weight to conv2d. + + kwargs : dict + Additional arguments. + + Returns + ------- + result : relay.Expr + The result. + """ + name = kwargs.get("name") + kwargs.pop("name") + if not weight: + weight = relay.var(name + "_weight") + return relay.nn.conv2d(data, weight, **kwargs) + + +def dense_add_bias(data, weight=None, bias=None, **kwargs): + """Wrapper of dense which automatically creates weights if not given. + + Parameters + ---------- + data : relay.Expr + The input expression. + + weight : relay.Expr + The weight to conv2d. + + bias : relay.Expr + The bias. + + kwargs : dict + Additional arguments. + + Returns + ------- + result : relay.Expr + The result. + """ + name = kwargs.get("name") + kwargs.pop("name") + if not weight: + weight = relay.var(name + "_weight") + if not bias: + bias = relay.var(name + "_bias") + data = relay.nn.dense(data, weight, **kwargs) + data = relay.nn.bias_add(data, bias) + return data diff --git a/python/tvm/relay/testing/mlp.py b/python/tvm/relay/testing/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..67fa0d90c643f702becfce781e717293ed1535f8 --- /dev/null +++ b/python/tvm/relay/testing/mlp.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +a simple multilayer perceptron +""" +from tvm import relay +from .init import create_workload + +def get_net(batch_size, + num_classes=10, + image_shape=(1, 28, 28), + dtype="float32"): + """Get network a simple multilayer perceptron. + + batch_size : int + The batch size used in the model + + num_classes : int, optional + Number of claseses + + image_shape : tuple, optional + The input image shape + + dtype : str, optional + The data type + + Returns + ------- + net : relay.Function + The dataflow. + """ + data_shape = (batch_size,) + image_shape + data = relay.var("data", + shape=data_shape, + dtype=dtype) + data = relay.nn.batch_flatten(data) + fc1 = relay.nn.dense(data, relay.var("fc1_weight"), units=128) + fc1 = relay.nn.bias_add(fc1, relay.var("fc2_bias")) + act1 = relay.nn.relu(fc1) + fc2 = relay.nn.dense(act1, relay.var("fc2_weight"), units=64) + fc2 = relay.nn.bias_add(fc2, relay.var("fc2_bias")) + act2 = relay.nn.relu(fc2) + fc3 = relay.nn.dense(act2, relay.var("fc3_weight"), units=num_classes) + fc3 = relay.nn.bias_add(fc3, relay.var("fc3_bias")) + mlp = relay.nn.softmax(data=fc3) + args = relay.ir_pass.free_vars(mlp) + return relay.Function(args, mlp) + + +def get_workload(batch_size, + num_classes=10, + image_shape=(1, 28, 28), + dtype="float32"): + """Get benchmark workload for a simple multilayer perceptron. + + Parameters + ---------- + batch_size : int + The batch size used in the model + + num_classes : int, optional + Number of claseses + + image_shape : tuple, optional + The input image shape + + dtype : str, optional + The data type + + Returns + ------- + net : relay.Function + The dataflow. + + params : dict of str to NDArray + The parameters. + """ + net = get_net(batch_size, num_classes, image_shape, dtype) + return create_workload(net) diff --git a/python/tvm/relay/testing/resnet.py b/python/tvm/relay/testing/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..cb489f6e2471e3628bc5aa9ca680063f4f3c3c87 --- /dev/null +++ b/python/tvm/relay/testing/resnet.py @@ -0,0 +1,276 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Adapted from https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py +Original author Wei Wu + +Implemented the following paper: + +Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks" +""" +# pylint: disable=unused-argument +from tvm import relay +from .init import create_workload +from . import layers + +def residual_unit(data, + num_filter, + stride, + dim_match, + name, + bottle_neck=True): + """Return ResNet Unit symbol for building ResNet + + Parameters + ---------- + data : str + Input data + + num_filter : int + Number of output channels + + bnf : int + Bottle neck channels factor with regard to num_filter + + stride : tuple + Stride used in convolution + + dim_match : bool + True means channel number between input and output is the same, + otherwise means differ + + name : str + Base name of the operators + """ + if bottle_neck: + bn1 = layers.batch_norm_infer(data=data, + epsilon=2e-5, + name=name + '_bn1') + act1 = relay.relu(data=bn1) + conv1 = layers.conv2d( + data=act1, + channels=int(num_filter*0.25), + kernel_size=(1, 1), + strides=stride, + padding=(0, 0), + name=name + '_conv1') + bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2') + act2 = relay.relu(data=bn2) + conv2 = layers.conv2d( + data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3), + strides=(1, 1), padding=(1, 1), name=name + '_conv2') + bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, name=name + '_bn3') + act3 = relay.relu(data=bn3) + conv3 = layers.conv2d( + data=act3, channels=num_filter, kernel_size=(1, 1), + strides=(1, 1), padding=(0, 0), name=name + '_conv3') + if dim_match: + shortcut = data + else: + shortcut = layers.conv2d( + data=act1, channels=num_filter, kernel_size=(1, 1), + strides=stride, name=name+'_sc') + return relay.add(conv3, shortcut) + else: + bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, name=name + '_bn1') + act1 = relay.nn.relu(data=bn1) + conv1 = layers.conv2d( + data=act1, channels=num_filter, kernel_size=(3, 3), + strides=stride, padding=(1, 1), name=name + '_conv1') + bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2') + act2 = relay.nn.relu(data=bn2) + conv2 = layers.conv2d( + data=act2, channels=num_filter, kernel_size=(3, 3), + strides=(1, 1), padding=(1, 1), name=name + '_conv2') + if dim_match: + shortcut = data + else: + shortcut = layers.conv2d( + data=act1, channels=num_filter, kernel_size=(1, 1), + strides=stride, name=name+'_sc') + return relay.add(conv2, shortcut) + + +def resnet(units, + num_stages, + filter_list, + num_classes, + data_shape, + bottle_neck=True, + dtype="float32"): + """Return ResNet Program. + + Parameters + ---------- + units : list + Number of units in each stage + + num_stages : int + Number of stage + + filter_list : list + Channel size of each stage + + num_classes : int + Ouput size of symbol + + data_shape : tuple of int. + The shape of input data. + + bottle_neck : bool + Whether apply bottleneck transformation. + + dtype : str + The global data type. + """ + num_unit = len(units) + assert num_unit == num_stages + data = relay.var("data", shape=data_shape, dtype=dtype) + data = layers.batch_norm_infer(data=data, epsilon=2e-5, scale=False, name='bn_data') + (_, _, height, _) = data_shape + if height <= 32: # such as cifar10 + body = layers.conv2d( + data=data, channels=filter_list[0], kernel_size=(3, 3), + strides=(1, 1), padding=(1, 1), name="conv0") + else: # often expected to be 224 such as imagenet + body = layers.conv2d( + data=data, channels=filter_list[0], kernel_size=(7, 7), + strides=(2, 2), padding=(3, 3), name="conv0") + body = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn0') + body = relay.nn.relu(data=body) + body = relay.nn.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1)) + + for i in range(num_stages): + body = residual_unit( + body, filter_list[i+1], (1 if i == 0 else 2, 1 if i == 0 else 2), + False, name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck) + for j in range(units[i]-1): + body = residual_unit( + body, filter_list[i+1], (1, 1), True, + name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck) + bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn1') + relu1 = relay.nn.relu(data=bn1) + # Although kernel is not used here when global_pool=True, we should put one + pool1 = relay.nn.global_avg_pool2d(data=relu1) + flat = relay.nn.batch_flatten(data=pool1) + fc1 = layers.dense_add_bias(data=flat, units=num_classes, name='fc1') + net = relay.nn.softmax(data=fc1) + return relay.Function(relay.ir_pass.free_vars(net), net) + + +def get_net(batch_size, + num_classes, + num_layers=50, + image_shape=(3, 224, 224), + dtype="float32", + **kwargs): + """ + Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py + Original author Wei Wu + """ + (_, height, _) = image_shape + data_shape = (batch_size,) + image_shape + if height <= 28: + num_stages = 3 + if (num_layers-2) % 9 == 0 and num_layers >= 164: + per_unit = [(num_layers-2)//9] + filter_list = [16, 64, 128, 256] + bottle_neck = True + elif (num_layers-2) % 6 == 0 and num_layers < 164: + per_unit = [(num_layers-2)//6] + filter_list = [16, 16, 32, 64] + bottle_neck = False + else: + raise ValueError("no experiments done on num_layers {}".format(num_layers)) + units = per_unit * num_stages + else: + if num_layers >= 50: + filter_list = [64, 256, 512, 1024, 2048] + bottle_neck = True + else: + filter_list = [64, 64, 128, 256, 512] + bottle_neck = False + num_stages = 4 + if num_layers == 18: + units = [2, 2, 2, 2] + elif num_layers == 34: + units = [3, 4, 6, 3] + elif num_layers == 50: + units = [3, 4, 6, 3] + elif num_layers == 101: + units = [3, 4, 23, 3] + elif num_layers == 152: + units = [3, 8, 36, 3] + elif num_layers == 200: + units = [3, 24, 36, 3] + elif num_layers == 269: + units = [3, 30, 48, 8] + else: + raise ValueError("no experiments done on num_layers {}".format(num_layers)) + + return resnet(units=units, + num_stages=num_stages, + filter_list=filter_list, + num_classes=num_classes, + data_shape=data_shape, + bottle_neck=bottle_neck, + dtype=dtype) + + +def get_workload(batch_size=1, + num_classes=1000, + num_layers=18, + image_shape=(3, 224, 224), + dtype="float32", + **kwargs): + """Get benchmark workload for resnet + + Parameters + ---------- + batch_size : int + The batch size used in the model + + num_classes : int, optional + Number of classes + + num_layers : int, optional + Number of layers + + image_shape : tuple, optional + The input image shape + + dtype : str, optional + The data type + + kwargs : dict + Extra arguments + + Returns + ------- + net : relay.Function + The computational graph + + params : dict of str to NDArray + The parameters. + """ + net = get_net(batch_size=batch_size, + num_classes=num_classes, + num_layers=num_layers, + image_shape=image_shape, + dtype=dtype, + **kwargs) + return create_workload(net) diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index 0f7e0e82ad4df8cde1d439f931b96982d6c05446..7ea63e6200bf8885e2a466efb891b5e94af45b71 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -47,6 +47,21 @@ class TensorType(Type): self.__init_handle_by_constructor__( _make.TensorType, shape, dtype) + @property + def concrete_shape(self): + """Get shape of the type as concrete tuple of int. + + Returns + ------- + shape : List[int] + The concrete shape of the Type. + + Raises + ------ + TypeError : If the shape is symbolic + """ + return tuple(int(x) for x in self.shape) + class Kind(IntEnum): """The kind of a type parameter, represents a variable shape, diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index a7367c384cb3b42c8180dd771e5c89e948b86490..557daa98e89988c1f4f32a844029e1f6147bc356 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -159,6 +159,13 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) { Type ExprMutator::VisitType(const Type& t) { return t; } +void ExprVisitor::VisitExpr(const Expr& expr) { + if (visited_.count(expr.get())) return; + using TParent = ExprFunctor<void(const Expr&)>; + TParent::VisitExpr(expr); + visited_.insert(expr.get()); +} + void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { if (op->type_annotation.defined()) { this->VisitType(op->type_annotation); @@ -197,8 +204,8 @@ void ExprVisitor::VisitExpr_(const CallNode* op) { } void ExprVisitor::VisitExpr_(const LetNode* op) { - this->VisitExpr(op->var); this->VisitExpr(op->value); + this->VisitExpr(op->var); this->VisitExpr(op->body); } diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 3cbe1e00b9ca84e8434dd67815b911d6dbc85f4b..8056adc9a8b88263a89690d386bd3873eae90070 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -63,7 +63,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO * * \code * - * function(%x: Tensor[(meta.Variable(id=0),), float32]) { + * fn (%x: Tensor[(meta.Variable(id=0),), float32]) { * %x * } * # Meta data section is a json-serialized string @@ -154,7 +154,7 @@ class TextPrinter : } void PrintFunc(const Function& func) { - this->PrintFuncInternal("function", func); + this->PrintFuncInternal("fn ", func); stream_ << "\n"; } @@ -343,7 +343,7 @@ class TextPrinter : TextValue tuple = GetValue(op->tuple); TextValue id = this->AllocTempVar(); this->PrintIndent(); - stream_ << id << " = " << tuple << "[" << op->index << "]"; + stream_ << id << " = " << tuple << "." << op->index << ""; this->PrintEndInst("\n"); return id; } @@ -379,6 +379,17 @@ class TextPrinter : os << "), " << runtime::TVMType2String(Type2TVMType(node->dtype)) << "]"; } + void VisitType_(const TupleTypeNode* node, std::ostream& os) final { // NOLINT(*) + os << "Tuple["; + for (size_t i = 0; i < node->fields.size(); ++i) { + this->PrintType(node->fields[i], os); + if (i + 1 != node->fields.size()) { + os << ", "; + } + } + os << "]"; + } + void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*) // by default always print as meta-data os << meta_.GetMetaNode(GetRef<NodeRef>(node)); diff --git a/src/relay/ir/type_functor.h b/src/relay/ir/type_functor.h index 03bb4db1f59eebbd975706abcd12cf4bd6a46632..f51c8c746eb9192c87f26fc4701f3d869847d619 100644 --- a/src/relay/ir/type_functor.h +++ b/src/relay/ir/type_functor.h @@ -96,40 +96,41 @@ class TypeFunctor<R(const Type& n, Args...)> { * * We recursively visit each type contained inside the visitor. */ -template <typename... Args> -struct TypeVisitor : ::tvm::relay::TypeFunctor<void(const Type& n, Args...)> { - void VisitType_(const TypeVarNode* op, Args... args) override {} +class TypeVisitor : + public ::tvm::relay::TypeFunctor<void(const Type& n)> { + public: + void VisitType_(const TypeVarNode* op) override {} - void VisitType_(const FuncTypeNode* op, Args... args) override { + void VisitType_(const FuncTypeNode* op) override { for (auto type_param : op->type_params) { - this->VisitType(type_param, std::forward<Args>(args)...); + this->VisitType(type_param); } for (auto type_cs : op->type_constraints) { - this->VisitType(type_cs, std::forward<Args>(args)...); + this->VisitType(type_cs); } for (auto arg_type : op->arg_types) { - this->VisitType(arg_type, std::forward<Args>(args)...); + this->VisitType(arg_type); } - this->VisitType(op->ret_type, std::forward<Args>(args)...); + this->VisitType(op->ret_type); } - void VisitType_(const TensorTypeNode* op, Args... args) override {} + void VisitType_(const TensorTypeNode* op) override {} - void VisitType_(const TupleTypeNode* op, Args... args) override { + void VisitType_(const TupleTypeNode* op) override { for (const Type& t : op->fields) { - this->VisitType(t, std::forward<Args>(args)...); + this->VisitType(t); } } - void VisitType_(const TypeRelationNode* op, Args... args) override { + void VisitType_(const TypeRelationNode* op) override { for (const Type& t : op->args) { - this->VisitType(t, std::forward<Args>(args)...); + this->VisitType(t); } } - void VisitType_(const IncompleteTypeNode* op, Args... args) override {} + void VisitType_(const IncompleteTypeNode* op) override {} }; // A functional visitor for rebuilding an AST in place. diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 8a7cffd2cd278d6a3203d0bf691a50a806935046..8459a99cde23639a94deeb9b62dbb56f1bc5aba8 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -15,6 +15,62 @@ namespace tvm { namespace relay { +TVM_REGISTER_NODE_TYPE(BiasAddAttrs); + +bool BiasAddRel(const Array<Type>& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as<TensorTypeNode>(); + if (data == nullptr) return false; + + const BiasAddAttrs* param = attrs.as<BiasAddAttrs>(); + CHECK(param != nullptr); + int axis = param->axis; + if (axis < 0) { + axis = data->shape.size() + axis; + } + CHECK_LE(axis, static_cast<int>(data->shape.size())) + << "axis " << param->axis << " is out of range"; + + // assign output type + reporter->Assign(types[1], TensorTypeNode::make( + {data->shape[axis]}, data->dtype)); + reporter->Assign(types[2], types[0]); + return true; +} + + +// Positional relay function to create dense operator used by frontend FFI. +Expr MakeBiasAdd(Expr data, + Expr bias, + int axis) { + auto attrs = make_node<BiasAddAttrs>(); + attrs->axis = axis; + static const Op& op = Op::Get("nn.bias_add"); + return CallNode::make(op, {data, bias}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op.nn._make.bias_add") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call<Expr, 3>(MakeBiasAdd, args, rv); + }); + + +RELAY_REGISTER_OP("nn.bias_add") +.describe(R"code(Add bias to an axis of the input. + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.BiasAddAttrs") +.set_num_inputs(2) +.add_argument("data", "nD Tensor", "Input data.") +.add_argument("bias", "1D Tensor", "Bias.") +.set_support_level(1) +.add_type_rel("BiasAdd", BiasAddRel); + + TVM_REGISTER_NODE_TYPE(DenseAttrs); @@ -82,7 +138,7 @@ RELAY_REGISTER_OP("nn.dense") .set_num_inputs(2) .add_argument("data", "nD Tensor", "Input data.") .add_argument("weight", "2D Tensor", "Weight matrix.") -.set_support_level(2) +.set_support_level(1) .add_type_rel("Dense", DenseRel); @@ -235,13 +291,23 @@ Example:: .set_support_level(2) .add_type_rel("BatchFlatten", BatchFlattenRel); -RELAY_REGISTER_UNARY_OP("relay.op.nn._make.", "relu") + +// relu +TVM_REGISTER_API("relay.op.nn._make.relu") +.set_body_typed<Expr(Expr)>([](Expr data) { + static const Op& op = Op::Get("nn.relu"); + return CallNode::make(op, {data}, Attrs(), {}); + }); + +RELAY_REGISTER_OP("nn.relu") .describe(R"code(Returns the relu input array, computed element-wise. .. math:: max(x, 0) )code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") .set_support_level(1) .add_type_rel("Identity", IdentityRel); @@ -371,24 +437,6 @@ The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input // batch_norm TVM_REGISTER_NODE_TYPE(BatchNormAttrs); -bool CheckVectorLength(int64_t dim, const DataType& dtype, Type vector, const char* name) { - const auto* candidate = vector.as<TensorTypeNode>(); - CHECK(candidate != nullptr) - << name << " should be a vector but is not a tensor type,"; - CHECK_EQ(dtype, candidate->dtype) - << name << " should be of the same data type as the original but it is not."; - CHECK_EQ(candidate->shape.size(), 1) - << name << " should be a vector but has a shape of " - << candidate->shape.size() << " dimensions instead of 1."; - - const int64_t* length = as_const_int(candidate->shape[0]); - if (length == nullptr) return false; - CHECK(*length == dim) - << name << " should be as long as the channel but has length " - << *length << " instead of " << dim << "."; - return true; -} - bool BatchNormRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, @@ -396,33 +444,19 @@ bool BatchNormRel(const Array<Type>& types, CHECK_EQ(types.size(), 6); const auto* data = types[0].as<TensorTypeNode>(); if (data == nullptr) return false; - if (data->shape.size() == 0) return false; const BatchNormAttrs* param = attrs.as<BatchNormAttrs>(); // axis of -1 means use the last dimension CHECK(param->axis >= -1 && param->axis < (int)data->shape.size()); int axis = (param->axis != -1) ? param->axis : data->shape.size() - 1; - - auto dim = as_const_int(data->shape[axis]); - if (dim == nullptr) return false; + auto axis_size = data->shape[axis]; // if we are using beta and gamma, they need to be of shape (dim,) - if (param->scale && !CheckVectorLength(*dim, data->dtype, types[1], "The gamma scale factor")) { - return false; - } - - if (param->center && !CheckVectorLength(*dim, data->dtype, types[2], "The beta offset factor")) { - return false; - } - - // the two running averages must also be vectors of length dim - if (!CheckVectorLength(*dim, data->dtype, types[3], "The moving mean")) { - return false; - } - if (!CheckVectorLength(*dim, data->dtype, types[4], "The moving variance")) { - return false; - } + reporter->Assign(types[1], TensorTypeNode::make({axis_size}, data->dtype)); + reporter->Assign(types[2], TensorTypeNode::make({axis_size}, data->dtype)); + reporter->Assign(types[3], TensorTypeNode::make({axis_size}, data->dtype)); + reporter->Assign(types[4], TensorTypeNode::make({axis_size}, data->dtype)); // output is a tuple of the normed data (same shape as input), new running mean, // and new running average (the latter two are both vectors of length dim) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index e3c8bcef217ed14206cd6390b2b80e2716dd220d..bab875fd190ec6b9d766801a6c8bccd7fc6cc013 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -13,8 +13,52 @@ namespace tvm { namespace relay { -/* relay.expand_dims */ +// relay.cast +TVM_REGISTER_NODE_TYPE(CastAttrs); +bool CastRel(const Array<Type>& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as<TensorTypeNode>(); + if (data == nullptr) { + CHECK(types[0].as<IncompleteTypeNode>()) + << "cast: expect input type to be TensorType but get " + << types[0]; + return false; + } + const auto* param = attrs.as<CastAttrs>(); + reporter->Assign(types[1], TensorTypeNode::make( + data->shape, param->dtype)); + return true; +} + +Expr MakeCast(Expr data, + DataType dtype) { + auto attrs = make_node<CastAttrs>(); + attrs->dtype = dtype; + static const Op& op = Op::Get("cast"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay._make.dtype_cast") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call<Expr, 2>(MakeCast, args, rv); +}); + +RELAY_REGISTER_OP("cast") +.describe(R"code(Cast the data into a new data type. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.CastAttrs") +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(3) +.add_type_rel("Cast", CastRel); + + +// relay.expand_dims TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); bool ExpandDimsRel(const Array<Type>& types, @@ -25,6 +69,9 @@ bool ExpandDimsRel(const Array<Type>& types, CHECK_EQ(types.size(), 2); const auto* data = types[0].as<TensorTypeNode>(); if (data == nullptr) { + CHECK(types[0].as<IncompleteTypeNode>()) + << "expand_dims: expect input type to be TensorType but get " + << types[0]; return false; } const auto* param = attrs.as<ExpandDimsAttrs>(); @@ -91,6 +138,9 @@ bool ConcatenateRel(const Array<Type>& types, CHECK_EQ(types.size(), 2); const auto* tensor_tuple = types[0].as<TupleTypeNode>(); if (tensor_tuple == nullptr) { + CHECK(types[0].as<TupleTypeNode>()) + << "cast: expect input type to be TupleType but get " + << types[0]; return false; } const auto* param = attrs.as<ConcatenateAttrs>(); @@ -161,6 +211,9 @@ bool TransposeRel(const Array<Type>& types, CHECK_EQ(types.size(), 2); const auto* data = types[0].as<TensorTypeNode>(); if (data == nullptr) { + CHECK(types[0].as<IncompleteTypeNode>()) + << "transpose: expect input type to be TensorType but get " + << types[0]; return false; } const auto* param = attrs.as<TransposeAttrs>(); @@ -243,6 +296,9 @@ bool ReshapeRel(const Array<Type>& types, CHECK_EQ(types.size(), 2); const auto* data = types[0].as<TensorTypeNode>(); if (data == nullptr) { + CHECK(types[0].as<IncompleteTypeNode>()) + << "reshape: expect input type to be TensorType but get " + << types[0]; return false; } const auto* param = attrs.as<ReshapeAttrs>(); diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index c3d16c2976bf5ab7901a674440f258da76c9c16a..81e72c6d7df8d33acbc2dc8bf85e86dd882a8ac0 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -22,7 +22,7 @@ namespace relay { using namespace tvm::runtime; using Kind = TypeVarNode::Kind; -struct KindChecker : TypeVisitor<> { +struct KindChecker : TypeVisitor { bool valid; KindChecker() : valid(true) {} diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 87fdb1c0ffbae9e36d8abcd52da21b06199414b8..7c8eeef92c5d59c5ebf03e48a9e953900721bbdc 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -471,6 +471,5 @@ TVM_REGISTER_API("relay._ir_pass.infer_type") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = InferType(args[0], args[1]); }); - } // namespace relay } // namespace tvm diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index d69f1bce70d4404ecc1b91d7e0a5365f98ae0e0d..c1f00c7b65e020b35f945be6c820398b2acd0155 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -12,107 +12,120 @@ namespace tvm { namespace relay { -class FreeVar; -class FreeTypeVar : private TypeVisitor<> { - std::unordered_set<TypeVar, NodeHash, NodeEqual>* free_vars; - std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars; - FreeTypeVar(std::unordered_set<TypeVar, NodeHash, NodeEqual>* free_vars, - std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars) : - free_vars(free_vars), bound_vars(bound_vars) { } +// FreeTypeVar + +class FreeTypeVarTVisitor : public TypeVisitor { + public: + FreeTypeVarTVisitor( + Array<TypeVar>* free_vars, + std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars) + : free_vars_(free_vars), bound_vars_(bound_vars) { } void VisitType_(const TypeVarNode* tp) final { - auto var = GetRef<TypeVar>(tp); - if (bound_vars->count(var) == 0) { - free_vars->insert(var); + TypeVar var = GetRef<TypeVar>(tp); + if (bound_vars_->count(var) == 0) { + free_vars_->push_back(var); } } void VisitType_(const FuncTypeNode* f) final { for (auto type_param : f->type_params) { - bound_vars->insert(type_param); + bound_vars_->insert(type_param); } + TypeVisitor::VisitType_(f); + } - for (auto type_cs : f->type_constraints) { - this->VisitType(type_cs); - } + private: + Array<TypeVar>* free_vars_; + std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars_; +}; - for (auto arg_type : f->arg_types) { - this->VisitType(arg_type); - } - this->VisitType(f->ret_type); +class FreeTypeVarEVisitor : private ExprVisitor { + public: + Array<TypeVar> Find(const Expr& expr) { + this->VisitExpr(expr); + return free_vars_; } - friend FreeVar; -}; -class FreeVar : public ExprVisitor { - void VisitExpr_(const VarNode* v) final { - auto var = GetRef<Var>(v); - if (bound_vars.count(var) == 0) { - free_vars.insert(var); - } - if (v->type_annotation.defined()) { - VisitType(v->type_annotation); - } + Array<TypeVar> Find(const Type& type) { + this->VisitType(type); + return free_vars_; } void VisitExpr_(const FunctionNode* f) final { for (const auto& tp : f->type_params) { - bound_types.insert(tp); - } - for (const auto& param : f->params) { - bound_vars.insert(param); + bound_vars_.insert(tp); } - VisitExpr(f->body); - VisitType(f->ret_type); + ExprVisitor::VisitExpr_(f); } - void VisitExpr_(const LetNode* l) final { - bound_vars.insert(l->var); - VisitExpr(l->value); - VisitExpr(l->body); + void VisitType(const Type& t) final { + FreeTypeVarTVisitor(&free_vars_, &bound_vars_) + .VisitType(t); } + private: + // The result list + Array<TypeVar> free_vars_; + std::unordered_set<TypeVar, NodeHash, NodeEqual> bound_vars_; +}; + +class FreeVarVisitor : protected ExprVisitor { public: - std::unordered_set<Var, NodeHash, NodeEqual> free_vars; - std::unordered_set<Var, NodeHash, NodeEqual> bound_vars; - std::unordered_set<TypeVar, NodeHash, NodeEqual> free_types; - std::unordered_set<TypeVar, NodeHash, NodeEqual> bound_types; + Array<Var> Find(const Expr& expr) { + this->VisitExpr(expr); + return free_vars_; + } - void VisitType(const Type& t) final { - FreeTypeVar(&free_types, &bound_types)(t); + void VisitExpr_(const VarNode* var) final { + if (bound_vars_.count(var) == 0) { + free_vars_.push_back(GetRef<Var>(var)); + } } + + void VisitExpr_(const FunctionNode* op) final { + for (const auto& param : op->params) { + bound_vars_.insert(param.operator->()); + } + VisitExpr(op->body); + } + + void VisitExpr_(const LetNode* op) final { + bound_vars_.insert(op->var.operator->()); + VisitExpr(op->value); + VisitExpr(op->body); + } + + private: + // The result list + Array<Var> free_vars_; + std::unordered_set<const VarNode*> bound_vars_; }; -tvm::Array<Var> FreeVariables(const Expr& e) { - FreeVar fv; - fv.VisitExpr(e); - return tvm::Array<Var>(fv.free_vars.begin(), fv.free_vars.end()); +tvm::Array<TypeVar> FreeTypeVars(const Expr& expr) { + return FreeTypeVarEVisitor().Find(expr); } -tvm::Array<TypeVar> FreeTypeVariables(const Expr& e) { - FreeVar fv; - fv.VisitExpr(e); - return tvm::Array<TypeVar>(fv.free_types.begin(), fv.free_types.end()); +tvm::Array<TypeVar> FreeTypeVars(const Type& type) { + return FreeTypeVarEVisitor().Find(type); } -tvm::Array<TypeVar> FreeTypeVariables(const Type& t) { - FreeVar fv; - fv.VisitType(t); - return tvm::Array<TypeVar>(fv.free_types.begin(), fv.free_types.end()); +tvm::Array<Var> FreeVars(const Expr& expr) { + return FreeVarVisitor().Find(expr); } TVM_REGISTER_API("relay._ir_pass.free_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = FreeVariables(args[0]); + *ret = FreeVars(args[0]); }); TVM_REGISTER_API("relay._ir_pass.free_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; if (x.as<TypeNode>()) { - *ret = FreeTypeVariables(Downcast<Type>(x)); + *ret = FreeTypeVars(Downcast<Type>(x)); } else { - *ret = FreeTypeVariables(Downcast<Expr>(x)); + *ret = FreeTypeVars(Downcast<Expr>(x)); } }); diff --git a/src/relay/pass/well_formed.cc b/src/relay/pass/well_formed.cc index a37969f9e317579295de569015eed2e48da2c1cf..d9c6b617ca5f4ce303e76a18b0fd8659ef163c8a 100644 --- a/src/relay/pass/well_formed.cc +++ b/src/relay/pass/well_formed.cc @@ -10,7 +10,6 @@ namespace tvm { namespace relay { -struct NotWellFormed { }; //! brief make sure each Var is bind at most once. class WellFormedChecker : private ExprVisitor { diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 29814ecc5eb7729b80fe4261db1339000a0e74a0..69ba4797a1c76b1a8ca5fa6206f4bfe633a49fc6 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -1,7 +1,9 @@ import tvm +import tvm.relay.testing import numpy as np from tvm import relay + do_print = [False] def show(text): @@ -94,9 +96,18 @@ def test_variable_name(): v1 = relay.var("1") assert "%v1" in v1.astext() +def test_mlp(): + net, params = tvm.relay.testing.mlp.get_workload(batch_size=1) + net.astext() + +def test_resnet(): + net, params = tvm.relay.testing.resnet.get_workload(batch_size=1) + net.astext() if __name__ == "__main__": do_print[0] = True + test_resnet() + test_mlp() test_func() test_env() test_meta_data() diff --git a/tests/python/relay/test_ir_well_formed.py b/tests/python/relay/test_ir_well_formed.py index 7ccc96d271ac393a72d246b95504195fa947beb4..725b2fbd3c3d0110bde51fa06df633a3f1e8c2b1 100644 --- a/tests/python/relay/test_ir_well_formed.py +++ b/tests/python/relay/test_ir_well_formed.py @@ -12,10 +12,9 @@ def test_well_formed(): assert not well_formed(relay.Let(x, v, let)) f = relay.Function([x], x, ty) assert well_formed(f) - # this test should pass in case of weak uniqueness (only test for shadowing) - # but we want all binder to be distinct from each other. - assert not well_formed(relay.Let(relay.Var("y"), f, - relay.Let(relay.Var("z"), f, v))) + assert well_formed( + relay.Let(relay.Var("y"), f, + relay.Let(relay.Var("z"), f, v))) def test_tuple(): @@ -25,7 +24,7 @@ def test_tuple(): let = relay.Let(x, v, x) assert well_formed(let) assert well_formed(relay.Tuple([v, v])) - assert not well_formed(relay.Tuple([let, let])) + assert not well_formed(relay.Tuple([let, relay.Let(x, v, x)])) def test_tuple_get_item(): diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 5afae6e872d1b3f42925983169bac0ff873f74ed..fd01dbdde01227bcf28360e7308ec724d58bc033 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -42,6 +42,15 @@ def test_binary_op(): check_binary_op(opfunc) +def test_bias_add(): + x = relay.var("x", shape=(10, 2, 3, 4)) + bias = relay.var("bias") + z = relay.nn.bias_add(x, bias) + zz = relay.ir_pass.infer_type(z) + assert "axis=" not in zz.astext() + assert zz.args[1].checked_type == relay.TensorType((2,)) + + def test_expand_dims_infer_type(): n, t, d = tvm.var("n"), tvm.var("t"), 100 x = relay.var("x", shape=(n, t, d)) @@ -91,7 +100,7 @@ def test_dropout(): n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d") input_ty = relay.TensorType((n, t, d), "float32") x = relay.var("x", input_ty) - y, _ = relay.nn.dropout(x, rate=0.75) + y = relay.nn.dropout(x, rate=0.75) assert "rate=" in y.astext() yy = relay.ir_pass.infer_type(y) assert yy.checked_type == input_ty @@ -106,7 +115,7 @@ def test_batch_norm(): moving_var = relay.var("moving_var", relay.TensorType((2,))) y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, center=False, scale=False) - yy = relay.ir_pass.infer_type(y) + yy = relay.ir_pass.infer_type(y.astuple()) assert "center=" in yy.astext() assert yy.checked_type == relay.ty.TupleType(tvm.convert([ relay.TensorType((3, 2, 1), "float32"), @@ -121,7 +130,7 @@ def test_batch_norm(): y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, axis=0, center=False, scale=False) - yy = relay.ir_pass.infer_type(y) + yy = relay.ir_pass.infer_type(y.astuple()) assert yy.checked_type == relay.ty.TupleType(tvm.convert([ relay.ty.TensorType((3, 2, 1), "float32"), relay.ty.TensorType((3,), "float32"), @@ -136,7 +145,7 @@ def test_batch_norm(): moving_var = relay.var("moving_var", relay.TensorType((3,))) y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, axis=-1, center=False, scale=False) - yy = relay.ir_pass.infer_type(y) + yy = relay.ir_pass.infer_type(y.astuple()) assert yy.checked_type == relay.ty.TupleType(tvm.convert([ relay.ty.TensorType((1, 2, 3), "float32"), relay.ty.TensorType((3,), "float32"), @@ -145,6 +154,7 @@ def test_batch_norm(): if __name__ == "__main__": + test_bias_add() test_unary_op() test_binary_op() test_expand_dims_infer_type() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index d1bff29404577a0152f109bc5db0a0970815ef79..8ab3c41c079d72d9f0bff731217fa0a280dd3176 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -27,6 +27,14 @@ def test_unary_identity(): assert yy.checked_type == relay.TensorType((8, 9, 4), "float32") +def test_cast(): + x = relay.var("x", relay.TensorType((8, 9, 4), "float32")) + y = x.astype("int32") + yy = relay.ir_pass.infer_type(y) + assert "dtype=" in yy.astext() + assert yy.checked_type == relay.TensorType((8, 9, 4), "int32") + + def test_clip_type(): a = relay.var("a", relay.TensorType((10, 4), "float32")) y = relay.clip(a, 1., 4.) @@ -139,7 +147,9 @@ def test_infer_type_leaky_relu(): yy = relay.ir_pass.infer_type(y) assert yy.checked_type == relay.TensorType((n, c, h, w), "float32") + if __name__ == "__main__": + test_cast() test_zeros_ones() test_unary_identity() test_clip_type()