From 54d4fe4b0f9b6b11d273420ad9b1e018b04c8d77 Mon Sep 17 00:00:00 2001
From: masahi <masahi129@gmail.com>
Date: Sat, 10 Feb 2018 11:14:34 +0900
Subject: [PATCH] [TOPI] Initial NHWC layout support (#882)

* add 4 dim softmax

* update for NHWC layout

* remove layout param from softmax

* fix typo

* minor fix to pool

support axis=1 ndims=5 softmax.

add softmax axis

* few fix for softmax

* fix typo

* add more doc

* minor doc fix

* fix upsampling output shape

* fix lint

* cleanup softmax

* minor fix

* raise exception instead of assert, handles negative axis

* check axis after axis transformation
---
 topi/python/topi/nn/conv2d.py     |   2 +
 topi/python/topi/nn/pooling.py    | 118 +++++++++++++++++++++++++++++-
 topi/python/topi/nn/softmax.py    |  47 +++++++++---
 topi/python/topi/nn/upsampling.py |  53 +++++++++++++-
 4 files changed, 206 insertions(+), 14 deletions(-)

diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py
index 3bd910e29..d6488b164 100644
--- a/topi/python/topi/nn/conv2d.py
+++ b/topi/python/topi/nn/conv2d.py
@@ -82,6 +82,8 @@ def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'):
         return conv2d_nchw(data, kernel, stride, padding, out_dtype)
     elif layout == 'HWCN':
         return conv2d_hwcn(data, kernel, stride, padding, out_dtype)
+    elif layout == 'NHWC':
+        return conv2d_nhwc(data, kernel, stride, padding, out_dtype)
     else:
         raise ValueError("not support this layout {} yet".format(layout))
 
diff --git a/topi/python/topi/nn/pooling.py b/topi/python/topi/nn/pooling.py
index 99b15e18e..0519471ea 100644
--- a/topi/python/topi/nn/pooling.py
+++ b/topi/python/topi/nn/pooling.py
@@ -44,9 +44,50 @@ def global_pool(data, pool_type):
         raise ValueError("Pool type should be 'avg' or 'max'.")
 
 
-def pool(data, kernel, stride, padding, pool_type, ceil_mode=False):
+def pool(data, kernel, stride, padding, pool_type, ceil_mode=False, layout="NCHW"):
     """Perform pooling on the data
 
+    Parameters
+    ----------
+    data : tvm.Tensor
+        4-D with shape [batch, channel, in_height, in_width]
+        or  [batch, in_height, in_width, channel]
+
+    kernel : list/tuple of two ints
+        Kernel size, [kernel_height, kernel_width]
+
+    stride : list/tuple of two ints
+        Stride size, [stride_height, stride_width]
+
+    paddding : list/tuple of two ints
+        Pad size, [pad_height, pad_width]
+
+    pool_type : str
+        Pool type, 'max' or 'avg'
+
+    ceil_mode : bool
+        Whether to use ceil when caculate output size.
+
+    layout: string
+        either "NCHW" or "NHWC"
+
+    Returns
+    -------
+    output : tvm.Tensor
+        4-D with shape [batch, channel, out_height, out_width]
+        or [batch, out_height, out_width, channel]
+    """
+    if layout == "NCHW":
+        return pool_nchw(data, kernel, stride, padding, pool_type, ceil_mode=ceil_mode)
+    elif layout == "NHWC":
+        return pool_nhwc(data, kernel, stride, padding, pool_type, ceil_mode=ceil_mode)
+    else:
+        raise ValueError("not support this layout {} yet".format(layout))
+
+
+def pool_nchw(data, kernel, stride, padding, pool_type, ceil_mode=False):
+    """Perform pooling on the data in NCHW layout
+
     Parameters
     ----------
     data : tvm.Tensor
@@ -117,3 +158,78 @@ def pool(data, kernel, stride, padding, pool_type, ceil_mode=False):
                             tag=tag.ELEMWISE)
     else:
         raise ValueError("Pool type should be 'avg' or 'max'.")
+
+
+def pool_nhwc(data, kernel, stride, padding, pool_type, ceil_mode=False):
+    """Perform pooling on the data in NHWC layout
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        4-D with shape [batch, in_height, in_width, channel]
+
+    kernel : list/tuple of two ints
+        Kernel size, [kernel_height, kernel_width]
+
+    stride : list/tuple of two ints
+        Stride size, [stride_height, stride_width]
+
+    paddding : list/tuple of two ints
+        Pad size, [pad_height, pad_width]
+
+    pool_type : str
+        Pool type, 'max' or 'avg'
+
+    ceil_mode : bool
+        Whether to use ceil when caculate output size.
+
+    Returns
+    -------
+    output : tvm.Tensor
+        4-D with shape [batch, channel, out_height, out_width]
+    """
+    assert len(data.shape) == 4, "only support 4-dim pooling"
+    assert len(stride) == 2, "only support 2-dim stride"
+    kernel_height, kernel_width = kernel
+    stride_height, stride_width = stride
+    batch, height, width, channel = data.shape
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (kernel_height, kernel_width))
+
+    if ceil_mode:
+        # Additional padding to ensure we do ceil instead of floor when divide stride.
+        pad_down += stride_height -1
+        pad_right += stride_width - 1
+
+    pad_before = [0, pad_top, pad_left, 0]
+    pad_after = [0, pad_down, pad_right, 0]
+
+    out_height = util.simplify((height - kernel_height + pad_top + pad_down) // stride_height + 1)
+    out_width = util.simplify((width - kernel_width + pad_left + pad_right) // stride_width + 1)
+
+    dheight = tvm.reduce_axis((0, kernel_height))
+    dwidth = tvm.reduce_axis((0, kernel_width))
+
+    if pool_type == 'max':
+        temp = pad(data, pad_before, pad_after, name="pad_temp", \
+            pad_value=tvm.min_value(data.dtype))
+        return tvm.compute((batch, out_height, out_width, channel), \
+                            lambda n, h, w, c: \
+                            tvm.max(temp[n, h*stride_height+dheight, w*stride_width+dwidth, c], \
+                                axis=[dheight, dwidth]), \
+                            tag="pool_max")
+    elif pool_type == 'avg':
+        temp = pad(data, pad_before, pad_after, name="pad_temp", \
+            pad_value=tvm.const(0.).astype(data.dtype))
+        tsum = tvm.compute((batch, out_height, out_width, channel, ), \
+                            lambda n, h, w, c: \
+                            tvm.sum(temp[n, h*stride_height+dheight, w*stride_width+dwidth, c], \
+                                axis=[dheight, dwidth]), \
+                            tag="pool_avg")
+        return tvm.compute((batch, out_height, out_width, channel), \
+                            lambda n, h, w, c: \
+                            tsum[n, h, w, c] / (kernel_height*kernel_width), \
+                            tag=tag.ELEMWISE)
+    else:
+        raise ValueError("Pool type should be 'avg' or 'max'.")
diff --git a/topi/python/topi/nn/softmax.py b/topi/python/topi/nn/softmax.py
index 9060a31f5..5e73f7633 100644
--- a/topi/python/topi/nn/softmax.py
+++ b/topi/python/topi/nn/softmax.py
@@ -4,28 +4,51 @@ from __future__ import absolute_import
 import tvm
 
 @tvm.tag_scope(tag='softmax_output')
-def softmax(x):
+def softmax(x, axis=-1):
     """Perform softmax activation on the data
 
     Parameters
     ----------
     data : tvm.Tensor
-        2-D input data
+        can be any dimension
+
+    axis : int
+        channel axis
 
     Returns
     -------
     output : tvm.Tensor
-        2-D output with same shape
+        output shape is the same as input
     """
-    assert len(x.shape) == 2, "only support 2-dim softmax"
-    m, n = x.shape
-    k = tvm.reduce_axis((0, n), name='k')
-    max_elem = tvm.compute((m, ), lambda i: tvm.max(x[i, k], axis=k))
-    k = tvm.reduce_axis((0, n), name='k')
-    expsum = tvm.compute(
-        (m, ), lambda i: tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k))
-    return tvm.compute(
-        x.shape, lambda i, j: tvm.exp(x[i, j] - max_elem[i]) / expsum[i])
+    shape = x.shape
+    if axis < 0:
+        axis = len(shape) + axis
+    if axis >= len(shape):
+        ValueError("axis parameter should be less than input dim")
+
+    k1 = tvm.reduce_axis((0, shape[axis]), name='k')
+    k2 = tvm.reduce_axis((0, shape[axis]), name='k')
+
+    def insert_reduce_index(indices, reduce_index):
+        return indices[:axis] + (reduce_index,) + indices[axis:]
+
+    def _compute_max(*indices):
+        eval_range = insert_reduce_index(indices, k1)
+        return tvm.max(x[eval_range], axis=k1)
+
+    def _compute_expsum(max_elem, *indices):
+        eval_range = insert_reduce_index(indices, k2)
+        return tvm.sum(tvm.exp(x[eval_range] - max_elem[indices]), axis=k2)
+
+    def _normalize(max_elem, expsum, *indices):
+        non_reduce_indices = tuple([var for (i, var) in enumerate(indices) if i != axis])
+        return tvm.exp(x[indices] - max_elem[non_reduce_indices]) / expsum[non_reduce_indices]
+
+    reduced_shape = tuple([dim for (i, dim) in enumerate(shape) if i != axis])
+    max_elem = tvm.compute(reduced_shape, _compute_max)
+    expsum = tvm.compute(reduced_shape, lambda *indices: _compute_expsum(max_elem, *indices))
+    return tvm.compute(shape, lambda *indices: _normalize(max_elem, expsum, *indices))
+
 
 @tvm.tag_scope(tag='log_softmax_output')
 def log_softmax(x):
diff --git a/topi/python/topi/nn/upsampling.py b/topi/python/topi/nn/upsampling.py
index df77bbdb2..9297eb4ad 100644
--- a/topi/python/topi/nn/upsampling.py
+++ b/topi/python/topi/nn/upsampling.py
@@ -4,10 +4,40 @@ import tvm
 from .. import util
 
 
-def upsampling(data, scale):
+def upsampling(data, scale, layout="NCHW"):
     """Perform nearest neighbor upsampling on the data.
        Bilinear upsampling is not supported.
 
+    Parameters
+    ----------
+    data : tvm.Tensor
+        4-D with shape [batch, channel, in_height, in_width]
+        or  [batch, in_height, in_width, channel]
+
+    scale: int
+        upsampling scaling factor
+
+    layout: string
+        either "NCHW" or "NHWC"
+
+    Returns
+    -------
+    output : tvm.Tensor
+        4-D with shape [batch, channel, in_height*scale, in_width*scale]
+        or [batch, in_height*scale, in_width*scale, channel]
+    """
+
+    if layout == "NCHW":
+        return upsampling_nchw(data, scale)
+    elif layout == "NHWC":
+        return upsampling_nhwc(data, scale)
+    else:
+        raise ValueError("not support this layout {} yet".format(layout))
+
+
+def upsampling_nchw(data, scale):
+    """Perform nearest neighor upsampling on NCHW layout input.
+
     Parameters
     ----------
     data : tvm.Tensor
@@ -27,3 +57,24 @@ def upsampling(data, scale):
 
     return tvm.compute((batch, channel, out_height, out_width), \
                         lambda n, c, h, w: data[n, c, h/scale, w/scale])
+
+
+def upsampling_nhwc(data, scale):
+    """Perform nearest neighor upsampling on NHWC layout input.
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        4-D with shape [batch, in_height, in_width, channel]
+
+    scale: int
+        upsampling scaling factor
+
+    """
+
+    batch, height, width, channel = data.shape
+    out_height = util.simplify(height * scale)
+    out_width = util.simplify(width * scale)
+
+    return tvm.compute((batch, out_height, out_width, channel), \
+                        lambda n, h, w, c: data[n, h/scale, w/scale, c])
-- 
GitLab