From 359e335f38941b36fa0b8d5797e83026efda02f0 Mon Sep 17 00:00:00 2001
From: Yuwei Hu <huyuwei1995@gmail.com>
Date: Wed, 25 Apr 2018 12:03:58 -0400
Subject: [PATCH] [TOPI] support dilated conv2d and depthwise_conv2d (#1129)

* support dilation in conv2d and depthwise_conv2d

* handle dilated conv in extern libs (cudnn, miopen)
---
 topi/python/topi/cuda/conv2d.py               | 28 +++++++++++++++----
 topi/python/topi/cuda/conv2d_hwcn.py          |  2 ++
 topi/python/topi/cuda/conv2d_nchw.py          |  2 ++
 topi/python/topi/cuda/depthwise_conv2d.py     |  4 +++
 topi/python/topi/generic/extern.py            |  2 +-
 topi/python/topi/mali/conv2d.py               | 12 ++++++++
 topi/python/topi/mali/depthwise_conv2d.py     |  2 ++
 topi/python/topi/nn/dilate.py                 |  2 +-
 topi/python/topi/nn/pad.py                    |  2 +-
 topi/python/topi/opengl/conv2d_nchw.py        |  3 ++
 topi/python/topi/rasp/conv2d.py               |  2 ++
 topi/python/topi/rasp/depthwise_conv2d.py     |  2 ++
 topi/python/topi/rocm/conv2d.py               | 21 +++++++++++---
 topi/python/topi/x86/conv2d.py                |  7 +++++
 topi/tests/python/test_topi_conv2d_hwcn.py    | 11 +++++---
 topi/tests/python/test_topi_conv2d_nchw.py    | 17 +++++++----
 topi/tests/python/test_topi_conv2d_nhwc.py    | 10 +++++--
 .../python/test_topi_depthwise_conv2d.py      | 21 +++++++++-----
 18 files changed, 118 insertions(+), 32 deletions(-)

diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py
index 2641bfe49..3c494cdeb 100644
--- a/topi/python/topi/cuda/conv2d.py
+++ b/topi/python/topi/cuda/conv2d.py
@@ -4,6 +4,7 @@ import tvm
 from tvm.contrib import cudnn
 import topi
 from ..nn.conv2d import conv2d
+from ..util import get_const_int
 
 @conv2d.register("cuda")
 def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'):
@@ -40,6 +41,23 @@ def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32
         pad_h = pad_w = padding
     else:
         pad_h, pad_w = padding
+    # handle dilation
+    dilation_h = dilation_w = 1
+    kernel_tvm = kernel
+    kernel_cudnn = kernel
+    if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
+        kernel_before_dilation = kernel.op.input_tensors[0]
+        kernel_cudnn = kernel_before_dilation
+        if layout == 'NCHW':
+            dilation_h = (get_const_int(kernel.shape[2]) + get_const_int(kernel_before_dilation.shape[2]) - 1) \
+                // get_const_int(kernel_before_dilation.shape[2])
+            dilation_w = (get_const_int(kernel.shape[3]) + get_const_int(kernel_before_dilation.shape[3]) - 1) \
+                // get_const_int(kernel_before_dilation.shape[2])
+        elif layout == 'NHWC':
+            dilation_h = (get_const_int(kernel.shape[1]) + get_const_int(kernel_before_dilation.shape[1]) - 1) \
+                // get_const_int(kernel_before_dilation.shape[1])
+            dilation_w = (get_const_int(kernel.shape[2]) + get_const_int(kernel_before_dilation.shape[2]) - 1) \
+                // get_const_int(kernel_before_dilation.shape[2])
     target = tvm.target.current_target()
     if "cudnn" in target.libs:
         assert layout != 'HWCN', "HWCN layout not supported with CUDNN."
@@ -47,19 +65,19 @@ def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32
         if layout == 'NHWC':
             tensor_format = 1 # CUDNN_TENSOR_NHWC
         return cudnn.conv2d_forward(data,
-                                    kernel,
+                                    kernel_cudnn,
                                     stride_h,
                                     stride_w,
                                     pad_h,
                                     pad_w,
-                                    1,  # dilation_h
-                                    1,  # dilation_w
+                                    dilation_h,
+                                    dilation_w,
                                     conv_mode=1,
                                     tensor_format=tensor_format,
                                     algo=-1) # let CUDNN choose the best algo
     elif layout == 'NCHW':
-        return topi.nn.conv2d_nchw(data, kernel, stride, padding, out_dtype)
+        return topi.nn.conv2d_nchw(data, kernel_tvm, stride, padding, out_dtype)
     elif layout == 'HWCN':
-        return topi.nn.conv2d_hwcn(data, kernel, stride, padding, out_dtype)
+        return topi.nn.conv2d_hwcn(data, kernel_tvm, stride, padding, out_dtype)
     else:
         raise ValueError("not support this layout {} yet".format(layout))
diff --git a/topi/python/topi/cuda/conv2d_hwcn.py b/topi/python/topi/cuda/conv2d_hwcn.py
index ec9025855..082966a3c 100644
--- a/topi/python/topi/cuda/conv2d_hwcn.py
+++ b/topi/python/topi/cuda/conv2d_hwcn.py
@@ -110,6 +110,8 @@ def schedule_conv2d_hwcn(outs):
         elif operator.tag == 'conv2d_hwcn':
             Apad = operator.input_tensors[0]
             W = operator.input_tensors[1]
+            if isinstance(W.op, tvm.tensor.ComputeOp) and 'dilate' in W.op.tag:
+                sch[W].compute_inline()
             B = operator.output(0)
             schedule(Apad, W, B)
         else:
diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py
index e313029e7..71104c59a 100644
--- a/topi/python/topi/cuda/conv2d_nchw.py
+++ b/topi/python/topi/cuda/conv2d_nchw.py
@@ -495,6 +495,8 @@ def schedule_conv2d_small_batch(outs):
         if 'conv2d_nchw' in OP.tag:
             temp = OP.input_tensors[0]
             Filter = OP.input_tensors[1]
+            if isinstance(Filter.op, tvm.tensor.ComputeOp) and 'dilate' in Filter.op.tag:
+                s[Filter].compute_inline()
             Output = OP.output(0)
             schedule(temp, Filter, Output)
 
diff --git a/topi/python/topi/cuda/depthwise_conv2d.py b/topi/python/topi/cuda/depthwise_conv2d.py
index 304577429..1901fcbb7 100644
--- a/topi/python/topi/cuda/depthwise_conv2d.py
+++ b/topi/python/topi/cuda/depthwise_conv2d.py
@@ -114,6 +114,8 @@ def schedule_depthwise_conv2d_nchw(outs):
         if OP.tag == 'depthwise_conv2d_nchw':
             PaddedInput = OP.input_tensors[0]
             Filter = OP.input_tensors[1]
+            if isinstance(Filter.op, tvm.tensor.ComputeOp) and 'dilate' in Filter.op.tag:
+                s[Filter].compute_inline()
             DepthwiseConv2d = OP.output(0)
             _schedule(PaddedInput, Filter, DepthwiseConv2d)
 
@@ -191,6 +193,8 @@ def schedule_depthwise_conv2d_nhwc(outs):
         if OP.tag == 'depthwise_conv2d_nhwc':
             PaddedInput = OP.input_tensors[0]
             Filter = OP.input_tensors[1]
+            if isinstance(Filter.op, tvm.tensor.ComputeOp) and 'dilate' in Filter.op.tag:
+                s[Filter].compute_inline()
             DepthwiseConv2d = OP.output(0)
             _schedule(PaddedInput, Filter, DepthwiseConv2d)
 
diff --git a/topi/python/topi/generic/extern.py b/topi/python/topi/generic/extern.py
index 082c1bca8..92a47f8a9 100644
--- a/topi/python/topi/generic/extern.py
+++ b/topi/python/topi/generic/extern.py
@@ -21,6 +21,6 @@ def schedule_extern(outs):
     """
     target = tvm.target.current_target(allow_none=False)
     if target.target_name != "llvm":
-        raise RuntimeError("schedule_injective not registered for '%s'" % target)
+        raise RuntimeError("schedule_extern not registered for '%s'" % target)
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     return tvm.create_schedule([x.op for x in outs])
diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py
index 5b4cf5bae..ad1dfbe61 100644
--- a/topi/python/topi/mali/conv2d.py
+++ b/topi/python/topi/mali/conv2d.py
@@ -289,6 +289,10 @@ def _schedule_spatialpack_conv2d(s, op):
     if data.dtype == 'float16' and (util.get_const_int(conv.shape[1]) == 4 or output_height == 28):
         num_thread //= 2
 
+    # schedule dilation
+    if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
+        s[kernel].compute_inline()
+
     # schedule padding
     if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
         data_pad = data
@@ -431,6 +435,10 @@ def _schedule_im2col_conv2d(s, op):
             num_thread1 = num_thread * 2
             num_thread2 = num_thread // 2
 
+    # schedule dilation
+    if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
+        s[kernel].compute_inline()
+
     # schedule padding
     if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
         data_pad = data
@@ -616,6 +624,10 @@ def _schedule_winograd(s, op):
     data_pad = s[d].op.input_tensors[0]
     data = s[data_pad].op.input_tensors[0]
 
+    # dilation
+    if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
+        s[kernel].compute_inline()
+
     # padding
     s[data_pad].compute_inline()
 
diff --git a/topi/python/topi/mali/depthwise_conv2d.py b/topi/python/topi/mali/depthwise_conv2d.py
index 428140550..61ec6334e 100644
--- a/topi/python/topi/mali/depthwise_conv2d.py
+++ b/topi/python/topi/mali/depthwise_conv2d.py
@@ -100,6 +100,8 @@ def schedule_depthwise_conv2d_nchw(outs):
         if op.tag == 'depthwise_conv2d_nchw':
             pad_data = op.input_tensors[0]
             kernel = op.input_tensors[1]
+            if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag:
+                s[kernel].compute_inline()
             conv = op.output(0)
             _schedule(pad_data, kernel, conv)
 
diff --git a/topi/python/topi/nn/dilate.py b/topi/python/topi/nn/dilate.py
index 07433280c..8ea280878 100644
--- a/topi/python/topi/nn/dilate.py
+++ b/topi/python/topi/nn/dilate.py
@@ -5,7 +5,7 @@ import tvm
 from .. import util
 from .. import tag
 
-@tvm.tag_scope(tag=tag.INJECTIVE)
+@tvm.tag_scope(tag=tag.INJECTIVE+",dilate")
 def dilate(data, strides, name="DilatedInput"):
     """Dilate data with zeros.
 
diff --git a/topi/python/topi/nn/pad.py b/topi/python/topi/nn/pad.py
index 9420f8a70..7ebbc566c 100644
--- a/topi/python/topi/nn/pad.py
+++ b/topi/python/topi/nn/pad.py
@@ -6,7 +6,7 @@ from .. import tag
 
 @tvm.tag_scope(tag=tag.INJECTIVE+",pad")
 def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
-    """Dilate Input with zeros.
+    """Pad Input with zeros.
 
     Parameters
     ----------
diff --git a/topi/python/topi/opengl/conv2d_nchw.py b/topi/python/topi/opengl/conv2d_nchw.py
index 7e8b7275f..573270c37 100644
--- a/topi/python/topi/opengl/conv2d_nchw.py
+++ b/topi/python/topi/opengl/conv2d_nchw.py
@@ -43,6 +43,9 @@ def schedule_conv2d_nchw(outs):
         elif OP.tag.startswith('conv2d_nchw'):
             conv2d = OP.output(0)
             data = OP.input_tensors[0]
+            kernel = OP.input_tensors[1]
+            if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
+                s[kernel].compute_inline()
             _schedule(conv2d, data)
         else:
             raise RuntimeError("Unsupported operator: %s" % OP.tag)
diff --git a/topi/python/topi/rasp/conv2d.py b/topi/python/topi/rasp/conv2d.py
index 6b2ce832e..13af1b937 100644
--- a/topi/python/topi/rasp/conv2d.py
+++ b/topi/python/topi/rasp/conv2d.py
@@ -328,6 +328,8 @@ def schedule_conv2d_nchw(outs):
             conv_out = op.input_tensors[0]
             kernel_vec = conv_out.op.input_tensors[1]
             kernel = kernel_vec.op.input_tensors[0]
+            if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
+                s[kernel].compute_inline()
             data_vec = conv_out.op.input_tensors[0]
             data = data_vec.op.input_tensors[0]
             data_pad = None
diff --git a/topi/python/topi/rasp/depthwise_conv2d.py b/topi/python/topi/rasp/depthwise_conv2d.py
index bb8ebed34..b2ff78e46 100644
--- a/topi/python/topi/rasp/depthwise_conv2d.py
+++ b/topi/python/topi/rasp/depthwise_conv2d.py
@@ -194,6 +194,8 @@ def schedule_depthwise_conv2d_nchw(outs):
         if op.tag == 'depthwise_conv2d_nchw':
             output = op.output(0)
             kernel = op.input_tensors[1]
+            if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
+                s[kernel].compute_inline()
             data = op.input_tensors[0]
             data_pad = None
             if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
diff --git a/topi/python/topi/rocm/conv2d.py b/topi/python/topi/rocm/conv2d.py
index 4dd5e5fd0..1aa125f8f 100644
--- a/topi/python/topi/rocm/conv2d.py
+++ b/topi/python/topi/rocm/conv2d.py
@@ -5,6 +5,8 @@ from tvm.contrib import miopen
 import topi
 from .. import generic
 from ..nn.conv2d import conv2d
+from ..util import get_const_int
+
 
 @conv2d.register("rocm")
 def conv2d_rocm(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'):
@@ -42,18 +44,29 @@ def conv2d_rocm(data, kernel, stride, padding, layout='NCHW', out_dtype='float32
         pad_h = pad_w = padding
     else:
         pad_h, pad_w = padding
+    # handle dilation
+    dilation_h = dilation_w = 1
+    kernel_tvm = kernel
+    kernel_cudnn = kernel
+    if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
+        kernel_before_dilation = kernel.op.input_tensors[0]
+        kernel_cudnn = kernel_before_dilation
+        dilation_h = (get_const_int(kernel.shape[2]) + get_const_int(kernel_before_dilation.shape[2]) - 1) \
+            // get_const_int(kernel_before_dilation.shape[2])
+        dilation_w = (get_const_int(kernel.shape[3]) + get_const_int(kernel_before_dilation.shape[3]) - 1) \
+            // get_const_int(kernel_before_dilation.shape[2])
     target = tvm.target.current_target()
     if "miopen" in target.libs:
         return miopen.conv2d_forward(data,
-                                     kernel,
+                                     kernel_cudnn,
                                      stride_h,
                                      stride_w,
                                      pad_h,
                                      pad_w,
-                                     1,  # dilation_h
-                                     1,  # dilation_w
+                                     dilation_h,
+                                     dilation_w,
                                      conv_mode=0)
-    return topi.nn.conv2d_nchw(data, kernel, stride, padding, out_dtype)
+    return topi.nn.conv2d_nchw(data, kernel_tvm, stride, padding, out_dtype)
 
 
 @generic.schedule_conv2d_nchw.register(["rocm"])
diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py
index beee95b13..7e9632ed9 100644
--- a/topi/python/topi/x86/conv2d.py
+++ b/topi/python/topi/x86/conv2d.py
@@ -91,6 +91,8 @@ def schedule_conv2d(outs):
         """NCHW conv2d schedule for non imagenet workloads"""
         conv = op.output(0)
         kernel = op.input_tensors[1]
+        if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
+            s[kernel].compute_inline()
         data = op.input_tensors[0]
         data_pad = None
         if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
@@ -134,6 +136,8 @@ def schedule_conv2d(outs):
                     conv_out = op.input_tensors[0]
                     kernel_vec = conv_out.op.input_tensors[1]
                     kernel = kernel_vec.op.input_tensors[0]
+                    if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
+                        s[kernel].compute_inline()
                     data_vec = conv_out.op.input_tensors[0]
                     data = data_vec.op.input_tensors[0]
                     data_pad = None
@@ -184,6 +188,9 @@ def schedule_conv2d_nhwc(outs):
         if 'conv2d_nhwc' in op.tag:
             conv = op.output(0)
             kernel = op.input_tensors[1]
+            if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
+                s[kernel].compute_inline()
+
             data = op.input_tensors[0]
             data_pad = None
             if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
diff --git a/topi/tests/python/test_topi_conv2d_hwcn.py b/topi/tests/python/test_topi_conv2d_hwcn.py
index 84962a0b4..645e18702 100644
--- a/topi/tests/python/test_topi_conv2d_hwcn.py
+++ b/topi/tests/python/test_topi_conv2d_hwcn.py
@@ -8,12 +8,13 @@ from tvm.contrib.pickle_memoize import memoize
 from topi.util import get_const_tuple
 
 
-def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, padding):
+def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
     in_height = in_width = in_size
 
     A = tvm.placeholder((in_height, in_width, in_channel, batch), name='A')
     W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
-    B = topi.nn.conv2d_hwcn(A, W, stride, padding)
+    dW = topi.nn.dilate(W, (dilation, dilation, 1, 1))
+    B = topi.nn.conv2d_hwcn(A, dW, stride, padding)
     C = topi.nn.relu(B)
     s1 = topi.cuda.schedule_conv2d_hwcn([B])
     s2 = topi.cuda.schedule_conv2d_hwcn([C])
@@ -26,7 +27,8 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
     def get_ref_data():
         a_np = np.random.uniform(size=a_shape).astype(dtype)
         w_np = np.random.uniform(size=w_shape).astype(dtype)
-        b_np = topi.testing.conv2d_hwcn_python(a_np, w_np, stride, padding)
+        dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
+        b_np = topi.testing.conv2d_hwcn_python(a_np, dw_np, stride, padding)
         c_np = np.maximum(b_np, 0)
         return a_np, w_np, b_np, c_np
     a_np, w_np, b_np, c_np = get_ref_data()
@@ -63,7 +65,8 @@ def test_conv2d_hwcn():
     verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "VALID")
     verify_conv2d_hwcn(4, 128, 16, 128, 5, 2, "VALID")
     verify_conv2d_hwcn(4, 128, 16, 256, 5, 2, "VALID")
-
+    # dilation = 2
+    verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "SAME", dilation=2)
 
 if __name__ == "__main__":
     test_conv2d_hwcn()
diff --git a/topi/tests/python/test_topi_conv2d_nchw.py b/topi/tests/python/test_topi_conv2d_nchw.py
index e5a674709..ed8fab524 100644
--- a/topi/tests/python/test_topi_conv2d_nchw.py
+++ b/topi/tests/python/test_topi_conv2d_nchw.py
@@ -3,10 +3,11 @@ import os
 import numpy as np
 import tvm
 import topi
+import topi.testing
 from tvm.contrib.pickle_memoize import memoize
 from topi.util import get_const_tuple
 
-def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
+def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
     in_height = in_width = in_size
 
     A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
@@ -16,11 +17,12 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
     w_shape = get_const_tuple(W.shape)
     dtype = A.dtype
 
-    @memoize("topi.tests.test_topi_conv2d.verify_con2d_nchw")
+    @memoize("topi.tests.test_topi_conv2d_nchw.verify_conv2d_nchw")
     def get_ref_data():
         a_np = np.random.uniform(size=a_shape).astype(dtype)
         w_np = np.random.uniform(size=w_shape).astype(dtype)
-        b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
+        dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+        b_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
         c_np = np.maximum(b_np, 0)
         return a_np, w_np, b_np, c_np
 
@@ -33,7 +35,8 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            B = topi.nn.conv2d(A, W, stride, padding, layout='NCHW')
+            dW = topi.nn.dilate(W, (1, 1, dilation, dilation))
+            B = topi.nn.conv2d(A, dW, stride, padding, layout='NCHW')
             C = topi.nn.relu(B)
             s1 = topi.generic.schedule_conv2d_nchw([B])
             s2 = topi.generic.schedule_conv2d_nchw([C])
@@ -43,8 +46,8 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
         with tvm.build_config(auto_unroll_max_step=1400,
                               unroll_explicit=(device != "cuda")):
-            func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding))
-            func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding))
+            func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
+            func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
             func1(a, w, b)
             func2(a, w, c)
             np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
@@ -75,6 +78,8 @@ def test_conv2d_nchw():
     verify_conv2d_nchw(1, 64, 224, 64, 3, 1, 1)
     verify_conv2d_nchw(1, 64, 224, 32, 3, 1, 1)
     verify_conv2d_nchw(1, 32, 224, 9, 3, 1, 1)
+    # dilation = 2
+    verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1, dilation=2)
 
 if __name__ == "__main__":
     test_conv2d_nchw()
diff --git a/topi/tests/python/test_topi_conv2d_nhwc.py b/topi/tests/python/test_topi_conv2d_nhwc.py
index 40aa3e550..7e41517c5 100644
--- a/topi/tests/python/test_topi_conv2d_nhwc.py
+++ b/topi/tests/python/test_topi_conv2d_nhwc.py
@@ -8,12 +8,13 @@ from tvm.contrib.pickle_memoize import memoize
 from topi.util import get_const_tuple
 
 
-def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding):
+def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
     in_height = in_width = in_size
 
     A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A')
     W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
-    B = topi.nn.conv2d_nhwc(A, W, stride, padding)
+    dW = topi.nn.dilate(W, (1, dilation, dilation, 1))
+    B = topi.nn.conv2d_nhwc(A, dW, stride, padding)
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -23,7 +24,8 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, p
     def get_ref_data():
         a_np = np.random.uniform(size=a_shape).astype(dtype)
         w_np = np.random.uniform(size=w_shape).astype(dtype)
-        b_np = topi.testing.conv2d_nhwc_python(a_np, w_np, stride, padding)
+        dw_np = topi.testing.dilate_python(w_np, (1, dilation, dilation, 1))
+        b_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding)
         return a_np, w_np, b_np
     a_np, w_np, b_np = get_ref_data()
 
@@ -54,6 +56,8 @@ def test_conv2d_nhwc():
     verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "VALID")
     verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "VALID")
     verify_conv2d_nhwc(4, 128, 16, 256, 5, 2, "VALID")
+    # dilation = 2
+    verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "SAME", dilation=2)
 
 
 if __name__ == "__main__":
diff --git a/topi/tests/python/test_topi_depthwise_conv2d.py b/topi/tests/python/test_topi_depthwise_conv2d.py
index 5a12d7abb..3e9794ac3 100644
--- a/topi/tests/python/test_topi_depthwise_conv2d.py
+++ b/topi/tests/python/test_topi_depthwise_conv2d.py
@@ -8,21 +8,21 @@ from tvm.contrib.pickle_memoize import memoize
 from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nhwc
 
 
-def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding):
+def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1):
     in_width = in_height
     filter_channel = in_channel
     filter_width = filter_height
     # placeholder
     Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input')
     Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter')
+    DilatedFilter = topi.nn.dilate(Filter, (1, 1, dilation, dilation), name='DilatedFilter')
     Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
     Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
     # declare
-    DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, stride=stride, padding=padding)
+    DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, DilatedFilter, stride=stride, padding=padding)
     ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
     Relu = topi.nn.relu(ScaleShift)
 
-
     def check_device(device):
         ctx = tvm.context(device, 0)
         if not ctx.exist:
@@ -52,11 +52,12 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
         def get_ref_data():
             input_np = np.random.uniform(size=input_shape).astype(dtype)
             filter_np = np.random.uniform(size=filter_shape).astype(dtype)
+            dilated_filter_np = topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation))
             scale_np = np.random.uniform(size=scale_shape).astype(dtype)
             shift_np = np.random.uniform(size=shift_shape).astype(dtype)
             # correctness with scipy
             depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(
-                input_np, filter_np, stride=stride, padding=padding)
+                input_np, dilated_filter_np, stride=stride, padding=padding)
             scale_shift_scipy = np.zeros(shape=scale_shift_shape)
             for c in range(in_channel * channel_multiplier):
                 scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
@@ -94,7 +95,7 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
     check_device("vulkan")
 
 
-def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
+def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding, dilation=1):
     in_width = in_height
     filter_channel = in_channel
     filter_width = filter_height
@@ -102,10 +103,11 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
     # placeholder
     Input = tvm.placeholder((batch, in_height, in_width, in_channel), name='Input')
     Filter = tvm.placeholder((filter_height, filter_width,filter_channel, channel_multiplier), name='Filter')
+    DilatedFilter = topi.nn.dilate(Filter, (1, 1, dilation, dilation), name='DilatedFilter')
     Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
     Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
     # declare
-    DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, Filter, stride=[stride_h, stride_w], padding=padding)
+    DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, DilatedFilter, stride=[stride_h, stride_w], padding=padding)
     ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
     Relu = topi.nn.relu(ScaleShift)
     # schedule
@@ -139,11 +141,12 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
         def get_ref_data():
             input_np = np.random.uniform(size=input_shape).astype(dtype)
             filter_np = np.random.uniform(size=filter_shape).astype(dtype)
+            dilated_filter_np = topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation))
             scale_np = np.random.uniform(size=scale_shape).astype(dtype)
             shift_np = np.random.uniform(size=shift_shape).astype(dtype)
             # correctness with scipy
             depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc(
-                input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
+                input_np, dilated_filter_np, stride=[stride_h, stride_w], padding=padding)
             scale_shift_scipy = np.zeros(shape=scale_shift_shape)
             for c in range(in_channel * channel_multiplier):
                 scale_shift_scipy[:,:,:,c] = depthwise_conv2d_scipy[:,:,:,c] * scale_np[c] + shift_np[c]
@@ -192,6 +195,8 @@ def test_depthwise_conv2d():
     depthwise_conv2d_with_workload_nchw(1, 728, 32, 1, 3, 1, "VALID")
     depthwise_conv2d_with_workload_nchw(4, 256, 64, 2, 5, 2, "VALID")
     depthwise_conv2d_with_workload_nchw(4, 256, 32, 2, 5, 2, "VALID")
+    # dilation = 2
+    depthwise_conv2d_with_workload_nchw(1, 728, 64, 1, 3, 1, "SAME", dilation=2)
     print("testing nhwc")
     depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME")
     depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "SAME")
@@ -201,6 +206,8 @@ def test_depthwise_conv2d():
     depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "VALID")
     depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "VALID")
     depthwise_conv2d_with_workload_nhwc(4, 256, 32, 2, 5, 2, "VALID")
+    # dilation = 2
+    depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2)
 
 if __name__ == "__main__":
     test_depthwise_conv2d()
-- 
GitLab