diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 174a37b1d451a0ef26de5550fc221a8a4f1bb403..74e92384a9e0f6c080e9317beb29a3c39bc2f38f 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -16,3 +16,4 @@ from .pooling import schedule_pool, schedule_global_pool from .conv2d_transpose_nchw import schedule_conv2d_transpose_nchw from .extern import schedule_extern from .vision import schedule_region +from .nn import schedule_lrn, schedule_l2norm diff --git a/topi/python/topi/cuda/nn.py b/topi/python/topi/cuda/nn.py new file mode 100644 index 0000000000000000000000000000000000000000..e8757970505b477ca14a120ee136e5716a998913 --- /dev/null +++ b/topi/python/topi/cuda/nn.py @@ -0,0 +1,91 @@ +# pylint: disable=invalid-name +"""scheduler functions for cuda backend""" +from __future__ import absolute_import as _abs + +import tvm +from .. import generic +from .. import tag +from .reduction import _schedule_reduce + +@generic.schedule_lrn.register(["cuda"]) +def schedule_lrn(outs): + """Schedule for LRN + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of LRN + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + num_thread = 64 + block_x = tvm.thread_axis("blockIdx.x") + thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") + + lrn = outs[0] + sqr_sum_up = lrn.op.input_tensors[1] + sqr_sum = sqr_sum_up.op.input_tensors[0] + set_pad = sqr_sum.op.input_tensors[0] + s[set_pad].bind(set_pad.op.axis[0], block_x) + rxk = sqr_sum.op.reduce_axis[0] + _, xki = s[sqr_sum].split(rxk, factor=num_thread) + srf = s.rfactor(sqr_sum, xki) + s[sqr_sum].bind(s[sqr_sum].op.axis[0], block_x) + s[sqr_sum].bind(s[sqr_sum].op.reduce_axis[0], thread_x) + s[srf].compute_at(s[sqr_sum], s[sqr_sum].op.reduce_axis[0]) + s[sqr_sum_up].bind(sqr_sum_up.op.axis[0], block_x) + xto, _ = s[lrn].split(lrn.op.axis[1], nparts=num_thread) + s[lrn].bind(lrn.op.axis[0], block_x) + s[lrn].bind(xto, thread_x) + return s + +@generic.schedule_l2norm.register(["cuda"]) +def schedule_l2norm(outs): + """Schedule for L2norm + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of L2norm + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + + def traverse(OP): + '''inline all one-to-one-mapping operators + except the last stage (output)''' + if tag.is_injective(OP.tag) or OP.tag == 'l2norm': + if OP not in s.outputs: + s[OP].compute_inline() + for tensor in OP.input_tensors: + if tensor.op.input_tensors: + traverse(tensor.op) + elif OP.tag == 'comm_reduce': + _schedule_reduce(OP, s, is_idx_reduce=False) + for tensor in OP.input_tensors: + traverse(tensor.op) + else: + raise RuntimeError("Unsupported operator tag: %s" % OP.tag) + traverse(outs[0].op) + + num_thread = 64 + l2norm = outs[0] + block_x = tvm.thread_axis("blockIdx.x") + thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") + xto, _ = s[l2norm].split(l2norm.op.axis[1], nparts=num_thread) + s[l2norm].bind(l2norm.op.axis[0], block_x) + s[l2norm].bind(xto, thread_x) + + return s diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 7fe76e1739f01e5862e844053370bbf19af695d0..7252a23b90c3c9ba7477231868c372c4ecbeb1d5 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -229,3 +229,39 @@ def schedule_binary_dense(outs): The computation schedule for the op. """ return _default_schedule(outs, False) + + +@tvm.target.generic_func +def schedule_lrn(outs): + """Schedule for lrn + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of lrn + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + +@tvm.target.generic_func +def schedule_l2norm(outs): + """Schedule for l2norm + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of l2norm + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py index 918f399f503ce6a856b7ba2051157dbde1deeca3..056d1a76339a2414df8bf6198ce70439851dbdd1 100644 --- a/topi/python/topi/nn/__init__.py +++ b/topi/python/topi/nn/__init__.py @@ -15,3 +15,5 @@ from .softmax import * from .conv2d_transpose import * from .bnn import * from .upsampling import * +from .local_response_norm import * +from .l2_norm import * diff --git a/topi/python/topi/nn/l2_norm.py b/topi/python/topi/nn/l2_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..6b5381a8559973b70bbd4f019677be2ce68f7509 --- /dev/null +++ b/topi/python/topi/nn/l2_norm.py @@ -0,0 +1,35 @@ +# pylint: disable=invalid-name +"""TVM operator for l2norm""" +from __future__ import absolute_import +import tvm +import topi + +@tvm.target.generic_func +def l2norm_instance(data, eps, axis=None): + """Perform L2norm on the input data + + For axis=None, y(i, j) = x(i, j) / sqrt(max(sum(x^2), eps)) + + Parameters + ---------- + data : tvm.Tensor + 4-D with NCHW or NHWC layout + + eps : float + epsilon value + + axis : list of int + axis over the normalization applied + + Returns + ------- + output : tvm.Tensor + 4-D output with same shape + """ + assert len(data.shape) == 4, "only support 4-dim lrn" + dot_value = topi.cpp.pow(data, 2.0) + sum_value = topi.sum(dot_value, axis=axis, keepdims=True) + expand_sum = topi.broadcast_to(sum_value, data.shape) + return topi.broadcast_div(data, topi.sqrt(\ + tvm.compute(expand_sum.shape, lambda i, j, k, l:\ + tvm.max(expand_sum[i, j, k, l], eps), tag='l2norm'))) diff --git a/topi/python/topi/nn/local_response_norm.py b/topi/python/topi/nn/local_response_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..b44e02214acc0370bb9a5e0cb6f9fc044d20944b --- /dev/null +++ b/topi/python/topi/nn/local_response_norm.py @@ -0,0 +1,68 @@ +# pylint: disable=invalid-name +"""TVM operator for local response norm compute.""" +from __future__ import absolute_import +import tvm +import topi +from .pad import pad + +@tvm.target.generic_func +def lrn(data, size, axis=1, alpha=0.0001, beta=0.75, bias=2): + """Perform the across channels local response normalisation + on the input data. + + sum_sqr_up^i{x, y} = (bias+((alpha/size)* \ + {sum_{j=max(0, i-size/2)}^{min(N-1,i+size/2)} \ + (data^j{x,y})^2}))^beta + output^i{x, y} = data^i{x, y}/sum_sqr_up^i{x, y} + N is the number for input channels + + Parameters + ---------- + data : tvm.Tensor + 4-D with shape [batch, channel, height, width] + + size : int + normalisation window size + + axis : int + input data layout channel axis + default value is 1 for NCHW format + + bias : float + offset to avoid dividing by 0 + + alpha : float + to be divided + + beta : float + exponent + + Returns + ------- + output : tvm.Tensor + 4-D output with same shape + """ + assert len(data.shape) == 4, "only support 4-dim lrn" + assert (size % 2) == 1, "size should be odd number" + assert (axis == 1) or (axis == 3), "axis should 1 or 3 for NCHW and NHWC" + ##Add padding on left & right of size radius first + pad_after = pad_before = [0, 0, 0, 0] + pad_after[axis] = pad_before[axis] = (size//2) + pad_data = pad(data, pad_before, pad_after, name="pad_data") + + rxs = tvm.reduce_axis((0, size), name='rxs') + if axis == 1: + #NCHW layout + sqr_sum = tvm.compute(data.shape, lambda i, j, k, l: tvm.sum( + pad_data[i, j + rxs, k, l] * pad_data[i, j + rxs, k, l], + axis=rxs)) + elif axis == 3: + #NHWC layout + sqr_sum = tvm.compute(data.shape, lambda i, j, k, l: tvm.sum( + pad_data[i, j, k, l + rxs] * pad_data[i, j, k, l + rxs], + axis=rxs)) + + sqr_sum_up = tvm.compute(data.shape, lambda i, j, k, l: tvm.power( + (bias + (alpha * sqr_sum[i, j, k, l] / size)), beta)) + + return topi.broadcast_div(data, sqr_sum_up) diff --git a/topi/python/topi/rocm/__init__.py b/topi/python/topi/rocm/__init__.py index a5b4ee30dc37368a48dd6ae00faeef7b1726cd67..96a04794c680513978479b571323be399efc5ab8 100644 --- a/topi/python/topi/rocm/__init__.py +++ b/topi/python/topi/rocm/__init__.py @@ -5,3 +5,4 @@ from __future__ import absolute_import as _abs from .conv2d import * from .dense import * from .vision import * +from .nn import * diff --git a/topi/python/topi/rocm/nn.py b/topi/python/topi/rocm/nn.py new file mode 100644 index 0000000000000000000000000000000000000000..d9c529155f7b7d1a620b5bc00d15a7507d521509 --- /dev/null +++ b/topi/python/topi/rocm/nn.py @@ -0,0 +1,13 @@ +"""scheduler for normalization functions on rocm backend""" +from __future__ import absolute_import as _abs + +import topi +from .. import generic + +@generic.schedule_lrn.register(["rocm", "gpu"]) +def schedule_lrn(outs): + return topi.cuda.schedule_lrn(outs) + +@generic.schedule_l2norm.register(["rocm", "gpu"]) +def schedule_l2norm(outs): + return topi.cuda.schedule_l2norm(outs) diff --git a/topi/tests/python/test_topi_l2norm.py b/topi/tests/python/test_topi_l2norm.py new file mode 100644 index 0000000000000000000000000000000000000000..aa7970125b4406e85a062588c964628569c4cc23 --- /dev/null +++ b/topi/tests/python/test_topi_l2norm.py @@ -0,0 +1,70 @@ +"""Test code for L2 norm""" +import numpy as np +import tvm +import topi +from topi.util import get_const_tuple + +def l2norm_instance_python(a_np, eps, axis=None): + """L2 norm operator in NCHW layout. + + Parameters + ---------- + a_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] + + eps : float + epsilon constant value + axis : list of int + axis over the normalization applied + + Returns + ------- + l2norm_out : np.ndarray + 4-D with shape [batch, out_channel, out_height, out_width] + """ + batch, axis1, axis2, axis3 = a_np.shape + sqr_sum = np.zeros(shape=(batch,)).astype(a_np.dtype) + sqrt_sum = np.zeros(shape=(batch,)).astype(a_np.dtype) + l2norm_out = np.zeros(shape=a_np.shape).astype(a_np.dtype) + dot_value = np.power(a_np, 2.0) + sqr_sum = np.sum(dot_value, axis, keepdims=True) + sqrt_sum = np.sqrt(np.maximum(np.broadcast_to(sqr_sum, a_np.shape), eps)) + return np.divide(a_np, sqrt_sum) + +def verify_l2norm(n, c, h, w, eps, axis=None): + + A = tvm.placeholder((n, c, h, w), name='A') + B = topi.nn.l2norm_instance(A, eps, axis) + dtype = A.dtype + + a_np = np.random.uniform(size=(n, c, h, w)).astype(dtype) + b_np = l2norm_instance_python(a_np, eps, axis) + + def check_device(device): + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_l2norm(B) + ctx = tvm.context(device, 0) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) + f = tvm.build(s, [A, B], device) + f(a, b) + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']: + check_device(device) + +def test_l2norm(): + verify_l2norm(1, 3, 20, 20, 0.001) + verify_l2norm(1, 3, 20, 20, 0.001, 1) + verify_l2norm(1, 3, 20, 20, 0.001, (1, 2)) + verify_l2norm(1, 3, 20, 20, 0.001, (2, 3)) + verify_l2norm(1, 3, 20, 20, 0.001, (0, 3)) + verify_l2norm(1, 3, 20, 20, 0.001, (0, 2, 3)) + + +if __name__ == "__main__": + test_l2norm() diff --git a/topi/tests/python/test_topi_lrn.py b/topi/tests/python/test_topi_lrn.py new file mode 100644 index 0000000000000000000000000000000000000000..6c4714077e1ebb880c4c2b94a71e51d177a4c4da --- /dev/null +++ b/topi/tests/python/test_topi_lrn.py @@ -0,0 +1,95 @@ +"""Test code for local response normalization""" +import numpy as np +import tvm +import topi +from topi.util import get_const_tuple + +def lrn_python(a_np, size, axis, bias, alpha, beta): + """Local response norm operator in NCHW layout. + + Parameters + ---------- + a_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] + + size : int + normalisation window size + + axis : int + input data layout channel axis + + bias : float + offset to avoid dividing by 0. constant value + + alpha : float + contant valie + + beta : float + exponent constant value + + Returns + ------- + b_np : np.ndarray + 4-D with shape [batch, out_channel, out_height, out_width] + """ + axis0, axis1, axis2, axis3 = a_np.shape + radius = size // 2 + sqr_sum = np.zeros(shape=a_np.shape).astype(a_np.dtype) + sqr_sum_up = np.zeros(shape=a_np.shape).astype(a_np.dtype) + lrn_out = np.zeros(shape=a_np.shape).astype(a_np.dtype) + def sum_dot_values(i, j, k, l): + axis_size = a_np.shape[axis] + if (axis == 1): + #NCHW layout + sum_start = j-radius if j-radius >= 0 else 0 + sum_end = j+radius+1 if j+radius+1 < axis_size else axis_size + sqr_sum[i, j, k, l] = sum(a_np[i, sum_start:sum_end, k, l] * \ + a_np[i, sum_start:sum_end, k, l]) + elif (axis == 3): + #NHWC layout + sum_start = l-radius if l-radius >= 0 else 0 + sum_end = l+radius+1 if l+radius+1 < axis_size else axis_size + sqr_sum[i, j, k, l] = sum(a_np[i, j, k, sum_start:sum_end] * \ + a_np[i, j, k, sum_start:sum_end]) + + for i in range(axis0): + for j in range(axis1): + for k in range(axis2): + for l in range(axis3): + sum_dot_values(i, j, k, l) + + sqr_sum_up = np.power((bias + (alpha * sqr_sum /size)), beta) + return np.divide(a_np, sqr_sum_up) + +def verify_lrn(shape, size, axis, bias, alpha, beta): + A = tvm.placeholder(shape, name='A') + B = topi.nn.lrn(A, size, axis, alpha, beta, bias) + dtype = A.dtype + + a_np = np.random.uniform(size=shape).astype(dtype) + b_np = lrn_python(a_np, size, axis, bias, alpha, beta) + + def check_device(device): + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_lrn(B) + ctx = tvm.context(device, 0) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) + f = tvm.build(s, [A, B], device) + f(a, b) + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']: + check_device(device) + +def test_lrn(): + verify_lrn((1, 3, 5, 5), 3, 1, 1, 1, 0.5) + verify_lrn((1, 3, 5, 5), 3, 3, 1, 1, 0.5) + verify_lrn((1, 3, 20, 20), 3, 1, 2, 1, 0.75) + +if __name__ == "__main__": + test_lrn()