From 866d458c9989d3dbe98d5cb9120f59830e22bbc1 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng <mercy_zheng@sjtu.edu.cn>
Date: Tue, 30 Oct 2018 09:38:52 -0700
Subject: [PATCH] [TOPI][AUTOTVM] Improve style (#2034)

* [TOPI] Improve the style of using autotvm

* fix
---
 topi/python/topi/arm_cpu/conv2d.py | 199 ++++++++++++++---------------
 topi/python/topi/mali/conv2d.py    |  89 +++++++------
 topi/python/topi/nn/conv2d.py      |  11 --
 topi/python/topi/x86/conv2d.py     |   8 +-
 4 files changed, 150 insertions(+), 157 deletions(-)

diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py
index a193e9acf..c34bf2567 100644
--- a/topi/python/topi/arm_cpu/conv2d.py
+++ b/topi/python/topi/arm_cpu/conv2d.py
@@ -12,34 +12,40 @@ from ..util import traverse_inline, get_const_tuple, const_matrix
 from ..nn import pad, conv2d, conv2d_alter_layout, conv2d_winograd_without_weight_transform
 from ..nn.util import get_const_int, get_pad_tuple
 
-def _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype):
-    """convert argument to workload"""
-    if len(kernel.shape) == 4:
-        raw_kernel = kernel
-    else:  # the input kernel is transformed by alter_op_layout
-        shape = get_const_tuple(kernel.shape)
-        raw_kernel = tvm.placeholder((shape[0] * shape[4], shape[1], shape[2], shape[3]),
-                                     dtype=kernel.dtype)
-    return ('conv2d', ) + autotvm.task.args_to_workload(
-        [data, raw_kernel, strides, padding, layout, out_dtype])
-
-@conv2d.register('arm_cpu')
-@autotvm.task.dispatcher
-def conv2d_arm_cpu(data, kernel, strides, padding, layout, out_dtype):
-    """TOPI compute callback. Mark this function as a dispatcher, so
-    this template can assign config according to workload
+@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct'])
+def conv2d_arm_cpu(cfg, data, kernel, strides, padding, layout, out_dtype):
+    """TOPI compute callback for conv2d
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.Tensor
+        4-D with shape [batch, in_channel, in_height, in_width]
+
+    kernel : tvm.Tensor
+        4-D with shape [num_filter, in_channel, filter_height, filter_width] or
+        pre-packed 5-D with shape [num_filter_chunk, in_channel, filter_height,
+        filter_width, num_filter_block]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    layout : str
+        layout of data
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
 
     Returns
     -------
-    workload: Tuple
-        Dispatcher will use this workload to query corresponding config.
-        Then use cfg.template_key to call a registered template.
+    output : tvm.Tensor
+        4-D with shape [batch, out_channel, out_height, out_width]
     """
-    return _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype)
-
-@conv2d_arm_cpu.register(['direct'])
-def decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype):
-    """spatial packing template"""
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile=2)
 
 @autotvm.register_topi_schedule(schedule_conv2d_nchw, 'arm_cpu', ['direct', 'winograd'])
@@ -93,8 +99,6 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
 def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile):
     assert layout == "NCHW", "Only support NCHW"
     # create workload according to raw arguments
-    wkl = _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype)
-
     out_dtype = out_dtype or data.dtype
     N, CI, IH, IW = get_const_tuple(data.shape)
     if len(kernel.shape) == 4:
@@ -177,8 +181,7 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n
 
     output = tvm.compute(oshape, lambda n, co, h, w:
                          conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
-                         name='output_unpack', tag='spatial_conv2d_output',
-                         attrs={'workload': wkl})
+                         name='output_unpack', tag='spatial_conv2d_output')
     return output
 
 def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
@@ -238,16 +241,13 @@ def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
     return s
 
 
-@conv2d_arm_cpu.register('winograd')
-def decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype):
+@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd'])
+def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, layout, out_dtype):
+    """ TOPI compute callback. Use winograd template """
     tile_size = 4
     return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size)
 
 def _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size):
-    # create workload according to raw arguments
-    wkl = _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout,
-                                         out_dtype, tile_size)
-
     N, CI, IH, IW = get_const_tuple(data.shape)
     if len(kernel.shape) == 4:
         pre_computed = False
@@ -368,8 +368,7 @@ def _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_
     # unpack output
     output = tvm.compute((N, K, H, W), lambda n, k, h, w:
                          Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m],
-                         name='output', tag='winograd_conv2d_output',
-                         attrs={'workload': wkl})
+                         name='output', tag='winograd_conv2d_output')
 
     # we have to manually assign effective GFLOP for winograd
     cfg.add_flop(2 * N * K * H * W * KH * KW * C)
@@ -458,36 +457,11 @@ def _schedule_winograd(cfg, s, output, last):
         s[output].compute_inline()
 
 
-def _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype, tile_size):
-    """convert argument to workload"""
-    K = 3
-    shape = get_const_tuple(kernel.shape)
-    alpha = tile_size + K - 1
-    if len(kernel.shape) == 4:
-        assert shape[2:] == (K, K)
-        CO, CI = shape[:2]
-    else:
-        assert shape[:2] == (alpha, alpha)
-        CO, CI, VCO = shape[2:]
-        CO *= VCO
-
-    raw_kernel = tvm.placeholder((CO, CI, K, K), dtype=kernel.dtype)
-    return ('conv2d', ) + autotvm.task.args_to_workload(
-        [data, raw_kernel, strides, padding, layout, out_dtype])
-
-
 ##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
-@conv2d_winograd_without_weight_transform.register(['arm_cpu'])
-@autotvm.task.dispatcher
-def winograd_ww_config_dispatcher_(data, kernel, strides, padding, layout, out_dtype, tile_size):
-    return _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype,
-                                          tile_size)
-
-
-@winograd_ww_config_dispatcher_.register(['winograd'])
-def decl_winograd_ww(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size):
-    return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype,
-                          tile_size)
+@autotvm.register_topi_compute(conv2d_winograd_without_weight_transform, 'arm_cpu', ['winograd'])
+def conv2d_winograd_ww(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size):
+    """TOPI compute callback"""
+    return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size)
 
 
 @autotvm.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform,
@@ -514,8 +488,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
 
     new_attrs = {k: attrs[k] for k in attrs.keys()}
 
-    assert attrs.get_int_tuple("dilation") == (1, 1), "Does not support dilation " \
-                                                      "when alter_op_layout is enabled"
+    dilation = attrs.get_int_tuple("dilation")
     strides = attrs.get_int_tuple("strides")
     padding = attrs.get_int_tuple("padding")
     groups = attrs.get_int('groups')
@@ -523,38 +496,60 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
     out_dtype = attrs["out_dtype"]
     out_dtype = tinfos[0].dtype if out_dtype == "same" else out_dtype
 
-    if groups == 1:
-        # query config of this workload
-        workload = _conv_arg_to_workload(tinfos[0], tinfos[1], strides, padding,
-                                         layout, out_dtype)
-        cfg = autotvm.DispatchContext.current.query(tvm.target.current_target(), workload)
-
-        if cfg.is_fallback:  # if is fallback, clear query cache and return None
-            autotvm.task.clear_fallback_cache(tvm.target.current_target(), workload)
-            return None
-
-        if cfg.template_key == 'direct':  # packing weight tensor
-            new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1])
-            return sym.conv2d(*copy_inputs, **new_attrs)
-        else:  # pre-compute weight transformation in winograd
-            if "-device=arm_cpu" in tvm.target.current_target().options:
-                tile_size = 4
-                VC = cfg['tile_k'].size[-1]
-            else:
-                from ..mali.conv2d import _pick_tile_size
-                tile_size = _pick_tile_size(tinfos[0], tinfos[1])
-                VC = cfg['tile_bna'].val
-
-            weight = sym.contrib.conv2d_winograd_weight_transform(copy_inputs[1],
-                                                                  tile_size=tile_size)
-            CO, CI, KH, KW = get_const_tuple(tinfos[1].shape)
-            weight = sym.reshape(weight,
-                                 shape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
-            weight = sym.transpose(weight, axes=[0, 1, 2, 4, 3])
-
-            copy_inputs[1] = weight
-            new_attrs['tile_size'] = tile_size
-            return sym.contrib.conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
-
-    # do nothing for depthwise convolution
-    return None
+    if layout != 'NCHW' or groups != 1 or dilation != (1, 1):
+        return None
+
+    data, kernel = tinfos[0:2]
+    N, CI, H, W = get_const_tuple(data.shape)
+    CO, _, KH, KW = get_const_tuple(kernel.shape)
+
+    # query config of this workload
+    workload = autotvm.task.args_to_workload(
+        [data, kernel, strides, padding, layout, out_dtype], conv2d)
+    target = tvm.target.current_target()
+    dispatch_ctx = autotvm.DispatchContext.current
+    cfg = dispatch_ctx.query(target, workload)
+
+    if cfg.is_fallback:  # if is fallback, clear query cache and return None
+        autotvm.task.clear_fallback_cache(target, workload)
+        return None
+
+    if cfg.template_key == 'direct':  # pack weight tensor
+        VC = cfg['tile_co'].size[-1]
+        new_attrs['kernel_layout'] = 'OIHW%do' % VC
+
+        # Store the same config for the altered operator (workload)
+        new_data = data
+        new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype)
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_kernel, strides, padding, 'NCHW', out_dtype], conv2d)
+        dispatch_ctx.update(target, new_workload, cfg)
+
+        return sym.conv2d(*copy_inputs, **new_attrs)
+    else:  # pre-compute weight transformation in winograd
+        if "-device=arm_cpu" in target.options:
+            tile_size = 4
+            VC = cfg['tile_k'].size[-1]
+        else:
+            from ..mali.conv2d import _pick_tile_size
+            tile_size = _pick_tile_size(tinfos[0], tinfos[1])
+            VC = cfg['tile_bna'].val
+
+        weight = sym.contrib.conv2d_winograd_weight_transform(copy_inputs[1], tile_size=tile_size)
+        weight = sym.reshape(weight,
+                             shape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
+        weight = sym.transpose(weight, axes=[0, 1, 2, 4, 3])
+
+        copy_inputs[1] = weight
+        new_attrs['tile_size'] = tile_size
+
+        # Store the same config for the altered operator (workload)
+        new_data = data
+        new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC),
+                                     kernel.dtype)
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_weight, strides, padding, new_attrs['layout'], out_dtype, tile_size],
+            conv2d_winograd_without_weight_transform)
+        dispatch_ctx.update(target, new_workload, cfg)
+
+        return sym.contrib.conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py
index 121498f21..390b60ba6 100644
--- a/topi/python/topi/mali/conv2d.py
+++ b/topi/python/topi/mali/conv2d.py
@@ -12,27 +12,43 @@ from ..nn import conv2d, conv2d_winograd_without_weight_transform, \
     get_pad_tuple, pad, conv2d_alter_layout
 
 # reuse some compute declarations from ARM CPU
-from ..arm_cpu.conv2d import _conv_arg_to_workload, _decl_spatial_pack,\
-    _winograd_conv_arg_to_workload, _alter_conv2d_layout_arm
+from ..arm_cpu.conv2d import _decl_spatial_pack, _alter_conv2d_layout_arm
 
 
-@conv2d.register('mali')
-@autotvm.task.dispatcher
-def conv2d_mali(data, kernel, strides, padding, layout, out_dtype):
-    """TOPI compute callback. Mark this function as a dispatcher, so
-    this template can assign config according to workload
+@autotvm.register_topi_compute(conv2d, 'mali', ['direct'])
+def conv2d_mali(cfg, data, kernel, strides, padding, layout, out_dtype):
+    """TOPI compute callback for conv2d
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.Tensor
+        4-D with shape [batch, in_channel, in_height, in_width]
+
+    kernel : tvm.Tensor
+        4-D with shape [num_filter, in_channel, filter_height, filter_width] or
+        pre-packed 5-D with shape [num_filter_chunk, in_channel, filter_height,
+        filter_width, num_filter_block]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    layout : str
+        layout of data
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
 
     Returns
     -------
-    workload: Tuple
-        Dispatcher will use this workload to query corresponding config.
-        Then use cfg.template_key to call a registered template.
+    output : tvm.Tensor
+        4-D with shape [batch, out_channel, out_height, out_width]
     """
-    return _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype)
-
-@conv2d_mali.register(['direct'])
-def decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype):
-    """spatial packing template"""
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile=3)
 
 @autotvm.register_topi_schedule(schedule_conv2d_nchw, 'mali', ['direct', 'winograd'])
@@ -158,8 +174,8 @@ def _pick_tile_size(data, kernel):
     else:
         return 2
 
-@conv2d_mali.register('winograd')
-def decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype):
+@autotvm.register_topi_compute(conv2d, 'mali', ['winograd'])
+def conv2d_mali_winograd(cfg, data, kernel, strides, padding, layout, out_dtype):
     tile_size = _pick_tile_size(data, kernel)
     return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size)
 
@@ -305,9 +321,7 @@ def _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_
                          # thw following term is used to make the padding effective,
                          # otherwise the padding will be eliminated by bound inference
                          + tvm.const(0, out_dtype) * M[alpha-1][alpha-1][CO-1][P_round-1],
-                         name='output', tag='winograd_conv2d_output',
-                         attrs={'workload': _winograd_conv_arg_to_workload(
-                             data, kernel, strides, padding, layout, out_dtype, tile_size)})
+                         name='output', tag='winograd_conv2d_output')
 
     # we have to manually assign effective GFLOP for winograd
     cfg.add_flop(2 * N * CO * H * W * KH * KW * CI)
@@ -410,29 +424,15 @@ def _schedule_winograd(cfg, s, op):
 
     s[Y].compute_at(s[output], tt)
 
-@conv2d_alter_layout.register(["mali"])
-def _alter_conv2d_layout(attrs, inputs, tinfos):
-    try:
-        return _alter_conv2d_layout_arm(attrs, inputs, tinfos)
-    except KeyError:  # to filter out fallback opencl templates
-        return None
-
 ##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
-@conv2d_winograd_without_weight_transform.register(['mali'])
-@autotvm.task.dispatcher
-def winograd_ww_config_dispatcher_(data, kernel, strides, padding, layout, out_dtype, tile_size):
-    return _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype,
-                                          tile_size)
-
-
-@winograd_ww_config_dispatcher_.register(['winograd'])
-def decl_winograd_ww(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size):
-    return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype,
-                          tile_size)
+@autotvm.register_topi_compute(conv2d_winograd_without_weight_transform, 'mali', ['winograd'])
+def conv2d_winograd_ww(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size):
+    """TOPI compute callback"""
+    return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size)
 
 
-@autotvm.task.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform,
-                                     'mali', ['winograd'])
+@autotvm.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform,
+                                'mali', ['winograd'])
 def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
     """TOPI schedule callback"""
     s = tvm.create_schedule([x.op for x in outs])
@@ -445,6 +445,15 @@ def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
     return s
 
 
+##### REGISTER ALTER OP LAYOUT #####
+@conv2d_alter_layout.register(["mali"])
+def _alter_conv2d_layout(attrs, inputs, tinfos):
+    try:
+        return _alter_conv2d_layout_arm(attrs, inputs, tinfos)
+    except KeyError:  # to filter out fallback opencl templates
+        return None
+
+
 ##### SCHECULE UTILITIES #####
 def tile_and_bind(s, tensor, y, x, y_factor, x_factor=None):
     """ tile and bind to GPU threads """
diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py
index 7636350df..17b1ceb7a 100644
--- a/topi/python/topi/nn/conv2d.py
+++ b/topi/python/topi/nn/conv2d.py
@@ -85,17 +85,6 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
     return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
 
 
-@tvm.target.generic_func
-def _get_schedule(wkl):
-    # pylint: disable=unreachable
-    """ Get the platform specific schedule. """
-    target = tvm.target.current_target()
-    raise RuntimeError(
-        "No schedule for current target:{}".format(target))
-    # This return has no use, merely to supress pylint warning
-    return wkl
-
-
 def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
     """Convolution operator in NCHW layout.
 
diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py
index c588e7443..3dc6d5e4b 100644
--- a/topi/python/topi/x86/conv2d.py
+++ b/topi/python/topi/x86/conv2d.py
@@ -3,7 +3,7 @@
 import tvm
 from tvm import autotvm
 from tvm.autotvm.task.nnvm_integration import deserialize_args
-from tvm.autotvm.task import register, get_config
+from tvm.autotvm.task import get_config
 from .. import generic, tag
 from .. import nn
 from ..util import get_const_tuple
@@ -145,7 +145,7 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, layout, out_dtyp
     return unpack
 
 
-@autotvm.task.register_topi_schedule(generic.schedule_conv2d_nchw, 'cpu', ['direct'])
+@autotvm.register_topi_schedule(generic.schedule_conv2d_nchw, 'cpu', ['direct'])
 def schedule_conv2d(cfg, outs):
     """Create schedule for tensors"""
     s = tvm.create_schedule([x.op for x in outs])
@@ -248,7 +248,7 @@ def schedule_conv2d_nhwc(outs):
 # We define schedule template in this function instead of
 # declaration function since actual input arguments need
 # to be altered by the schedule selected.
-@register("topi_x86_conv2d_NCHWc")
+@autotvm.task.register("topi_x86_conv2d_NCHWc")
 def _topi_nn_conv2d_NCHWc(*args, **kwargs):
     assert not kwargs, "Do not support kwargs in template function call"
     data, kernel, strides, padding, origin_layout, dtype = deserialize_args(args)
@@ -311,7 +311,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo):
     # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
     new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
 
-    # Store altered operator's config
+    # Store the same config for the altered operator (workload)
     new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
                                dtype=data.dtype)
     new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
-- 
GitLab