From 359e335f38941b36fa0b8d5797e83026efda02f0 Mon Sep 17 00:00:00 2001 From: Yuwei Hu <huyuwei1995@gmail.com> Date: Wed, 25 Apr 2018 12:03:58 -0400 Subject: [PATCH] [TOPI] support dilated conv2d and depthwise_conv2d (#1129) * support dilation in conv2d and depthwise_conv2d * handle dilated conv in extern libs (cudnn, miopen) --- topi/python/topi/cuda/conv2d.py | 28 +++++++++++++++---- topi/python/topi/cuda/conv2d_hwcn.py | 2 ++ topi/python/topi/cuda/conv2d_nchw.py | 2 ++ topi/python/topi/cuda/depthwise_conv2d.py | 4 +++ topi/python/topi/generic/extern.py | 2 +- topi/python/topi/mali/conv2d.py | 12 ++++++++ topi/python/topi/mali/depthwise_conv2d.py | 2 ++ topi/python/topi/nn/dilate.py | 2 +- topi/python/topi/nn/pad.py | 2 +- topi/python/topi/opengl/conv2d_nchw.py | 3 ++ topi/python/topi/rasp/conv2d.py | 2 ++ topi/python/topi/rasp/depthwise_conv2d.py | 2 ++ topi/python/topi/rocm/conv2d.py | 21 +++++++++++--- topi/python/topi/x86/conv2d.py | 7 +++++ topi/tests/python/test_topi_conv2d_hwcn.py | 11 +++++--- topi/tests/python/test_topi_conv2d_nchw.py | 17 +++++++---- topi/tests/python/test_topi_conv2d_nhwc.py | 10 +++++-- .../python/test_topi_depthwise_conv2d.py | 21 +++++++++----- 18 files changed, 118 insertions(+), 32 deletions(-) diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py index 2641bfe49..3c494cdeb 100644 --- a/topi/python/topi/cuda/conv2d.py +++ b/topi/python/topi/cuda/conv2d.py @@ -4,6 +4,7 @@ import tvm from tvm.contrib import cudnn import topi from ..nn.conv2d import conv2d +from ..util import get_const_int @conv2d.register("cuda") def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'): @@ -40,6 +41,23 @@ def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 pad_h = pad_w = padding else: pad_h, pad_w = padding + # handle dilation + dilation_h = dilation_w = 1 + kernel_tvm = kernel + kernel_cudnn = kernel + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + kernel_before_dilation = kernel.op.input_tensors[0] + kernel_cudnn = kernel_before_dilation + if layout == 'NCHW': + dilation_h = (get_const_int(kernel.shape[2]) + get_const_int(kernel_before_dilation.shape[2]) - 1) \ + // get_const_int(kernel_before_dilation.shape[2]) + dilation_w = (get_const_int(kernel.shape[3]) + get_const_int(kernel_before_dilation.shape[3]) - 1) \ + // get_const_int(kernel_before_dilation.shape[2]) + elif layout == 'NHWC': + dilation_h = (get_const_int(kernel.shape[1]) + get_const_int(kernel_before_dilation.shape[1]) - 1) \ + // get_const_int(kernel_before_dilation.shape[1]) + dilation_w = (get_const_int(kernel.shape[2]) + get_const_int(kernel_before_dilation.shape[2]) - 1) \ + // get_const_int(kernel_before_dilation.shape[2]) target = tvm.target.current_target() if "cudnn" in target.libs: assert layout != 'HWCN', "HWCN layout not supported with CUDNN." @@ -47,19 +65,19 @@ def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 if layout == 'NHWC': tensor_format = 1 # CUDNN_TENSOR_NHWC return cudnn.conv2d_forward(data, - kernel, + kernel_cudnn, stride_h, stride_w, pad_h, pad_w, - 1, # dilation_h - 1, # dilation_w + dilation_h, + dilation_w, conv_mode=1, tensor_format=tensor_format, algo=-1) # let CUDNN choose the best algo elif layout == 'NCHW': - return topi.nn.conv2d_nchw(data, kernel, stride, padding, out_dtype) + return topi.nn.conv2d_nchw(data, kernel_tvm, stride, padding, out_dtype) elif layout == 'HWCN': - return topi.nn.conv2d_hwcn(data, kernel, stride, padding, out_dtype) + return topi.nn.conv2d_hwcn(data, kernel_tvm, stride, padding, out_dtype) else: raise ValueError("not support this layout {} yet".format(layout)) diff --git a/topi/python/topi/cuda/conv2d_hwcn.py b/topi/python/topi/cuda/conv2d_hwcn.py index ec9025855..082966a3c 100644 --- a/topi/python/topi/cuda/conv2d_hwcn.py +++ b/topi/python/topi/cuda/conv2d_hwcn.py @@ -110,6 +110,8 @@ def schedule_conv2d_hwcn(outs): elif operator.tag == 'conv2d_hwcn': Apad = operator.input_tensors[0] W = operator.input_tensors[1] + if isinstance(W.op, tvm.tensor.ComputeOp) and 'dilate' in W.op.tag: + sch[W].compute_inline() B = operator.output(0) schedule(Apad, W, B) else: diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py index e313029e7..71104c59a 100644 --- a/topi/python/topi/cuda/conv2d_nchw.py +++ b/topi/python/topi/cuda/conv2d_nchw.py @@ -495,6 +495,8 @@ def schedule_conv2d_small_batch(outs): if 'conv2d_nchw' in OP.tag: temp = OP.input_tensors[0] Filter = OP.input_tensors[1] + if isinstance(Filter.op, tvm.tensor.ComputeOp) and 'dilate' in Filter.op.tag: + s[Filter].compute_inline() Output = OP.output(0) schedule(temp, Filter, Output) diff --git a/topi/python/topi/cuda/depthwise_conv2d.py b/topi/python/topi/cuda/depthwise_conv2d.py index 304577429..1901fcbb7 100644 --- a/topi/python/topi/cuda/depthwise_conv2d.py +++ b/topi/python/topi/cuda/depthwise_conv2d.py @@ -114,6 +114,8 @@ def schedule_depthwise_conv2d_nchw(outs): if OP.tag == 'depthwise_conv2d_nchw': PaddedInput = OP.input_tensors[0] Filter = OP.input_tensors[1] + if isinstance(Filter.op, tvm.tensor.ComputeOp) and 'dilate' in Filter.op.tag: + s[Filter].compute_inline() DepthwiseConv2d = OP.output(0) _schedule(PaddedInput, Filter, DepthwiseConv2d) @@ -191,6 +193,8 @@ def schedule_depthwise_conv2d_nhwc(outs): if OP.tag == 'depthwise_conv2d_nhwc': PaddedInput = OP.input_tensors[0] Filter = OP.input_tensors[1] + if isinstance(Filter.op, tvm.tensor.ComputeOp) and 'dilate' in Filter.op.tag: + s[Filter].compute_inline() DepthwiseConv2d = OP.output(0) _schedule(PaddedInput, Filter, DepthwiseConv2d) diff --git a/topi/python/topi/generic/extern.py b/topi/python/topi/generic/extern.py index 082c1bca8..92a47f8a9 100644 --- a/topi/python/topi/generic/extern.py +++ b/topi/python/topi/generic/extern.py @@ -21,6 +21,6 @@ def schedule_extern(outs): """ target = tvm.target.current_target(allow_none=False) if target.target_name != "llvm": - raise RuntimeError("schedule_injective not registered for '%s'" % target) + raise RuntimeError("schedule_extern not registered for '%s'" % target) outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs return tvm.create_schedule([x.op for x in outs]) diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py index 5b4cf5bae..ad1dfbe61 100644 --- a/topi/python/topi/mali/conv2d.py +++ b/topi/python/topi/mali/conv2d.py @@ -289,6 +289,10 @@ def _schedule_spatialpack_conv2d(s, op): if data.dtype == 'float16' and (util.get_const_int(conv.shape[1]) == 4 or output_height == 28): num_thread //= 2 + # schedule dilation + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + # schedule padding if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: data_pad = data @@ -431,6 +435,10 @@ def _schedule_im2col_conv2d(s, op): num_thread1 = num_thread * 2 num_thread2 = num_thread // 2 + # schedule dilation + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + # schedule padding if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: data_pad = data @@ -616,6 +624,10 @@ def _schedule_winograd(s, op): data_pad = s[d].op.input_tensors[0] data = s[data_pad].op.input_tensors[0] + # dilation + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + # padding s[data_pad].compute_inline() diff --git a/topi/python/topi/mali/depthwise_conv2d.py b/topi/python/topi/mali/depthwise_conv2d.py index 428140550..61ec6334e 100644 --- a/topi/python/topi/mali/depthwise_conv2d.py +++ b/topi/python/topi/mali/depthwise_conv2d.py @@ -100,6 +100,8 @@ def schedule_depthwise_conv2d_nchw(outs): if op.tag == 'depthwise_conv2d_nchw': pad_data = op.input_tensors[0] kernel = op.input_tensors[1] + if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag: + s[kernel].compute_inline() conv = op.output(0) _schedule(pad_data, kernel, conv) diff --git a/topi/python/topi/nn/dilate.py b/topi/python/topi/nn/dilate.py index 07433280c..8ea280878 100644 --- a/topi/python/topi/nn/dilate.py +++ b/topi/python/topi/nn/dilate.py @@ -5,7 +5,7 @@ import tvm from .. import util from .. import tag -@tvm.tag_scope(tag=tag.INJECTIVE) +@tvm.tag_scope(tag=tag.INJECTIVE+",dilate") def dilate(data, strides, name="DilatedInput"): """Dilate data with zeros. diff --git a/topi/python/topi/nn/pad.py b/topi/python/topi/nn/pad.py index 9420f8a70..7ebbc566c 100644 --- a/topi/python/topi/nn/pad.py +++ b/topi/python/topi/nn/pad.py @@ -6,7 +6,7 @@ from .. import tag @tvm.tag_scope(tag=tag.INJECTIVE+",pad") def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"): - """Dilate Input with zeros. + """Pad Input with zeros. Parameters ---------- diff --git a/topi/python/topi/opengl/conv2d_nchw.py b/topi/python/topi/opengl/conv2d_nchw.py index 7e8b7275f..573270c37 100644 --- a/topi/python/topi/opengl/conv2d_nchw.py +++ b/topi/python/topi/opengl/conv2d_nchw.py @@ -43,6 +43,9 @@ def schedule_conv2d_nchw(outs): elif OP.tag.startswith('conv2d_nchw'): conv2d = OP.output(0) data = OP.input_tensors[0] + kernel = OP.input_tensors[1] + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() _schedule(conv2d, data) else: raise RuntimeError("Unsupported operator: %s" % OP.tag) diff --git a/topi/python/topi/rasp/conv2d.py b/topi/python/topi/rasp/conv2d.py index 6b2ce832e..13af1b937 100644 --- a/topi/python/topi/rasp/conv2d.py +++ b/topi/python/topi/rasp/conv2d.py @@ -328,6 +328,8 @@ def schedule_conv2d_nchw(outs): conv_out = op.input_tensors[0] kernel_vec = conv_out.op.input_tensors[1] kernel = kernel_vec.op.input_tensors[0] + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() data_vec = conv_out.op.input_tensors[0] data = data_vec.op.input_tensors[0] data_pad = None diff --git a/topi/python/topi/rasp/depthwise_conv2d.py b/topi/python/topi/rasp/depthwise_conv2d.py index bb8ebed34..b2ff78e46 100644 --- a/topi/python/topi/rasp/depthwise_conv2d.py +++ b/topi/python/topi/rasp/depthwise_conv2d.py @@ -194,6 +194,8 @@ def schedule_depthwise_conv2d_nchw(outs): if op.tag == 'depthwise_conv2d_nchw': output = op.output(0) kernel = op.input_tensors[1] + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() data = op.input_tensors[0] data_pad = None if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: diff --git a/topi/python/topi/rocm/conv2d.py b/topi/python/topi/rocm/conv2d.py index 4dd5e5fd0..1aa125f8f 100644 --- a/topi/python/topi/rocm/conv2d.py +++ b/topi/python/topi/rocm/conv2d.py @@ -5,6 +5,8 @@ from tvm.contrib import miopen import topi from .. import generic from ..nn.conv2d import conv2d +from ..util import get_const_int + @conv2d.register("rocm") def conv2d_rocm(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'): @@ -42,18 +44,29 @@ def conv2d_rocm(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 pad_h = pad_w = padding else: pad_h, pad_w = padding + # handle dilation + dilation_h = dilation_w = 1 + kernel_tvm = kernel + kernel_cudnn = kernel + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + kernel_before_dilation = kernel.op.input_tensors[0] + kernel_cudnn = kernel_before_dilation + dilation_h = (get_const_int(kernel.shape[2]) + get_const_int(kernel_before_dilation.shape[2]) - 1) \ + // get_const_int(kernel_before_dilation.shape[2]) + dilation_w = (get_const_int(kernel.shape[3]) + get_const_int(kernel_before_dilation.shape[3]) - 1) \ + // get_const_int(kernel_before_dilation.shape[2]) target = tvm.target.current_target() if "miopen" in target.libs: return miopen.conv2d_forward(data, - kernel, + kernel_cudnn, stride_h, stride_w, pad_h, pad_w, - 1, # dilation_h - 1, # dilation_w + dilation_h, + dilation_w, conv_mode=0) - return topi.nn.conv2d_nchw(data, kernel, stride, padding, out_dtype) + return topi.nn.conv2d_nchw(data, kernel_tvm, stride, padding, out_dtype) @generic.schedule_conv2d_nchw.register(["rocm"]) diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index beee95b13..7e9632ed9 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -91,6 +91,8 @@ def schedule_conv2d(outs): """NCHW conv2d schedule for non imagenet workloads""" conv = op.output(0) kernel = op.input_tensors[1] + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() data = op.input_tensors[0] data_pad = None if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: @@ -134,6 +136,8 @@ def schedule_conv2d(outs): conv_out = op.input_tensors[0] kernel_vec = conv_out.op.input_tensors[1] kernel = kernel_vec.op.input_tensors[0] + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() data_vec = conv_out.op.input_tensors[0] data = data_vec.op.input_tensors[0] data_pad = None @@ -184,6 +188,9 @@ def schedule_conv2d_nhwc(outs): if 'conv2d_nhwc' in op.tag: conv = op.output(0) kernel = op.input_tensors[1] + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + data = op.input_tensors[0] data_pad = None if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: diff --git a/topi/tests/python/test_topi_conv2d_hwcn.py b/topi/tests/python/test_topi_conv2d_hwcn.py index 84962a0b4..645e18702 100644 --- a/topi/tests/python/test_topi_conv2d_hwcn.py +++ b/topi/tests/python/test_topi_conv2d_hwcn.py @@ -8,12 +8,13 @@ from tvm.contrib.pickle_memoize import memoize from topi.util import get_const_tuple -def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, padding): +def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): in_height = in_width = in_size A = tvm.placeholder((in_height, in_width, in_channel, batch), name='A') W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W') - B = topi.nn.conv2d_hwcn(A, W, stride, padding) + dW = topi.nn.dilate(W, (dilation, dilation, 1, 1)) + B = topi.nn.conv2d_hwcn(A, dW, stride, padding) C = topi.nn.relu(B) s1 = topi.cuda.schedule_conv2d_hwcn([B]) s2 = topi.cuda.schedule_conv2d_hwcn([C]) @@ -26,7 +27,8 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p def get_ref_data(): a_np = np.random.uniform(size=a_shape).astype(dtype) w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = topi.testing.conv2d_hwcn_python(a_np, w_np, stride, padding) + dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) + b_np = topi.testing.conv2d_hwcn_python(a_np, dw_np, stride, padding) c_np = np.maximum(b_np, 0) return a_np, w_np, b_np, c_np a_np, w_np, b_np, c_np = get_ref_data() @@ -63,7 +65,8 @@ def test_conv2d_hwcn(): verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "VALID") verify_conv2d_hwcn(4, 128, 16, 128, 5, 2, "VALID") verify_conv2d_hwcn(4, 128, 16, 256, 5, 2, "VALID") - + # dilation = 2 + verify_conv2d_hwcn(1, 256, 32, 256, 3, 1, "SAME", dilation=2) if __name__ == "__main__": test_conv2d_hwcn() diff --git a/topi/tests/python/test_topi_conv2d_nchw.py b/topi/tests/python/test_topi_conv2d_nchw.py index e5a674709..ed8fab524 100644 --- a/topi/tests/python/test_topi_conv2d_nchw.py +++ b/topi/tests/python/test_topi_conv2d_nchw.py @@ -3,10 +3,11 @@ import os import numpy as np import tvm import topi +import topi.testing from tvm.contrib.pickle_memoize import memoize from topi.util import get_const_tuple -def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding): +def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): in_height = in_width = in_size A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') @@ -16,11 +17,12 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p w_shape = get_const_tuple(W.shape) dtype = A.dtype - @memoize("topi.tests.test_topi_conv2d.verify_con2d_nchw") + @memoize("topi.tests.test_topi_conv2d_nchw.verify_conv2d_nchw") def get_ref_data(): a_np = np.random.uniform(size=a_shape).astype(dtype) w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + b_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) c_np = np.maximum(b_np, 0) return a_np, w_np, b_np, c_np @@ -33,7 +35,8 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p return print("Running on target: %s" % device) with tvm.target.create(device): - B = topi.nn.conv2d(A, W, stride, padding, layout='NCHW') + dW = topi.nn.dilate(W, (1, 1, dilation, dilation)) + B = topi.nn.conv2d(A, dW, stride, padding, layout='NCHW') C = topi.nn.relu(B) s1 = topi.generic.schedule_conv2d_nchw([B]) s2 = topi.generic.schedule_conv2d_nchw([C]) @@ -43,8 +46,8 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) with tvm.build_config(auto_unroll_max_step=1400, unroll_explicit=(device != "cuda")): - func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) - func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) + func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) func1(a, w, b) func2(a, w, c) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) @@ -75,6 +78,8 @@ def test_conv2d_nchw(): verify_conv2d_nchw(1, 64, 224, 64, 3, 1, 1) verify_conv2d_nchw(1, 64, 224, 32, 3, 1, 1) verify_conv2d_nchw(1, 32, 224, 9, 3, 1, 1) + # dilation = 2 + verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1, dilation=2) if __name__ == "__main__": test_conv2d_nchw() diff --git a/topi/tests/python/test_topi_conv2d_nhwc.py b/topi/tests/python/test_topi_conv2d_nhwc.py index 40aa3e550..7e41517c5 100644 --- a/topi/tests/python/test_topi_conv2d_nhwc.py +++ b/topi/tests/python/test_topi_conv2d_nhwc.py @@ -8,12 +8,13 @@ from tvm.contrib.pickle_memoize import memoize from topi.util import get_const_tuple -def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding): +def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): in_height = in_width = in_size A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A') W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W') - B = topi.nn.conv2d_nhwc(A, W, stride, padding) + dW = topi.nn.dilate(W, (1, dilation, dilation, 1)) + B = topi.nn.conv2d_nhwc(A, dW, stride, padding) a_shape = get_const_tuple(A.shape) w_shape = get_const_tuple(W.shape) @@ -23,7 +24,8 @@ def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, p def get_ref_data(): a_np = np.random.uniform(size=a_shape).astype(dtype) w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = topi.testing.conv2d_nhwc_python(a_np, w_np, stride, padding) + dw_np = topi.testing.dilate_python(w_np, (1, dilation, dilation, 1)) + b_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) return a_np, w_np, b_np a_np, w_np, b_np = get_ref_data() @@ -54,6 +56,8 @@ def test_conv2d_nhwc(): verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "VALID") verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "VALID") verify_conv2d_nhwc(4, 128, 16, 256, 5, 2, "VALID") + # dilation = 2 + verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "SAME", dilation=2) if __name__ == "__main__": diff --git a/topi/tests/python/test_topi_depthwise_conv2d.py b/topi/tests/python/test_topi_depthwise_conv2d.py index 5a12d7abb..3e9794ac3 100644 --- a/topi/tests/python/test_topi_depthwise_conv2d.py +++ b/topi/tests/python/test_topi_depthwise_conv2d.py @@ -8,21 +8,21 @@ from tvm.contrib.pickle_memoize import memoize from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nhwc -def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding): +def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1): in_width = in_height filter_channel = in_channel filter_width = filter_height # placeholder Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input') Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter') + DilatedFilter = topi.nn.dilate(Filter, (1, 1, dilation, dilation), name='DilatedFilter') Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') # declare - DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, stride=stride, padding=padding) + DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, DilatedFilter, stride=stride, padding=padding) ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift) Relu = topi.nn.relu(ScaleShift) - def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: @@ -52,11 +52,12 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu def get_ref_data(): input_np = np.random.uniform(size=input_shape).astype(dtype) filter_np = np.random.uniform(size=filter_shape).astype(dtype) + dilated_filter_np = topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation)) scale_np = np.random.uniform(size=scale_shape).astype(dtype) shift_np = np.random.uniform(size=shift_shape).astype(dtype) # correctness with scipy depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw( - input_np, filter_np, stride=stride, padding=padding) + input_np, dilated_filter_np, stride=stride, padding=padding) scale_shift_scipy = np.zeros(shape=scale_shift_shape) for c in range(in_channel * channel_multiplier): scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c] @@ -94,7 +95,7 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu check_device("vulkan") -def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): +def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding, dilation=1): in_width = in_height filter_channel = in_channel filter_width = filter_height @@ -102,10 +103,11 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu # placeholder Input = tvm.placeholder((batch, in_height, in_width, in_channel), name='Input') Filter = tvm.placeholder((filter_height, filter_width,filter_channel, channel_multiplier), name='Filter') + DilatedFilter = topi.nn.dilate(Filter, (1, 1, dilation, dilation), name='DilatedFilter') Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') # declare - DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, Filter, stride=[stride_h, stride_w], padding=padding) + DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, DilatedFilter, stride=[stride_h, stride_w], padding=padding) ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift) Relu = topi.nn.relu(ScaleShift) # schedule @@ -139,11 +141,12 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu def get_ref_data(): input_np = np.random.uniform(size=input_shape).astype(dtype) filter_np = np.random.uniform(size=filter_shape).astype(dtype) + dilated_filter_np = topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation)) scale_np = np.random.uniform(size=scale_shape).astype(dtype) shift_np = np.random.uniform(size=shift_shape).astype(dtype) # correctness with scipy depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc( - input_np, filter_np, stride=[stride_h, stride_w], padding=padding) + input_np, dilated_filter_np, stride=[stride_h, stride_w], padding=padding) scale_shift_scipy = np.zeros(shape=scale_shift_shape) for c in range(in_channel * channel_multiplier): scale_shift_scipy[:,:,:,c] = depthwise_conv2d_scipy[:,:,:,c] * scale_np[c] + shift_np[c] @@ -192,6 +195,8 @@ def test_depthwise_conv2d(): depthwise_conv2d_with_workload_nchw(1, 728, 32, 1, 3, 1, "VALID") depthwise_conv2d_with_workload_nchw(4, 256, 64, 2, 5, 2, "VALID") depthwise_conv2d_with_workload_nchw(4, 256, 32, 2, 5, 2, "VALID") + # dilation = 2 + depthwise_conv2d_with_workload_nchw(1, 728, 64, 1, 3, 1, "SAME", dilation=2) print("testing nhwc") depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME") depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "SAME") @@ -201,6 +206,8 @@ def test_depthwise_conv2d(): depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "VALID") depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "VALID") depthwise_conv2d_with_workload_nhwc(4, 256, 32, 2, 5, 2, "VALID") + # dilation = 2 + depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2) if __name__ == "__main__": test_depthwise_conv2d() -- GitLab