diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py
index 39bd50686fa3c0ee81a30e467c2d92a0dee95a27..5b5c976862064a2f9717fd1792ab0e5c28aa7bb5 100644
--- a/topi/python/topi/cuda/__init__.py
+++ b/topi/python/topi/cuda/__init__.py
@@ -7,3 +7,4 @@ from .conv2d_hwcn import schedule_conv2d_hwcn
 from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc
 from .reduction import schedule_reduce
 from .broadcast import schedule_broadcast_to
+from .softmax import schedule_softmax
diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py
index d06ca1d65241054f394ae196d92a51f0335e4329..57e8be2951d54ce4247aaf68a032d703a7f34003 100644
--- a/topi/python/topi/cuda/conv2d_nchw.py
+++ b/topi/python/topi/cuda/conv2d_nchw.py
@@ -115,7 +115,7 @@ def schedule_conv2d_small_batch(outs):
     return s
 
 def schedule_conv2d_nchw(outs):
-    """Schedule for conv2d_nchw and any element-wise operations.
+    """Schedule for conv2d_nchw.
 
     Parameters
     ----------
diff --git a/topi/python/topi/cuda/softmax.py b/topi/python/topi/cuda/softmax.py
new file mode 100644
index 0000000000000000000000000000000000000000..de990e086ac8750f7dbf420d24788654b0ac750a
--- /dev/null
+++ b/topi/python/topi/cuda/softmax.py
@@ -0,0 +1,42 @@
+# pylint: disable=invalid-name, unused-variable, trailing-whitespace 
+"""Schedule for softmax operator"""
+import tvm
+
+def schedule_softmax(outs):
+    """Schedule for softmax op.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of reduce in the format
+          of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+    softmax = outs[0]
+    max_elem = softmax.op.input_tensors[1]
+    expsum = softmax.op.input_tensors[2]
+
+    num_thread = 64
+    block_x = tvm.thread_axis("blockIdx.x")
+    thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
+
+    s[max_elem].bind(max_elem.op.axis[0], block_x)
+
+    k = expsum.op.reduce_axis[0]
+    ko, ki = s[expsum].split(k, factor=num_thread)
+    EF = s.rfactor(expsum, ki)
+    s[expsum].bind(s[expsum].op.axis[0], block_x)
+    s[expsum].bind(s[expsum].op.reduce_axis[0], thread_x)
+    s[EF].compute_at(s[expsum], s[expsum].op.reduce_axis[0])
+
+    tx, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread)
+    s[softmax].bind(softmax.op.axis[0], block_x)
+    s[softmax].bind(tx, thread_x)
+
+    return s
diff --git a/topi/python/topi/nn/softmax.py b/topi/python/topi/nn/softmax.py
index 8394b1afb0ab54760f33859003789088b3fef04a..4c39a9f1f78422b5de65b9351983ce287fd2cdbc 100644
--- a/topi/python/topi/nn/softmax.py
+++ b/topi/python/topi/nn/softmax.py
@@ -1,8 +1,9 @@
+# pylint: disable=invalid-name
 """TVM operator softmax compute."""
 from __future__ import absolute_import
 import tvm
 
-@tvm.tag_scope(tag='softmax')
+@tvm.tag_scope(tag='softmax_output')
 def softmax(x):
     """Perform softmax activation on the data
 
diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py
index 1cd4c54dfecef64462c4bb699c7eb1ffc45275b4..1a715eb4fdc824cbcdeac75613498305790205b6 100644
--- a/topi/python/topi/testing/__init__.py
+++ b/topi/python/topi/testing/__init__.py
@@ -8,3 +8,4 @@ from .conv2d_hwcn_python import conv2d_hwcn_python
 from .conv2d_nchw_python import conv2d_nchw_python
 from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
 from .dilate_python import dilate_python
+from .softmax_python import softmax_python
diff --git a/topi/python/topi/testing/softmax_python.py b/topi/python/topi/testing/softmax_python.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dc561a56e2ad67a8390465b7048f1cd581743c6
--- /dev/null
+++ b/topi/python/topi/testing/softmax_python.py
@@ -0,0 +1,23 @@
+# pylint: disable=invalid-name, trailing-whitespace
+"""Softmax operation in python"""
+import numpy as np
+
+def softmax_python(a_np):
+    """Softmax operator.
+    Parameters
+    ----------
+    a_np : numpy.ndarray
+        2-D input data
+
+    Returns
+    -------
+    output_np : numpy.ndarray
+        2-D output with same shape
+    """
+    assert len(a_np.shape) == 2, "only support 2-dim softmax"
+    max_elem = np.amax(a_np, axis=1)
+    max_elem = max_elem.reshape(max_elem.shape[0], 1)
+    e = np.exp(a_np-max_elem)
+    expsum = np.sum(e, axis=1)
+    out_np = e / expsum[:, None]
+    return out_np
diff --git a/topi/tests/python/test_topi_softmax.py b/topi/tests/python/test_topi_softmax.py
new file mode 100644
index 0000000000000000000000000000000000000000..0afd74606343559c862e25b22d481c39c4bb3dc1
--- /dev/null
+++ b/topi/tests/python/test_topi_softmax.py
@@ -0,0 +1,36 @@
+"""Test code for softmax"""
+import os
+import numpy as np
+import tvm
+import topi
+from topi.util import get_const_tuple
+
+def verify_softmax(m, n):
+    
+    A = tvm.placeholder((m, n), name='A')
+    B = topi.nn.softmax(A)
+    s = topi.cuda.schedule_softmax(B)
+
+    a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
+    b_np = topi.testing.softmax_python(a_np)
+
+    def check_device(device):
+        if not tvm.module.enabled(device):
+            print("Skip because %s is not enabled" % device)
+            return
+        ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
+        a = tvm.nd.array(a_np, ctx)
+        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
+        foo = tvm.build(s, [A, B], device, name="softmax")
+        foo(a, b)
+        np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
+
+    for device in ['cuda', 'opencl', 'metal']:
+        check_device(device)         
+
+def test_softmax():
+    verify_softmax(32, 10)
+
+
+if __name__ == "__main__":
+    test_softmax()