diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 6f641e99f7dd3357be317090b330a097287660e6..5c580aad24c4455230c94a018a59376e2a09e243 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -35,6 +35,24 @@ def schedule_conv2d_nchw(outs): return _default_schedule(outs, False) +@tvm.target.generic_func +def schedule_conv2d_nhwc(outs): + """Schedule for conv2d_nhwc + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of conv2d_nchw + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + @tvm.target.generic_func def schedule_conv2d_transpose_nchw(outs): """Schedule for conv2d_transpose_nchw diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 11866aedc1010429478319e28de890bda3f2ac41..3bd910e299741803015d691384e0ba8c7e76845c 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -337,6 +337,57 @@ def conv2d_hwcn(Input, Filter, stride, padding, out_dtype='float32'): name="Conv2dOutput", tag="conv2d_hwcn") return Output + +def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'): + """Convolution operator in NHWC layout. + + Parameters + ---------- + Input : tvm.Tensor + 4-D with shape [batch, in_height, in_width, in_channel] + + Filter : tvm.Tensor + 4-D with shape [filter_height, filter_width, in_channel, num_filter] + + stride : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + Returns + ------- + output : tvm.Tensor + 4-D with shape [batch, out_height, out_width, out_channel] + """ + assert isinstance(stride, int) or len(stride) == 2 + batch, in_height, in_width, in_channel = Input.shape + kernel_h, kernel_w, channel, num_filter = Filter.shape + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (kernel_h, kernel_w)) + # compute the output shape + out_channel = num_filter + out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1) + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") + rc = tvm.reduce_axis((0, in_channel), name='rc') + ry = tvm.reduce_axis((0, kernel_h), name='ry') + rx = tvm.reduce_axis((0, kernel_w), name='rx') + Output = tvm.compute( + (batch, out_height, out_width, out_channel), + lambda nn, yy, xx, ff: tvm.sum( + PaddedInput[nn, yy * stride_h + ry, xx * stride_w + rx, rc].astype(out_dtype) * + Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]), + name="Conv2dOutput", tag="conv2d_nhwc") + return Output + # map from schedule type to declaration function _SCH_TO_DECL_FUNC = { SpatialPack: _spatial_pack, diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 6a1b361e30970ce4de852f67a7888130acb7bed9..2a20a1c4f6225cdc8dd86b217bb2242e5a5734fe 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -6,6 +6,7 @@ from __future__ import absolute_import as _abs from .conv2d_hwcn_python import conv2d_hwcn_python from .conv2d_nchw_python import conv2d_nchw_python +from .conv2d_nhwc_python import conv2d_nhwc_python from .conv2d_transpose_nchw_python import conv2d_transpose_nchw_python from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .dilate_python import dilate_python diff --git a/topi/python/topi/testing/conv2d_nhwc_python.py b/topi/python/topi/testing/conv2d_nhwc_python.py new file mode 100644 index 0000000000000000000000000000000000000000..880088a6f89fb8263bca2842ea2cd5d3395535f8 --- /dev/null +++ b/topi/python/topi/testing/conv2d_nhwc_python.py @@ -0,0 +1,67 @@ +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Convolution in python""" +import numpy as np +import scipy.signal + + +def conv2d_nhwc_python(a_np, w_np, stride, padding): + """Convolution operator in NHWC layout. + + Parameters + ---------- + a_np : numpy.ndarray + 4-D with shape [batch, in_height, in_width, in_channel] + + w_np : numpy.ndarray + 4-D with shape [num_filter, filter_height, filter_width, in_channel] + + stride : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + Returns + ------- + b_np : np.ndarray + 4-D with shape [out_height, out_width, out_channel, batch] + """ + batch, in_height, in_width, in_channel = a_np.shape + kernel_h, kernel_w, _, num_filter = w_np.shape + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + if isinstance(padding, int): + pad_h = pad_w = padding * 2 + elif padding == 'VALID': + pad_h = 0 + pad_w = 0 + else: # 'SAME' + pad_h = kernel_h - 1 + pad_w = kernel_w - 1 + pad_top = int(np.ceil(float(pad_h) / 2)) + pad_bottom = pad_h - pad_top + pad_left = int(np.ceil(float(pad_w) / 2)) + pad_right = pad_w - pad_left + # compute the output shape + out_channel = num_filter + out_height = (in_height - kernel_h + pad_h) // stride_h + 1 + out_width = (in_width - kernel_w + pad_w) // stride_w + 1 + # change the layout from NHWC to NCHW + at = a_np.transpose((0, 3, 1, 2)) + wt = w_np.transpose((3, 2, 0, 1)) + bt = np.zeros((batch, out_channel, out_height, out_width)) + # computation + for n in range(batch): + for f in range(out_channel): + for c in range(in_channel): + if pad_h > 0: + apad = np.zeros((in_height + pad_h, in_width + pad_w)) + apad[pad_top:-pad_bottom, pad_left:-pad_right] = at[n, c] + else: + apad = at[n, c] + out = scipy.signal.convolve2d( + apad, np.rot90(np.rot90(wt[f, c])), mode='valid') + bt[n, f] += out[::stride, ::stride] + return bt.transpose((0, 2, 3, 1)) diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index 6ab37b8c03ac33c9e79b3f64b978cd12963cfb87..ef227d035fce540a0e21ca88a4ff9946b9d44b5d 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -2,6 +2,8 @@ """x86 specific declaration and schedules.""" from __future__ import absolute_import as _abs -from .conv2d import schedule_conv2d +from .conv2d import schedule_conv2d, schedule_conv2d_nhwc from .binarize_pack import schedule_binarize_pack from .binary_dense import schedule_binary_dense +from .nn import * +from .injective import * diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 0c91f8c25c88b04c8b90d80340b6fd0fd8ed2576..cb3571d6a91b7b5833fee211d8a730b0c64beebe 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -15,6 +15,12 @@ def schedule_conv2d(outs): if tag.is_broadcast(op.tag): if op not in s.outputs: s[op].compute_inline() + else: # inject custom schedule + if len(op.axis) == 4: # schedule bias + bn + relu + n, c, h, w = op.axis + fused = s[op].fuse(n, c) + s[op].parallel(fused) + s[op].vectorize(w) for tensor in op.input_tensors: if tensor.op.input_tensors: traverse(tensor.op) @@ -28,10 +34,68 @@ def schedule_conv2d(outs): data_pad = data data = data_pad.op.input_tensors[0] + n_pad, c_pad, h_pad, w_pad = data_pad.op.axis + pad_fused = s[data_pad].fuse(n_pad, c_pad) + s[data_pad].parallel(pad_fused) C = conv n, c, h, w = C.op.axis - s[C].parallel(c) - s[C].pragma(n, "parallel_launch_point") + rc, ry, rx = C.op.reduce_axis + fused = s[C].fuse(n, c) + s[C].parallel(fused) + wo, wi = s[C].split(w, factor=16) + s[C].reorder(fused, rc, h, wo, ry, rx, wi) # move rc to outer loop + s[C].unroll(rx) + s[C].unroll(ry) + s[C].vectorize(wi) traverse(outs[0].op) return s + + +@generic.schedule_conv2d_nhwc.register(["cpu"]) +def schedule_conv2d_nhwc(outs): + """Create schedule for tensors""" + s = tvm.create_schedule([x.op for x in outs]) + output_op = outs[0].op + + def traverse(op): + """Traverse operators from computation graph""" + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + else: # inject custom schedule + if len(op.axis) == 4: # schedule bias + bn + relu + n, h, w, c = op.axis + fused = s[op].fuse(n, h, w) + s[op].parallel(fused) + s[op].vectorize(c) + for tensor in op.input_tensors: + if tensor.op.input_tensors: + traverse(tensor.op) + + if 'conv2d_nhwc' in op.tag: + conv = op.output(0) + kernel = op.input_tensors[1] + data = op.input_tensors[0] + data_pad = None + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + + n_pad, h_pad, w_pad, c_pad = data_pad.op.axis + pad_fused = s[data_pad].fuse(n_pad, h_pad) + s[data_pad].parallel(pad_fused) + C = conv + n, h, w, c = C.op.axis + ry, rx, rc = C.op.reduce_axis + n_out, h_out, w_out, c_out = output_op.axis + s[C].vectorize(c) + if op != output_op: # fuse bias + bn + relu into conv + s[C].compute_at(s[output_op], c_out) + else: + fused = s[C].fuse(n, h, w) + s[C].parallel(fused) + + traverse(output_op) + return s diff --git a/topi/python/topi/x86/injective.py b/topi/python/topi/x86/injective.py new file mode 100644 index 0000000000000000000000000000000000000000..0970b76142ae4d9e8437a734839a7281ad614e74 --- /dev/null +++ b/topi/python/topi/x86/injective.py @@ -0,0 +1,35 @@ +# pylint: disable=invalid-name +"""x86 declaration and schedules.""" +from __future__ import absolute_import as _abs +import tvm +from .. import generic + +@generic.schedule_injective.register(["cpu"]) +def schedule_injective(outs): + """X86 schedule for injective op. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of injective in the format + of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + x = outs[0] + s = tvm.create_schedule([x.op for x in outs]) + tvm.schedule.AutoInlineInjective(s) + if len(s[x].op.axis) == 4: + n, c, _, _ = s[x].op.axis + fused = s[x].fuse(n, c) # for nhwc layout, fuse n and h + s[x].parallel(fused) + else: + s[x].parallel(s[x].op.axis[0]) + return s + +schedule_elemwise = schedule_injective +schedule_broadcast = schedule_injective diff --git a/topi/python/topi/x86/nn.py b/topi/python/topi/x86/nn.py new file mode 100644 index 0000000000000000000000000000000000000000..49aa382589d13b3d16373b739440257d60388d1b --- /dev/null +++ b/topi/python/topi/x86/nn.py @@ -0,0 +1,56 @@ +"""x86 nn operators""" +from __future__ import absolute_import as _abs +import tvm +from .. import generic + +def _default_schedule(outs, auto_inline): + """Default schedule for x86.""" + x = outs[0] + s = tvm.create_schedule([x.op for x in outs]) + if auto_inline: + tvm.schedule.AutoInlineInjective(s) + s[x].fuse(s[x].op.axis) + return s + if len(s[x].op.axis) == 4: + n, c, _, _ = s[x].op.axis + fused = s[x].fuse(n, c) # for nhwc layout, fuse n and h + s[x].parallel(fused) + else: + s[x].parallel(s[x].op.axis[0]) + return s + + +@generic.schedule_softmax.register(["cpu"]) +def schedule_softmax(outs): + """Schedule for softmax + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of softmax + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + +@generic.schedule_pool.register(["cpu"]) +def schedule_pool(outs): + """Schedule for pool + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of pool + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/topi/tests/python/test_topi_conv2d_nhwc.py b/topi/tests/python/test_topi_conv2d_nhwc.py new file mode 100644 index 0000000000000000000000000000000000000000..7fc5b841908f1c32711418ca768494c6bfc7eff3 --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_nhwc.py @@ -0,0 +1,59 @@ +"""Example code to do convolution.""" +import os +import numpy as np +import tvm +import topi +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + + +def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding): + in_height = in_width = in_size + + A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A') + W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W') + B = topi.nn.conv2d_nhwc(A, W, stride, padding) + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_conv2d_nhwc.verify_nhwc") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = topi.testing.conv2d_nhwc_python(a_np, w_np, stride, padding) + return a_np, w_np, b_np + a_np, w_np, b_np = get_ref_data() + + def check_device(device): + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_conv2d_nhwc([B]) + ctx = tvm.context(device, 0) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + func = tvm.build(s, [A, W, B], device) + func(a, w, b) + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in ['llvm']: + check_device(device) + + +def test_conv2d_nhwc(): + verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "SAME") + verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "SAME") + verify_conv2d_nhwc(4, 128, 16, 256, 5, 2, "SAME") + verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "VALID") + verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "VALID") + verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "VALID") + verify_conv2d_nhwc(4, 128, 16, 256, 5, 2, "VALID") + + +if __name__ == "__main__": + test_conv2d_nhwc()