diff --git a/topi/python/topi/rasp/depthwise_conv2d.py b/topi/python/topi/rasp/depthwise_conv2d.py
index 1446556dc2071937f72cafc4853a05e1b7b5dd4c..00bab8c1b17477ed70ae574dc296b158501fc91e 100644
--- a/topi/python/topi/rasp/depthwise_conv2d.py
+++ b/topi/python/topi/rasp/depthwise_conv2d.py
@@ -1,25 +1,147 @@
 # pylint: disable=invalid-name,unused-variable
 """Schedule for depthwise_conv2d with auto fusion"""
+from __future__ import absolute_import as _abs
+from collections import namedtuple
 import tvm
 from .. import tag
+from ..nn.util import infer_pad, infer_stride, get_pad_tuple
+
+
+_Workload = namedtuple('Workload',
+                       ['height', 'width', 'channel', 'multiplier',
+                        'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
+
+_Schedule = namedtuple('Schedule', ['vh', 'vw', 'vc', 'bc', 'unroll'])
+
+# workloads of depthwise conv mobile net on imagenet
+_WORKLOADS = [
+    _Workload(112, 112,   32, 1, 3, 3, 1, 1, 1, 1),
+    _Workload(112, 112,   64, 1, 3, 3, 1, 1, 2, 2),
+    _Workload( 56,  56,  128, 1, 3, 3, 1, 1, 1, 1),
+    _Workload( 56,  56,  128, 1, 3, 3, 1, 1, 2, 2),
+    _Workload( 28,  28,  256, 1, 3, 3, 1, 1, 1, 1),
+    _Workload( 28,  28,  256, 1, 3, 3, 1, 1, 2, 2),
+    _Workload( 14,  14,  512, 1, 3, 3, 1, 1, 1, 1),
+    _Workload( 14,  14,  512, 1, 3, 3, 1, 1, 2, 2),
+    _Workload( 14,  14, 1024, 1, 3, 3, 1, 1, 1, 1),
+]
+
+_SCHEDULES = [
+    _Schedule(2, 1, 4, 1, True),
+    _Schedule(2, 4, 4, 2, True),
+    _Schedule(2, 1, 4, 2, False),
+    _Schedule(2, 4, 4, 1, True),
+    _Schedule(4, 1, 4, 8, True),
+    _Schedule(1, 1, 4, 2, True),
+    _Schedule(1, 1, 8, 8, True),
+    _Schedule(1, 1, 4, 1, False),
+    _Schedule(2, 1, 4, 16, False),
+]
+
+def _get_workload(data, kernel, stride, padding):
+    _, C, IH, IW = [x.value for x in data.shape]
+    _, MT, KH, KW = [x.value for x in kernel.shape]
+    HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
+    if isinstance(stride, (tuple, list)):
+        HSTR, WSTR = stride
+    else:
+        HSTR, WSTR = stride, stride
+    return _Workload(IH, IW, C, MT, KH, KW, HPAD, WPAD, HSTR, WSTR)
+
 
 def _schedule(s, data, data_pad, kernel, output, last):
+    padding = infer_pad(data, data_pad)
+    if data_pad is None:
+        stride = infer_stride(data, kernel, output)
+    else:
+        stride = infer_stride(data_pad, kernel, output)
+    wkl = _get_workload(data, kernel, stride, padding)
+
+    if wkl not in _WORKLOADS:
+        return s
+
+    # use specified schedule
+    sch = _SCHEDULES[_WORKLOADS.index(wkl)]
+
+    H, W = wkl.height, wkl.width
+    CN = wkl.channel
+    MT = wkl.multiplier
+
+    HK, WK = wkl.hkernel, wkl.wkernel
+    HPAD, WPAD = wkl.hpad, wkl.wpad
+    HSTR, WSTR = wkl.hstride, wkl.wstride
+
+    VH, VW = sch.vh, sch.vw
+    BC = sch.bc
+    VC = sch.vc
+
+    TH = H + 2*HPAD
+    TW = W + 2*WPAD
+    OH = (H + 2*HPAD - HK) / HSTR + 1
+    OW = (W + 2*WPAD - WK) / WSTR + 1
+
+
     A, B, C = data, kernel, output
     A0 = data_pad
-    C0 = last
+
+    A1 = s.cache_read(A0, "global", C)
+    _, c, h, w = s[A1].op.axis
+    c, vc = s[A1].split(c, VC)
+    s[A1].reorder(c, h, w, vc)
+
+    A2 = s.cache_write(A1, 'global')
+    s[A0].compute_inline()
+    s[A1].compute_inline()
+
+    B0 = s.cache_read(B, "global", C)
+    c, m, h, w = s[B0].op.axis
+    c, vc = s[B0].split(c, VC)
+    s[B0].reorder(c, m, h, w, vc)
+
+    B1 = s.cache_write(B0, 'global')
+    s[B0].compute_inline()
 
     _, c, h, w = s[C].op.axis
-    dh, dw = s[C].op.reduce_axis
+    c, vc = s[C].split(c, VC)
+    s[C].reorder(c, h, w, vc)
+
+
+    C0 = s.cache_write(C, 'global')
+    _, c, h, w, vc = s[C0].op.axis
+    dh, dw = s[C0].op.reduce_axis
+    oh, ow, ih, iw = s[C0].tile(h, w, VH, VW)
+    s[C0].reorder(c, oh, ow, dh, dw, ih, iw, vc)
+    if sch.unroll:
+        s[C0].unroll(iw)
+    s[C0].vectorize(vc)
 
-    oh, ow, ih, iw = s[C].tile(h, w, 2, 4)
-    s[C].reorder(oh, ow, dh, dw, ih, iw)
-    s[C].unroll(ih)
-    s[C].vectorize(iw)
+
+    # # s[C0].compute_at(s[C0], ow)
+    launch, c, _, _ = s[C].op.axis
+    s[C].pragma(launch, "parallel_launch_point")
 
     s[C].parallel(c)
-    s[C].pragma(c, "parallel_launch_point")
     s[C].pragma(c, "parallel_stride_pattern")
     s[C].pragma(c, "parallel_barrier_when_finish")
+
+
+    s[C0].compute_at(s[C], launch)
+    _, c, h, w, vc = s[C0].op.axis
+    s[C0].parallel(c)
+    s[C0].pragma(c, "parallel_stride_pattern")
+    s[C0].pragma(c, "parallel_barrier_when_finish")
+
+
+    s[A2].compute_at(s[C0], oh)
+    # parallel(s[A2], s[A2].op.axis[1], BC)
+
+    # # s[B0].compute_at(s[C0], ow)
+    s[B1].compute_at(s[C], launch)
+    c, m, h, w, vc = s[B1].op.axis
+    s[B1].parallel(c)
+    s[B1].pragma(c, "parallel_stride_pattern")
+    s[B1].pragma(c, "parallel_barrier_when_finish")
+
     return s