diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 39bd50686fa3c0ee81a30e467c2d92a0dee95a27..5b5c976862064a2f9717fd1792ab0e5c28aa7bb5 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -7,3 +7,4 @@ from .conv2d_hwcn import schedule_conv2d_hwcn from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc from .reduction import schedule_reduce from .broadcast import schedule_broadcast_to +from .softmax import schedule_softmax diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py index d06ca1d65241054f394ae196d92a51f0335e4329..57e8be2951d54ce4247aaf68a032d703a7f34003 100644 --- a/topi/python/topi/cuda/conv2d_nchw.py +++ b/topi/python/topi/cuda/conv2d_nchw.py @@ -115,7 +115,7 @@ def schedule_conv2d_small_batch(outs): return s def schedule_conv2d_nchw(outs): - """Schedule for conv2d_nchw and any element-wise operations. + """Schedule for conv2d_nchw. Parameters ---------- diff --git a/topi/python/topi/cuda/softmax.py b/topi/python/topi/cuda/softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..de990e086ac8750f7dbf420d24788654b0ac750a --- /dev/null +++ b/topi/python/topi/cuda/softmax.py @@ -0,0 +1,42 @@ +# pylint: disable=invalid-name, unused-variable, trailing-whitespace +"""Schedule for softmax operator""" +import tvm + +def schedule_softmax(outs): + """Schedule for softmax op. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of reduce in the format + of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + softmax = outs[0] + max_elem = softmax.op.input_tensors[1] + expsum = softmax.op.input_tensors[2] + + num_thread = 64 + block_x = tvm.thread_axis("blockIdx.x") + thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") + + s[max_elem].bind(max_elem.op.axis[0], block_x) + + k = expsum.op.reduce_axis[0] + ko, ki = s[expsum].split(k, factor=num_thread) + EF = s.rfactor(expsum, ki) + s[expsum].bind(s[expsum].op.axis[0], block_x) + s[expsum].bind(s[expsum].op.reduce_axis[0], thread_x) + s[EF].compute_at(s[expsum], s[expsum].op.reduce_axis[0]) + + tx, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread) + s[softmax].bind(softmax.op.axis[0], block_x) + s[softmax].bind(tx, thread_x) + + return s diff --git a/topi/python/topi/nn/softmax.py b/topi/python/topi/nn/softmax.py index 8394b1afb0ab54760f33859003789088b3fef04a..4c39a9f1f78422b5de65b9351983ce287fd2cdbc 100644 --- a/topi/python/topi/nn/softmax.py +++ b/topi/python/topi/nn/softmax.py @@ -1,8 +1,9 @@ +# pylint: disable=invalid-name """TVM operator softmax compute.""" from __future__ import absolute_import import tvm -@tvm.tag_scope(tag='softmax') +@tvm.tag_scope(tag='softmax_output') def softmax(x): """Perform softmax activation on the data diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 1cd4c54dfecef64462c4bb699c7eb1ffc45275b4..1a715eb4fdc824cbcdeac75613498305790205b6 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -8,3 +8,4 @@ from .conv2d_hwcn_python import conv2d_hwcn_python from .conv2d_nchw_python import conv2d_nchw_python from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .dilate_python import dilate_python +from .softmax_python import softmax_python diff --git a/topi/python/topi/testing/softmax_python.py b/topi/python/topi/testing/softmax_python.py new file mode 100644 index 0000000000000000000000000000000000000000..5dc561a56e2ad67a8390465b7048f1cd581743c6 --- /dev/null +++ b/topi/python/topi/testing/softmax_python.py @@ -0,0 +1,23 @@ +# pylint: disable=invalid-name, trailing-whitespace +"""Softmax operation in python""" +import numpy as np + +def softmax_python(a_np): + """Softmax operator. + Parameters + ---------- + a_np : numpy.ndarray + 2-D input data + + Returns + ------- + output_np : numpy.ndarray + 2-D output with same shape + """ + assert len(a_np.shape) == 2, "only support 2-dim softmax" + max_elem = np.amax(a_np, axis=1) + max_elem = max_elem.reshape(max_elem.shape[0], 1) + e = np.exp(a_np-max_elem) + expsum = np.sum(e, axis=1) + out_np = e / expsum[:, None] + return out_np diff --git a/topi/tests/python/test_topi_softmax.py b/topi/tests/python/test_topi_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..0afd74606343559c862e25b22d481c39c4bb3dc1 --- /dev/null +++ b/topi/tests/python/test_topi_softmax.py @@ -0,0 +1,36 @@ +"""Test code for softmax""" +import os +import numpy as np +import tvm +import topi +from topi.util import get_const_tuple + +def verify_softmax(m, n): + + A = tvm.placeholder((m, n), name='A') + B = topi.nn.softmax(A) + s = topi.cuda.schedule_softmax(B) + + a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) + b_np = topi.testing.softmax_python(a_np) + + def check_device(device): + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + foo = tvm.build(s, [A, B], device, name="softmax") + foo(a, b) + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in ['cuda', 'opencl', 'metal']: + check_device(device) + +def test_softmax(): + verify_softmax(32, 10) + + +if __name__ == "__main__": + test_softmax()