From bb7df695cd5835065a2ade3d5c8ae6e8ff00d139 Mon Sep 17 00:00:00 2001
From: Siva <sivar.b@huawei.com>
Date: Sun, 24 Jun 2018 21:39:09 +0530
Subject: [PATCH] [NNVM][CONVOLUTION] Group convolution generalization for NHWC
 (#1232)

---
 nnvm/python/nnvm/frontend/tensorflow.py       | 135 ++++++++++++++++--
 nnvm/python/nnvm/top/nn.py                    |  20 ++-
 nnvm/src/top/nn/convolution.cc                |   3 +-
 nnvm/tests/python/compiler/test_top_level2.py |  27 +++-
 4 files changed, 172 insertions(+), 13 deletions(-)

diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py
index 17f08cc3e..3ab997617 100644
--- a/nnvm/python/nnvm/frontend/tensorflow.py
+++ b/nnvm/python/nnvm/frontend/tensorflow.py
@@ -33,6 +33,8 @@ class AttrCvt(object):
         self._ignores.append('_input_shapes')
         self._ignores.append('T')
         self._ignores.append('use_cudnn_on_gpu')
+        self._ignores.append('_node_name')
+        self._ignores.append('is_training')
         return AttrConvert(self._op_name, self._transforms, self._excludes,
                            self._disables, self._ignores, self._extras,
                            self._custom_check)(inputs, attrs, *args)
@@ -230,6 +232,85 @@ def _conv():
             custom_check=_dimension_constraint())(inputs, attr)
     return _impl
 
+def _depthwise_conv():
+    def _impl(inputs, attr, params):
+        attr['data_format'] = attr['data_format'].decode("utf-8")
+        input_shapes = attr['_input_shapes'][inputs[0]]
+
+        # Extract kernel shape from params
+        conv_param_weights = params[inputs[1].list_output_names()[0]]
+
+        if attr['data_format'] == 'NHWC':
+            kernel_h, kernel_w, _, depth_mult = conv_param_weights.shape
+            attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1])
+            attr['channels'] = input_shapes[0][3] * depth_mult
+            if 'dilations' in attr:
+                attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
+        elif attr['data_format'] == 'NCHW':
+            depth_mult, _, kernel_h, kernel_w = conv_param_weights.shape
+            attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3])
+            attr['channels'] = input_shapes[0][1] * depth_mult
+            if 'dilations' in attr:
+                attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
+        else:
+            raise TypeError("Unsupported data format type : {}".format(attr['data_format']))
+
+        # Fix strides
+        attr['strides'] = (attr['strides'][1], attr['strides'][2])
+
+        # Fix groups
+        attr['groups'] = attr['channels']
+
+        # Fix padding
+        attr['padding'] = attr['padding'].decode("utf-8")
+
+        if attr['padding'] == 'VALID':
+            attr['padding'] = [0, 0]
+        elif attr['padding'] == 'SAME':
+            stride_h, stride_w = attr['strides']
+            kernel_h, kernel_w = attr['kernel_shape']
+            if attr['data_format'] == 'NHWC':
+                in_h = input_shapes[0][1]
+                in_w = input_shapes[0][2]
+            else:
+                in_h = input_shapes[0][2]
+                in_w = input_shapes[0][3]
+
+            pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
+            pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
+
+            if attr['data_format'] == 'NHWC':
+                inputs[0] = _sym.pad(data=inputs[0],
+                                     pad_width=((0, 0),
+                                                (pad_v[0], pad_v[1]),
+                                                (pad_h[0], pad_h[1]),
+                                                (0, 0)))
+            else:
+                inputs[0] = _sym.pad(data=inputs[0],
+                                     pad_width=((0, 0),
+                                                (0, 0),
+                                                (pad_v[0], pad_v[1]),
+                                                (pad_h[0], pad_h[1])))
+
+            attr['padding'] = [0, 0]
+
+        else:
+            raise TypeError("Unsupported padding type : {}".format(attr['padding']))
+
+        if 'kernel_layout' not in attr:
+            attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'
+
+        return AttrCvt(
+            op_name=_dimension_picker('conv'),
+            transforms={
+                'kernel_shape': 'kernel_size',
+                'data_format': 'layout',
+                'dilations': ('dilation', (0, 0)),
+                'group': ('groups', 1)},
+            extras={'use_bias': len(inputs) == 3},
+            custom_check=_dimension_constraint())(inputs, attr)
+    return _impl
+
 def _decode_image():
     def _impl(inputs, attr, params):
         # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
@@ -358,9 +439,27 @@ def _batch_norm():
             op_name='batch_norm',
             transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'},
             extras={'axis': 3}, # Fix axis
+            ignores=['data_format'],
             disables=['momentum'])(new_inputs, attr)
     return _impl
 
+def _relu6():
+    def _impl(inputs, attr, params):
+        return _sym.clip(inputs[0], a_min=0, a_max=6)
+    return _impl
+
+def _shape():
+    def _impl(inputs, attr, params):
+        input_shapes = attr['_input_shapes'][inputs[0]]
+
+        # Fix the -1 dimensions to 1
+        input_shapes[0] = [1 if x == -1 else x for x in input_shapes[0]]
+        params[attr['_node_name']] = tvm.nd.array(input_shapes[0])
+
+        return _sym.Variable(name=attr['_node_name'],
+                             shape=params[attr['_node_name']].shape)
+    return _impl
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
@@ -392,6 +491,10 @@ _convert_map = {
     'Add'                               : _elemwise('add'),
     'Rsqrt'                             : _rsqrt(),
     'Squeeze'                           : _squeeze(),
+    'FusedBatchNorm'                    : _batch_norm(),
+    'Relu6'                             : _relu6(),
+    'DepthwiseConv2dNative'             : _depthwise_conv(),
+    'Shape'                             : _shape(),
 }
 
 
@@ -458,9 +561,13 @@ class GraphProto(object):
                 self._num_input += 1
                 self._nodes[node.name] = _sym.Variable(name=node.name)
 
-                self._output_shapes[node.name] = \
-                     [tensor_util.TensorShapeProtoToList(shape) \
-                     for shape in self._parse_attr(node.attr)['_output_shapes']]
+                try:
+                    self._output_shapes[node.name] = \
+                         [tensor_util.TensorShapeProtoToList(shape) \
+                         for shape in self._parse_attr(node.attr)['_output_shapes']]
+                except KeyError:
+                    raise NotImplementedError( \
+                        "Please freeze the graph with add_shapes=True")
             elif node.op == "Const":
                 # Assuming first Const node as Graph Input node
                 if self._input_node == '':
@@ -476,17 +583,29 @@ class GraphProto(object):
                         raise NotImplementedError( \
                             "Const {} couldn't be converted to Param.".format(node.name))
 
-                self._output_shapes[node.name] = \
-                     [tensor_util.TensorShapeProtoToList(shape) \
-                     for shape in self._parse_attr(node.attr)['_output_shapes']]
+                try:
+                    self._output_shapes[node.name] = \
+                         [tensor_util.TensorShapeProtoToList(shape) \
+                         for shape in self._parse_attr(node.attr)['_output_shapes']]
+                except KeyError:
+                    raise NotImplementedError( \
+                        "Please freeze the graph with add_shapes=True")
             else:
                 attr = self._parse_attr(node.attr)
-                self._output_shapes[node.name] = \
-                     [tensor_util.TensorShapeProtoToList(shape) for shape in attr['_output_shapes']]
+                try:
+                    self._output_shapes[node.name] = \
+                         [tensor_util.TensorShapeProtoToList(shape) \
+                          for shape in attr['_output_shapes']]
+                except KeyError:
+                    raise NotImplementedError( \
+                        "Please freeze the graph with add_shapes=True")
 
                 # Pass the parsed shapes instead
                 attr["_output_shapes"] = self._output_shapes[node.name]
 
+                # Pass the node name too in attr
+                attr["_node_name"] = node.name
+
                 try:
                     inputs = [self._nodes[i] for i in node.input]
                     input_shapes = {}
diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py
index e4e5415a6..2432cc84f 100644
--- a/nnvm/python/nnvm/top/nn.py
+++ b/nnvm/python/nnvm/top/nn.py
@@ -84,6 +84,7 @@ def compute_conv2d(attrs, inputs, _):
     groups = attrs.get_int("groups")
     channels = attrs.get_int("channels")
     layout = attrs["layout"]
+    kernel_layout = attrs["kernel_layout"]
     assert layout == "NCHW" or layout == "NHWC"
     (dilation_h, dilation_w) = dilation
     if dilation_h < 1 or dilation_w < 1:
@@ -97,10 +98,18 @@ def compute_conv2d(attrs, inputs, _):
 
     if groups == 1:
         out = topi.nn.conv2d(inputs[0], kernel, strides, padding, layout)
-    elif groups == get_const_int(inputs[0].shape[1]) and groups == channels:
+    elif layout == "NCHW" and \
+         groups == get_const_int(inputs[0].shape[1]) and \
+         groups == channels:
         out = topi.nn.depthwise_conv2d_nchw(inputs[0], kernel, strides, padding)
+    elif layout == "NHWC" and \
+         kernel_layout == "HWOI" and \
+         groups == get_const_int(inputs[0].shape[3]) and \
+         groups == channels:
+        out = topi.nn.depthwise_conv2d_nhwc(inputs[0], kernel, strides, padding)
     else:
         raise ValueError("not support arbitrary group number for now")
+
     if attrs.get_bool("use_bias"):
         bias = inputs[2]
         expand_axis = 1 if layout == "NCHW" else 0
@@ -112,13 +121,20 @@ def compute_conv2d(attrs, inputs, _):
 def schedule_conv2d(attrs, outs, target):
     """Schedule definition of conv2d"""
     groups = attrs.get_int("groups")
+    channels = attrs.get_int("channels")
     layout = attrs["layout"]
+    kernel_layout = attrs["kernel_layout"]
     with tvm.target.create(target):
         if groups == 1 and layout == "NCHW":
             return topi.generic.schedule_conv2d_nchw(outs)
         elif groups == 1 and layout == "NHWC":
             return topi.generic.schedule_conv2d_nhwc(outs)
-        return topi.generic.schedule_depthwise_conv2d_nchw(outs)
+        elif groups == channels and layout == "NCHW":
+            return topi.generic.schedule_depthwise_conv2d_nchw(outs)
+        elif groups == channels and layout == "NHWC" and kernel_layout == "HWOI":
+            return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
+        else:
+            raise ValueError("No compatible schedule")
 
 @reg.register_alter_op_layout("conv2d")
 def alter_conv2d_layout(attrs, inputs, tinfos):
diff --git a/nnvm/src/top/nn/convolution.cc b/nnvm/src/top/nn/convolution.cc
index 8b66f9757..40c4d7a0e 100644
--- a/nnvm/src/top/nn/convolution.cc
+++ b/nnvm/src/top/nn/convolution.cc
@@ -79,7 +79,8 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
                  param.kernel_size[1]});
 
   wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
-  wshape[0] *= param.groups;
+
+  wshape[kernel_layout.indexof('O')] *= param.groups;
 
   NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape);
   if (param.use_bias) {
diff --git a/nnvm/tests/python/compiler/test_top_level2.py b/nnvm/tests/python/compiler/test_top_level2.py
index ed67dab44..8f4c330d2 100644
--- a/nnvm/tests/python/compiler/test_top_level2.py
+++ b/nnvm/tests/python/compiler/test_top_level2.py
@@ -58,7 +58,7 @@ def test_dilated_conv2d():
         np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
 
 
-def test_grouped_conv2d():
+def test_grouped_conv2d_nchw():
     x = sym.Variable("x")
     y = sym.conv2d(x, channels=32, kernel_size=(3,3), groups=32,
                    name="y", padding=(1,1))
@@ -80,6 +80,28 @@ def test_grouped_conv2d():
         c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1)
         np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
 
+def test_grouped_conv2d_nhwc():
+    x = sym.Variable("x")
+    y = sym.conv2d(x, channels=32, kernel_size=(3,3), groups=32,
+                   name="y", padding=(1,1), layout="NHWC", kernel_layout ='HWOI')
+    dtype = "float32"
+    dshape = (1, 18, 18, 32)
+    kshape = (3, 3, 32, 1)
+    oshape = (1, 18, 18, 32)
+    shape_dict = {"x": dshape}
+    for target, ctx in ctx_list():
+        graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
+        m = graph_runtime.create(graph, lib, ctx)
+        data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
+        kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
+        bias = tvm.nd.array(np.random.uniform(size=kshape[2]).astype(dtype))
+        m.run(x=data, y_weight=kernel, y_bias=bias)
+        out = m.get_output(0, tvm.nd.empty(oshape, dtype))
+        c_np = topi.testing.depthwise_conv2d_python_nhwc(
+            data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME')
+        c_np = c_np + bias.asnumpy().reshape(1, 1, kshape[2])
+        np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
+
 
 def test_conv2d_transpose():
     x = sym.Variable("x")
@@ -269,7 +291,8 @@ def test_resize_bilinear():
 if __name__ == "__main__":
     test_conv2d()
     test_dilated_conv2d()
-    test_grouped_conv2d()
+    test_grouped_conv2d_nchw()
+    test_grouped_conv2d_nhwc()
     test_conv2d_transpose()
     test_max_pool2d()
     test_avg_pool2d()
-- 
GitLab