Skip to content
Snippets Groups Projects
Commit 7bafca4e authored by Siva's avatar Siva Committed by Tianqi Chen
Browse files

[RELAY][OP] Operators. pool2d, global_pool2d, batch_flatten, tanh, sigmoid,...

[RELAY][OP] Operators.  	 pool2d, global_pool2d, batch_flatten, tanh, sigmoid, floor, ceil, trunc, abs, negative, multiply, mod, pow,  resize (#1813)
parent 90159022
No related branches found
No related tags found
No related merge requests found
Showing
with 1233 additions and 9 deletions
...@@ -30,6 +30,13 @@ This level enables fully connected multi-layer perceptron. ...@@ -30,6 +30,13 @@ This level enables fully connected multi-layer perceptron.
tvm.relay.expand_dims tvm.relay.expand_dims
tvm.relay.concatenate tvm.relay.concatenate
tvm.relay.nn.softmax tvm.relay.nn.softmax
tvm.relay.subtract
tvm.relay.multiply
tvm.relay.divide
tvm.relay.mod
tvm.relay.tanh
tvm.relay.sigmoid
**Level 2: Convolutions** **Level 2: Convolutions**
...@@ -39,10 +46,18 @@ This level enables typical convnet models. ...@@ -39,10 +46,18 @@ This level enables typical convnet models.
:nosignatures: :nosignatures:
tvm.relay.nn.conv2d tvm.relay.nn.conv2d
tvm.relay.nn.max_pool2d
tvm.relay.nn.avg_pool2d
tvm.relay.nn.global_max_pool2d
tvm.relay.nn.global_avg_pool2d
tvm.relay.nn.upsampling
tvm.relay.nn.batch_flatten
**Level 3: Additional Math And Transform Operators** **Level 3: Additional Math And Transform Operators**
This level enables additional math and transform operators.
.. autosummary:: .. autosummary::
:nosignatures: :nosignatures:
...@@ -51,6 +66,13 @@ This level enables typical convnet models. ...@@ -51,6 +66,13 @@ This level enables typical convnet models.
tvm.relay.reshape tvm.relay.reshape
tvm.relay.copy tvm.relay.copy
tvm.relay.transpose tvm.relay.transpose
tvm.relay.floor
tvm.relay.ceil
tvm.relay.trunc
tvm.relay.round
tvm.relay.abs
tvm.relay.negative
**Level 4: Broadcast and Reductions** **Level 4: Broadcast and Reductions**
...@@ -67,9 +89,15 @@ This level enables typical convnet models. ...@@ -67,9 +89,15 @@ This level enables typical convnet models.
tvm.relay.less_equal tvm.relay.less_equal
tvm.relay.maximum tvm.relay.maximum
tvm.relay.minimum tvm.relay.minimum
tvm.relay.pow
**Level 5: Vision/Image Operators** **Level 5: Vision/Image Operators**
.. autosummary::
:nosignatures:
tvm.relay.image.resize
Level 1 Definitions Level 1 Definitions
------------------- -------------------
...@@ -78,12 +106,38 @@ Level 1 Definitions ...@@ -78,12 +106,38 @@ Level 1 Definitions
.. autofunction:: tvm.relay.exp .. autofunction:: tvm.relay.exp
.. autofunction:: tvm.relay.sigmoid .. autofunction:: tvm.relay.sigmoid
.. autofunction:: tvm.relay.add .. autofunction:: tvm.relay.add
.. autofunction:: tvm.relay.subtract
.. autofunction:: tvm.relay.multiply
.. autofunction:: tvm.relay.divide
.. autofunction:: tvm.relay.mod
.. autofunction:: tvm.relay.tanh
.. autofunction:: tvm.relay.sigmoid
.. autofunction:: tvm.relay.concatenate
.. autofunction:: tvm.relay.nn.softmax
Level 2 Definitions Level 2 Definitions
------------------- -------------------
.. autofunction:: tvm.relay.nn.conv2d .. autofunction:: tvm.relay.nn.conv2d
.. autofunction:: tvm.relay.nn.max_pool2d
.. autofunction:: tvm.relay.nn.avg_pool2d
.. autofunction:: tvm.relay.nn.global_max_pool2d
.. autofunction:: tvm.relay.nn.global_avg_pool2d
.. autofunction:: tvm.relay.nn.upsampling
.. autofunction:: tvm.relay.nn.batch_flatten
Level 3 Definitions
-------------------
.. autofunction:: tvm.relay.floor
.. autofunction:: tvm.relay.ceil
.. autofunction:: tvm.relay.trunc
.. autofunction:: tvm.relay.round
.. autofunction:: tvm.relay.abs
.. autofunction:: tvm.relay.negative
.. autofunction:: tvm.relay.reshape
.. autofunction:: tvm.relay.copy
.. autofunction:: tvm.relay.transpose
Level 4 Definitions Level 4 Definitions
------------------- -------------------
...@@ -97,3 +151,8 @@ Level 4 Definitions ...@@ -97,3 +151,8 @@ Level 4 Definitions
.. autofunction:: tvm.relay.less_equal .. autofunction:: tvm.relay.less_equal
.. autofunction:: tvm.relay.maximum .. autofunction:: tvm.relay.maximum
.. autofunction:: tvm.relay.minimum .. autofunction:: tvm.relay.minimum
.. autofunction:: tvm.relay.pow
Level 5 Definitions
-------------------
.. autofunction:: tvm.relay.image.resize
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/attrs/image.h
* \brief Auxiliary attributes for image operators.
*/
#ifndef TVM_RELAY_ATTRS_IMAGE_H_
#define TVM_RELAY_ATTRS_IMAGE_H_
#include <tvm/attrs.h>
#include <string>
namespace tvm {
namespace relay {
/*! \brief Attributes used in image resize operator */
struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> {
Array<IndexExpr> size;
std::string layout;
std::string method;
bool align_corners;
TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") {
TVM_ATTR_FIELD(size).set_default(NullValue<Array<IndexExpr> >())
.describe("Output Size.");
TVM_ATTR_FIELD(layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Resize is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(method).set_default("BILINEAR")
.describe("Specify the mode to use for scaling."
"NEAREST_NEIGHBOR - Nearest Neighbor"
"BILINEAR - Bilinear Interpolation");
TVM_ATTR_FIELD(align_corners).set_default(false)
.describe("Should be true to preserve the values at the corner pixels");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_IMAGE_H_
...@@ -77,6 +77,102 @@ struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> { ...@@ -77,6 +77,102 @@ struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
} }
}; };
/*! \brief Attributes for max pool operator */
struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
Array<IndexExpr> pool_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
std::string layout;
bool ceil_mode;
TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relay.attrs.MaxPool2DAttrs") {
TVM_ATTR_FIELD(pool_size)
.describe("Size of the pooling windows.");
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"two int : bottom, right will use same padding as top, left"
"four int : padding width in the order of (top, left, bottom, right)");
TVM_ATTR_FIELD(layout).set_default("NCHW")
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(ceil_mode).set_default(false)
.describe("When true, will use ceil instead of floor to compute the output shape.");
}
};
/*! \brief Attributes for avg pool operator */
struct AvgPool2DAttrs : public tvm::AttrsNode<AvgPool2DAttrs> {
Array<IndexExpr> pool_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
std::string layout;
bool ceil_mode;
bool count_include_pad;
TVM_DECLARE_ATTRS(AvgPool2DAttrs, "relay.attrs.AvgPool2DAttrs") {
TVM_ATTR_FIELD(pool_size)
.describe("Size of the pooling windows.");
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"two int : bottom, right will use same padding as top, left"
"four int : padding width in the order of (top, left, bottom, right)");
TVM_ATTR_FIELD(layout).set_default("NCHW")
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(ceil_mode).set_default(false)
.describe("When true, will use ceil instead of floor to compute the output shape.");
TVM_ATTR_FIELD(count_include_pad).set_default(false)
.describe("When true, will include padding to compute the average");
}
};
/*! \brief Attributes for global pool operator */
struct GlobalPool2DAttrs : public tvm::AttrsNode<GlobalPool2DAttrs> {
std::string layout;
TVM_DECLARE_ATTRS(GlobalPool2DAttrs, "relay.attrs.GlobalPool2DAttrs") {
TVM_ATTR_FIELD(layout).set_default("NCHW")
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
}
};
/*! \brief Attributes for upsampling operator */
struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
int scale;
std::string layout;
std::string method;
TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") {
TVM_ATTR_FIELD(scale)
.describe("Should be true to preserve the values at the corner pixels");
TVM_ATTR_FIELD(layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Upsampling is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(method).set_default("NEAREST_NEIGHBOR")
.describe("Specify the mode to use for scaling."
"NEAREST_NEIGHBOR - Nearest Neighbor"
"BILINEAR - Bilinear Interpolation");
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_NN_H_ #endif // TVM_RELAY_ATTRS_NN_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/attrs/vision.h
* \brief Auxiliary attributes for vision operators.
*/
#ifndef TVM_RELAY_ATTRS_VISION_H_
#define TVM_RELAY_ATTRS_VISION_H_
#include <tvm/attrs.h>
#include <string>
namespace tvm {
namespace relay {
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_VISION_H_
# pylint: disable=wildcard-import # pylint: disable=wildcard-import, redefined-builtin
"""The Relay IR namespace containing the IR definition and compiler.""" """The Relay IR namespace containing the IR definition and compiler."""
from . import base from . import base
from . import ty from . import ty
...@@ -10,8 +10,10 @@ from . import ir_builder ...@@ -10,8 +10,10 @@ from . import ir_builder
# Root operators # Root operators
from .op import Op from .op import Op
from .op.tensor import * from .op.tensor import *
from . import nn
from .op.transform import * from .op.transform import *
from . import nn
from . import vision
from . import image
# Span # Span
Span = base.Span Span = base.Span
......
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import
"""Image nets related operators."""
# Re-export in a specific file name so that autodoc can pick it up
from .op.image import *
#pylint: disable=wildcard-import #pylint: disable=wildcard-import, redefined-builtin
"""Relay core operators.""" """Relay core operators."""
# operator defs # operator defs
from .op import get, register, Op from .op import get, register, Op
# Operators # Operators
from .tensor import * from .tensor import *
from . import nn
from .transform import * from .transform import *
from . import nn
from . import image
from . import vision
# operator registry # operator registry
from . import _tensor from . import _tensor
......
# pylint: disable=wildcard-import
"""Image network related operators."""
from __future__ import absolute_import as _abs
from .image import *
"""Constructor APIs"""
from ...._ffi.function import _init_api
_init_api("relay.op.image._make", __name__)
"""Image operations."""
from __future__ import absolute_import as _abs
from . import _make
def resize(data,
size,
layout="NCHW",
method="BILINEAR",
align_corners=False):
"""Image resize operator.
This operator takes data as input and does 2D scaling to the given scale factor.
In the default case, where the data_layout is `NCHW`
with data of shape (n, c, h, w)
out will have a shape (n, c, size[0], size[1])
method indicates the algorithm to be used while calculating ghe out value
and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR")
Parameters
----------
data : relay.Expr
The input data to the operator.
size: Tuple of Expr
The out size to which the image will be resized.
layout : str, optional
Layout of the input.
method : str, optional
Scale method to used [NEAREST_NEIGHBOR, BILINEAR].
align_corners : int, optional
Should be true to preserve the values at the corner pixels
Returns
-------
result: relay.Expr
The resized result.
"""
return _make.resize(data, size, layout, method, align_corners)
...@@ -106,3 +106,239 @@ def softmax(data, axis): ...@@ -106,3 +106,239 @@ def softmax(data, axis):
""" """
return _make.softmax(data, axis) return _make.softmax(data, axis)
def max_pool2d(data,
pool_size=(1, 1),
strides=(1, 1),
padding=(0, 0),
layout="NCHW",
ceil_mode=False):
r"""2D maximum pooling operator.
This operator takes data as input and does 2D max value calculation
with in pool_size sized window by striding defined by stride
In the default case, where the data_layout is `NCHW`
a data Tensor with shape `(batch_size, in_channels, height, width)`,
to produce an output Tensor with the following rule:
with data of shape (b, c, h, w) and pool_size (kh, kw)
.. math::
\mbox{out}(b, c, y, x) = \max_{m=0, \ldots, kh-1} \max_{n=0, \ldots, kw-1}
\mbox{data}(b, c, \mbox{stride}[0] * y + m, \mbox{stride}[1] * x + n)
Padding is applied to data before the computation.
ceil_mode is used to take ceil or floor while computing out shape.
This operator accepts data layout specification.
Parameters
----------
data : relay.Expr
The input data to the operator.
strides : tuple of int, optional
The strides of pooling.
padding : tuple of int, optional
The padding for pooling.
layout : str, optional
Layout of the input.
ceil_mode : bool, optional
To enable or disable ceil while pooling.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.max_pool2d(data, pool_size, strides, padding,
layout, ceil_mode)
def avg_pool2d(data,
pool_size=(1, 1),
strides=(1, 1),
padding=(0, 0),
layout="NCHW",
ceil_mode=False,
count_include_pad=False):
r"""2D average pooling operator.
This operator takes data as input and does 2D average value calculation
with in pool_size sized window by striding defined by stride
In the default case, where the data_layout is `NCHW`
a data Tensor with shape `(batch_size, in_channels, height, width)`,
to produce an output Tensor with the following rule:
with data of shape (b, c, h, w), pool_size (kh, kw)
.. math::
\mbox{out}(b, c, y, x) = \frac{1}{kh * kw} \sum_{m=0}^{kh-1} \sum_{n=0}^{kw-1}
\mbox{data}(b, c, \mbox{stride}[0] * y + m, \mbox{stride}[1] * x + n)
Padding is applied to data before the computation.
ceil_mode is used to take ceil or floor while computing out shape.
count_include_pad indicates including or excluding padded input values in computation.
This operator accepts data layout specification.
Parameters
----------
data : relay.Expr
The input data to the operator.
strides : tuple of int, optional
The strides of pooling.
padding : tuple of int, optional
The padding for pooling.
layout : str, optional
Layout of the input.
ceil_mode : bool, optional
To enable or disable ceil while pooling.
count_include_pad : bool, optional
To include padding to compute the average.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.avg_pool2d(data, pool_size, strides, padding,
layout, ceil_mode, count_include_pad)
def global_max_pool2d(data,
layout="NCHW"):
r"""2D global maximum pooling operator.
This operator takes data as input and does 2D max value calculation
across each window represented by WxH.
In the default case, where the data_layout is `NCHW`
a data Tensor with shape `(batch_size, in_channels, height, width)`,
to produce an output Tensor with the following rule:
with data of shape (b, c, h, w)
.. math::
\mbox{out}(b, c, 1, 1) = \max_{m=0, \ldots, h} \max_{n=0, \ldots, w}
\mbox{data}(b, c, m, n)
Parameters
----------
data : relay.Expr
The input data to the operator.
layout : str, optional
Layout of the input.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.global_max_pool2d(data, layout)
def global_avg_pool2d(data,
layout="NCHW"):
r"""2D global average pooling operator.
This operator takes data as input and does 2D average value calculation
across each window represented by WxH.
In the default case, where the data_layout is `NCHW`
a data Tensor with shape `(batch_size, in_channels, height, width)`,
to produce an output Tensor with the following rule:
with data of shape (b, c, h, w)
.. math::
\mbox{out}(b, c, 1, 1) = \frac{1}{h * w} \sum_{m=0}^{h-1} \sum_{n=0}^{w-1}
\mbox{data}(b, c, m, n)
Parameters
----------
data : relay.Expr
The input data to the operator.
layout : str, optional
Layout of the input.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.global_avg_pool2d(data, layout)
def upsampling(data,
scale=1,
layout="NCHW",
method="NEAREST_NEIGHBOR"):
"""Upsampling.
This operator takes data as input and does 2D scaling to the given scale factor.
In the default case, where the data_layout is `NCHW`
with data of shape (n, c, h, w)
out will have a shape (n, c, h*scale, w*scale)
method indicates the algorithm to be used while calculating ghe out value
and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR")
Parameters
----------
data : relay.Expr
The input data to the operator.
scale : relay.Expr
The scale factor for upsampling.
layout : str, optional
Layout of the input.
method : str, optional
Scale method to used [NEAREST_NEIGHBOR, BILINEAR].
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.upsampling(data, scale, layout, method)
def batch_flatten(data):
"""BatchFlatten.
This operator flattens all the dimensions except for the batch dimension.
which results a 2D output.
For data with shape ``(d1, d2, ..., dk)``
batch_flatten(data) returns reshaped output of shape ``(d1, d2*...*dk)``.
Parameters
----------
data : relay.Expr
The input data to the operator.
Returns
-------
result: relay.Expr
The Flattened result.
"""
return _make.batch_flatten(data)
"""Basic tensor operations.""" """Basic tensor operations."""
# pylint: disable=redefined-builtin
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import _make from . import _make
from ..expr import Tuple from ..expr import Tuple
...@@ -59,7 +60,6 @@ def sqrt(data): ...@@ -59,7 +60,6 @@ def sqrt(data):
""" """
return _make.sqrt(data) return _make.sqrt(data)
def sigmoid(data): def sigmoid(data):
"""Compute elementwise sigmoid of data. """Compute elementwise sigmoid of data.
...@@ -76,6 +76,118 @@ def sigmoid(data): ...@@ -76,6 +76,118 @@ def sigmoid(data):
return _make.sigmoid(data) return _make.sigmoid(data)
def floor(data):
"""Compute element-wise floor of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.floor(data)
def ceil(data):
"""Compute element-wise ceil of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.ceil(data)
def trunc(data):
"""Compute element-wise trunc of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.trunc(data)
def round(data):
"""Compute element-wise round of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.round(data)
def abs(data):
"""Compute element-wise absolute of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.abs(data)
def tanh(data):
"""Compute element-wise tanh of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.tanh(data)
def negative(data):
"""Compute element-wise negative of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.negative(data)
def add(lhs, rhs): def add(lhs, rhs):
"""Addition with numpy-style broadcasting. """Addition with numpy-style broadcasting.
...@@ -102,8 +214,80 @@ def add(lhs, rhs): ...@@ -102,8 +214,80 @@ def add(lhs, rhs):
return _make.add(lhs, rhs) return _make.add(lhs, rhs)
def multiply(lhs, rhs):
"""Multiplication with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.multiply(lhs, rhs)
def divide(lhs, rhs):
"""Division with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.divide(lhs, rhs)
def pow(lhs, rhs):
"""Power with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.pow(lhs, rhs)
def mod(lhs, rhs):
"""Mod with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.mod(lhs, rhs)
def subtract(lhs, rhs): def subtract(lhs, rhs):
"""Elementwise subtraction with broadcasting. """Subtraction with numpy-style broadcasting.
Parameters Parameters
---------- ----------
......
# pylint: disable=wildcard-import
"""Vision network related operators."""
from __future__ import absolute_import as _abs
"""Constructor APIs"""
from ...._ffi.function import _init_api
_init_api("relay.op.vision._make", __name__)
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import
"""Vision network related operators."""
# Re-export in a specific file name so that autodoc can pick it up
from .op.vision import *
/*!
* Copyright (c) 2018 by Contributors
* \file resize.cc
* \brief Image operators
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/image.h>
#include "../nn/layout.h"
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(ResizeAttrs);
bool ResizeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCHW("NCHW");
const ResizeAttrs* param = attrs.as<ResizeAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->layout);
CHECK(in_layout.convertible(kNCHW))
<< "Resize only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
auto oshape = ConvertLayout(data->shape, in_layout, kNCHW);
oshape[2] = param->size[0];
oshape[3] = param->size[1];
// assign output type
reporter->Assign(types[1],
TensorTypeNode::make(ConvertLayout(oshape, kNCHW, in_layout),
data->dtype));
return true;
}
// Positional relay function to create image operator
// used by frontend FFI.
Expr MakeResize(Expr data,
Array<IndexExpr> size,
std::string layout,
std::string method,
bool align_corners) {
auto attrs = make_node<ResizeAttrs>();
attrs->size = std::move(size);
attrs->layout = std::move(layout);
attrs->method = std::move(method);
attrs->align_corners = align_corners;
static const Op& op = Op::Get("image.resize");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.image._make.resize")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 5>(MakeResize, args, rv);
});
RELAY_REGISTER_OP("image.resize")
.describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation.
- **data**: data is 4D array of shape
(batch_size, channels, in_height, in_width) for NCHW
(batch_size, in_height, in_width, channels) for NHWC
- **out**: Output is 4D array of shape
for layout NCHW
(batch_size, channels, size[0], size[1])
for layout NHWC
(batch_size, size[0], size[1], channels)
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(5)
.add_type_rel("Resize", ResizeRel);
} // namespace relay
} // namespace tvm
...@@ -6,12 +6,14 @@ ...@@ -6,12 +6,14 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/image.h>
#include <vector>
#include "../type_relations.h" #include "../type_relations.h"
#include "layout.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
TVM_REGISTER_API("relay.op.nn._make.softmax") TVM_REGISTER_API("relay.op.nn._make.softmax")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body([](const TVMArgs& args, TVMRetValue* rv) {
auto make_func = [](Expr data, int axis) { auto make_func = [](Expr data, int axis) {
...@@ -39,5 +41,67 @@ RELAY_REGISTER_OP("nn.softmax") ...@@ -39,5 +41,67 @@ RELAY_REGISTER_OP("nn.softmax")
.set_support_level(1) .set_support_level(1)
.add_type_rel("Identity", IdentityRel); .add_type_rel("Identity", IdentityRel);
// BatchFlatten
bool BatchFlattenRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
if (data->shape.size() == 0) return false;
auto target_dim = make_const(Int(32), 1);
for (uint32_t i = 1; i < data->shape.size(); ++i) {
target_dim = target_dim * data->shape[i];
}
std::vector<IndexExpr> oshape({data->shape[0], target_dim});
// assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}
Expr MakeBatchFlatten(Expr data) {
static const Op& op = Op::Get("nn.batch_flatten");
return CallNode::make(op, {data}, Attrs(), {});
}
TVM_REGISTER_API("relay.op.nn._make.batch_flatten")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 1>(MakeBatchFlatten, args, rv);
});
RELAY_REGISTER_OP("nn.batch_flatten")
.describe(R"code(Flattens the input into a 2-D array.
For an input array with shape ``(d1, d2, ..., dk)``, `batch_flatten` operation reshapes
the input array into an output array of shape ``(d1, d2*...*dk)``.
Example::
x = [[
[1,2,3],
[4,5,6],
[7,8,9]
],
[ [1,2,3],
[4,5,6],
[7,8,9]
]],
batch_flatten(x) = [[ 1., 2., 3., 4., 5., 6., 7., 8., 9.],
[ 1., 2., 3., 4., 5., 6., 7., 8., 9.]]
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("BatchFlatten", BatchFlattenRel);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file pooling.cc
* \brief Pooling operators
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <vector>
#include "layout.h"
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs);
template <typename AttrTtype>
bool Pool2DRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data != nullptr);
const auto dshape = data->shape;
CHECK_NE(dshape.size(), 0);
CHECK_GE(dshape.size(), 2U)
<< "Pool2D only support input >= 2-D: input must have height and width";
const auto param = attrs.as<AttrTtype>();
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(layout.contains('H') && layout.contains('W') &&
!layout.contains('h') && !layout.contains('w'))
<< "Invalid layout " << layout
<< ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.indexof('H');
const auto widx = layout.indexof('W');
IndexExpr pad_h, pad_w;
if (param->padding.size() == 1) {
pad_h = param->padding[0] * 2;
pad_w = param->padding[0] * 2;
} else if (param->padding.size() == 2) {
// (top, left)
pad_h = param->padding[0] * 2;
pad_w = param->padding[1] * 2;
} else if (param->padding.size() == 4) {
// (top, left, bottom, right)
pad_h = param->padding[0] + param->padding[2];
pad_w = param->padding[1] + param->padding[3];
} else {
return false;
}
std::vector<IndexExpr> oshape({dshape[0], dshape[1], dshape[2], dshape[3]});
if (param->ceil_mode) {
oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] +
param->strides[0] - 1) / param->strides[0]) + 1;
oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] +
param->strides[1] - 1) / param->strides[1]) + 1;
} else {
oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1;
oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1]) / param->strides[1]) + 1;
}
// assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}
// MaxPool2D
Expr MakeMaxPool2D(Expr data,
Array<IndexExpr> pool_size,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
std::string layout,
bool ceil_mode) {
auto attrs = make_node<MaxPool2DAttrs>();
attrs->pool_size = std::move(pool_size);
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->layout = std::move(layout);
attrs->ceil_mode = ceil_mode;
static const Op& op = Op::Get("nn.max_pool2d");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.max_pool2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 6>(MakeMaxPool2D, args, rv);
});
RELAY_REGISTER_OP("nn.max_pool2d")
.describe(R"code(Max pooling operation for two dimensional data.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
out_height and out_width are calculated as::
out_height = floor((height+padding[0]+padding[2]-pool_size[0])/strides[0])+1
out_width = floor((width+padding[1]+padding[3]-pool_size[1])/strides[1])+1
where padding will be an expanded array based on number of values passed as::
one int : all sides same padding used.
two int : bottom, right use same as top and left.
four int: padding width in the order of (top, left, bottom, right).
When `ceil_mode` is `True`, ceil will be used instead of floor in this
equation.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("MaxPool2D", Pool2DRel<MaxPool2DAttrs>);
// AvgPool2D
Expr MakeAvgPool2D(Expr data,
Array<IndexExpr> pool_size,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
std::string layout,
bool ceil_mode,
bool count_include_pad) {
auto attrs = make_node<AvgPool2DAttrs>();
attrs->pool_size = std::move(pool_size);
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->layout = std::move(layout);
attrs->ceil_mode = ceil_mode;
attrs->count_include_pad = count_include_pad;
static const Op& op = Op::Get("nn.avg_pool2d");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.avg_pool2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 7>(MakeAvgPool2D, args, rv);
});
RELAY_REGISTER_OP("nn.avg_pool2d")
.describe(R"code(
Average pooling operation for one dimensional data.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
out_height and out_width are calculated as::
out_height = floor((height+padding[0]+padding[2]-pool_size[0])/strides[0])+1
out_width = floor((width+padding[1]+padding[3]-pool_size[1])/strides[1])+1
where padding will be an expanded array based on number of values passed as::
one int : all sides same padding used.
two int : bottom, right use same as top and left.
four int: padding width in the order of (top, left, bottom, right).
When `ceil_mode` is `True`, ceil will be used instead of floor in this
equation.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("AvgPool2D", Pool2DRel<AvgPool2DAttrs>);
// Global Pool
TVM_REGISTER_NODE_TYPE(GlobalPool2DAttrs);
bool GlobalPool2DRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data != nullptr);
const auto dshape = data->shape;
CHECK_NE(dshape.size(), 0);
CHECK_GE(dshape.size(), 2U)
<< "Pool2D only support input >= 2-D: input must have height and width";
const auto param = attrs.as<GlobalPool2DAttrs>();
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(layout.contains('H') && layout.contains('W') &&
!layout.contains('h') && !layout.contains('w'))
<< "Invalid layout " << layout
<< ". Pool2D layout must have H and W, which cannot be split";
const auto hidx = layout.indexof('H');
const auto widx = layout.indexof('W');
std::vector<IndexExpr> oshape({dshape[0], dshape[1], dshape[2], dshape[3]});
oshape[hidx] = oshape[widx] = 1;
// assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}
Expr MakeGlobalAvgPool2D(Expr data,
std::string layout) {
auto attrs = make_node<GlobalPool2DAttrs>();
attrs->layout = std::move(layout);
static const Op& op = Op::Get("nn.global_avg_pool2d");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.global_avg_pool2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeGlobalAvgPool2D, args, rv);
});
// GlobalAvgPool
RELAY_REGISTER_OP("nn.global_avg_pool2d")
.describe(R"code(Global average pooling operation for 2D data.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(batch_size, channels, 1, 1) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("GlobalAvgPool2D", GlobalPool2DRel);
// GlobalMaxPool
Expr MakeGlobalMaxPool2D(Expr data,
std::string layout) {
auto attrs = make_node<GlobalPool2DAttrs>();
attrs->layout = std::move(layout);
static const Op& op = Op::Get("nn.global_max_pool2d");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.global_max_pool2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeGlobalMaxPool2D, args, rv);
});
RELAY_REGISTER_OP("nn.global_max_pool2d")
.describe(R"code(Global max pooling operation for 2D data.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(batch_size, channels, 1, 1) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("GlobalMaxPool2D", GlobalPool2DRel);
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file upsampling.cc
* \brief upsampling operator
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include "layout.h"
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(UpSamplingAttrs);
bool UpSamplingRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCHW("NCHW");
const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->layout);
CHECK(in_layout.convertible(kNCHW))
<< "UpSampling only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
auto oshape = ConvertLayout(data->shape, in_layout, kNCHW);
oshape[2] = oshape[2] * param->scale;
oshape[3] = oshape[3] * param->scale;
// assign output type
reporter->Assign(types[1],
TensorTypeNode::make(ConvertLayout(oshape, kNCHW, in_layout),
data->dtype));
return true;
}
// Positional relay function to create upsampling operator
// used by frontend FFI.
Expr MakeUpSampling(Expr data,
int scale,
std::string layout,
std::string method) {
auto attrs = make_node<UpSamplingAttrs>();
attrs->layout = std::move(layout);
attrs->method = std::move(method);
attrs->scale = scale;
static const Op& op = Op::Get("nn.upsampling");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.upsampling")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeUpSampling, args, rv);
});
RELAY_REGISTER_OP("nn.upsampling")
.describe(R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation.
- **data**: data is 4D array of shape
(batch_size, channels, in_height, in_width) for NCHW
(batch_size, in_height, in_width, channels) for NHWC
- **out**: Output is 4D array of shape
for layout NCHW
(batch_size, channels, in_height*scale, in_width*scale)
for layout NHWC
(batch_size, in_height*scale, in_width*scale, channels)
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("UpSampling", UpSamplingRel);
} // namespace relay
} // namespace tvm
...@@ -22,7 +22,6 @@ namespace relay { ...@@ -22,7 +22,6 @@ namespace relay {
.add_argument("rhs", "Tensor", "The right hand side tensor.") \ .add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("Broadcast", BroadcastRel) .add_type_rel("Broadcast", BroadcastRel)
// Addition
RELAY_REGISTER_BINARY_OP("add") RELAY_REGISTER_BINARY_OP("add")
.describe("Elementwise add with with broadcasting") .describe("Elementwise add with with broadcasting")
.set_support_level(1); .set_support_level(1);
...@@ -49,6 +48,22 @@ RELAY_REGISTER_BINARY_OP("minimum") ...@@ -49,6 +48,22 @@ RELAY_REGISTER_BINARY_OP("minimum")
.describe("Elementwise minimum of two tensors with broadcasting") .describe("Elementwise minimum of two tensors with broadcasting")
.set_support_level(4); .set_support_level(4);
RELAY_REGISTER_BINARY_OP("divide")
.describe("Elementwise divide with broadcasting")
.set_support_level(1);
RELAY_REGISTER_BINARY_OP("multiply")
.describe("Elementwise multiply with broadcasting")
.set_support_level(1);
RELAY_REGISTER_BINARY_OP("pow")
.describe("Elementwise power with broadcasting")
.set_support_level(4);
RELAY_REGISTER_BINARY_OP("mod")
.describe("Elementwise mod with broadcasting")
.set_support_level(1);
// Comparisons // Comparisons
#define RELAY_REGISTER_CMP_OP(OpName) \ #define RELAY_REGISTER_CMP_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \ TVM_REGISTER_API("relay.op._make." OpName) \
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment