From 020b6398c4c69c4a33c24f2826c9df52c102513d Mon Sep 17 00:00:00 2001
From: Siva <sivar.b@huawei.com>
Date: Thu, 11 Oct 2018 10:26:02 +0530
Subject: [PATCH] [RELAY][OP] conv2d_transpose (#1862)

---
 docs/langref/relay_op.rst            |   2 +
 include/tvm/relay/attrs/nn.h         |  51 +++++++++
 python/tvm/relay/op/nn/nn.py         |  68 +++++++++++-
 src/relay/op/nn/convolution.cc       | 148 +++++++++++++++++++++++++++
 tests/python/relay/test_op_level2.py |  37 +++++++
 5 files changed, 304 insertions(+), 2 deletions(-)

diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst
index 97f2d4cb9..fe5356557 100644
--- a/docs/langref/relay_op.rst
+++ b/docs/langref/relay_op.rst
@@ -48,6 +48,7 @@ This level enables typical convnet models.
    :nosignatures:
 
    tvm.relay.nn.conv2d
+   tvm.relay.nn.conv2d_transpose
    tvm.relay.nn.max_pool2d
    tvm.relay.nn.avg_pool2d
    tvm.relay.nn.global_max_pool2d
@@ -129,6 +130,7 @@ Level 1 Definitions
 Level 2 Definitions
 -------------------
 .. autofunction:: tvm.relay.nn.conv2d
+.. autofunction:: tvm.relay.nn.conv2d_transpose
 .. autofunction:: tvm.relay.nn.max_pool2d
 .. autofunction:: tvm.relay.nn.avg_pool2d
 .. autofunction:: tvm.relay.nn.global_max_pool2d
diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index ce80407f1..7eb7a8360 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -77,6 +77,57 @@ struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
   }
 };
 
+/*! \brief Attributes used in transposed convolution operator */
+struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
+  IndexExpr channels;
+  Array<IndexExpr> kernel_size;
+  Array<IndexExpr> strides;
+  Array<IndexExpr> padding;
+  Array<IndexExpr> output_padding;
+  Array<IndexExpr> dilation;
+  int groups;
+  std::string data_layout;
+  std::string weight_layout;
+  DataType out_dtype;
+
+  TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relay.attrs.Conv2DTransposeAttrs") {
+    TVM_ATTR_FIELD(channels)
+      .set_default(NullValue<IndexExpr>())
+      .describe("The dimensionality of the output space"
+                "i.e. the number of output channels in the convolution.");
+    TVM_ATTR_FIELD(kernel_size)
+      .describe("The dimensions of the convolution window.")
+      .set_default(NullValue<Array<IndexExpr> >());
+    TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
+      .describe("The strides of the convolution.");
+    TVM_ATTR_FIELD(output_padding).set_default(Array<IndexExpr>({0, 0}))
+      .describe("Zero-padding added to one side of the output.");
+    TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
+      .describe("If padding is non-zero, then the input is implicitly zero-padded"
+                "on both sides for padding number of points");
+    TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
+      .describe("Specifies the dilation rate to use for dilated convolution.");
+    TVM_ATTR_FIELD(groups).set_default(1)
+      .describe("Controls the connections between inputs and outputs."
+                "At groups=1, all inputs are convolved to all outputs."
+                "At groups=2, the operation becomes equivalent to having two convolution"
+                "layers side by side, each seeing half the input channels, and producing"
+                "half the output channels, and both subsequently concatenated.");
+    TVM_ATTR_FIELD(data_layout).set_default("NCHW")
+      .describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc."
+                "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+                "dimensions respectively. Convolution is applied on the 'H' and"
+                "'W' dimensions.");
+    TVM_ATTR_FIELD(weight_layout).set_default("OIHW")
+      .describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc."
+                "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
+                "dimensions respectively.");
+    TVM_ATTR_FIELD(out_dtype)
+        .set_default(Int(0))
+        .describe("Output data type, set to explicit type under mixed precision setting");
+  }
+};
+
 /*! \brief Attributes for max pool operator */
 struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
   Array<IndexExpr> pool_size;
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index 7985d57c9..52414df8e 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -88,6 +88,62 @@ def conv2d(data,
                         weight_layout, out_layout, out_dtype)
 
 
+def conv2d_transpose(data,
+                     weight,
+                     strides=(1, 1),
+                     padding=(0, 0),
+                     dilation=(1, 1),
+                     groups=1,
+                     channels=None,
+                     kernel_size=None,
+                     data_layout="NCHW",
+                     weight_layout="OIHW",
+                     output_padding=(0, 0),
+                     out_dtype=""):
+    """Two dimensional trnasposed convolution operator.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data to the operator.
+
+    weight : relay.Expr
+        The weight expressions.
+
+    strides : Tuple[int], optional
+        The strides of convoltution.
+
+    padding : Tuple[int], optional
+        The padding of convolution on both sides of inputs.
+
+    dilation : Tuple[int], optional
+        Specifies the dilation rate to be used for dilated convolution.
+
+    groups : int, optional
+        Number of groups for grouped convolution.
+
+    data_layout : str, optional
+        Layout of the input.
+
+    weight_layout : str, optional
+        Layout of the weight.
+
+    output_padding : Tuple[int], optional
+        Additional zero-padding to be added to one side of the output.
+
+    out_dtype : str, optional
+        Specifies the output data type for mixed precision conv2d.
+
+    Returns
+    -------
+    result : relay.Expr
+        The computed result.
+    """
+    return _make.conv2d_transpose(data, weight, strides, padding, dilation,
+                                  groups, channels, kernel_size, data_layout,
+                                  weight_layout, output_padding, out_dtype)
+
+
 def softmax(data, axis):
     r"""Computes softmax.
 
@@ -103,8 +159,12 @@ def softmax(data, axis):
 
     axis: int
         The axis to sum over when computing softmax
-    """
 
+    Returns
+    -------
+    result : relay.Expr
+        The computed result.
+    """
     return _make.softmax(data, axis)
 
 
@@ -125,8 +185,12 @@ def log_softmax(data, axis):
 
     axis: int
         The axis to sum over when computing softmax
-    """
 
+    Returns
+    -------
+    result : relay.Expr
+        The computed result.
+    """
     return _make.log_softmax(data, axis)
 
 
diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc
index ba4241286..4717e3fe0 100644
--- a/src/relay/op/nn/convolution.cc
+++ b/src/relay/op/nn/convolution.cc
@@ -154,5 +154,153 @@ with the layer input to produce a tensor of outputs.
 .set_support_level(2)
 .add_type_rel("Conv2D", Conv2DRel);
 
+
+// Conv2DTranspose
+TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
+
+bool Conv2DTransposeRel(const Array<Type>& types,
+                        int num_inputs,
+                        const Attrs& attrs,
+                        const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  const auto* weight = types[1].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  static const Layout kNCHW("NCHW");
+  static const Layout kOIHW("OIHW");
+
+  const Conv2DTransposeAttrs* param = attrs.as<Conv2DTransposeAttrs>();
+  CHECK(param != nullptr);
+  const Layout in_layout(param->data_layout);
+  const Layout kernel_layout(param->weight_layout);
+  CHECK(in_layout.convertible(kNCHW))
+    << "Conv only support input layouts that are convertible from NCHW."
+    << " But got " << in_layout;
+  CHECK(kernel_layout.convertible(kOIHW))
+    << "Conv only support kernel layouts that are convertible from OIHW."
+    << " But got "<< kernel_layout;
+
+  IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
+  const auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW);
+  // infer weight if the kernel_size and channels are defined
+  if (param->kernel_size.defined() && param->channels.defined()) {
+    CHECK_EQ(param->kernel_size.size(), 2);
+    CHECK_EQ(param->dilation.size(), 2);
+
+    std::vector<IndexExpr> wshape({dshape_nchw[1],
+                                   param->channels / param->groups,
+                                   param->kernel_size[0],
+                                   param->kernel_size[1]});
+
+    wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
+    dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
+    dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
+    channels = param->channels;
+
+    // assign result to reporter
+    reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
+  } else {
+    // use weight to infer the conv shape.
+    if (weight == nullptr) return false;
+    auto wshape = ConvertLayout(weight->shape, kernel_layout, kOIHW);
+    if (param->kernel_size.defined()) {
+      CHECK_EQ(param->kernel_size.size(), 2);
+      // check the size
+      CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
+            reporter->AssertEQ(param->kernel_size[1], wshape[3]))
+          << "Conv2D: shape of weight is inconsistent with kernel_size, "
+          << " kernel_size=" << param->kernel_size
+          << " wshape=" << Array<IndexExpr>(wshape);
+    }
+    if (param->channels.defined()) {
+      CHECK(reporter->AssertEQ(param->channels, wshape[1]))
+          << "Conv2D: shape of weight is inconsistent with channels, "
+          << " channels=" << param->channels
+          << " wshape=" << Array<IndexExpr>(wshape);
+    }
+    CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[0]));
+    channels = wshape[1];
+    dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
+    dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
+  }
+  // dilation
+  std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
+  oshape[2] = (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
+               2 * param->padding[0] + param->output_padding[0]);
+  oshape[3] = (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
+               2 * param->padding[1] + param->output_padding[1]);
+
+  DataType out_dtype = param->out_dtype;
+  if (out_dtype.bits() == 0) {
+    out_dtype = data->dtype;
+  }
+  oshape = ConvertLayout(oshape, kNCHW, in_layout);
+  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+  return true;
+}
+
+
+Expr MakeConv2DTranspose(Expr data,
+                         Expr weight,
+                         Array<IndexExpr> strides,
+                         Array<IndexExpr> padding,
+                         Array<IndexExpr> dilation,
+                         int groups,
+                         IndexExpr channels,
+                         Array<IndexExpr> kernel_size,
+                         std::string data_layout,
+                         std::string weight_layout,
+                         Array<IndexExpr> output_padding,
+                         DataType out_dtype) {
+  auto attrs = make_node<Conv2DTransposeAttrs>();
+  attrs->channels = channels;
+  attrs->kernel_size = kernel_size;
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->output_padding = std::move(output_padding);
+  attrs->dilation = std::move(dilation);
+  attrs->groups = groups;
+  attrs->data_layout = std::move(data_layout);
+  attrs->weight_layout = std::move(weight_layout);
+  attrs->out_dtype = std::move(out_dtype);
+  static const Op& op = Op::Get("nn.conv2d_transpose");
+  return CallNode::make(op, {data, weight}, Attrs(attrs), {});
+}
+
+
+TVM_REGISTER_API("relay.op.nn._make.conv2d_transpose")
+.set_body([](const TVMArgs& args, TVMRetValue* rv) {
+    runtime::detail::unpack_call<Expr, 12>(MakeConv2DTranspose, args, rv);
+  });
+
+RELAY_REGISTER_OP("nn.conv2d_transpose")
+.describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution).
+
+The need for transposed convolutions generally arises
+from the desire to use a transformation going in the opposite direction
+of a normal convolution, i.e., from something that has the shape of the
+output of some convolution to something that has the shape of its input
+while maintaining a connectivity pattern that is compatible with
+said convolution.
+
+- **data**: This depends on the `layout` parameter. Input is 4D array of shape
+            (batch_size, in_channels, height, width) if `layout` is `NCHW`.
+- **weight**: (in_channels, channels, kernel_size[0], kernel_size[1])
+- **bias**: (channels,)
+- **out**:  This depends on the `layout` parameter. Output is 4D array of shape
+v            (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
+
+            out_height and out_width are calculated as::
+                out_height = (height-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0]
+                out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1]
+
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(2)
+.add_argument("data", "Tensor", "The input tensor.")
+.add_argument("weight", "Tensor", "The weight tensor.")
+.set_support_level(2)
+.add_type_rel("Conv2DTranspose", Conv2DTransposeRel);
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py
index b9599982a..1d6d00277 100644
--- a/tests/python/relay/test_op_level2.py
+++ b/tests/python/relay/test_op_level2.py
@@ -57,6 +57,42 @@ def test_conv2d_infer_type():
     assert ftype.arg_types[1] == relay.ty.TensorType(
         (4, 8, 3, 3, 4, 4), "int8")
 
+def test_conv2d_transpose_infer_type():
+    # symbolic in batch dimension
+    ib = relay.ir_builder.IRBuilder()
+    n, c, h, w = tvm.var("n"), 10, 10, 12
+    x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
+    w = ib.param("w", relay.ty.IncompleteType())
+
+    with ib.function(x, w) as func:
+        ib.ret(relay.nn.conv2d_transpose(x.var, w.var,
+                                         kernel_size=(3, 3),
+                                         padding=(1, 1),
+                                         channels=15))
+    ib.ret(func)
+    func = relay.ir_pass.infer_type(ib.env, func.to_func())
+    ftype = func.checked_type
+    assert ftype.ret_type == relay.ty.TensorType(
+        (n, 15, 10, 12), "float32")
+    assert ftype.arg_types[1] == relay.ty.TensorType(
+        (10, 15, 3, 3), "float32")
+
+    # infer by shape of w, mixed precision
+    ib = relay.ir_builder.IRBuilder()
+    n, c, h, w = tvm.var("n"), 10, 10, 12
+    x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
+    w = ib.param("w", relay.ty.TensorType((12, 11, 5, 5), "float32"))
+    with ib.function(x, w) as func:
+        ib.ret(relay.nn.conv2d_transpose(x.var, w.var,
+                                         output_padding=(1, 1),
+                                         channels=11,
+                                         data_layout="NHWC"))
+    ib.ret(func)
+    func = relay.ir_pass.infer_type(ib.env, func.to_func())
+    ftype = func.checked_type
+    assert ftype.ret_type == relay.ty.TensorType(
+        (n, 15, 15, 11), "float32")
+
 def test_upsampling_infer_type():
     ib = relay.ir_builder.IRBuilder()
     n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
@@ -166,3 +202,4 @@ if __name__ == "__main__":
     test_pool2d_infer_type()
     test_upsampling_infer_type()
     test_flatten_infer_type()
+    test_conv2d_transpose_infer_type()
-- 
GitLab