diff --git a/nnvm/src/top/nn/convolution.cc b/nnvm/src/top/nn/convolution.cc
index 8139474921175e1c3a5a46baf5659e6bfcca008f..e6ff72239672054ea964e7f3ced1e0e613c61ad1 100644
--- a/nnvm/src/top/nn/convolution.cc
+++ b/nnvm/src/top/nn/convolution.cc
@@ -73,15 +73,13 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(param.channels % param.groups, 0U)
       << "output channels must divide group size";
 
-  TShape wshape({param.channels / param.groups,
+  TShape wshape({param.channels,
                  dshape[1] / param.groups,
                  param.kernel_size[0],
                  param.kernel_size[1]});
 
   wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
 
-  wshape[kernel_layout.indexof('O')] *= param.groups;
-
   if (in_shape->at(Conv2DParam::kWeight).ndim() == 0) {
     NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape);
   }
diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc
index 608cdab2bacb9834c78004df2823abbb0bb3f2a9..53098b71ff7733f4855893af38572336e5b557d4 100644
--- a/src/relay/op/nn/convolution.cc
+++ b/src/relay/op/nn/convolution.cc
@@ -52,12 +52,11 @@ bool Conv2DRel(const Array<Type>& types,
     CHECK_EQ(param->kernel_size.size(), 2);
     CHECK_EQ(param->dilation.size(), 2);
     std::vector<IndexExpr> wshape(
-       {param->channels / param->groups,
+       {param->channels,
          dshape_nchw[1] / param->groups,
          param->kernel_size[0],
          param->kernel_size[1]});
     wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
-    wshape[kernel_layout.Indexof('O')] *= param->groups;
     channels = param->channels;
     dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
     dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py
index 017c62b77a7b346e54420ac51a248f314c30c3e4..605749d460f717cb377c1a3d0e9e543db276e9ed 100644
--- a/topi/python/topi/arm_cpu/conv2d.py
+++ b/topi/python/topi/arm_cpu/conv2d.py
@@ -11,7 +11,8 @@ from tvm import autotvm
 
 from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform
 from ..util import traverse_inline, get_const_tuple, const_matrix
-from ..nn import dilate, pad, conv2d, conv2d_alter_layout, conv2d_winograd_without_weight_transform
+from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
+                 conv2d_winograd_without_weight_transform, depthwise_conv2d_nchw
 from ..nn.util import get_const_int, get_pad_tuple
 
 @autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct'])
@@ -556,7 +557,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
     if out_dtype == "" or out_dtype == "same":
         out_dtype = tinfos[0].dtype
 
-    if layout != 'NCHW' or groups != 1:
+    if layout != 'NCHW':
         return None
     if dilation != (1, 1):
         warnings.warn("Does not support weight pre-transform for dilated convolution.")
@@ -566,54 +567,84 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
     N, CI, H, W = get_const_tuple(data.shape)
     CO, _, KH, KW = get_const_tuple(kernel.shape)
 
-    # query config of this workload
-    workload = autotvm.task.args_to_workload(
-        [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
-    target = tvm.target.current_target()
-    dispatch_ctx = autotvm.DispatchContext.current
-    cfg = dispatch_ctx.query(target, workload)
-
-    if cfg.is_fallback:  # if is fallback, clear query cache and return None
-        autotvm.task.clear_fallback_cache(target, workload)
-        return None
-
-    if cfg.template_key == 'direct':  # pack weight tensor
-        VC = cfg['tile_co'].size[-1]
-        new_attrs['kernel_layout'] = 'OIHW%do' % VC
-
-        # Store the same config for the altered operator (workload)
-        new_data = data
-        new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype)
-        new_workload = autotvm.task.args_to_workload(
-            [new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
-        dispatch_ctx.update(target, new_workload, cfg)
-
-        return F.nn.conv2d(*copy_inputs, **new_attrs)
-    else:  # pre-compute weight transformation in winograd
-        if "-device=arm_cpu" in target.options:
-            tile_size = 4
-            VC = cfg['tile_k'].size[-1]
+    if groups == 1:
+        # query config of this workload
+        workload = autotvm.task.args_to_workload(
+            [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
+        target = tvm.target.current_target()
+        dispatch_ctx = autotvm.DispatchContext.current
+        cfg = dispatch_ctx.query(target, workload)
+
+        if cfg.is_fallback:  # if is fallback, clear query cache and return None
+            autotvm.task.clear_fallback_cache(target, workload)
+            return None
+
+        if cfg.template_key == 'direct':  # pack weight tensor
+            VC = cfg['tile_co'].size[-1]
+            new_attrs['kernel_layout'] = 'OIHW%do' % VC
+
+            # Store the same config for the altered operator (workload)
+            new_data = data
+            new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype)
+            new_workload = autotvm.task.args_to_workload(
+                [new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
+            dispatch_ctx.update(target, new_workload, cfg)
+
+            return F.nn.conv2d(*copy_inputs, **new_attrs)
+        else:  # pre-compute weight transformation in winograd
+            if "-device=arm_cpu" in target.options:
+                tile_size = 4
+                VC = cfg['tile_k'].size[-1]
+            else:
+                from ..mali.conv2d import _pick_tile_size
+                tile_size = _pick_tile_size(tinfos[0], tinfos[1])
+                VC = cfg['tile_bna'].val
+
+            weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
+                                                                   tile_size=tile_size)
+            weight = F.reshape(weight,
+                               newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
+            weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])
+
+            copy_inputs[1] = weight
+            new_attrs['tile_size'] = tile_size
+
+            # Store the same config for the altered operator (workload)
+            new_data = data
+            new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC),
+                                         kernel.dtype)
+            new_workload = autotvm.task.args_to_workload(
+                [new_data, new_weight, strides, padding, dilation,
+                 new_attrs[data_layout_key], out_dtype, tile_size],
+                conv2d_winograd_without_weight_transform)
+            dispatch_ctx.update(target, new_workload, cfg)
+
+            return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
+    else:
+        workload = autotvm.task.args_to_workload(
+            [data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw)
+        target = tvm.target.current_target()
+        dispatch_ctx = autotvm.DispatchContext.current
+        cfg = dispatch_ctx.query(target, workload)
+
+        if cfg.is_fallback:  # if is fallback, clear query cache and return None
+            autotvm.task.clear_fallback_cache(tvm.target.current_target(), workload)
+            return None
+        if cfg.template_key == 'contrib_spatial_pack':
+            VC = cfg['tile_co'].size[-1]
+            new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1])
+
+            # Store the same config for the altered operator (workload)
+            new_data = data
+            CO, M, KH, KW = get_const_tuple(kernel.shape)
+            new_kernel = tvm.placeholder((CO // VC, M, KH, KW, VC), dtype=kernel.dtype)
+            new_workload = autotvm.task.args_to_workload(
+                [new_data, new_kernel, strides, padding, dilation, out_dtype],
+                depthwise_conv2d_nchw)
+            dispatch_ctx.update(target, new_workload, cfg)
+
+            return F.nn.conv2d(*copy_inputs, **new_attrs)
         else:
-            from ..mali.conv2d import _pick_tile_size
-            tile_size = _pick_tile_size(tinfos[0], tinfos[1])
-            VC = cfg['tile_bna'].val
-
-        weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1], tile_size=tile_size)
-        weight = F.reshape(weight,
-                           newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
-        weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])
-
-        copy_inputs[1] = weight
-        new_attrs['tile_size'] = tile_size
-
-        # Store the same config for the altered operator (workload)
-        new_data = data
-        new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC),
-                                     kernel.dtype)
-        new_workload = autotvm.task.args_to_workload(
-            [new_data, new_weight, strides, padding, dilation,
-             new_attrs[data_layout_key], out_dtype, tile_size],
-            conv2d_winograd_without_weight_transform)
-        dispatch_ctx.update(target, new_workload, cfg)
-
-        return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
+            # currently we only have contrib_spatial_pack and direct template
+            # add more schedule templates.
+            return None
diff --git a/topi/python/topi/arm_cpu/depthwise_conv2d.py b/topi/python/topi/arm_cpu/depthwise_conv2d.py
index 2556af36e5f9c7f76664374189982c2b2937b8a7..1e25eb58dbaee07956896b6ce8dcabfd7c7790c3 100644
--- a/topi/python/topi/arm_cpu/depthwise_conv2d.py
+++ b/topi/python/topi/arm_cpu/depthwise_conv2d.py
@@ -5,15 +5,17 @@ import tvm
 from tvm import autotvm
 
 from ..generic import schedule_depthwise_conv2d_nchw
-from ..nn import depthwise_conv2d_nchw
-from ..util import traverse_inline
+from ..nn import depthwise_conv2d_nchw, pad
+from ..util import traverse_inline, get_const_tuple, get_const_int
+from ..nn.util import get_pad_tuple
 
 # register original implementation of depthwise_conv2d_nchw since we don't need to change this part
 autotvm.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct',
                               depthwise_conv2d_nchw.fdefault)
 
 # register customized schedule for arm cpu.
-@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct')
+@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, ['arm_cpu', 'cpu'],
+                                ['direct', 'contrib_spatial_pack'])
 def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
     """Schedule depthwise conv2d
 
@@ -116,5 +118,277 @@ def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
                 data = data_pad.op.input_tensors[0]
             _schedule(cfg, s, data, data_pad, kernel, output)
 
+        if op.tag == 'spatial_depthwise_conv_nchw_output':
+            output = op.output(0)
+            conv = op.input_tensors[0]
+            data_vec = conv.op.input_tensors[0]
+            kernel_vec = conv.op.input_tensors[1]
+            if kernel_vec.op.name == 'kernel_vec':
+                kernel = kernel_vec.op.input_tensors[0]
+            else:
+                kernel = kernel_vec
+            if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
+                s[kernel].compute_inline()
+
+            _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0])
+
     traverse_inline(s, outs[0].op, _callback)
     return s
+
+@autotvm.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], ['contrib_spatial_pack'])
+def depthwise_conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nchw
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.Tensor
+        4-D with shape [batch, in_channel, in_height, in_width]
+
+    kernel : tvm.Tensor
+        4-D with shape [num_filter, multiplier, filter_height, filter_width] or
+        pre-packed 5-D with shape [num_filter_chunk, multiplier, filter_height,
+        filter_width, num_filter_block]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.Tensor
+        4-D with shape [batch, out_channel, out_height, out_width]
+    """
+
+    return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
+
+
+def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile):
+    out_dtype = out_dtype or data.dtype
+
+    N, C, IH, IW = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    if len(kernel.shape) == 4:
+        pre_packed = False
+        C, M, KH, KW = get_const_tuple(kernel.shape)
+    else:  # kernel tensor is pre packed
+        pre_packed = True
+        C, M, KH, KW, VC = get_const_tuple(kernel.shape)
+        C = C * VC
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+    # pack data
+    HPAD = pad_top + pad_down
+    WPAD = pad_left + pad_right
+    DOPAD = (HPAD != 0 or WPAD != 0)
+    if DOPAD:
+        data_pad = pad(data, (0, 0, pad_top, pad_left), (0, 0, pad_down, pad_right),
+                       name="data_pad")
+    else:
+        data_pad = data
+
+    # fallback support
+    # Currently, Mali schedule doesn't use it like conv2d.
+    if cfg.is_fallback:
+        ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'depthwise_conv2d_nchw',
+                                                    'contrib_spatial_pack')
+        cfg.fallback_with_reference_log(ref_log)
+
+    # ==================== define configuration space ====================
+    n, c, oh, ow = cfg.axis(N), cfg.axis(C), cfg.axis(OH), cfg.axis(OW)
+    kh, kw = cfg.reduce_axis(KH), cfg.reduce_axis(KW)
+
+    # Currently, Mali schedule doesn't use it like conv2d.
+    # Leave num_tile for possible future use of Mali schedule
+    if num_tile == 2:     # for arm cpu
+        co, vc = cfg.define_split('tile_co', c, num_outputs=2)
+        oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2)
+        ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2)
+    else:
+        raise RuntimeError("Invalid num_tile")
+
+    cfg.define_reorder("reorder_0",
+                       [n, co, oh, ow, kh, kw, vh, vw, vc],
+                       policy='candidate', candidate=[
+                           [n, co, oh, ow, kh, kw, vh, vw, vc],
+                           [n, co, oh, ow, kh, kw, vc, vh, vw]])
+
+    cfg.define_reorder("reorder_1",
+                       [n, co, oh, ow, vh, vw, vc],
+                       policy='candidate', candidate=[
+                           [n, co, oh, ow, vh, vw, vc],
+                           [n, co, oh, ow, vc, vh, vw],
+                           [n, co, oh, ow, vh, vc, vw]])
+
+    cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
+    cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec')
+    # ====================================================================
+
+    VC = cfg["tile_co"].size[-1]
+    VH = cfg["tile_oh"].size[-1]
+    VW = cfg["tile_ow"].size[-1]
+
+    kvshape = (C // VC, M, KH, KW, VC)
+    ovshape = (N, C * M // VC, OH // VH, OW // VW, VH, VW, VC)
+    oshape = (N, C * M, OH, OW)
+
+    if dilation_h != 1 or dilation_w != 1:
+        # undilate input data
+        dvshape = (N, OH // VH, OW // VW, C, KH, KW, VH, VW)
+        data_vec = tvm.compute(dvshape, lambda n, h, w, c, kh, kw, vh, vw:
+                               data_pad[n][c][(h * VH + vh) * HSTR + kh * dilation_h]
+                               [(w*VW+vw)*WSTR+kw*dilation_w],
+                               name='data_vec_undilated')
+    else:
+        dvshape = (N, OH // VH, OW // VW, C, VH*HSTR + KH-1, VW*WSTR + KW-1)
+        data_vec = tvm.compute(dvshape, lambda n, h, w, c, vh, vw:
+                               data_pad[n][c][h * VH * HSTR + vh][w * VW * WSTR + vw],
+                               name='data_vec')
+
+    if pre_packed:
+        kernel_vec = kernel
+    else:
+        kernel_vec = tvm.compute(kvshape, lambda co, m, kh, kw, vc:
+                                 kernel[co*VC+vc][m][kh][kw],
+                                 name='kernel_vec')
+
+    kh = tvm.reduce_axis((0, KH), name='kh')
+    kw = tvm.reduce_axis((0, KW), name='kw')
+
+    if dilation_h != 1 or dilation_w != 1:
+        conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
+                          tvm.sum(data_vec[n, h, w, (co * VC + vc) // M, kh, kw, vh, vw]
+                                  .astype(out_dtype) *
+                                  kernel_vec[co // M, co % M, kh, kw, vc].astype(out_dtype),
+                                  axis=[kh, kw]), name='depthwise_conv')
+    else:
+        conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
+                           tvm.sum(data_vec[n, h, w, (co * VC + vc) // M, vh * HSTR + kh,
+                                            vw * WSTR + kw].astype(out_dtype) *
+                                   kernel_vec[co // M, co % M, kh, kw, vc].astype(out_dtype),
+                                   axis=[kh, kw]), name='depthwise_conv')
+
+    output = tvm.compute(oshape, lambda n, co, h, w:
+                         conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
+                         name='output_unpack', tag='spatial_depthwise_conv_nchw_output')
+    return output
+
+def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
+                           conv, output, last):
+    """schedule implementation"""
+    n, co, oh, ow, vh, vw, vc = s[conv].op.axis
+    kh, kw = s[conv].op.reduce_axis
+
+    if data_vec.op.name == 'data_vec_undilated':
+        _, dv_oh, dv_ow, dv_c, _, _, dv_vh, dv_vw = s[data_vec].op.axis
+    else:
+        _, dv_oh, dv_ow, dv_c, dv_vh, dv_vw = s[data_vec].op.axis
+
+    data_pad = data_vec.op.input_tensors[0]
+    if data_pad.op.name == "data_pad":
+        assert isinstance(data_pad.op, tvm.tensor.ComputeOp)
+        has_padding = True
+    else:
+        assert isinstance(data_pad.op, tvm.tensor.PlaceholderOp)
+        has_padding = False
+
+    cfg.define_knob('data_pad_inline', [0, 1, 2, 3, 4])
+
+    if cfg['data_pad_inline'].val == 1 and has_padding:
+        s[data_pad].compute_inline()
+    if cfg['data_pad_inline'].val == 2 and has_padding:
+        s[data_pad].vectorize(list(s[data_pad].op.axis)[-1])
+    if cfg['data_pad_inline'].val == 3 and has_padding:
+        s[data_pad].vectorize(list(s[data_pad].op.axis)[-1])
+        s[data_pad].compute_at(s[data_vec], dv_oh)
+    if cfg['data_pad_inline'].val == 4 and has_padding:
+        s[data_pad].vectorize(list(s[data_pad].op.axis)[-1])
+        s[data_pad].compute_at(s[data_vec], dv_ow)
+
+    cfg.define_knob('data_vec_inline', [0, 1, 2, 3])
+    if cfg['data_vec_inline'].val == 1:
+        s[data_vec].compute_at(s[conv], oh)
+    if cfg['data_vec_inline'].val == 2:
+        s[data_vec].compute_at(s[conv], ow)
+    if cfg['data_vec_inline'].val == 3:
+        s[data_vec].compute_at(s[conv], co)
+
+    # schedule conv
+    cfg["reorder_0"].apply(s, conv, [n, co, oh, ow, kh, kw, vh, vw, vc])
+    cfg["ann_reduce"].apply(s, conv, [kh, kw],
+                            axis_lens=[get_const_int(kh.dom.extent),
+                                       get_const_int(kw.dom.extent)],
+                            max_unroll=16,
+                            cfg=cfg)
+    cfg["ann_spatial"].apply(s, conv, [vh, vw, vc],
+                             axis_lens=[cfg['tile_oh'].size[-1],
+                                        cfg['tile_ow'].size[-1],
+                                        cfg['tile_co'].size[-1]],
+                             max_unroll=16,
+                             cfg=cfg)
+
+    # schedule fusion
+    n, co, h, w = s[last].op.axis
+    co, vc = cfg['tile_co'].apply(s, last, co)
+    oh, vh = cfg['tile_oh'].apply(s, last, h)
+    ow, vw = cfg['tile_ow'].apply(s, last, w)
+    cfg["reorder_1"].apply(s, last, [n, co, oh, ow, vh, vw, vc])
+    if last != output:
+        s[output].compute_inline()
+        cfg["ann_spatial"].apply(s, last, [vh, vw, vc],
+                                 axis_lens=[cfg['tile_oh'].size[-1],
+                                            cfg['tile_ow'].size[-1],
+                                            cfg['tile_co'].size[-1]],
+                                 max_unroll=16,
+                                 cfg=cfg)
+    else:
+        s[last].vectorize(vw)
+    cfg.define_knob('conv_inline', [0, 1, 2, 3])
+    if cfg['conv_inline'].val == 1:
+        s[conv].compute_at(s[last], ow)
+    if cfg['conv_inline'].val == 2:
+        s[conv].compute_at(s[last], oh)
+    if cfg['conv_inline'].val == 3:
+        s[conv].compute_at(s[last], co)
+
+    # mark parallel
+    s[last].parallel(co)
+
+    if data_vec.op.name == 'data_vec_undilated':
+        _, h, _, _, _, _, _, _ = s[data_vec].op.axis
+    else:
+        _, h, _, _, _, _ = s[data_vec].op.axis
+    s[data_vec].parallel(h)
+
+    if kernel_vec.op.name == 'kernel_vec':
+        co, _, _, _, _ = s[kernel_vec].op.axis
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            # kernel packing will be pre-computed during compliation, so we skip
+            # this part to make tuning records correct
+            s[kernel_vec].pragma(co, 'debug_skip_region')
+        else:
+            s[kernel_vec].parallel(co)
+
+    return s