diff --git a/topi/python/topi/arm_cpu/depthwise_conv2d.py b/topi/python/topi/arm_cpu/depthwise_conv2d.py index e066a1e2943559871e1b737e8e7c1a394dc4e5f0..c341d1a5b3259916db49d5270467f738022e89f3 100644 --- a/topi/python/topi/arm_cpu/depthwise_conv2d.py +++ b/topi/python/topi/arm_cpu/depthwise_conv2d.py @@ -9,11 +9,11 @@ from ..nn import depthwise_conv2d_nchw from ..util import traverse_inline # register original implementation of depthwise_conv2d_nchw since we don't need to change this part -autotvm.task.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', 'direct', +autotvm.task.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct', depthwise_conv2d_nchw.fdefault) # register customized schedule for arm cpu. -@autotvm.task.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'arm_cpu', 'direct') +@autotvm.task.register_topi_schedule(schedule_depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct') def schedule_depthwise_conv2d_nchw_arm(cfg, outs): """Schedule depthwise conv2d @@ -44,15 +44,15 @@ def schedule_depthwise_conv2d_nchw_arm(cfg, outs): cfg.define_split('tile_w', w, num_outputs=2) if cfg.is_fallback: - cfg.fallback_split('tile_c', [-1, 8]) + cfg.fallback_split('tile_c', [-1, 4]) cfg.fallback_split('tile_h', [-1, 2]) - cfg.fallback_split('tile_w', [-1, 8]) + cfg.fallback_split('tile_w', [-1, 4]) # park data to vector form [n, c, h, w] -> [n, C, h, w, VC] A0 = s.cache_read(data_pad, "global", C) - _, c, h, w = s[A0].op.axis + n, c, h, w = s[A0].op.axis c, vc = cfg['tile_c'].apply(s, A0, c) - s[A0].reorder(c, h, w, vc) + s[A0].reorder(n, c, h, w, vc) A1 = s.cache_write(A0, 'global') s[A0].compute_inline() @@ -64,9 +64,9 @@ def schedule_depthwise_conv2d_nchw_arm(cfg, outs): B1 = s.cache_write(B0, 'global') s[B0].compute_inline() - _, c, h, w = s[C].op.axis + n, c, h, w = s[C].op.axis c, vc, = cfg['tile_c'].apply(s, C, c) - s[C].reorder(c, h, w, vc) + s[C].reorder(n, c, h, w, vc) # depthwise conv C0 = s.cache_write(C, 'global') @@ -86,9 +86,14 @@ def schedule_depthwise_conv2d_nchw_arm(cfg, outs): max_unroll=16, cfg=cfg) + # fusion + if C.op not in s.outputs: + s[C].compute_inline() + # mark parallel - n, c, h, w = s[C].op.axis - s[C].parallel(c) + last = outs[0] + n, c, h, w = s[last].op.axis + s[last].parallel(c) n, c, h, w, vc = s[C0].op.axis s[C0].parallel(c)