diff --git a/topi/python/topi/cuda/depthwise_conv2d.py b/topi/python/topi/cuda/depthwise_conv2d.py
index 352c04c5e8a8f69e189e9a2723cef005e0463b29..d0cfde03f3d95996e879971aa397a1f89b79e20c 100644
--- a/topi/python/topi/cuda/depthwise_conv2d.py
+++ b/topi/python/topi/cuda/depthwise_conv2d.py
@@ -36,64 +36,62 @@ def schedule_depthwise_conv2d_nchw(outs):
             Output = outs[0].op.output(0)
             s[DepthwiseConv2d].set_scope("local")
         # schedule parameters
-        num_thread_x = 8
         num_thread_y = 8
-        num_vthread_x = 1
+        num_thread_x = 8
         num_vthread_y = 1
+        num_vthread_x = 1
         blocking_h = out_height
         blocking_w = out_width
         if out_height % 32 == 0:
             blocking_h = 32
-            num_thread_x = 2
-            num_vthread_x = 2
         if out_width % 32 == 0:
             blocking_w = 32
-            num_thread_y = 16
-            num_vthread_y = 2
-        block_x = tvm.thread_axis("blockIdx.x")
+            num_thread_x = 16
+            num_vthread_x = 2
         block_y = tvm.thread_axis("blockIdx.y")
-        thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
+        block_x = tvm.thread_axis("blockIdx.x")
         thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
-        thread_vx = tvm.thread_axis((0, num_vthread_x), "vthread", name="vx")
+        thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
         thread_vy = tvm.thread_axis((0, num_vthread_y), "vthread", name="vy")
+        thread_vx = tvm.thread_axis((0, num_vthread_x), "vthread", name="vx")
         # split and bind
-        bx, bxi = s[Output].split(Output.op.axis[1], factor=channel_multiplier)
-        s[Output].reorder(Output.op.axis[2], Output.op.axis[3], bxi)
-        bx = s[Output].fuse(Output.op.axis[0], bx)
-        s[Output].bind(bx, block_x)
-        by1, y1i = s[Output].split(Output.op.axis[2], factor=blocking_h)
-        tvx, vxi = s[Output].split(y1i, nparts=num_vthread_x)
-        tx, xi = s[Output].split(vxi, nparts=num_thread_x)
-        by2, y2i = s[Output].split(Output.op.axis[3], factor=blocking_w)
-        tvy, vyi = s[Output].split(y2i, nparts=num_vthread_y)
+        by, byi = s[Output].split(Output.op.axis[1], factor=channel_multiplier)
+        s[Output].reorder(Output.op.axis[2], Output.op.axis[3], byi)
+        by = s[Output].fuse(Output.op.axis[0], by)
+        s[Output].bind(by, block_y)
+        bx1, x1i = s[Output].split(Output.op.axis[2], factor=blocking_h)
+        tvy, vyi = s[Output].split(x1i, nparts=num_vthread_y)
         ty, yi = s[Output].split(vyi, nparts=num_thread_y)
-        s[Output].reorder(by1, by2, tvx, tvy, tx, ty, xi, yi)
-        by = s[Output].fuse(by1, by2)
-        s[Output].bind(tvx, thread_vx)
+        bx2, x2i = s[Output].split(Output.op.axis[3], factor=blocking_w)
+        tvx, vxi = s[Output].split(x2i, nparts=num_vthread_x)
+        tx, xi = s[Output].split(vxi, nparts=num_thread_x)
+        s[Output].reorder(bx1, bx2, tvy, tvx, ty, tx, yi, xi)
+        bx = s[Output].fuse(bx1, bx2)
+        s[Output].bind(bx, block_x)
         s[Output].bind(tvy, thread_vy)
-        s[Output].bind(tx, thread_x)
+        s[Output].bind(tvx, thread_vx)
         s[Output].bind(ty, thread_y)
-        s[Output].bind(by, block_y)
+        s[Output].bind(tx, thread_x)
         # local memory load
-        s[IL].compute_at(s[Output], ty)
-        s[FL].compute_at(s[Output], ty)
+        s[IL].compute_at(s[Output], tx)
+        s[FL].compute_at(s[Output], tx)
         if DepthwiseConv2d.op in s.outputs:
-            s[CL].compute_at(s[Output], ty)
+            s[CL].compute_at(s[Output], tx)
         else:
-            s[DepthwiseConv2d].compute_at(s[Output], ty)
+            s[DepthwiseConv2d].compute_at(s[Output], tx)
         # input's shared memory load
-        s[IS].compute_at(s[Output], by)
-        tx, xi = s[IS].split(IS.op.axis[2], nparts=num_thread_x)
-        ty, yi = s[IS].split(IS.op.axis[3], nparts=num_thread_y)
-        s[IS].bind(tx, thread_x)
+        s[IS].compute_at(s[Output], bx)
+        ty, yi = s[IS].split(IS.op.axis[2], nparts=num_thread_y)
+        tx, xi = s[IS].split(IS.op.axis[3], nparts=num_thread_x)
         s[IS].bind(ty, thread_y)
+        s[IS].bind(tx, thread_x)
         # filter's shared memory load
-        s[FS].compute_at(s[Output], by)
+        s[FS].compute_at(s[Output], bx)
         s[FS].reorder(FS.op.axis[2], FS.op.axis[3], FS.op.axis[1])
-        tx, xi = s[FS].split(FS.op.axis[2], nparts=num_thread_x)
-        ty, yi = s[FS].split(FS.op.axis[3], nparts=num_thread_y)
-        s[FS].bind(tx, thread_x)
+        ty, yi = s[FS].split(FS.op.axis[2], nparts=num_thread_y)
+        tx, xi = s[FS].split(FS.op.axis[3], nparts=num_thread_x)
         s[FS].bind(ty, thread_y)
+        s[FS].bind(tx, thread_x)
 
     def traverse(OP):
         # inline all one-to-one-mapping operators except the last stage (output)
diff --git a/topi/recipe/conv/depthwise_conv2d_test.py b/topi/recipe/conv/depthwise_conv2d_test.py
index 0f3a14dde2aa4e0da31bc9f5e6e2678d92deaaed..64dc10e11158e3622166af4123d962b420109955 100644
--- a/topi/recipe/conv/depthwise_conv2d_test.py
+++ b/topi/recipe/conv/depthwise_conv2d_test.py
@@ -97,9 +97,9 @@ def test_depthwise_conv2d_nchw():
         print("Stride = (%d, %d)" % (stride_h, stride_w))
         print("padding = %s\n" % padding)
         print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape)))
-        print("average time cost of 1000 runs (depthwise_conv2d) = %g sec" % tcost_1)
-        print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g sec" % tcost_2)
-        print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g sec" % tcost_3)
+        print("average time cost of 1000 runs (depthwise_conv2d) = %g us" % (tcost_1*1e6))
+        print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g us" % (tcost_2*1e6))
+        print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g us" % (tcost_3*1e6))
         # correctness
         depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
         scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
@@ -186,9 +186,9 @@ def test_depthwise_conv2d_nhwc():
         print("Stride = (%d, %d)" % (stride_h, stride_w))
         print("padding = %s\n" % padding)
         print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape)))
-        print("average time cost of 1000 runs (depthwise_conv2d) = %g sec" % tcost_1)
-        print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g sec" % tcost_2)
-        print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g sec" % tcost_3)
+        print("average time cost of 1000 runs (depthwise_conv2d) = %g us" % (tcost_1*1e6))
+        print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g us" % (tcost_2*1e6))
+        print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g us" % (tcost_3*1e6))
         # correctness
         depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
         scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))