diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py
index e0d2c403d4b4d3e9f39e6062027d990027e7d12e..3e06f6f6fed56e51837087c3737ec2070b87a6e5 100644
--- a/topi/python/topi/nn/conv2d.py
+++ b/topi/python/topi/nn/conv2d.py
@@ -79,12 +79,27 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
         HSTR, WSTR = stride
     else:
         HSTR, WSTR = stride, stride
-    assert data.dtype == kernel.dtype, \
+    assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
+        "Do not support inputs with different data types now. ' \
+        '{} vs. {}".format(data.dtype, kernel.dtype)
+    return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
+
+def _get_workload_int8(data, kernel, stride, padding, out_dtype):
+    """ Get the workload structure. """
+    _, CI, IH, IW = [x.value for x in data.shape]
+    CO, _, KH, KW = [x.value for x in kernel.shape]
+    HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
+    if isinstance(stride, (tuple, list)):
+        HSTR, WSTR = stride
+    else:
+        HSTR, WSTR = stride, stride
+    assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
         "Do not support inputs with different data types now. ' \
         '{} vs. {}".format(data.dtype, kernel.dtype)
     return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
 
 
+
 @tvm.target.generic_func
 def _get_alter_layout_schedule(wkl):
     # pylint: disable=unreachable
@@ -118,6 +133,17 @@ def _get_schedule_NCHWc(wkl, layout, out_layout):
     return wkl
 
 
+@tvm.target.generic_func
+def _get_schedule_NCHWc_int8(wkl, layout, out_layout):
+    # pylint: disable=unreachable
+    """ Get the platform specific schedule. """
+    target = tvm.target.current_target()
+    raise RuntimeError(
+        "No schedule for current target:{}".format(target))
+    # This return has no use, merely to supress pylint warning
+    return wkl
+
+
 def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
     """Convolution operator in NCHW layout.
 
diff --git a/topi/python/topi/x86/check_targets.py b/topi/python/topi/x86/check_targets.py
new file mode 100644
index 0000000000000000000000000000000000000000..fad74eaf582aa5bf4ffa7c911e1029a7f78ba90b
--- /dev/null
+++ b/topi/python/topi/x86/check_targets.py
@@ -0,0 +1,12 @@
+# pylint: disable=invalid-name,unused-variable,invalid-name,unused-argument
+"""Checks different x86 targets for target specific schedules"""
+
+def check_skylake(target):
+    """
+    Checks if the target is skylake
+    """
+
+    for opt in target.options:
+        if opt == '-mcpu=skylake-avx512':
+            return True
+    return False
diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py
index 721c7c169d99e2b9275fcd6a8ec0d0a4456d47c0..6fe59a9095107b5fd1b6d31bbc14b7025256b707 100644
--- a/topi/python/topi/x86/conv2d.py
+++ b/topi/python/topi/x86/conv2d.py
@@ -5,12 +5,13 @@ from .. import generic, tag
 from .. import nn
 from ..nn.util import infer_pad, infer_stride
 from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, \
-    _get_workload, _get_schedule, _get_schedule_NCHWc, \
-    _get_alter_layout_schedule, Workload
+    _get_workload, _get_workload_int8, _get_schedule, _get_schedule_NCHWc, \
+    _get_schedule_NCHWc_int8, _get_alter_layout_schedule, Workload
 
 from . import conv2d_avx_1x1, conv2d_avx_common
 from .conv2d_avx_common import AVXConvCommonFwd
 from .conv2d_avx_1x1 import AVXConv1x1Fwd
+from .check_targets import check_skylake
 
 @_get_schedule.register("cpu")
 def _get_schedule_conv(wkl):
@@ -100,10 +101,95 @@ def _get_schedule_conv(wkl):
     sch = _SCHEDULES_AVX[idx]
     return sch
 
+def _get_schedule_conv_int8(wkl):
+    _WORKLOADS_AVX = [
+        ## Following are for INT8 kernels
+        Workload('uint8', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
+        Workload('uint8', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
+        Workload('uint8', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
+        Workload('uint8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
+        Workload('uint8', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
+        Workload('uint8', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
+        Workload('uint8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
+        Workload('uint8', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
+        Workload('uint8', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
+        Workload('uint8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
+        Workload('uint8', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
+        # workloads of resnet34_v1 on imagenet, no extra workload required
+        # workloads of resnet50_v1 on imagenet
+        Workload('uint8', 'int32', 56, 56, 64, 256, 1, 1, 0, 0, 1, 1),
+        Workload('uint8', 'int32', 56, 56, 256, 64, 1, 1, 0, 0, 1, 1),
+        Workload('uint8', 'int32', 56, 56, 256, 128, 1, 1, 0, 0, 2, 2),
+        Workload('uint8', 'int32', 28, 28, 128, 512, 1, 1, 0, 0, 1, 1),
+        Workload('uint8', 'int32', 56, 56, 256, 512, 1, 1, 0, 0, 2, 2),
+        Workload('uint8', 'int32', 28, 28, 512, 128, 1, 1, 0, 0, 1, 1),
+        Workload('uint8', 'int32', 28, 28, 512, 256, 1, 1, 0, 0, 2, 2),
+        Workload('uint8', 'int32', 14, 14, 256, 1024, 1, 1, 0, 0, 1, 1),
+        Workload('uint8', 'int32', 28, 28, 512, 1024, 1, 1, 0, 0, 2, 2),
+        Workload('uint8', 'int32', 14, 14, 1024, 256, 1, 1, 0, 0, 1, 1),
+        Workload('uint8', 'int32', 14, 14, 1024, 512, 1, 1, 0, 0, 2, 2),
+        Workload('uint8', 'int32', 7, 7, 512, 2048, 1, 1, 0, 0, 1, 1),
+        Workload('uint8', 'int32', 14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2),
+        Workload('uint8', 'int32', 7, 7, 2048, 512, 1, 1, 0, 0, 1, 1),
+    ]
+
+    fp32_vec_len = 8
+    target = tvm.target.current_target(allow_none=False)
+    if check_skylake(target):
+        fp32_vec_len = 16
+
+    _SCHEDULES_AVX = [
+        # Following are for INT8 operations
+        # workloads of resnet18_v1 on imagenet
+        AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28),
+        AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28),
+        AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
+        AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, False),
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
+        AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, True),
+        AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True),
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 7),
+        AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True),
+        # workloads of resnet34_v1 on imagenet, no extra workload required
+        # workloads of resnet50_v1 on imagenet
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
+        AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
+        # workloads of resnet101_v1 on imagenet, no extra workload required
+        # workloads of resnet152_v1 on imagenet, no extra workload required
+        # workloads of resnet18_v2 on imagenet, no extra workload required
+        # workloads of resnet34_v2 on imagenet, no extra workload required
+    ]
+
+    if wkl not in _WORKLOADS_AVX:
+        if wkl.hkernel == 1 and wkl.wkernel == 1:
+            return conv2d_avx_1x1._get_default_schedule(wkl, fp32_vec_len)
+        return conv2d_avx_common._get_default_schedule(wkl, fp32_vec_len)
+    idx = _WORKLOADS_AVX.index(wkl)
+    sch = _SCHEDULES_AVX[idx]
+    return sch
+
 @_get_schedule_NCHWc.register("cpu")
 def _get_schedule_NCHWc_x86(wkl, layout, out_layout):
     return _get_schedule_conv(wkl)
 
+@_get_schedule_NCHWc_int8.register("cpu")
+def _get_schedule_NCHWc_x86_int8(wkl, layout, out_layout):
+    return _get_schedule_conv_int8(wkl)
+
 @_get_alter_layout_schedule.register("cpu")
 def _get_alter_layout_schedule_x86(wkl):
     return _get_schedule_conv(wkl)
@@ -162,6 +248,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
     return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
 
 
+
 @conv2d_NCHWc.register("cpu")
 def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride,
                             padding, layout, out_layout, out_dtype):
@@ -169,13 +256,29 @@ def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride,
         AVXConvCommonFwd: conv2d_avx_common._declaration_conv_NCHWc,
         AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv_NCHWc
     }
+
+    # Use int8 schedules if the input data is of int8 dtype
+    if data.dtype == 'uint8':
+        _AVX_SCH_TO_DECL_FUNC = {
+            AVXConvCommonFwd: conv2d_avx_common._declaration_conv_NCHWc_int8,
+            AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv_NCHWc_int8
+        }
+
     n, ic_chunk, h, w, ic_block = [x.value for x in data.shape]
     ic = ic_chunk * ic_block
     kh, kw = kernel_size
-    wkl = _get_workload(tvm.placeholder((n, ic, h, w), dtype=out_dtype),
-                        tvm.placeholder((num_filter, ic, kh, kw), dtype=out_dtype),
-                        stride, padding, out_dtype)
-    sch = _get_schedule_NCHWc(wkl, layout, out_layout)
+    if data.dtype == 'uint8':
+        wkl = _get_workload_int8(tvm.placeholder((n, ic, h, w), dtype=data.dtype),
+                                 tvm.placeholder((num_filter, ic, kh, kw),
+                                                 dtype=kernel.dtype),
+                                 stride, padding, out_dtype)
+        sch = _get_schedule_NCHWc_int8(wkl, layout, out_layout)
+    else:
+        wkl = _get_workload(tvm.placeholder((n, ic, h, w), dtype=data.dtype),
+                            tvm.placeholder((num_filter, ic, kh, kw),
+                                            dtype=kernel.dtype),
+                            stride, padding, out_dtype)
+        sch = _get_schedule_NCHWc(wkl, layout, out_layout)
     return _AVX_SCH_TO_DECL_FUNC[type(sch)](wkl, sch, data, kernel)
 
 
@@ -289,10 +392,6 @@ def schedule_conv2d_nhwc(outs):
 def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding,
                           layout, out_layout, outs):
     """Create schedule for tensors"""
-    _AVX_SCH_TO_SCH_FUNC = {
-        AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc,
-        AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc
-    }
     s = tvm.create_schedule([x.op for x in outs])
     scheduled_ops = []
 
@@ -317,15 +416,33 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding,
                 data_pad = data
                 data = data_pad.op.input_tensors[0]
 
+            _AVX_SCH_TO_SCH_FUNC = {
+                AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc,
+                AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc
+            }
+
+            # Use int8 schedules if the input data is of int8 dtype
+            if data.dtype == 'uint8':
+                _AVX_SCH_TO_SCH_FUNC = {
+                    AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc_int8,
+                    AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc_int8
+                }
+
             n, ic_chunk, h, w, ic_block = [x.value for x in data.shape]
             ic = ic_chunk * ic_block
-            original_data = tvm.placeholder((n, ic, h, w), dtype=conv_out.dtype)
+            original_data = tvm.placeholder((n, ic, h, w), dtype=data.dtype)
 
             kh, kw = kernel_size
-            original_kernel = tvm.placeholder((num_filter, ic, kh, kw), dtype=conv_out.dtype)
+            original_kernel = tvm.placeholder((num_filter, ic, kh, kw),
+                                              dtype=kernel.dtype)
 
-            wkl = _get_workload(original_data, original_kernel, stride, padding, conv_out.dtype)
-            sch = _get_schedule_NCHWc(wkl, layout, out_layout)
+            if data.dtype == 'uint8':
+                wkl = _get_workload_int8(original_data, original_kernel,
+                                         stride, padding, conv_out.dtype)
+                sch = _get_schedule_NCHWc_int8(wkl, layout, out_layout)
+            else:
+                wkl = _get_workload(original_data, original_kernel, stride, padding, conv_out.dtype)
+                sch = _get_schedule_NCHWc(wkl, layout, out_layout)
             _AVX_SCH_TO_SCH_FUNC[type(sch)](s, wkl, sch, data_vec,
                                             kernel, conv_out, outs[0])
 
diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py
index 7d820701e1f48057d23145d732010596adcc2ead..bace7451d665f37a500a1d57cd14209bc8426b8f 100644
--- a/topi/python/topi/x86/conv2d_avx_1x1.py
+++ b/topi/python/topi/x86/conv2d_avx_1x1.py
@@ -3,11 +3,14 @@
 from __future__ import absolute_import as _abs
 from collections import namedtuple
 import tvm
+import topi
 
 from ..util import get_const_tuple
 from ..nn.conv2d import _get_schedule, _get_workload
 from ..nn.util import infer_pad, infer_stride
 from ..nn.pad import pad
+from .tensor_intrin import dot_16x1x16_int8_int8_int32
+from .check_targets import check_skylake
 
 AVXConv1x1Fwd = namedtuple('AVXConv1x1Fwd', ['ic_bn', 'oc_bn', 'oh_factor', 'ow_factor'])
 
@@ -229,3 +232,117 @@ def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last):
         s[O].parallel(parallel_axis)
 
     return s
+
+
+def _declaration_conv_NCHWc_int8(wkl, sch, data, kernel):
+    """ Declaration for int8 conv"""
+    out_dtype = wkl.out_dtype
+    HPAD, WPAD = wkl.hpad, wkl.wpad
+    HSTR, WSTR = wkl.hstride, wkl.wstride
+
+    batch_size = data.shape[0]
+    out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
+    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
+
+    DOPAD = (HPAD != 0 or WPAD != 0)
+    if DOPAD:
+        data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
+    else:
+        data_pad = data
+
+    oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn)
+
+    # Intel performs dot product of 2 "4" Int8 values
+    n_elems = 4
+    assert sch.ic_bn%n_elems == 0
+    ic_outer = tvm.reduce_axis((0, wkl.in_filter//(sch.ic_bn)), name='ic_outer')
+    ic_f_inner = tvm.reduce_axis((0, sch.ic_bn//n_elems), name='ic_f_inner')
+    ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
+
+    # Reshaping kernel as the last 2 dimensions are 1x1 (k_h x k_w)
+    k_shape = kernel.shape
+    kernel = topi.reshape(kernel, (k_shape[0], k_shape[1], k_shape[2], k_shape[3],
+                                   k_shape[4] * k_shape[5] * k_shape[6]))
+
+    conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
+                       tvm.sum(data_pad[n, ic_outer, oh*HSTR, ow*WSTR,
+                                        ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) *
+                               kernel[oc_chunk, ic_outer, ic_f_inner,
+                                      oc_block, ic_s_inner].astype(out_dtype),
+                               axis=[ic_outer, ic_f_inner, ic_s_inner]),
+                       name='conv2d_NCHWc_int8',
+                       tag="conv2d_NCHWc_int8")
+
+
+    return conv
+
+
+def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
+    """
+    Defines the schedule for INT8 for intel machines
+    Uses the Intel intrinsics to use INT8 operations
+    More details - https://software.intel.com/en-us/articles/
+    lower-numerical-precision-deep-learning-inference-and-training
+    """
+
+    target = tvm.target.current_target(allow_none=False)
+    int32_lanes = -1
+    if check_skylake(target):
+        int32_lanes = 16
+    else:
+        return s
+    assert int32_lanes != -1
+
+    # schedule data
+    A = data
+    if isinstance(s[A].op, tvm.tensor.ComputeOp):
+        batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
+        parallel_axis = s[A].fuse(ic_chunk, ih)
+        s[A].parallel(parallel_axis)
+
+    C, O = conv_out, last
+    CC = s.cache_write(C, 'global')
+
+    batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
+    oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor)
+    ow_outer, ow_inner = s[C].split(ow, factor=sch.ow_factor)
+    s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
+    s[C].vectorize(oc_block)
+
+    parallel_axis = s[C].fuse(oc_chunk, oh_outer)
+    s[CC].compute_at(s[C], parallel_axis)
+    if C == O:
+        s[C].parallel(parallel_axis)
+
+    _, oc_chunk, oh, ow, oc_block = s[CC].op.axis
+    ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
+
+    # Skylake and future processors have 16 vector lanes
+    assert sch.oc_bn % int32_lanes == 0
+
+    oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
+
+    oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor)
+    ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor)
+
+    s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_outer, ic_f_inner, oh_inner,
+                  ow_inner, oc_f_inner, oc_s_inner, ic_s_inner)
+    s[CC].fuse(oc_chunk, oh_outer)
+
+    pc = dot_16x1x16_int8_int8_int32()
+    s[CC].tensorize(oc_s_inner, pc)
+    s[CC].unroll(ow_inner)
+    s[CC].unroll(oh_inner)
+
+    if C != O:
+        batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
+        oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor)
+        ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor)
+        s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
+
+        parallel_axis = s[O].fuse(oc_chunk, oh_outer)
+        s[C].compute_at(s[O], parallel_axis)
+        s[O].vectorize(oc_block)
+        s[O].parallel(parallel_axis)
+
+    return s
diff --git a/topi/python/topi/x86/conv2d_avx_common.py b/topi/python/topi/x86/conv2d_avx_common.py
index 8f8086fdebb4c989ac8a1aee9d87b579321510b4..0d7aba23d236d42330550cddf6cc68d5318ce3ff 100644
--- a/topi/python/topi/x86/conv2d_avx_common.py
+++ b/topi/python/topi/x86/conv2d_avx_common.py
@@ -8,6 +8,8 @@ from ..util import get_const_tuple
 from ..nn.conv2d import _get_schedule, _get_workload
 from ..nn.util import infer_pad, infer_stride
 from ..nn.pad import pad
+from .tensor_intrin import dot_16x1x16_int8_int8_int32
+from .check_targets import check_skylake
 
 AVXConvCommonFwd = namedtuple('AVXConvCommonFwd', ['ic_bn', 'oc_bn', 'reg_n', 'unroll_kw'])
 
@@ -252,3 +254,124 @@ def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last):
         s[O].parallel(parallel_axis)
 
     return s
+
+
+def _declaration_conv_NCHWc_int8(wkl, sch, data, kernel):
+    """
+    This function sets up the compute for INT8 conv 2d
+    Inputs are in INT8 datatype
+    Output is in INT32 datatype
+    """
+
+    out_dtype = wkl.out_dtype
+    HPAD, WPAD = wkl.hpad, wkl.wpad
+    HSTR, WSTR = wkl.hstride, wkl.wstride
+
+    batch_size = data.shape[0]
+    out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
+    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
+
+    # pack data
+    DOPAD = (HPAD != 0 or WPAD != 0)
+    if DOPAD:
+        data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
+    else:
+        data_pad = data
+
+    # convolution
+    oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn)
+    kh = tvm.reduce_axis((0, wkl.hkernel), name='kh')
+    kw = tvm.reduce_axis((0, wkl.wkernel), name='kw')
+
+    # Intel performs dot product of 2 "4" Int8 values
+    # Current implementation requires ic_bn to be a multiple of 4
+    n_elems = 4
+    assert sch.ic_bn%n_elems == 0
+
+    ic_outer = tvm.reduce_axis((0, wkl.in_filter//(sch.ic_bn)), name='ic_outer')
+    ic_f_inner = tvm.reduce_axis((0, sch.ic_bn//n_elems), name='ic_f_inner')
+    ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
+    conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
+                       tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh, ow*WSTR+kw,
+                                        ic_f_inner * n_elems +  ic_s_inner].astype(out_dtype) *
+                               kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner,
+                                      oc_block, ic_s_inner].astype(out_dtype),
+                               axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
+                       name='conv2d_NCHWc_int8',
+                       tag="conv2d_NCHWc_int8")
+    return conv
+
+def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
+    """
+    Defines the schedule for INT8 for intel machines
+    Uses the Intel intrinsics to use INT8 operations
+    More details - https://software.intel.com/en-us/articles/
+    lower-numerical-precision-deep-learning-inference-and-training
+    """
+
+    # Currently INT8 operations are supported for only Skylake
+    # In future the _intrin_reduce4int8 will be updated for VNNI instructions
+    # In case of unsupported target, the schedule will go to the original
+    # compute
+
+    target = tvm.target.current_target(allow_none=False)
+    int32_lanes = -1
+    if check_skylake(target):
+        int32_lanes = 16
+    else:
+        return s
+    assert int32_lanes != -1
+
+    A = data
+    if isinstance(s[A].op, tvm.tensor.ComputeOp):
+        batch, ic_chunk, ih, iw, _ = s[A].op.axis
+        parallel_axis = s[A].fuse(ic_chunk, ih)
+        s[A].parallel(parallel_axis)
+
+    # schedule 5-D NCHW[x]c conv
+    C, O = conv_out, last
+    CC = s.cache_write(C, 'global')
+
+    _, oc_chunk, oh, ow, oc_block = s[C].op.axis
+    ow_chunk, ow_block = s[C].split(ow, factor=sch.reg_n)
+    s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
+    parallel_axis = s[C].fuse(oc_chunk, oh)
+    s[C].vectorize(oc_block)
+    if C == O:
+        s[C].parallel(parallel_axis)
+
+    s[CC].compute_at(s[C], ow_chunk)
+    _, oc_chunk, oh, ow, oc_block = s[CC].op.axis
+    kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
+
+    ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n)
+
+    # Skylake and future processors have 16 vector lanes
+    assert sch.oc_bn % int32_lanes == 0
+
+    oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
+
+    if sch.unroll_kw:
+        s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw,
+                      ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
+        s[CC].unroll(kw)
+    else:
+        s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, kw, ic_f_inner,
+                      ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
+
+
+    pc = dot_16x1x16_int8_int8_int32()
+    s[CC].tensorize(oc_s_inner, pc)
+    s[CC].unroll(ow_block)
+    s[CC].unroll(oc_f_inner)
+
+    if C != O:
+        batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
+        ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n)
+        s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
+        parallel_axis = s[O].fuse(oc_chunk, oh)
+        s[C].compute_at(s[O], parallel_axis)
+        s[O].vectorize(oc_block)
+        s[O].parallel(parallel_axis)
+
+    return s
diff --git a/topi/python/topi/x86/tensor_intrin.py b/topi/python/topi/x86/tensor_intrin.py
new file mode 100644
index 0000000000000000000000000000000000000000..28e57f1c10f81e0136346aab60b1c17670d55b98
--- /dev/null
+++ b/topi/python/topi/x86/tensor_intrin.py
@@ -0,0 +1,84 @@
+"""Core kernel of dot product of 4 Int8 operations"""
+#pylint: disable=invalid-name
+import tvm
+
+
+def dot_16x1x16_int8_int8_int32():
+    """
+    Int8 dot product by every 4 elements using AVX2 Skylake instructions.
+    This function takes two arrays of int8 datatype -- data[4] and
+    kernel[16][4] -- and computes a dot product of data[4] with every
+    4 elements of kernels, resulting in output[16] of int32 datatype.
+    The pseudo code is as follows.
+    .. code-block:: c
+        void dot_16x1x16_int8_int8_int32(int8 data[4], int8 kernel[16][4],
+                int32 output[16]){
+            for (int i = 0; i < 16; i++){
+                out[i] = 0;
+                for (int k = 0; k < 4; k++){
+                    out[i] += data[k] * kernel[i][k]
+                }
+            }
+        }
+
+    Physically, the kernel array sits in an AVX512 vector register and
+    the data[4] is broadcasted to another AVX512 vector register. This
+    function returns a TensorIntrin that can be used to tensorize
+    a schedule.
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Skylake int8 TensorIntrin that can be used in tensorizing schedule
+    """
+
+    int32_lanes = 16 # 16 int32 lanes in AVX512
+    num_int8_elements = 4 # 4 int8 elements in int32
+    data = tvm.placeholder((num_int8_elements,), dtype='uint8', name='data')
+    kernel = tvm.placeholder((int32_lanes, num_int8_elements), dtype='int8', name='kernel')
+    k = tvm.reduce_axis((0, num_int8_elements), name='k')
+    C = tvm.compute((int32_lanes,),
+                    lambda i: tvm.sum(data[k].astype('int32') *
+                                      kernel[i, k].astype('int32'),
+                                      axis=k),
+                    name="C")
+
+    a_buffer = tvm.decl_buffer(data.shape, dtype='uint8', name="a_buffer",
+                               offset_factor=1,
+                               strides=[1])
+    b_buffer = tvm.decl_buffer(kernel.shape, dtype='int8', name="b_buffer",
+                               offset_factor=1,
+                               strides=[tvm.var('ldw'), 1])
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.ir_builder.create()
+            if index == 1:
+                ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16')))
+                return ib.get()
+
+            a_int8 = ins[0].vload([0], "uint8x4")
+            re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8)
+            vec_ai32 = re_int32.astype('int32x16')
+            vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai32)
+            vec_b = ins[1].vload([0, 0], "int8x64")
+            vec_one = tvm.const(1, "int16x32")
+            pair_reduction = tvm.call_llvm_intrin('int16x32',
+                                                  'llvm.x86.avx512.pmaddubs.w.512',
+                                                  tvm.const(0, 'uint32'),
+                                                  vec_a, vec_b)
+            quad_reduction = tvm.call_llvm_intrin('int32x16',
+                                                  'llvm.x86.avx512.pmaddw.d.512',
+                                                  tvm.const(0, 'uint32'),
+                                                  pair_reduction, vec_one)
+            if index == 0:
+                ib.emit(outs[0].vstore(0, quad_reduction))
+            else:
+                ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], 'int32x16')))
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    with tvm.build_config(offset_factor=1, partition_const_loop=True):
+        return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})
diff --git a/topi/recipe/conv/test_conv_int8_intel.py b/topi/recipe/conv/test_conv_int8_intel.py
new file mode 100644
index 0000000000000000000000000000000000000000..863b3a6a41ab9e8e8c686f09a9cbfb969e865956
--- /dev/null
+++ b/topi/recipe/conv/test_conv_int8_intel.py
@@ -0,0 +1,149 @@
+#pylint: disable-msg=too-many-arguments, too-many-locals, assignment-from-no-return
+""" Conv Int8 functional and performance testing"""
+import sys
+import logging
+import numpy as np
+import tvm
+import topi
+
+logging.basicConfig(stream=sys.stdout, level=logging.INFO)
+LOGGER = logging.getLogger('test_conv_int8_intel')
+LOGGER.disabled = False
+
+# All the WORKLOADS from Resnet except first layer
+# Workload is ['height', 'width', 'in_filter', 'out_filter',
+#              'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
+WORKLOADS = [(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
+             (56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
+             (56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
+             (56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
+             (28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
+             (28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
+             (28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
+             (14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
+             (14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
+             (14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
+             (7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
+             (56, 56, 64, 256, 1, 1, 0, 0, 1, 1),
+             (56, 56, 256, 64, 1, 1, 0, 0, 1, 1),
+             (56, 56, 256, 128, 1, 1, 0, 0, 2, 2),
+             (28, 28, 128, 512, 1, 1, 0, 0, 1, 1),
+             (56, 56, 256, 512, 1, 1, 0, 0, 2, 2),
+             (28, 28, 512, 128, 1, 1, 0, 0, 1, 1),
+             (28, 28, 512, 256, 1, 1, 0, 0, 2, 2),
+             (14, 14, 256, 1024, 1, 1, 0, 0, 1, 1),
+             (28, 28, 512, 1024, 1, 1, 0, 0, 2, 2),
+             (14, 14, 1024, 256, 1, 1, 0, 0, 1, 1),
+             (14, 14, 1024, 512, 1, 1, 0, 0, 2, 2),
+             (7, 7, 512, 2048, 1, 1, 0, 0, 1, 1),
+             (14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2),
+             (7, 7, 2048, 512, 1, 1, 0, 0, 1, 1)
+            ]
+
+
+TARGET_NAME = 'llvm -mcpu=skylake-avx512'
+NUM_VEC_LANES = 16
+CTX = tvm.context(TARGET_NAME, 0)
+
+def get_shape(im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad,
+              hstride, wstride, out_dtype):
+    """
+    Finds out the shape of all data structures
+    """
+    ## Find shapes
+    data_shape = (1, in_filter//NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES)
+
+    if out_dtype == 'int32':
+        if k_h != 1:
+            kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
+                            NUM_VEC_LANES//4, NUM_VEC_LANES, 4)
+        else:
+            kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, NUM_VEC_LANES//4,
+                            NUM_VEC_LANES, 4, k_h, k_w)
+    elif out_dtype == 'float32':
+        if k_h != 1:
+            kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
+                            NUM_VEC_LANES, NUM_VEC_LANES)
+        else:
+            kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, NUM_VEC_LANES,
+                            NUM_VEC_LANES, k_h, k_w)
+    out_height = (im_height + 2 * hpad - k_h) // hstride + 1
+    out_width = (im_width + 2 * wpad - k_w) // wstride + 1
+    o_shape = (1, out_filter//NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES)
+    return (data_shape, kernel_shape, o_shape)
+
+
+
+def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_filter,
+                  out_filter, k_h, k_w, hpad, wpad, hstride, wstride):
+    """
+    Runs the inference and checks the functional correctness between
+    compute and schedule outputs
+    """
+    (data_shape, kernel_shape, o_shape) = get_shape(im_height, im_width, in_filter,
+                                                    out_filter, k_h, k_w, hpad, wpad,
+                                                    hstride, wstride, out_dtype)
+
+    # Create TVM placeholders
+    data = tvm.placeholder(data_shape, name='data', dtype=data_dtype)
+    kernel = tvm.placeholder(kernel_shape, name='kernel', dtype=kernel_dtype)
+
+    # Create the numpy arrays to be used for executing conv models
+    if data_dtype == 'float32':
+        data_array = tvm.nd.array(np.random.rand(*data_shape).astype(dtype=data_dtype), CTX)
+        kernel_array = tvm.nd.array(np.random.rand(*kernel_shape).astype(dtype=kernel_dtype), CTX)
+    else:
+        data_array = tvm.nd.array(np.random.randint(100, size=data_shape).astype(data_dtype))
+        kernel_array = tvm.nd.array(np.random.randint(100, size=kernel_shape).astype(kernel_dtype))
+
+    # c_orig will be used for declaration ouptut
+    # c_sch will be used for scheduled computation output
+    c_orig = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX)
+    c_sch = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX)
+
+
+    with tvm.target.create(TARGET_NAME):
+        conv = topi.nn.conv2d_NCHWc(data, kernel, num_filter=out_filter,
+                                    kernel_size=(k_h, k_w), stride=hstride,
+                                    padding=hpad, layout='NCHWc',
+                                    out_layout='NCHWc', out_dtype=out_dtype)
+        out = topi.nn.relu(conv)
+        sch = tvm.create_schedule(out.op)
+        func = tvm.build(sch, [data, kernel, out], target=TARGET_NAME, name='out')
+        func(data_array, kernel_array, c_orig)
+        LOGGER.debug(tvm.lower(sch, [data, kernel], simple_mode=True))
+
+        # Generate and run the optimized schedule
+        sconv = topi.generic.nn.schedule_conv2d_NCHWc(num_filter=out_filter,
+                                                      kernel_size=(k_h, k_w),
+                                                      strides=hstride,
+                                                      padding=hpad,
+                                                      layout='NCHWc',
+                                                      out_layout='NCHWc',
+                                                      outs=[out])
+        func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name='conv')
+        func(data_array, kernel_array, c_sch)
+
+        # Functional check
+        if data_dtype == 'uint8':
+            np.testing.assert_equal(c_orig.asnumpy(), c_sch.asnumpy())
+        else:
+            assert np.allclose(c_orig.asnumpy(), c_sch.asnumpy())
+
+        evaluator = func.time_evaluator(func.entry_name, CTX, number=1000)
+        LOGGER.debug(tvm.lower(sconv, [data, kernel], simple_mode=True))
+        return evaluator(data_array, kernel_array, c_sch).mean
+
+if __name__ == "__main__":
+    LOGGER.info("Workload, Kernel_size, FP32_time, INT8_time, Speedup")
+    SPEEDUP_ARRAY = []
+    for i, wkl in enumerate(WORKLOADS):
+        fp32_time = run_inference('float32', 'float32', 'float32', *wkl)
+        int8_time = run_inference('uint8', 'int8', 'int32', *wkl)
+        kernel_h = wkl[4]
+        kernel_w = wkl[5]
+        LOGGER.info("Workload#" + str(i) + ", " + str(kernel_h) + "x" + str(kernel_w) + ", "
+                    + str(fp32_time) + ", " + str(int8_time) + ", " + str(fp32_time/int8_time))
+
+        SPEEDUP_ARRAY.append(fp32_time/int8_time)
+    LOGGER.info("Average speedup --> %s" % str(sum(SPEEDUP_ARRAY)/float(len(SPEEDUP_ARRAY))))