diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py
index 174a37b1d451a0ef26de5550fc221a8a4f1bb403..74e92384a9e0f6c080e9317beb29a3c39bc2f38f 100644
--- a/topi/python/topi/cuda/__init__.py
+++ b/topi/python/topi/cuda/__init__.py
@@ -16,3 +16,4 @@ from .pooling import schedule_pool, schedule_global_pool
 from .conv2d_transpose_nchw import schedule_conv2d_transpose_nchw
 from .extern import schedule_extern
 from .vision import schedule_region
+from .nn import schedule_lrn, schedule_l2norm
diff --git a/topi/python/topi/cuda/nn.py b/topi/python/topi/cuda/nn.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8757970505b477ca14a120ee136e5716a998913
--- /dev/null
+++ b/topi/python/topi/cuda/nn.py
@@ -0,0 +1,91 @@
+# pylint: disable=invalid-name
+"""scheduler functions for cuda backend"""
+from __future__ import absolute_import as _abs
+
+import tvm
+from .. import generic
+from .. import tag
+from .reduction import _schedule_reduce
+
+@generic.schedule_lrn.register(["cuda"])
+def schedule_lrn(outs):
+    """Schedule for LRN
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of LRN
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+    num_thread = 64
+    block_x = tvm.thread_axis("blockIdx.x")
+    thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
+
+    lrn = outs[0]
+    sqr_sum_up = lrn.op.input_tensors[1]
+    sqr_sum = sqr_sum_up.op.input_tensors[0]
+    set_pad = sqr_sum.op.input_tensors[0]
+    s[set_pad].bind(set_pad.op.axis[0], block_x)
+    rxk = sqr_sum.op.reduce_axis[0]
+    _, xki = s[sqr_sum].split(rxk, factor=num_thread)
+    srf = s.rfactor(sqr_sum, xki)
+    s[sqr_sum].bind(s[sqr_sum].op.axis[0], block_x)
+    s[sqr_sum].bind(s[sqr_sum].op.reduce_axis[0], thread_x)
+    s[srf].compute_at(s[sqr_sum], s[sqr_sum].op.reduce_axis[0])
+    s[sqr_sum_up].bind(sqr_sum_up.op.axis[0], block_x)
+    xto, _ = s[lrn].split(lrn.op.axis[1], nparts=num_thread)
+    s[lrn].bind(lrn.op.axis[0], block_x)
+    s[lrn].bind(xto, thread_x)
+    return s
+
+@generic.schedule_l2norm.register(["cuda"])
+def schedule_l2norm(outs):
+    """Schedule for L2norm
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of L2norm
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+
+    def traverse(OP):
+        '''inline all one-to-one-mapping operators
+        except the last stage (output)'''
+        if tag.is_injective(OP.tag) or OP.tag == 'l2norm':
+            if OP not in s.outputs:
+                s[OP].compute_inline()
+            for tensor in OP.input_tensors:
+                if tensor.op.input_tensors:
+                    traverse(tensor.op)
+        elif OP.tag == 'comm_reduce':
+            _schedule_reduce(OP, s, is_idx_reduce=False)
+            for tensor in OP.input_tensors:
+                traverse(tensor.op)
+        else:
+            raise RuntimeError("Unsupported operator tag: %s" % OP.tag)
+    traverse(outs[0].op)
+
+    num_thread = 64
+    l2norm = outs[0]
+    block_x = tvm.thread_axis("blockIdx.x")
+    thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
+    xto, _ = s[l2norm].split(l2norm.op.axis[1], nparts=num_thread)
+    s[l2norm].bind(l2norm.op.axis[0], block_x)
+    s[l2norm].bind(xto, thread_x)
+
+    return s
diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py
index 7fe76e1739f01e5862e844053370bbf19af695d0..7252a23b90c3c9ba7477231868c372c4ecbeb1d5 100644
--- a/topi/python/topi/generic/nn.py
+++ b/topi/python/topi/generic/nn.py
@@ -229,3 +229,39 @@ def schedule_binary_dense(outs):
         The computation schedule for the op.
     """
     return _default_schedule(outs, False)
+
+
+@tvm.target.generic_func
+def schedule_lrn(outs):
+    """Schedule for lrn
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of lrn
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
+
+
+@tvm.target.generic_func
+def schedule_l2norm(outs):
+    """Schedule for l2norm
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of l2norm
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py
index 918f399f503ce6a856b7ba2051157dbde1deeca3..056d1a76339a2414df8bf6198ce70439851dbdd1 100644
--- a/topi/python/topi/nn/__init__.py
+++ b/topi/python/topi/nn/__init__.py
@@ -15,3 +15,5 @@ from .softmax import *
 from .conv2d_transpose import *
 from .bnn import *
 from .upsampling import *
+from .local_response_norm import *
+from .l2_norm import *
diff --git a/topi/python/topi/nn/l2_norm.py b/topi/python/topi/nn/l2_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b5381a8559973b70bbd4f019677be2ce68f7509
--- /dev/null
+++ b/topi/python/topi/nn/l2_norm.py
@@ -0,0 +1,35 @@
+# pylint: disable=invalid-name
+"""TVM operator for l2norm"""
+from __future__ import absolute_import
+import tvm
+import topi
+
+@tvm.target.generic_func
+def l2norm_instance(data, eps, axis=None):
+    """Perform L2norm on the input data
+
+    For axis=None, y(i, j) = x(i, j) / sqrt(max(sum(x^2), eps))
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        4-D with NCHW or NHWC layout
+
+    eps : float
+        epsilon value
+
+    axis : list of int
+        axis over the normalization applied
+
+    Returns
+    -------
+    output : tvm.Tensor
+        4-D output with same shape
+    """
+    assert len(data.shape) == 4, "only support 4-dim lrn"
+    dot_value = topi.cpp.pow(data, 2.0)
+    sum_value = topi.sum(dot_value, axis=axis, keepdims=True)
+    expand_sum = topi.broadcast_to(sum_value, data.shape)
+    return topi.broadcast_div(data, topi.sqrt(\
+                tvm.compute(expand_sum.shape, lambda i, j, k, l:\
+                tvm.max(expand_sum[i, j, k, l], eps), tag='l2norm')))
diff --git a/topi/python/topi/nn/local_response_norm.py b/topi/python/topi/nn/local_response_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b44e02214acc0370bb9a5e0cb6f9fc044d20944b
--- /dev/null
+++ b/topi/python/topi/nn/local_response_norm.py
@@ -0,0 +1,68 @@
+# pylint: disable=invalid-name
+"""TVM operator for local response norm compute."""
+from __future__ import absolute_import
+import tvm
+import topi
+from .pad import pad
+
+@tvm.target.generic_func
+def lrn(data, size, axis=1, alpha=0.0001, beta=0.75, bias=2):
+    """Perform the across channels local response normalisation
+    on the input data.
+
+    sum_sqr_up^i{x, y} = (bias+((alpha/size)* \
+                                {sum_{j=max(0, i-size/2)}^{min(N-1,i+size/2)} \
+                                     (data^j{x,y})^2}))^beta
+    output^i{x, y} = data^i{x, y}/sum_sqr_up^i{x, y}
+    N is the number for input channels
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        4-D with shape [batch, channel, height, width]
+
+    size : int
+        normalisation window size
+
+    axis : int
+        input data layout channel axis
+        default value is 1 for NCHW format
+
+    bias : float
+        offset to avoid dividing by 0
+
+    alpha : float
+        to be divided
+
+    beta : float
+        exponent
+
+    Returns
+    -------
+    output : tvm.Tensor
+        4-D output with same shape
+    """
+    assert len(data.shape) == 4, "only support 4-dim lrn"
+    assert (size % 2) == 1, "size should be odd number"
+    assert (axis == 1) or (axis == 3), "axis should 1 or 3 for NCHW and NHWC"
+    ##Add padding on left & right of size radius first
+    pad_after = pad_before = [0, 0, 0, 0]
+    pad_after[axis] = pad_before[axis] = (size//2)
+    pad_data = pad(data, pad_before, pad_after, name="pad_data")
+
+    rxs = tvm.reduce_axis((0, size), name='rxs')
+    if axis == 1:
+        #NCHW layout
+        sqr_sum = tvm.compute(data.shape, lambda i, j, k, l: tvm.sum(
+            pad_data[i, j + rxs, k, l] * pad_data[i, j + rxs, k, l],
+            axis=rxs))
+    elif axis == 3:
+        #NHWC layout
+        sqr_sum = tvm.compute(data.shape, lambda i, j, k, l: tvm.sum(
+            pad_data[i, j, k, l + rxs] * pad_data[i, j, k, l + rxs],
+            axis=rxs))
+
+    sqr_sum_up = tvm.compute(data.shape, lambda i, j, k, l: tvm.power(
+        (bias + (alpha * sqr_sum[i, j, k, l] / size)), beta))
+
+    return topi.broadcast_div(data, sqr_sum_up)
diff --git a/topi/python/topi/rocm/__init__.py b/topi/python/topi/rocm/__init__.py
index a5b4ee30dc37368a48dd6ae00faeef7b1726cd67..96a04794c680513978479b571323be399efc5ab8 100644
--- a/topi/python/topi/rocm/__init__.py
+++ b/topi/python/topi/rocm/__init__.py
@@ -5,3 +5,4 @@ from __future__ import absolute_import as _abs
 from .conv2d import *
 from .dense import *
 from .vision import *
+from .nn import *
diff --git a/topi/python/topi/rocm/nn.py b/topi/python/topi/rocm/nn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9c529155f7b7d1a620b5bc00d15a7507d521509
--- /dev/null
+++ b/topi/python/topi/rocm/nn.py
@@ -0,0 +1,13 @@
+"""scheduler for normalization functions on rocm backend"""
+from __future__ import absolute_import as _abs
+
+import topi
+from .. import generic
+
+@generic.schedule_lrn.register(["rocm", "gpu"])
+def schedule_lrn(outs):
+    return topi.cuda.schedule_lrn(outs)
+
+@generic.schedule_l2norm.register(["rocm", "gpu"])
+def schedule_l2norm(outs):
+    return topi.cuda.schedule_l2norm(outs)
diff --git a/topi/tests/python/test_topi_l2norm.py b/topi/tests/python/test_topi_l2norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa7970125b4406e85a062588c964628569c4cc23
--- /dev/null
+++ b/topi/tests/python/test_topi_l2norm.py
@@ -0,0 +1,70 @@
+"""Test code for L2 norm"""
+import numpy as np
+import tvm
+import topi
+from topi.util import get_const_tuple
+
+def l2norm_instance_python(a_np, eps, axis=None):
+    """L2 norm operator in NCHW layout.
+
+    Parameters
+    ----------
+    a_np : numpy.ndarray
+        4-D with shape [batch, in_channel, in_height, in_width]
+
+    eps : float
+        epsilon constant value
+    axis : list of int
+        axis over the normalization applied
+
+    Returns
+    -------
+    l2norm_out : np.ndarray
+        4-D with shape [batch, out_channel, out_height, out_width]
+    """
+    batch, axis1, axis2, axis3 = a_np.shape
+    sqr_sum = np.zeros(shape=(batch,)).astype(a_np.dtype)
+    sqrt_sum = np.zeros(shape=(batch,)).astype(a_np.dtype)
+    l2norm_out = np.zeros(shape=a_np.shape).astype(a_np.dtype)
+    dot_value = np.power(a_np, 2.0)
+    sqr_sum = np.sum(dot_value, axis, keepdims=True)
+    sqrt_sum = np.sqrt(np.maximum(np.broadcast_to(sqr_sum, a_np.shape), eps))
+    return np.divide(a_np, sqrt_sum)
+
+def verify_l2norm(n, c, h, w, eps, axis=None):
+
+    A = tvm.placeholder((n, c, h, w), name='A')
+    B = topi.nn.l2norm_instance(A, eps, axis)
+    dtype = A.dtype
+
+    a_np = np.random.uniform(size=(n, c, h, w)).astype(dtype)
+    b_np = l2norm_instance_python(a_np, eps, axis)
+
+    def check_device(device):
+        if not tvm.module.enabled(device):
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            s = topi.generic.schedule_l2norm(B)
+        ctx = tvm.context(device, 0)
+        a = tvm.nd.array(a_np, ctx)
+        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
+        f = tvm.build(s, [A, B], device)
+        f(a, b)
+        np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
+
+    for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
+        check_device(device)
+
+def test_l2norm():
+    verify_l2norm(1, 3, 20, 20, 0.001)
+    verify_l2norm(1, 3, 20, 20, 0.001, 1)
+    verify_l2norm(1, 3, 20, 20, 0.001, (1, 2))
+    verify_l2norm(1, 3, 20, 20, 0.001, (2, 3))
+    verify_l2norm(1, 3, 20, 20, 0.001, (0, 3))
+    verify_l2norm(1, 3, 20, 20, 0.001, (0, 2, 3))
+
+
+if __name__ == "__main__":
+    test_l2norm()
diff --git a/topi/tests/python/test_topi_lrn.py b/topi/tests/python/test_topi_lrn.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c4714077e1ebb880c4c2b94a71e51d177a4c4da
--- /dev/null
+++ b/topi/tests/python/test_topi_lrn.py
@@ -0,0 +1,95 @@
+"""Test code for local response normalization"""
+import numpy as np
+import tvm
+import topi
+from topi.util import get_const_tuple
+
+def lrn_python(a_np, size, axis, bias, alpha, beta):
+    """Local response norm operator in NCHW layout.
+
+    Parameters
+    ----------
+    a_np : numpy.ndarray
+        4-D with shape [batch, in_channel, in_height, in_width]
+
+    size : int
+        normalisation window size
+
+    axis : int
+        input data layout channel axis
+
+    bias : float
+        offset to avoid dividing by 0. constant value
+
+    alpha : float
+        contant valie
+
+    beta : float
+        exponent constant value
+
+    Returns
+    -------
+    b_np : np.ndarray
+        4-D with shape [batch, out_channel, out_height, out_width]
+    """
+    axis0, axis1, axis2, axis3 = a_np.shape
+    radius = size // 2
+    sqr_sum = np.zeros(shape=a_np.shape).astype(a_np.dtype)
+    sqr_sum_up = np.zeros(shape=a_np.shape).astype(a_np.dtype)
+    lrn_out = np.zeros(shape=a_np.shape).astype(a_np.dtype)
+    def sum_dot_values(i, j, k, l):
+        axis_size = a_np.shape[axis]
+        if (axis == 1):
+            #NCHW layout
+            sum_start = j-radius if j-radius >= 0 else 0
+            sum_end = j+radius+1 if j+radius+1 < axis_size else axis_size
+            sqr_sum[i, j, k, l] = sum(a_np[i, sum_start:sum_end, k, l] * \
+                                      a_np[i, sum_start:sum_end, k, l])
+        elif (axis == 3):
+            #NHWC layout
+            sum_start = l-radius if l-radius >= 0 else 0
+            sum_end = l+radius+1 if l+radius+1 < axis_size else axis_size
+            sqr_sum[i, j, k, l] = sum(a_np[i, j, k, sum_start:sum_end] * \
+                                      a_np[i, j, k, sum_start:sum_end])
+
+    for i in range(axis0):
+        for j in range(axis1):
+            for k in range(axis2):
+                for l in range(axis3):
+                    sum_dot_values(i, j, k, l)
+
+    sqr_sum_up = np.power((bias + (alpha * sqr_sum /size)), beta)
+    return np.divide(a_np, sqr_sum_up)
+
+def verify_lrn(shape, size, axis, bias, alpha, beta):
+    A = tvm.placeholder(shape, name='A')
+    B = topi.nn.lrn(A, size, axis, alpha, beta, bias)
+    dtype = A.dtype
+
+    a_np = np.random.uniform(size=shape).astype(dtype)
+    b_np = lrn_python(a_np, size, axis, bias, alpha, beta)
+
+    def check_device(device):
+        if not tvm.module.enabled(device):
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            s = topi.generic.schedule_lrn(B)
+        ctx = tvm.context(device, 0)
+        a = tvm.nd.array(a_np, ctx)
+        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
+        f = tvm.build(s, [A, B], device)
+        f(a, b)
+        np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
+
+    for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
+        check_device(device)
+
+def test_lrn():
+    verify_lrn((1, 3, 5, 5), 3, 1, 1, 1, 0.5)
+    verify_lrn((1, 3, 5, 5), 3, 3, 1, 1, 0.5)
+    verify_lrn((1, 3, 20, 20), 3, 1, 2, 1, 0.75)
+
+if __name__ == "__main__":
+    test_lrn()