diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py
index 6f641e99f7dd3357be317090b330a097287660e6..5c580aad24c4455230c94a018a59376e2a09e243 100644
--- a/topi/python/topi/generic/nn.py
+++ b/topi/python/topi/generic/nn.py
@@ -35,6 +35,24 @@ def schedule_conv2d_nchw(outs):
     return _default_schedule(outs, False)
 
 
+@tvm.target.generic_func
+def schedule_conv2d_nhwc(outs):
+    """Schedule for conv2d_nhwc
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of conv2d_nchw
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
+
+
 @tvm.target.generic_func
 def schedule_conv2d_transpose_nchw(outs):
     """Schedule for conv2d_transpose_nchw
diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py
index 11866aedc1010429478319e28de890bda3f2ac41..3bd910e299741803015d691384e0ba8c7e76845c 100644
--- a/topi/python/topi/nn/conv2d.py
+++ b/topi/python/topi/nn/conv2d.py
@@ -337,6 +337,57 @@ def conv2d_hwcn(Input, Filter, stride, padding, out_dtype='float32'):
         name="Conv2dOutput", tag="conv2d_hwcn")
     return Output
 
+
+def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'):
+    """Convolution operator in NHWC layout.
+
+    Parameters
+    ----------
+    Input : tvm.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    Filter : tvm.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, num_filter]
+
+    stride : int or a list/tuple of two ints
+        Stride size, or [stride_height, stride_width]
+
+    padding : int or str
+        Padding size, or ['VALID', 'SAME']
+
+    Returns
+    -------
+    output : tvm.Tensor
+        4-D with shape [batch, out_height,  out_width, out_channel]
+    """
+    assert isinstance(stride, int) or len(stride) == 2
+    batch, in_height, in_width, in_channel = Input.shape
+    kernel_h, kernel_w, channel, num_filter = Filter.shape
+    if isinstance(stride, int):
+        stride_h = stride_w = stride
+    else:
+        stride_h, stride_w = stride
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (kernel_h, kernel_w))
+    # compute the output shape
+    out_channel = num_filter
+    out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1)
+    out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1)
+    pad_before = [0, pad_top, pad_left, 0]
+    pad_after = [0, pad_down, pad_right, 0]
+    PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
+    rc = tvm.reduce_axis((0, in_channel), name='rc')
+    ry = tvm.reduce_axis((0, kernel_h), name='ry')
+    rx = tvm.reduce_axis((0, kernel_w), name='rx')
+    Output = tvm.compute(
+        (batch, out_height, out_width, out_channel),
+        lambda nn, yy, xx, ff: tvm.sum(
+            PaddedInput[nn, yy * stride_h + ry, xx * stride_w + rx, rc].astype(out_dtype) *
+            Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]),
+        name="Conv2dOutput", tag="conv2d_nhwc")
+    return Output
+
 # map from schedule type to declaration function
 _SCH_TO_DECL_FUNC = {
     SpatialPack: _spatial_pack,
diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py
index 6a1b361e30970ce4de852f67a7888130acb7bed9..2a20a1c4f6225cdc8dd86b217bb2242e5a5734fe 100644
--- a/topi/python/topi/testing/__init__.py
+++ b/topi/python/topi/testing/__init__.py
@@ -6,6 +6,7 @@ from __future__ import absolute_import as _abs
 
 from .conv2d_hwcn_python import conv2d_hwcn_python
 from .conv2d_nchw_python import conv2d_nchw_python
+from .conv2d_nhwc_python import conv2d_nhwc_python
 from .conv2d_transpose_nchw_python import conv2d_transpose_nchw_python
 from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
 from .dilate_python import dilate_python
diff --git a/topi/python/topi/testing/conv2d_nhwc_python.py b/topi/python/topi/testing/conv2d_nhwc_python.py
new file mode 100644
index 0000000000000000000000000000000000000000..880088a6f89fb8263bca2842ea2cd5d3395535f8
--- /dev/null
+++ b/topi/python/topi/testing/conv2d_nhwc_python.py
@@ -0,0 +1,67 @@
+# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
+"""Convolution in python"""
+import numpy as np
+import scipy.signal
+
+
+def conv2d_nhwc_python(a_np, w_np, stride, padding):
+    """Convolution operator in NHWC layout.
+
+    Parameters
+    ----------
+    a_np : numpy.ndarray
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    w_np : numpy.ndarray
+        4-D with shape [num_filter, filter_height, filter_width, in_channel]
+
+    stride : int or a list/tuple of two ints
+        Stride size, or [stride_height, stride_width]
+
+    padding : int or str
+        Padding size, or ['VALID', 'SAME']
+
+    Returns
+    -------
+    b_np : np.ndarray
+        4-D with shape [out_height, out_width, out_channel, batch]
+    """
+    batch, in_height, in_width, in_channel = a_np.shape
+    kernel_h, kernel_w, _, num_filter = w_np.shape
+    if isinstance(stride, int):
+        stride_h = stride_w = stride
+    else:
+        stride_h, stride_w = stride
+    if isinstance(padding, int):
+        pad_h = pad_w = padding * 2
+    elif padding == 'VALID':
+        pad_h = 0
+        pad_w = 0
+    else: # 'SAME'
+        pad_h = kernel_h - 1
+        pad_w = kernel_w - 1
+    pad_top = int(np.ceil(float(pad_h) / 2))
+    pad_bottom = pad_h - pad_top
+    pad_left = int(np.ceil(float(pad_w) / 2))
+    pad_right = pad_w - pad_left
+    # compute the output shape
+    out_channel = num_filter
+    out_height = (in_height - kernel_h + pad_h) // stride_h + 1
+    out_width = (in_width - kernel_w + pad_w) // stride_w + 1
+    # change the layout from NHWC to NCHW
+    at = a_np.transpose((0, 3, 1, 2))
+    wt = w_np.transpose((3, 2, 0, 1))
+    bt = np.zeros((batch, out_channel, out_height, out_width))
+    # computation
+    for n in range(batch):
+        for f in range(out_channel):
+            for c in range(in_channel):
+                if pad_h > 0:
+                    apad = np.zeros((in_height + pad_h, in_width + pad_w))
+                    apad[pad_top:-pad_bottom, pad_left:-pad_right] = at[n, c]
+                else:
+                    apad = at[n, c]
+                out = scipy.signal.convolve2d(
+                    apad, np.rot90(np.rot90(wt[f, c])), mode='valid')
+                bt[n, f] += out[::stride, ::stride]
+    return bt.transpose((0, 2, 3, 1))
diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py
index 6ab37b8c03ac33c9e79b3f64b978cd12963cfb87..ef227d035fce540a0e21ca88a4ff9946b9d44b5d 100644
--- a/topi/python/topi/x86/__init__.py
+++ b/topi/python/topi/x86/__init__.py
@@ -2,6 +2,8 @@
 """x86 specific declaration and schedules."""
 from __future__ import absolute_import as _abs
 
-from .conv2d import schedule_conv2d
+from .conv2d import schedule_conv2d, schedule_conv2d_nhwc
 from .binarize_pack import schedule_binarize_pack
 from .binary_dense import schedule_binary_dense
+from .nn import *
+from .injective import *
diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py
index 0c91f8c25c88b04c8b90d80340b6fd0fd8ed2576..cb3571d6a91b7b5833fee211d8a730b0c64beebe 100644
--- a/topi/python/topi/x86/conv2d.py
+++ b/topi/python/topi/x86/conv2d.py
@@ -15,6 +15,12 @@ def schedule_conv2d(outs):
         if tag.is_broadcast(op.tag):
             if op not in s.outputs:
                 s[op].compute_inline()
+            else: # inject custom schedule
+                if len(op.axis) == 4: # schedule bias + bn + relu
+                    n, c, h, w = op.axis
+                    fused = s[op].fuse(n, c)
+                    s[op].parallel(fused)
+                    s[op].vectorize(w)
             for tensor in op.input_tensors:
                 if tensor.op.input_tensors:
                     traverse(tensor.op)
@@ -28,10 +34,68 @@ def schedule_conv2d(outs):
                 data_pad = data
                 data = data_pad.op.input_tensors[0]
 
+            n_pad, c_pad, h_pad, w_pad = data_pad.op.axis
+            pad_fused = s[data_pad].fuse(n_pad, c_pad)
+            s[data_pad].parallel(pad_fused)
             C = conv
             n, c, h, w = C.op.axis
-            s[C].parallel(c)
-            s[C].pragma(n, "parallel_launch_point")
+            rc, ry, rx = C.op.reduce_axis
+            fused = s[C].fuse(n, c)
+            s[C].parallel(fused)
+            wo, wi = s[C].split(w, factor=16)
+            s[C].reorder(fused, rc, h, wo, ry, rx, wi) # move rc to outer loop
+            s[C].unroll(rx)
+            s[C].unroll(ry)
+            s[C].vectorize(wi)
 
     traverse(outs[0].op)
     return s
+
+
+@generic.schedule_conv2d_nhwc.register(["cpu"])
+def schedule_conv2d_nhwc(outs):
+    """Create schedule for tensors"""
+    s = tvm.create_schedule([x.op for x in outs])
+    output_op = outs[0].op
+
+    def traverse(op):
+        """Traverse operators from computation graph"""
+        # inline all one-to-one-mapping operators except the last stage (output)
+        if tag.is_broadcast(op.tag):
+            if op not in s.outputs:
+                s[op].compute_inline()
+            else: # inject custom schedule
+                if len(op.axis) == 4: # schedule bias + bn + relu
+                    n, h, w, c = op.axis
+                    fused = s[op].fuse(n, h, w)
+                    s[op].parallel(fused)
+                    s[op].vectorize(c)
+            for tensor in op.input_tensors:
+                if tensor.op.input_tensors:
+                    traverse(tensor.op)
+
+        if 'conv2d_nhwc' in op.tag:
+            conv = op.output(0)
+            kernel = op.input_tensors[1]
+            data = op.input_tensors[0]
+            data_pad = None
+            if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
+                data_pad = data
+                data = data_pad.op.input_tensors[0]
+
+            n_pad, h_pad, w_pad, c_pad = data_pad.op.axis
+            pad_fused = s[data_pad].fuse(n_pad, h_pad)
+            s[data_pad].parallel(pad_fused)
+            C = conv
+            n, h, w, c = C.op.axis
+            ry, rx, rc = C.op.reduce_axis
+            n_out, h_out, w_out, c_out = output_op.axis
+            s[C].vectorize(c)
+            if op != output_op: # fuse bias + bn + relu into conv
+                s[C].compute_at(s[output_op], c_out)
+            else:
+                fused = s[C].fuse(n, h, w)
+                s[C].parallel(fused)
+
+    traverse(output_op)
+    return s
diff --git a/topi/python/topi/x86/injective.py b/topi/python/topi/x86/injective.py
new file mode 100644
index 0000000000000000000000000000000000000000..0970b76142ae4d9e8437a734839a7281ad614e74
--- /dev/null
+++ b/topi/python/topi/x86/injective.py
@@ -0,0 +1,35 @@
+# pylint: disable=invalid-name
+"""x86 declaration and schedules."""
+from __future__ import absolute_import as _abs
+import tvm
+from .. import generic
+
+@generic.schedule_injective.register(["cpu"])
+def schedule_injective(outs):
+    """X86 schedule for injective op.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of injective in the format
+          of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    x = outs[0]
+    s = tvm.create_schedule([x.op for x in outs])
+    tvm.schedule.AutoInlineInjective(s)
+    if len(s[x].op.axis) == 4:
+        n, c, _, _ = s[x].op.axis
+        fused = s[x].fuse(n, c) # for nhwc layout, fuse n and h
+        s[x].parallel(fused)
+    else:
+        s[x].parallel(s[x].op.axis[0])
+    return s
+
+schedule_elemwise = schedule_injective
+schedule_broadcast = schedule_injective
diff --git a/topi/python/topi/x86/nn.py b/topi/python/topi/x86/nn.py
new file mode 100644
index 0000000000000000000000000000000000000000..49aa382589d13b3d16373b739440257d60388d1b
--- /dev/null
+++ b/topi/python/topi/x86/nn.py
@@ -0,0 +1,56 @@
+"""x86 nn operators"""
+from __future__ import absolute_import as _abs
+import tvm
+from .. import generic
+
+def _default_schedule(outs, auto_inline):
+    """Default schedule for x86."""
+    x = outs[0]
+    s = tvm.create_schedule([x.op for x in outs])
+    if auto_inline:
+        tvm.schedule.AutoInlineInjective(s)
+        s[x].fuse(s[x].op.axis)
+        return s
+    if len(s[x].op.axis) == 4:
+        n, c, _, _ = s[x].op.axis
+        fused = s[x].fuse(n, c) # for nhwc layout, fuse n and h
+        s[x].parallel(fused)
+    else:
+        s[x].parallel(s[x].op.axis[0])
+    return s
+
+
+@generic.schedule_softmax.register(["cpu"])
+def schedule_softmax(outs):
+    """Schedule for softmax
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of softmax
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
+
+
+@generic.schedule_pool.register(["cpu"])
+def schedule_pool(outs):
+    """Schedule for pool
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of pool
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
diff --git a/topi/tests/python/test_topi_conv2d_nhwc.py b/topi/tests/python/test_topi_conv2d_nhwc.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fc5b841908f1c32711418ca768494c6bfc7eff3
--- /dev/null
+++ b/topi/tests/python/test_topi_conv2d_nhwc.py
@@ -0,0 +1,59 @@
+"""Example code to do convolution."""
+import os
+import numpy as np
+import tvm
+import topi
+from tvm.contrib.pickle_memoize import memoize
+from topi.util import get_const_tuple
+
+
+def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding):
+    in_height = in_width = in_size
+
+    A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A')
+    W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
+    B = topi.nn.conv2d_nhwc(A, W, stride, padding)
+
+    a_shape = get_const_tuple(A.shape)
+    w_shape = get_const_tuple(W.shape)
+    dtype = A.dtype
+
+    @memoize("topi.tests.test_topi_conv2d_nhwc.verify_nhwc")
+    def get_ref_data():
+        a_np = np.random.uniform(size=a_shape).astype(dtype)
+        w_np = np.random.uniform(size=w_shape).astype(dtype)
+        b_np = topi.testing.conv2d_nhwc_python(a_np, w_np, stride, padding)
+        return a_np, w_np, b_np
+    a_np, w_np, b_np = get_ref_data()
+
+    def check_device(device):
+        if not tvm.module.enabled(device):
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            s = topi.generic.schedule_conv2d_nhwc([B])
+        ctx = tvm.context(device, 0)
+        a = tvm.nd.array(a_np, ctx)
+        w = tvm.nd.array(w_np, ctx)
+        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
+        func = tvm.build(s, [A, W, B], device)
+        func(a, w, b)
+        np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
+
+    for device in ['llvm']:
+        check_device(device)
+
+
+def test_conv2d_nhwc():
+    verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "SAME")
+    verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "SAME")
+    verify_conv2d_nhwc(4, 128, 16, 256, 5, 2, "SAME")
+    verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "VALID")
+    verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "VALID")
+    verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "VALID")
+    verify_conv2d_nhwc(4, 128, 16, 256, 5, 2, "VALID")
+
+
+if __name__ == "__main__":
+    test_conv2d_nhwc()