diff --git a/topi/python/topi/testing/conv2d_nhwc_python.py b/topi/python/topi/testing/conv2d_nhwc_python.py index 461d7a6a0e0673e3b0f068de0890c8c1731568d8..a872bddab09bc34ca821728ca79d53861417b0a5 100644 --- a/topi/python/topi/testing/conv2d_nhwc_python.py +++ b/topi/python/topi/testing/conv2d_nhwc_python.py @@ -13,7 +13,7 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding): 4-D with shape [batch, in_height, in_width, in_channel] w_np : numpy.ndarray - 4-D with shape [num_filter, filter_height, filter_width, in_channel] + 4-D with shape [filter_height, filter_width, in_channel, num_filter] stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width]