From b154e6b904bcf8e5308108b771d8a11614ff7c20 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Thu, 28 Jun 2018 10:37:04 -0700
Subject: [PATCH] [NNVM] Initial mixed precision support of conv2d (#1356)

---
 nnvm/include/nnvm/top/nn.h                    |  7 ++++
 nnvm/python/nnvm/top/nn.py                    | 12 +++++--
 nnvm/src/top/nn/convolution.cc                | 30 ++++++++++++++--
 nnvm/tests/python/compiler/test_top_level2.py | 34 +++++++++++++++++--
 .../tests/python/unittest/test_infer_shape.py | 22 ++++++++++++
 topi/python/topi/nn/depthwise_conv2d.py       | 17 +++++++---
 6 files changed, 110 insertions(+), 12 deletions(-)

diff --git a/nnvm/include/nnvm/top/nn.h b/nnvm/include/nnvm/top/nn.h
index 69d49cd99..f37811315 100644
--- a/nnvm/include/nnvm/top/nn.h
+++ b/nnvm/include/nnvm/top/nn.h
@@ -11,6 +11,7 @@
 #include <nnvm/tuple.h>
 #include <nnvm/layout.h>
 #include <string>
+#include "./tensor.h"
 
 namespace nnvm {
 namespace top {
@@ -122,6 +123,7 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> {
   std::string layout;
   std::string kernel_layout;
   std::string out_layout;
+  int out_dtype;
   bool use_bias;
 
   DMLC_DECLARE_PARAMETER(Conv2DParam) {
@@ -156,6 +158,11 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> {
       .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
                 "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
                 "dimensions respectively.");
+    DMLC_DECLARE_DTYPE_FIELD(out_dtype)
+      .add_enum("same", -1)
+      .set_default(-1)
+      .describe("Output data type, set to explicit type under mixed precision setting");
+
     DMLC_DECLARE_FIELD(use_bias).set_default(true)
       .describe("Whether the layer uses a bias vector.");
   }
diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py
index ca1fae1d2..e86d54573 100644
--- a/nnvm/python/nnvm/top/nn.py
+++ b/nnvm/python/nnvm/top/nn.py
@@ -88,6 +88,8 @@ def compute_conv2d(attrs, inputs, _):
     channels = attrs.get_int("channels")
     layout = attrs["layout"]
     kernel_layout = attrs["kernel_layout"]
+    out_dtype = attrs["out_dtype"]
+    out_dtype = None if out_dtype == "same" else out_dtype
     assert layout == "NCHW" or layout == "NHWC"
     (dilation_h, dilation_w) = dilation
     if dilation_h < 1 or dilation_w < 1:
@@ -100,16 +102,19 @@ def compute_conv2d(attrs, inputs, _):
         kernel = topi.nn.dilate(inputs[1], [1, dilation_h, dilation_w, 1])
 
     if groups == 1:
-        out = topi.nn.conv2d(inputs[0], kernel, strides, padding, layout)
+        out = topi.nn.conv2d(
+            inputs[0], kernel, strides, padding, layout, out_dtype=out_dtype)
     elif layout == "NCHW" and \
          groups == get_const_int(inputs[0].shape[1]) and \
          groups == channels:
-        out = topi.nn.depthwise_conv2d_nchw(inputs[0], kernel, strides, padding)
+        out = topi.nn.depthwise_conv2d_nchw(
+            inputs[0], kernel, strides, padding, out_dtype=out_dtype)
     elif layout == "NHWC" and \
          kernel_layout == "HWOI" and \
          groups == get_const_int(inputs[0].shape[3]) and \
          groups == channels:
-        out = topi.nn.depthwise_conv2d_nhwc(inputs[0], kernel, strides, padding)
+        out = topi.nn.depthwise_conv2d_nhwc(
+            inputs[0], kernel, strides, padding, out_dtype=out_dtype)
     else:
         raise ValueError("not support arbitrary group number for now")
 
@@ -127,6 +132,7 @@ def schedule_conv2d(attrs, outs, target):
     channels = attrs.get_int("channels")
     layout = attrs["layout"]
     kernel_layout = attrs["kernel_layout"]
+
     with tvm.target.create(target):
         if groups == 1 and layout == "NCHW":
             return topi.generic.schedule_conv2d_nchw(outs)
diff --git a/nnvm/src/top/nn/convolution.cc b/nnvm/src/top/nn/convolution.cc
index 40c4d7a0e..6a0dad17a 100644
--- a/nnvm/src/top/nn/convolution.cc
+++ b/nnvm/src/top/nn/convolution.cc
@@ -130,6 +130,30 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+
+inline bool Conv2DInferType(const nnvm::NodeAttrs& attrs,
+                            std::vector<int>* in_type,
+                            std::vector<int>* out_type) {
+  const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
+  if (param.use_bias) {
+    CHECK_EQ(in_type->size(), 3U) << "Input:[data, weight, bias]";
+  } else {
+    CHECK_EQ(in_type->size(), 2U) << "Input:[data, weight]";
+  }
+  CHECK_EQ(out_type->size(), 1U);
+  if (param.out_dtype != -1) {
+    CHECK(!type_is_none((*in_type)[0]));
+    for (size_t i = 1; i < in_type->size(); ++i) {
+      NNVM_ASSIGN_INPUT_TYPE(attrs, *in_type, i, (*in_type)[0]);
+    }
+    NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_type, 0, param.out_dtype);
+  } else {
+    ElemwiseType<-1, 1>(attrs, in_type, out_type);
+  }
+  return true;
+}
+
+
 inline bool Conv2DCorrectLayout(const NodeAttrs& attrs,
                                 std::vector<Layout> *ilayouts,
                                 const std::vector<Layout> *last_ilayouts,
@@ -189,7 +213,7 @@ a bias vector is created and added to the outputs.
 .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Conv2DParam>)
 .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>)
 .set_attr<FInferShape>("FInferShape", Conv2DInferShape)
-.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
+.set_attr<FInferType>("FInferType", Conv2DInferType)
 .set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout)
 .set_num_outputs(1)
 .set_num_inputs(UseBiasNumInputs<Conv2DParam>)
@@ -214,7 +238,7 @@ NNVM_REGISTER_OP(_contrib_conv2d_NCHWc)
 .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Conv2DParam>)
 .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>)
 .set_attr<FInferShape>("FInferShape", Conv2DInferShape)
-.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
+.set_attr<FInferType>("FInferType", Conv2DInferType)
 .set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout)
 .set_num_outputs(1)
 .set_num_inputs(UseBiasNumInputs<Conv2DParam>)
@@ -348,7 +372,7 @@ said convolution.
 - **weight**: (in_channels, channels, kernel_size[0], kernel_size[1])
 - **bias**: (channels,)
 - **out**:  This depends on the `layout` parameter. Output is 4D array of shape
-            (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
+v            (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
 
             out_height and out_width are calculated as::
                 out_height = (height-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0]
diff --git a/nnvm/tests/python/compiler/test_top_level2.py b/nnvm/tests/python/compiler/test_top_level2.py
index 8f4c330d2..19387a586 100644
--- a/nnvm/tests/python/compiler/test_top_level2.py
+++ b/nnvm/tests/python/compiler/test_top_level2.py
@@ -32,6 +32,35 @@ def test_conv2d():
         np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
 
 
+def test_mixed_precision():
+    x = sym.Variable("x")
+    dtype = "int8"
+    out_dtype="int32"
+    y = sym.conv2d(x,
+                   channels=10,
+                   kernel_size=(3,3),
+                   name="y",
+                   padding=(1,1),
+                   use_bias=False,
+                   out_dtype="int32")
+    dshape = (1, 3, 18, 18)
+    kshape = (10, 3, 3, 3)
+    oshape = (1, 10, 18, 18)
+    shape_dict = {"x": dshape}
+    dtype_dict = {"x": dtype}
+    for target, ctx in ctx_list():
+        graph, lib, _ = nnvm.compiler.build(y, target, shape_dict, dtype_dict)
+        m = graph_runtime.create(graph, lib, ctx)
+        data = tvm.nd.array(np.random.uniform(-127, 127, size=dshape).astype(dtype))
+        kernel = tvm.nd.array(np.random.uniform(-127, 127, size=kshape).astype(dtype))
+        m.run(x=data, y_weight=kernel)
+        out = m.get_output(0, tvm.nd.empty(oshape, out_dtype))
+        c_np = topi.testing.conv2d_nchw_python(
+            data.asnumpy().astype(out_dtype),
+            kernel.asnumpy().astype(out_dtype), 1, 1)
+        np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
+
+
 def test_dilated_conv2d():
     dilation = 3
     x = sym.Variable("x")
@@ -167,7 +196,7 @@ def test_avg_pool2d_no_count_pad():
     kh, kw = (4, 4)
     sh, sw = (2, 2)
     ph, pw = (2, 2)
-    
+
     x = sym.Variable("x")
     y = sym.avg_pool2d(x, pool_size=(kh, kw), strides=(sw, sw), padding=(ph, pw),
                        name="y", count_include_pad=False)
@@ -181,7 +210,7 @@ def test_avg_pool2d_no_count_pad():
     no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw)))
     pad_np[np.ix_(*no_zero)] = a_np
     b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype)
-    
+
     for i in range(oh):
         for j in range(ow):
             pad_count = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2,3))
@@ -289,6 +318,7 @@ def test_resize_bilinear():
         np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
 
 if __name__ == "__main__":
+    test_mixed_precision()
     test_conv2d()
     test_dilated_conv2d()
     test_grouped_conv2d_nchw()
diff --git a/nnvm/tests/python/unittest/test_infer_shape.py b/nnvm/tests/python/unittest/test_infer_shape.py
index 9fbc93c07..51e0e9576 100644
--- a/nnvm/tests/python/unittest/test_infer_shape.py
+++ b/nnvm/tests/python/unittest/test_infer_shape.py
@@ -168,6 +168,27 @@ def test_conv2d():
           layout="NHWC")
 
 
+def test_conv2d_packed():
+    def check(in_shape,
+              out_shape,
+              kernel_shape,
+              **kwargs):
+        x = sym.Variable("x", shape=in_shape)
+        y = sym.conv2d(x, name="y", **kwargs)
+        sdict = infer_shape(y)
+        assert(tuple(sdict["y"][0]) == tuple(out_shape))
+        assert(tuple(sdict["y_weight"][0]) == tuple(kernel_shape))
+
+    check((4, 10, 10, 12, 1, 8),
+          (4, 10, 10, 2, 1, 8),
+          (2, 12, 3, 3, 8, 8),
+          channels=8 * 2,
+          kernel_size=(3,3),
+          padding=(1,1),
+          layout="NHWC1n8c",
+          kernel_layout="OIHW8o8i")
+
+
 def test_conv2d_transpose():
     def check(in_shape, out_shape, **kwargs):
         x = sym.Variable("x", shape=in_shape)
@@ -332,6 +353,7 @@ def test_reduce():
 
 
 if __name__ == "__main__":
+    test_conv2d_packed()
     test_expand_dims()
     test_dense()
     test_matmul()
diff --git a/topi/python/topi/nn/depthwise_conv2d.py b/topi/python/topi/nn/depthwise_conv2d.py
index 8d348d461..c7906d3a4 100644
--- a/topi/python/topi/nn/depthwise_conv2d.py
+++ b/topi/python/topi/nn/depthwise_conv2d.py
@@ -27,12 +27,15 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
     padding : int or str
         Padding size, or ['VALID', 'SAME']
 
+    out_dtype: str, optional
+        Output data type
+
     Returns
     -------
     Output : tvm.Tensor
         4-D with shape [batch, out_channel, out_height, out_width]
     """
-    out_dtype = Input.dtype
+    out_dtype = Input.dtype if out_dtype is None else out_dtype
 
     batch, in_channel, in_height, in_width = Input.shape
     filter_channel, channel_multiplier, filter_height, filter_width = Filter.shape
@@ -65,7 +68,7 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
 
 
 @tvm.target.generic_func
-def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
+def depthwise_conv2d_nhwc(Input, Filter, stride, padding, out_dtype=None):
     """Depthwise convolution nhwc forward operator.
 
     Parameters
@@ -82,11 +85,16 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
     padding : int or str
         Padding size, or ['VALID', 'SAME']
 
+    out_dtype: str, optional
+        Output data type
+
     Returns
     -------
     Output : tvm.Tensor
         4-D with shape [batch, out_height, out_width, out_channel]
     """
+    out_dtype = Input.dtype if out_dtype is None else out_dtype
+
     batch, in_height, in_width, in_channel = Input.shape
     filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape
     if isinstance(stride, int):
@@ -110,8 +118,9 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
     Output = tvm.compute(
         (batch, out_height, out_width, out_channel),
         lambda b, i, j, c: tvm.sum(
-            (PaddedInput[b, i*stride_h + di, j*stride_w + dj, c/channel_multiplier] *
-             Filter[di, dj, c/channel_multiplier, c%channel_multiplier]),
+            (PaddedInput[b, i*stride_h + di, j*stride_w + dj, c/channel_multiplier].astype(
+                out_dtype) *
+             Filter[di, dj, c/channel_multiplier, c%channel_multiplier].astype(out_dtype)),
             axis=[di, dj]),
         name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc")
     return Output
-- 
GitLab