From c6a1241e74a4901ab95ba6c3806a474f2474c029 Mon Sep 17 00:00:00 2001
From: ziheng <ziheng@apache.org>
Date: Tue, 14 Nov 2017 10:11:29 -0800
Subject: [PATCH] [TOPI] Add out_dtype argument for conv2d; Add x86 schedules
 (#646)

* [TOPI] Add out_dtype argument for conv2d; Add x86 schedules

* Fix

* Fix lint

* Fix
---
 topi/python/topi/__init__.py              |  1 +
 topi/python/topi/nn/conv2d.py             | 80 ++++++++++++-----------
 topi/python/topi/nn/depthwise_conv2d.py   |  6 +-
 topi/python/topi/rasp/conv2d.py           | 19 +++---
 topi/python/topi/rasp/depthwise_conv2d.py | 28 ++++----
 topi/python/topi/x86/__init__.py          |  5 ++
 topi/python/topi/x86/conv2d.py            | 37 +++++++++++
 7 files changed, 111 insertions(+), 65 deletions(-)
 create mode 100644 topi/python/topi/x86/__init__.py
 create mode 100644 topi/python/topi/x86/conv2d.py

diff --git a/topi/python/topi/__init__.py b/topi/python/topi/__init__.py
index 1306f9d9c..62a9ae153 100644
--- a/topi/python/topi/__init__.py
+++ b/topi/python/topi/__init__.py
@@ -14,6 +14,7 @@ from .reduction import *
 from .transform import *
 from .broadcast import *
 from . import nn
+from . import x86
 from . import cuda
 from . import rasp
 from . import testing
diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py
index db3a6079f..cc1ee0198 100644
--- a/topi/python/topi/nn/conv2d.py
+++ b/topi/python/topi/nn/conv2d.py
@@ -9,7 +9,7 @@ from ..util import simplify
 
 # workload description of conv2d
 Workload = namedtuple('Workload',
-                      ['height', 'width', 'in_filter', 'out_filter',
+                      ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter',
                        'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
 
 # schedule description of spatial
@@ -22,36 +22,36 @@ Im2ColPack = namedtuple('Im2ColPack',
 
 _WORKLOADS = [
     # workloads of resnet18 on imagenet
-    Workload(224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
-    Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
-    Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
-    Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
-    Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
-    Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
-    Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
-    Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
-    Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
-    Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
-    Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
-    Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
+    Workload('float32', 'float32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
+    Workload('float32', 'float32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
+    Workload('float32', 'float32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
+    Workload('float32', 'float32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
+    Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
+    Workload('float32', 'float32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
+    Workload('float32', 'float32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
+    Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
+    Workload('float32', 'float32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
+    Workload('float32', 'float32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
+    Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
+    Workload('float32', 'float32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
     # workloads of mobile net on imagenet
-    Workload(224, 224, 3, 32, 3, 3, 1, 1, 2, 2),
-    Workload(112, 112, 32, 64, 1, 1, 0, 0, 1, 1),
-    Workload(56, 56, 64, 128, 1, 1, 0, 0, 1, 1),
-    Workload(56, 56, 128, 128, 1, 1, 0, 0, 1, 1),
-    Workload(28, 28, 128, 256, 1, 1, 0, 0, 1, 1),
-    Workload(28, 28, 256, 256, 1, 1, 0, 0, 1, 1),
-    Workload(14, 14, 256, 512, 1, 1, 0, 0, 1, 1),
-    Workload(14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
-    Workload(7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
-    Workload(7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),
+    Workload('float32', 'float32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2),
+    Workload('float32', 'float32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1),
+    Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1),
+    Workload('float32', 'float32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1),
+    Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1),
+    Workload('float32', 'float32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1),
+    Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1),
+    Workload('float32', 'float32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
+    Workload('float32', 'float32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
+    Workload('float32', 'float32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),
 ]
 
 # platform specific schedule
 _CONV_SCHEDULE = {}
 
 @tvm.target.generic_func
-def conv2d(data, kernel, stride, padding, layout='NCHW'):
+def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'):
     """Conv2D operator.
 
     Parameters
@@ -79,14 +79,14 @@ def conv2d(data, kernel, stride, padding, layout='NCHW'):
     # search platform specific declaration first
     # default declaration
     if layout == 'NCHW':
-        return conv2d_nchw(data, kernel, stride, padding)
+        return conv2d_nchw(data, kernel, stride, padding, out_dtype)
     elif layout == 'HWCN':
-        return conv2d_hwcn(data, kernel, stride, padding)
+        return conv2d_hwcn(data, kernel, stride, padding, out_dtype)
     else:
         raise ValueError("not support this layout {} yet".format(layout))
 
 
-def _get_workload(data, kernel, stride, padding):
+def _get_workload(data, kernel, stride, padding, out_dtype):
     """ Get the workload structure. """
     _, CI, IH, IW = [x.value for x in data.shape]
     CO, _, KH, KW = [x.value for x in kernel.shape]
@@ -95,7 +95,8 @@ def _get_workload(data, kernel, stride, padding):
         HSTR, WSTR = stride
     else:
         HSTR, WSTR = stride, stride
-    return Workload(IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
+    assert data.dtype == kernel.dtype, "Do not support inputs with different data types now."
+    return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
 
 
 @tvm.target.generic_func
@@ -108,10 +109,10 @@ def _get_schedule(wkl):
     # This return has no use, merely to supress pylint warning
     return wkl
 
-def _spatial_pack(data, kernel, stride, padding):
+def _spatial_pack(data, kernel, stride, padding, out_dtype):
     """ Compute convolution with pack on spatial axes. """
     assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
-    wkl = _get_workload(data, kernel, stride, padding)
+    wkl = _get_workload(data, kernel, stride, padding, out_dtype)
     sch = _get_schedule(wkl)
 
     H, W = wkl.height, wkl.width
@@ -158,8 +159,8 @@ def _spatial_pack(data, kernel, stride, padding):
     dw = tvm.reduce_axis((0, KW), name='dw')
 
     conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
-        tvm.sum(data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw] *
-                kernel_vec[co, ci, dh, dw, vc],
+        tvm.sum(data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw].astype(out_dtype) *
+                kernel_vec[co, ci, dh, dw, vc].astype(out_dtype),
                 axis=[ci, dh, dw]), name='conv')
 
     output = tvm.compute(oshape, lambda n, co, h, w:
@@ -169,10 +170,10 @@ def _spatial_pack(data, kernel, stride, padding):
     return output
 
 
-def _im2col_pack(data, kernel, stride, padding):
+def _im2col_pack(data, kernel, stride, padding, out_dtype):
     """ Compute convolution with im2col pack layout. """
     assert data.shape[0].value == 1, "im2col pack convolution only support batch size=1"
-    wkl = _get_workload(data, kernel, stride, padding)
+    wkl = _get_workload(data, kernel, stride, padding, out_dtype)
     sch = _get_schedule(wkl)
 
     N = 1
@@ -234,7 +235,7 @@ def _im2col_pack(data, kernel, stride, padding):
     return output
 
 
-def conv2d_nchw(Input, Filter, stride, padding):
+def conv2d_nchw(Input, Filter, stride, padding, out_dtype='float32'):
     """Convolution operator in NCHW layout.
 
     Parameters
@@ -280,11 +281,12 @@ def conv2d_nchw(Input, Filter, stride, padding):
     return tvm.compute(
         (batch, out_channel, out_height, out_width),
         lambda nn, ff, yy, xx: tvm.sum(
-            temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx] * Filter[ff, rc, ry, rx],
+            temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx].astype(out_dtype) *
+            Filter[ff, rc, ry, rx].astype(out_dtype),
             axis=[rc, ry, rx]), tag="conv2d_nchw")
 
 
-def conv2d_hwcn(Input, Filter, stride, padding):
+def conv2d_hwcn(Input, Filter, stride, padding, out_dtype='float32'):
     """Convolution operator in HWCN layout.
 
     Parameters
@@ -329,8 +331,8 @@ def conv2d_hwcn(Input, Filter, stride, padding):
     Output = tvm.compute(
         (out_height, out_width, out_channel, batch),
         lambda yy, xx, ff, nn: tvm.sum(
-            PaddedInput[yy * stride_h + ry, xx * stride_w + rx, rc, nn] * Filter[ry, rx, rc, ff],
-            axis=[ry, rx, rc]),
+            PaddedInput[yy * stride_h + ry, xx * stride_w + rx, rc, nn].astype(out_dtype) *
+            Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]),
         name="Conv2dOutput", tag="conv2d_hwcn")
     return Output
 
diff --git a/topi/python/topi/nn/depthwise_conv2d.py b/topi/python/topi/nn/depthwise_conv2d.py
index 40aed1572..785bdab27 100644
--- a/topi/python/topi/nn/depthwise_conv2d.py
+++ b/topi/python/topi/nn/depthwise_conv2d.py
@@ -9,7 +9,7 @@ from .util import get_pad_tuple
 from ..util import simplify
 
 
-def depthwise_conv2d_nchw(Input, Filter, stride, padding):
+def depthwise_conv2d_nchw(Input, Filter, stride, padding, out_dtype='float32'):
     """Depthwise convolution nchw forward operator.
 
     Parameters
@@ -51,8 +51,8 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding):
     Output = tvm.compute(
         (batch, out_channel, out_height, out_width),
         lambda b, c, i, j: tvm.sum(
-            (PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] *
-             Filter[c/channel_multiplier, c%channel_multiplier, di, dj]),
+            (PaddedInput[b, c/channel_multiplier, i*stride_h+di, j*stride_w+dj].astype(out_dtype) *
+             Filter[c/channel_multiplier, c%channel_multiplier, di, dj].astype(out_dtype)),
             axis=[di, dj]),
         name='DepthwiseConv2d', tag="depthwise_conv2d_nchw")
     return Output
diff --git a/topi/python/topi/rasp/conv2d.py b/topi/python/topi/rasp/conv2d.py
index 86cb8a9d0..6e5a1b335 100644
--- a/topi/python/topi/rasp/conv2d.py
+++ b/topi/python/topi/rasp/conv2d.py
@@ -12,6 +12,7 @@ from ..nn.util import infer_pad, infer_stride
 from .. import generic
 
 _SCHEDULES = [
+    # float32 imagenet
     SpatialPack(1, 8, 4, 1, 4, True),
     SpatialPack(1, 7, 4, 2, 4, True),
     SpatialPack(1, 4, 8, 4, 1, True),
@@ -25,6 +26,7 @@ _SCHEDULES = [
     Im2ColPack(7, 4, 1, 8, False),
     Im2ColPack(7, 4, 1, 16, False),
 
+    # float32 mobilenet
     SpatialPack(2, 2, 4, 28, 1, True),
     SpatialPack(1, 4, 8, 14, 1, False),
     SpatialPack(1, 2, 16, 8, 1, True),
@@ -47,12 +49,12 @@ def _schedule_conv2d(wkl):
 
 
 @conv2d.register("rasp")
-def _declaration_conv2d(data, kernel, stride, padding, layout):
+def _declaration_conv2d(data, kernel, stride, padding, layout, out_dtype):
     assert layout == 'NCHW', "only support NCHW convolution on rasp"
     assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
-    wkl = _get_workload(data, kernel, stride, padding)
+    wkl = _get_workload(data, kernel, stride, padding, out_dtype)
     sch = _get_schedule(wkl)
-    return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding)
+    return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, out_dtype)
 
 
 def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
@@ -64,10 +66,8 @@ def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
         stride = infer_stride(data, kernel, output)
     else:
         stride = infer_stride(data_pad, kernel, output)
-    wkl = _get_workload(data, kernel, stride, padding)
-
-    with tvm.target.rasp():
-        sch = _get_schedule(wkl)
+    wkl = _get_workload(data, kernel, stride, padding, output.dtype)
+    sch = _get_schedule(wkl)
 
     H, W = wkl.height, wkl.width
     CI, CO = wkl.in_filter, wkl.out_filter
@@ -172,7 +172,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
         stride = infer_stride(data, kernel, output)
     else:
         stride = infer_stride(data_pad, kernel, output)
-    wkl = _get_workload(data, kernel, stride, padding)
+    wkl = _get_workload(data, kernel, stride, padding, output.dtype)
 
     with _target.rasp():
         sch = _get_schedule(wkl)
@@ -280,7 +280,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
 
     return s
 
-@generic.schedule_conv2d_nchw.register(["cpu", "rasp"])
+@generic.schedule_conv2d_nchw.register(["rasp"])
 def schedule_conv2d(outs):
     """Create schedule for tensors"""
     s = tvm.create_schedule([x.op for x in outs])
@@ -294,6 +294,7 @@ def schedule_conv2d(outs):
             for tensor in op.input_tensors:
                 if tensor.op.input_tensors:
                     traverse(tensor.op)
+
         if 'spatial_conv_output' in op.tag:
             output = op.output(0)
             conv_out = op.input_tensors[0]
diff --git a/topi/python/topi/rasp/depthwise_conv2d.py b/topi/python/topi/rasp/depthwise_conv2d.py
index e695f0463..a6fd691f8 100644
--- a/topi/python/topi/rasp/depthwise_conv2d.py
+++ b/topi/python/topi/rasp/depthwise_conv2d.py
@@ -8,22 +8,22 @@ from ..nn.util import infer_pad, infer_stride, get_pad_tuple
 from .. import generic
 
 _Workload = namedtuple('Workload',
-                       ['height', 'width', 'channel', 'multiplier',
+                       ['in_dtype', 'out_dtype', 'height', 'width', 'channel', 'multiplier',
                         'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
 
 _Schedule = namedtuple('Schedule', ['vh', 'vw', 'vc', 'bc', 'unroll'])
 
 # workloads of depthwise conv mobile net on imagenet
 _WORKLOADS = [
-    _Workload(112, 112, 32, 1, 3, 3, 1, 1, 1, 1),
-    _Workload(112, 112, 64, 1, 3, 3, 1, 1, 2, 2),
-    _Workload(56, 56, 128, 1, 3, 3, 1, 1, 1, 1),
-    _Workload(56, 56, 128, 1, 3, 3, 1, 1, 2, 2),
-    _Workload(28, 28, 256, 1, 3, 3, 1, 1, 1, 1),
-    _Workload(28, 28, 256, 1, 3, 3, 1, 1, 2, 2),
-    _Workload(14, 14, 512, 1, 3, 3, 1, 1, 1, 1),
-    _Workload(14, 14, 512, 1, 3, 3, 1, 1, 2, 2),
-    _Workload(14, 14, 1024, 1, 3, 3, 1, 1, 1, 1),
+    _Workload('float32', 'float32', 112, 112, 32, 1, 3, 3, 1, 1, 1, 1),
+    _Workload('float32', 'float32', 112, 112, 64, 1, 3, 3, 1, 1, 2, 2),
+    _Workload('float32', 'float32', 56, 56, 128, 1, 3, 3, 1, 1, 1, 1),
+    _Workload('float32', 'float32', 56, 56, 128, 1, 3, 3, 1, 1, 2, 2),
+    _Workload('float32', 'float32', 28, 28, 256, 1, 3, 3, 1, 1, 1, 1),
+    _Workload('float32', 'float32', 28, 28, 256, 1, 3, 3, 1, 1, 2, 2),
+    _Workload('float32', 'float32', 14, 14, 512, 1, 3, 3, 1, 1, 1, 1),
+    _Workload('float32', 'float32', 14, 14, 512, 1, 3, 3, 1, 1, 2, 2),
+    _Workload('float32', 'float32', 7, 7, 1024, 1, 3, 3, 1, 1, 1, 1),
 ]
 
 _SCHEDULES = [
@@ -35,10 +35,10 @@ _SCHEDULES = [
     _Schedule(1, 1, 4, 2, True),
     _Schedule(1, 1, 8, 8, True),
     _Schedule(1, 1, 4, 1, False),
-    _Schedule(2, 1, 4, 16, False),
+    _Schedule(1, 1, 4, 4, False),
 ]
 
-def _get_workload(data, kernel, stride, padding):
+def _get_workload(data, kernel, stride, padding, out_dtype):
     _, C, IH, IW = [x.value for x in data.shape]
     _, MT, KH, KW = [x.value for x in kernel.shape]
     HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
@@ -46,7 +46,7 @@ def _get_workload(data, kernel, stride, padding):
         HSTR, WSTR = stride
     else:
         HSTR, WSTR = stride, stride
-    return _Workload(IH, IW, C, MT, KH, KW, HPAD, WPAD, HSTR, WSTR)
+    return _Workload(data.dtype, out_dtype, IH, IW, C, MT, KH, KW, HPAD, WPAD, HSTR, WSTR)
 
 
 def _schedule(s, data, data_pad, kernel, output, last):
@@ -55,7 +55,7 @@ def _schedule(s, data, data_pad, kernel, output, last):
         stride = infer_stride(data, kernel, output)
     else:
         stride = infer_stride(data_pad, kernel, output)
-    wkl = _get_workload(data, kernel, stride, padding)
+    wkl = _get_workload(data, kernel, stride, padding, output.dtype)
 
     if wkl not in _WORKLOADS:
         return s
diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py
new file mode 100644
index 000000000..d9912de28
--- /dev/null
+++ b/topi/python/topi/x86/__init__.py
@@ -0,0 +1,5 @@
+# pylint: disable=redefined-builtin, wildcard-import
+"""x86 specific declaration and schedules."""
+from __future__ import absolute_import as _abs
+
+from .conv2d import schedule_conv2d
diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py
new file mode 100644
index 000000000..0c91f8c25
--- /dev/null
+++ b/topi/python/topi/x86/conv2d.py
@@ -0,0 +1,37 @@
+# pylint: disable=invalid-name,unused-variable,invalid-name
+"""Conv2D schedule on x86"""
+import tvm
+from .. import generic
+from .. import tag
+
+@generic.schedule_conv2d_nchw.register(["cpu"])
+def schedule_conv2d(outs):
+    """Create schedule for tensors"""
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def traverse(op):
+        """Traverse operators from computation graph"""
+        # inline all one-to-one-mapping operators except the last stage (output)
+        if tag.is_broadcast(op.tag):
+            if op not in s.outputs:
+                s[op].compute_inline()
+            for tensor in op.input_tensors:
+                if tensor.op.input_tensors:
+                    traverse(tensor.op)
+
+        if 'conv2d_nchw' in op.tag:
+            conv = op.output(0)
+            kernel = op.input_tensors[1]
+            data = op.input_tensors[0]
+            data_pad = None
+            if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
+                data_pad = data
+                data = data_pad.op.input_tensors[0]
+
+            C = conv
+            n, c, h, w = C.op.axis
+            s[C].parallel(c)
+            s[C].pragma(n, "parallel_launch_point")
+
+    traverse(outs[0].op)
+    return s
-- 
GitLab