diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py
index b6bd35768112fedb2015ca369d4fc2fa1f9c2008..a59a996242331f5cf47add07ca2cafb4c9766ee2 100644
--- a/topi/python/topi/cuda/conv2d_nchw.py
+++ b/topi/python/topi/cuda/conv2d_nchw.py
@@ -1,15 +1,17 @@
-#pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches
+#pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches, line-too-long
 """Schedule for conv2d_nchw with auto fusion"""
 import tvm
 from .. import util
 from .. import tag
 from .. import generic
 
-def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L):
+def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
     """Schedule conv2d for specific feature_in_out_filter pattern"""
     # scheduler params
     ofactor = 16
     hfactor = 2
+    if flag >= 96:
+        hfactor = 4
     ow_size = util.get_const_int(Out.shape[3])
     num_thread = ow_size * hfactor
     vthread = ofactor
@@ -22,7 +24,8 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L):
     oh, ih = s[Out].split(h, factor=hfactor)
     s[Out].reorder(ooc, oh, ioc, ih, w)
     oc = s[Out].fuse(ooc, oh)
-    w = s[Out].fuse(w, ih)
+    ow, _ = s[Out].split(w, nparts=ow_size)
+    w = s[Out].fuse(ow, ih)
     s[Out].bind(w, thread_x)
     s[Out].bind(ioc, thread_xz)
     s[Out].bind(oc, block_x)
@@ -360,7 +363,11 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L):
     if util.get_const_int(Filter.shape[0]) == 64:
         opart2 = 8
         ifactor = 16
-    sfactor = max(1, ofactor // (opart2*2))
+    if util.get_const_int(Out.shape[2]) == 224:
+        num_thread = 4
+        wfactor = 112
+        ifactor = 4
+    sfactor = max(1, ofactor // (opart2*vthread))
     spart = max(1, (wfactor + vthread-1) // vthread)
 
     block_x = tvm.thread_axis("blockIdx.x")
@@ -368,7 +375,7 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L):
     block_z = tvm.thread_axis("blockIdx.z")
     thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
     thread_y = tvm.thread_axis((0, wfactor // vthread), "threadIdx.y")
-    thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
+    thread_xz = tvm.thread_axis((0, opart2), "vthread", name="vx")
     thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
 
     i, oc, h, w = s[Out].op.axis
@@ -394,10 +401,10 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L):
     ic, dh, dw = s[Out_L].op.reduce_axis
     oic, iic = s[Out_L].split(ic, factor=ifactor)
     s[Out_L].reorder(oic, dh, dw, iic, h, w)
-
     fuse_index = s[Out_L].fuse(dw, dh)
     fuse_index = s[Out_L].fuse(fuse_index, oic)
     dw = fuse_index
+
     s[temp_S].compute_at(s[Out_L], dw)
     s[Filter_S].compute_at(s[Out_L], dw)
 
@@ -421,16 +428,6 @@ def schedule_conv2d_small_batch(outs):
 
     def schedule(temp, Filter, Output):
         """Schedule conv2d_nchw"""
-        block_h = util.get_const_int(Output.shape[3])
-        block_w = util.get_const_int(temp.shape[1])
-        if block_h % 48 == 0:
-            block_h = 48
-        elif block_h % 32 == 0:
-            block_h = 32
-        if block_w % 48 == 0:
-            block_w = 48
-        elif block_w % 32 == 0:
-            block_w = 32
 
         flag = util.get_const_int(Filter.shape[0])+util.get_const_int(Filter.shape[1])
 
@@ -450,7 +447,7 @@ def schedule_conv2d_small_batch(outs):
             s[temp_G].reorder(i, oic, h, w, iic)
             temp_R = s.cache_write(temp_G, "global")
             temp_S = s.cache_read(temp_R, "shared", [temp_G])
-        elif util.get_const_int(Filter.shape[3]) == 7:
+        elif util.get_const_int(Filter.shape[3]) == 7 or (util.get_const_int(Output.shape[2] == 224) and flag < 128):
             temp_G = s.cache_read(temp, "global", [Output])
             s[temp_G].compute_inline()
             i, ic, h, w = s[temp_G].op.axis
@@ -472,8 +469,8 @@ def schedule_conv2d_small_batch(outs):
             s[Output].set_scope("local")
             Out_L = Output
 
-        if util.get_const_int(Filter.shape[3]) == 7:
-            conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L)
+        if util.get_const_int(Filter.shape[3]) == 7 or (util.get_const_int(Output.shape[2] == 224) and flag < 128):
+            conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag)
         elif 128 < flag < 512:
             conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag)
         elif flag >= 512:
diff --git a/topi/python/topi/cuda/conv2d_transpose_nchw.py b/topi/python/topi/cuda/conv2d_transpose_nchw.py
index 82fd451f4ecb98ba2ad2920a5d2dd895fb5d1293..edd255a8fac036ca8da4fd8d48124fbe83869e29 100644
--- a/topi/python/topi/cuda/conv2d_transpose_nchw.py
+++ b/topi/python/topi/cuda/conv2d_transpose_nchw.py
@@ -1,4 +1,4 @@
-#pylint: disable=invalid-name
+#pylint: disable=invalid-name, line-too-long
 """Schedule for conv2d_transpose_nchw with auto fusion"""
 import tvm
 from .. import util
@@ -42,7 +42,7 @@ def schedule_conv2d_transpose_small_batch(outs):
             s[temp_G].reorder(i, oic, h, w, iic)
             temp_R = s.cache_write(temp_G, "global")
             temp_S = s.cache_read(temp_R, "shared", [temp_G])
-        elif util.get_const_int(Filter.shape[3]) == 7:
+        elif util.get_const_int(Filter.shape[3]) == 7 or (util.get_const_int(Output.shape[2] == 224) and flag < 128):
             temp_G = s.cache_read(temp, "global", [Output])
             s[temp_G].compute_inline()
             i, ic, h, w = s[temp_G].op.axis
@@ -64,8 +64,8 @@ def schedule_conv2d_transpose_small_batch(outs):
             s[Output].set_scope("local")
             Out_L = Output
 
-        if util.get_const_int(Filter.shape[3]) == 7:
-            conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L)
+        if util.get_const_int(Filter.shape[3]) == 7 or (util.get_const_int(Output.shape[2] == 224) and flag < 128):
+            conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag)
         elif 128 < flag < 512:
             conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag)
         elif flag >= 512: