diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index 7a51449838381359ea7dea9a8adb0eae59ad1d5c..ae9853db7cff3e9d62310eb34d91821a49f766f3 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -85,11 +85,18 @@ def compute_conv2d(attrs, inputs, _): channels = attrs.get_int("channels") layout = attrs["layout"] assert layout == "NCHW" or layout == "NHWC" - assert dilation == (1, 1), "not support dilate now" + (dilation_h, dilation_w) = dilation + if dilation_h < 1 or dilation_w < 1: + raise ValueError("dilation should be positive value") + elif layout == "NCHW": + kernel = topi.nn.dilate(inputs[1], [1, 1, dilation_h, dilation_w]) + else: #layout == NHWC + kernel = topi.nn.dilate(inputs[1], [1, dilation_h, dilation_w, 1]) + if groups == 1: - out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding, layout) + out = topi.nn.conv2d(inputs[0], kernel, strides, padding, layout) elif groups == get_const_int(inputs[0].shape[1]) and groups == channels: - out = topi.nn.depthwise_conv2d_nchw(inputs[0], inputs[1], strides, padding) + out = topi.nn.depthwise_conv2d_nchw(inputs[0], kernel, strides, padding) else: raise ValueError("not support arbitrary group number for now") if attrs.get_bool("use_bias"): diff --git a/nnvm/tests/python/compiler/test_top_level2.py b/nnvm/tests/python/compiler/test_top_level2.py index 89c582fedefacd4d446e6d4b4777a7675e8f2069..44767da4541f11ce2b628ebed8867dc7e68fda39 100644 --- a/nnvm/tests/python/compiler/test_top_level2.py +++ b/nnvm/tests/python/compiler/test_top_level2.py @@ -32,6 +32,32 @@ def test_conv2d(): np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) +def test_dilated_conv2d(): + dilation = 3 + x = sym.Variable("x") + y = sym.conv2d(x, channels=10, kernel_size=(3, 3), dilation=(dilation, dilation), + name="y", padding=(1, 1)) + dtype = "float32" + dshape = (1, 3, 18, 18) + kshape = (10, 3, 3, 3) + oshape = (1, 10, 14, 14) + shape_dict = {"x": dshape} + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) + m = graph_runtime.create(graph, lib, ctx) + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + bias = tvm.nd.array(np.random.uniform(size=kshape[0]).astype(dtype)) + kernel_np = np.random.uniform(size=kshape).astype(dtype) + kernel = tvm.nd.array(kernel_np) + dkernel_np = topi.testing.dilate_python(kernel_np, (1, 1, dilation, dilation)) + m.run(x=data, y_weight=kernel, y_bias=bias) + out = m.get_output(0, tvm.nd.empty(oshape, dtype)) + c_np = topi.testing.conv2d_nchw_python( + data.asnumpy(), dkernel_np, 1, 1) + c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1) + np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) + + def test_grouped_conv2d(): x = sym.Variable("x") y = sym.conv2d(x, channels=32, kernel_size=(3,3), groups=32, @@ -170,6 +196,7 @@ def test_upsampling(): if __name__ == "__main__": test_conv2d() + test_dilated_conv2d() test_grouped_conv2d() test_conv2d_transpose() test_max_pool2d()