diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index e0d2c403d4b4d3e9f39e6062027d990027e7d12e..3e06f6f6fed56e51837087c3737ec2070b87a6e5 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -79,12 +79,27 @@ def _get_workload(data, kernel, stride, padding, out_dtype): HSTR, WSTR = stride else: HSTR, WSTR = stride, stride - assert data.dtype == kernel.dtype, \ + assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \ + "Do not support inputs with different data types now. ' \ + '{} vs. {}".format(data.dtype, kernel.dtype) + return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) + +def _get_workload_int8(data, kernel, stride, padding, out_dtype): + """ Get the workload structure. """ + _, CI, IH, IW = [x.value for x in data.shape] + CO, _, KH, KW = [x.value for x in kernel.shape] + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + if isinstance(stride, (tuple, list)): + HSTR, WSTR = stride + else: + HSTR, WSTR = stride, stride + assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \ "Do not support inputs with different data types now. ' \ '{} vs. {}".format(data.dtype, kernel.dtype) return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) + @tvm.target.generic_func def _get_alter_layout_schedule(wkl): # pylint: disable=unreachable @@ -118,6 +133,17 @@ def _get_schedule_NCHWc(wkl, layout, out_layout): return wkl +@tvm.target.generic_func +def _get_schedule_NCHWc_int8(wkl, layout, out_layout): + # pylint: disable=unreachable + """ Get the platform specific schedule. """ + target = tvm.target.current_target() + raise RuntimeError( + "No schedule for current target:{}".format(target)) + # This return has no use, merely to supress pylint warning + return wkl + + def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None): """Convolution operator in NCHW layout. diff --git a/topi/python/topi/x86/check_targets.py b/topi/python/topi/x86/check_targets.py new file mode 100644 index 0000000000000000000000000000000000000000..fad74eaf582aa5bf4ffa7c911e1029a7f78ba90b --- /dev/null +++ b/topi/python/topi/x86/check_targets.py @@ -0,0 +1,12 @@ +# pylint: disable=invalid-name,unused-variable,invalid-name,unused-argument +"""Checks different x86 targets for target specific schedules""" + +def check_skylake(target): + """ + Checks if the target is skylake + """ + + for opt in target.options: + if opt == '-mcpu=skylake-avx512': + return True + return False diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 721c7c169d99e2b9275fcd6a8ec0d0a4456d47c0..6fe59a9095107b5fd1b6d31bbc14b7025256b707 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -5,12 +5,13 @@ from .. import generic, tag from .. import nn from ..nn.util import infer_pad, infer_stride from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, \ - _get_workload, _get_schedule, _get_schedule_NCHWc, \ - _get_alter_layout_schedule, Workload + _get_workload, _get_workload_int8, _get_schedule, _get_schedule_NCHWc, \ + _get_schedule_NCHWc_int8, _get_alter_layout_schedule, Workload from . import conv2d_avx_1x1, conv2d_avx_common from .conv2d_avx_common import AVXConvCommonFwd from .conv2d_avx_1x1 import AVXConv1x1Fwd +from .check_targets import check_skylake @_get_schedule.register("cpu") def _get_schedule_conv(wkl): @@ -100,10 +101,95 @@ def _get_schedule_conv(wkl): sch = _SCHEDULES_AVX[idx] return sch +def _get_schedule_conv_int8(wkl): + _WORKLOADS_AVX = [ + ## Following are for INT8 kernels + Workload('uint8', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1), + Workload('uint8', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1), + Workload('uint8', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2), + Workload('uint8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2), + Workload('uint8', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1), + Workload('uint8', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2), + Workload('uint8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2), + Workload('uint8', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1), + Workload('uint8', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2), + Workload('uint8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2), + Workload('uint8', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1), + # workloads of resnet34_v1 on imagenet, no extra workload required + # workloads of resnet50_v1 on imagenet + Workload('uint8', 'int32', 56, 56, 64, 256, 1, 1, 0, 0, 1, 1), + Workload('uint8', 'int32', 56, 56, 256, 64, 1, 1, 0, 0, 1, 1), + Workload('uint8', 'int32', 56, 56, 256, 128, 1, 1, 0, 0, 2, 2), + Workload('uint8', 'int32', 28, 28, 128, 512, 1, 1, 0, 0, 1, 1), + Workload('uint8', 'int32', 56, 56, 256, 512, 1, 1, 0, 0, 2, 2), + Workload('uint8', 'int32', 28, 28, 512, 128, 1, 1, 0, 0, 1, 1), + Workload('uint8', 'int32', 28, 28, 512, 256, 1, 1, 0, 0, 2, 2), + Workload('uint8', 'int32', 14, 14, 256, 1024, 1, 1, 0, 0, 1, 1), + Workload('uint8', 'int32', 28, 28, 512, 1024, 1, 1, 0, 0, 2, 2), + Workload('uint8', 'int32', 14, 14, 1024, 256, 1, 1, 0, 0, 1, 1), + Workload('uint8', 'int32', 14, 14, 1024, 512, 1, 1, 0, 0, 2, 2), + Workload('uint8', 'int32', 7, 7, 512, 2048, 1, 1, 0, 0, 1, 1), + Workload('uint8', 'int32', 14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2), + Workload('uint8', 'int32', 7, 7, 2048, 512, 1, 1, 0, 0, 1, 1), + ] + + fp32_vec_len = 8 + target = tvm.target.current_target(allow_none=False) + if check_skylake(target): + fp32_vec_len = 16 + + _SCHEDULES_AVX = [ + # Following are for INT8 operations + # workloads of resnet18_v1 on imagenet + AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28), + AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28), + AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False), + AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, False), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), + AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, True), + AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 7), + AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True), + # workloads of resnet34_v1 on imagenet, no extra workload required + # workloads of resnet50_v1 on imagenet + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7), + # workloads of resnet101_v1 on imagenet, no extra workload required + # workloads of resnet152_v1 on imagenet, no extra workload required + # workloads of resnet18_v2 on imagenet, no extra workload required + # workloads of resnet34_v2 on imagenet, no extra workload required + ] + + if wkl not in _WORKLOADS_AVX: + if wkl.hkernel == 1 and wkl.wkernel == 1: + return conv2d_avx_1x1._get_default_schedule(wkl, fp32_vec_len) + return conv2d_avx_common._get_default_schedule(wkl, fp32_vec_len) + idx = _WORKLOADS_AVX.index(wkl) + sch = _SCHEDULES_AVX[idx] + return sch + @_get_schedule_NCHWc.register("cpu") def _get_schedule_NCHWc_x86(wkl, layout, out_layout): return _get_schedule_conv(wkl) +@_get_schedule_NCHWc_int8.register("cpu") +def _get_schedule_NCHWc_x86_int8(wkl, layout, out_layout): + return _get_schedule_conv_int8(wkl) + @_get_alter_layout_schedule.register("cpu") def _get_alter_layout_schedule_x86(wkl): return _get_schedule_conv(wkl) @@ -162,6 +248,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) + @conv2d_NCHWc.register("cpu") def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride, padding, layout, out_layout, out_dtype): @@ -169,13 +256,29 @@ def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride, AVXConvCommonFwd: conv2d_avx_common._declaration_conv_NCHWc, AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv_NCHWc } + + # Use int8 schedules if the input data is of int8 dtype + if data.dtype == 'uint8': + _AVX_SCH_TO_DECL_FUNC = { + AVXConvCommonFwd: conv2d_avx_common._declaration_conv_NCHWc_int8, + AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv_NCHWc_int8 + } + n, ic_chunk, h, w, ic_block = [x.value for x in data.shape] ic = ic_chunk * ic_block kh, kw = kernel_size - wkl = _get_workload(tvm.placeholder((n, ic, h, w), dtype=out_dtype), - tvm.placeholder((num_filter, ic, kh, kw), dtype=out_dtype), - stride, padding, out_dtype) - sch = _get_schedule_NCHWc(wkl, layout, out_layout) + if data.dtype == 'uint8': + wkl = _get_workload_int8(tvm.placeholder((n, ic, h, w), dtype=data.dtype), + tvm.placeholder((num_filter, ic, kh, kw), + dtype=kernel.dtype), + stride, padding, out_dtype) + sch = _get_schedule_NCHWc_int8(wkl, layout, out_layout) + else: + wkl = _get_workload(tvm.placeholder((n, ic, h, w), dtype=data.dtype), + tvm.placeholder((num_filter, ic, kh, kw), + dtype=kernel.dtype), + stride, padding, out_dtype) + sch = _get_schedule_NCHWc(wkl, layout, out_layout) return _AVX_SCH_TO_DECL_FUNC[type(sch)](wkl, sch, data, kernel) @@ -289,10 +392,6 @@ def schedule_conv2d_nhwc(outs): def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, layout, out_layout, outs): """Create schedule for tensors""" - _AVX_SCH_TO_SCH_FUNC = { - AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc, - AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc - } s = tvm.create_schedule([x.op for x in outs]) scheduled_ops = [] @@ -317,15 +416,33 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, data_pad = data data = data_pad.op.input_tensors[0] + _AVX_SCH_TO_SCH_FUNC = { + AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc, + AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc + } + + # Use int8 schedules if the input data is of int8 dtype + if data.dtype == 'uint8': + _AVX_SCH_TO_SCH_FUNC = { + AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc_int8, + AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc_int8 + } + n, ic_chunk, h, w, ic_block = [x.value for x in data.shape] ic = ic_chunk * ic_block - original_data = tvm.placeholder((n, ic, h, w), dtype=conv_out.dtype) + original_data = tvm.placeholder((n, ic, h, w), dtype=data.dtype) kh, kw = kernel_size - original_kernel = tvm.placeholder((num_filter, ic, kh, kw), dtype=conv_out.dtype) + original_kernel = tvm.placeholder((num_filter, ic, kh, kw), + dtype=kernel.dtype) - wkl = _get_workload(original_data, original_kernel, stride, padding, conv_out.dtype) - sch = _get_schedule_NCHWc(wkl, layout, out_layout) + if data.dtype == 'uint8': + wkl = _get_workload_int8(original_data, original_kernel, + stride, padding, conv_out.dtype) + sch = _get_schedule_NCHWc_int8(wkl, layout, out_layout) + else: + wkl = _get_workload(original_data, original_kernel, stride, padding, conv_out.dtype) + sch = _get_schedule_NCHWc(wkl, layout, out_layout) _AVX_SCH_TO_SCH_FUNC[type(sch)](s, wkl, sch, data_vec, kernel, conv_out, outs[0]) diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index 7d820701e1f48057d23145d732010596adcc2ead..bace7451d665f37a500a1d57cd14209bc8426b8f 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -3,11 +3,14 @@ from __future__ import absolute_import as _abs from collections import namedtuple import tvm +import topi from ..util import get_const_tuple from ..nn.conv2d import _get_schedule, _get_workload from ..nn.util import infer_pad, infer_stride from ..nn.pad import pad +from .tensor_intrin import dot_16x1x16_int8_int8_int32 +from .check_targets import check_skylake AVXConv1x1Fwd = namedtuple('AVXConv1x1Fwd', ['ic_bn', 'oc_bn', 'oh_factor', 'ow_factor']) @@ -229,3 +232,117 @@ def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last): s[O].parallel(parallel_axis) return s + + +def _declaration_conv_NCHWc_int8(wkl, sch, data, kernel): + """ Declaration for int8 conv""" + out_dtype = wkl.out_dtype + HPAD, WPAD = wkl.hpad, wkl.wpad + HSTR, WSTR = wkl.hstride, wkl.wstride + + batch_size = data.shape[0] + out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 + out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 + + DOPAD = (HPAD != 0 or WPAD != 0) + if DOPAD: + data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") + else: + data_pad = data + + oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn) + + # Intel performs dot product of 2 "4" Int8 values + n_elems = 4 + assert sch.ic_bn%n_elems == 0 + ic_outer = tvm.reduce_axis((0, wkl.in_filter//(sch.ic_bn)), name='ic_outer') + ic_f_inner = tvm.reduce_axis((0, sch.ic_bn//n_elems), name='ic_f_inner') + ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') + + # Reshaping kernel as the last 2 dimensions are 1x1 (k_h x k_w) + k_shape = kernel.shape + kernel = topi.reshape(kernel, (k_shape[0], k_shape[1], k_shape[2], k_shape[3], + k_shape[4] * k_shape[5] * k_shape[6])) + + conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: + tvm.sum(data_pad[n, ic_outer, oh*HSTR, ow*WSTR, + ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) * + kernel[oc_chunk, ic_outer, ic_f_inner, + oc_block, ic_s_inner].astype(out_dtype), + axis=[ic_outer, ic_f_inner, ic_s_inner]), + name='conv2d_NCHWc_int8', + tag="conv2d_NCHWc_int8") + + + return conv + + +def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): + """ + Defines the schedule for INT8 for intel machines + Uses the Intel intrinsics to use INT8 operations + More details - https://software.intel.com/en-us/articles/ + lower-numerical-precision-deep-learning-inference-and-training + """ + + target = tvm.target.current_target(allow_none=False) + int32_lanes = -1 + if check_skylake(target): + int32_lanes = 16 + else: + return s + assert int32_lanes != -1 + + # schedule data + A = data + if isinstance(s[A].op, tvm.tensor.ComputeOp): + batch, ic_chunk, ih, iw, ic_block = s[A].op.axis + parallel_axis = s[A].fuse(ic_chunk, ih) + s[A].parallel(parallel_axis) + + C, O = conv_out, last + CC = s.cache_write(C, 'global') + + batch, oc_chunk, oh, ow, oc_block = s[C].op.axis + oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor) + ow_outer, ow_inner = s[C].split(ow, factor=sch.ow_factor) + s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + s[C].vectorize(oc_block) + + parallel_axis = s[C].fuse(oc_chunk, oh_outer) + s[CC].compute_at(s[C], parallel_axis) + if C == O: + s[C].parallel(parallel_axis) + + _, oc_chunk, oh, ow, oc_block = s[CC].op.axis + ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis + + # Skylake and future processors have 16 vector lanes + assert sch.oc_bn % int32_lanes == 0 + + oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) + + oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor) + ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor) + + s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_outer, ic_f_inner, oh_inner, + ow_inner, oc_f_inner, oc_s_inner, ic_s_inner) + s[CC].fuse(oc_chunk, oh_outer) + + pc = dot_16x1x16_int8_int8_int32() + s[CC].tensorize(oc_s_inner, pc) + s[CC].unroll(ow_inner) + s[CC].unroll(oh_inner) + + if C != O: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor) + ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor) + s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + + parallel_axis = s[O].fuse(oc_chunk, oh_outer) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + + return s diff --git a/topi/python/topi/x86/conv2d_avx_common.py b/topi/python/topi/x86/conv2d_avx_common.py index 8f8086fdebb4c989ac8a1aee9d87b579321510b4..0d7aba23d236d42330550cddf6cc68d5318ce3ff 100644 --- a/topi/python/topi/x86/conv2d_avx_common.py +++ b/topi/python/topi/x86/conv2d_avx_common.py @@ -8,6 +8,8 @@ from ..util import get_const_tuple from ..nn.conv2d import _get_schedule, _get_workload from ..nn.util import infer_pad, infer_stride from ..nn.pad import pad +from .tensor_intrin import dot_16x1x16_int8_int8_int32 +from .check_targets import check_skylake AVXConvCommonFwd = namedtuple('AVXConvCommonFwd', ['ic_bn', 'oc_bn', 'reg_n', 'unroll_kw']) @@ -252,3 +254,124 @@ def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last): s[O].parallel(parallel_axis) return s + + +def _declaration_conv_NCHWc_int8(wkl, sch, data, kernel): + """ + This function sets up the compute for INT8 conv 2d + Inputs are in INT8 datatype + Output is in INT32 datatype + """ + + out_dtype = wkl.out_dtype + HPAD, WPAD = wkl.hpad, wkl.wpad + HSTR, WSTR = wkl.hstride, wkl.wstride + + batch_size = data.shape[0] + out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 + out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 + + # pack data + DOPAD = (HPAD != 0 or WPAD != 0) + if DOPAD: + data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") + else: + data_pad = data + + # convolution + oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn) + kh = tvm.reduce_axis((0, wkl.hkernel), name='kh') + kw = tvm.reduce_axis((0, wkl.wkernel), name='kw') + + # Intel performs dot product of 2 "4" Int8 values + # Current implementation requires ic_bn to be a multiple of 4 + n_elems = 4 + assert sch.ic_bn%n_elems == 0 + + ic_outer = tvm.reduce_axis((0, wkl.in_filter//(sch.ic_bn)), name='ic_outer') + ic_f_inner = tvm.reduce_axis((0, sch.ic_bn//n_elems), name='ic_f_inner') + ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') + conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: + tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh, ow*WSTR+kw, + ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) * + kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner, + oc_block, ic_s_inner].astype(out_dtype), + axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), + name='conv2d_NCHWc_int8', + tag="conv2d_NCHWc_int8") + return conv + +def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): + """ + Defines the schedule for INT8 for intel machines + Uses the Intel intrinsics to use INT8 operations + More details - https://software.intel.com/en-us/articles/ + lower-numerical-precision-deep-learning-inference-and-training + """ + + # Currently INT8 operations are supported for only Skylake + # In future the _intrin_reduce4int8 will be updated for VNNI instructions + # In case of unsupported target, the schedule will go to the original + # compute + + target = tvm.target.current_target(allow_none=False) + int32_lanes = -1 + if check_skylake(target): + int32_lanes = 16 + else: + return s + assert int32_lanes != -1 + + A = data + if isinstance(s[A].op, tvm.tensor.ComputeOp): + batch, ic_chunk, ih, iw, _ = s[A].op.axis + parallel_axis = s[A].fuse(ic_chunk, ih) + s[A].parallel(parallel_axis) + + # schedule 5-D NCHW[x]c conv + C, O = conv_out, last + CC = s.cache_write(C, 'global') + + _, oc_chunk, oh, ow, oc_block = s[C].op.axis + ow_chunk, ow_block = s[C].split(ow, factor=sch.reg_n) + s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[C].fuse(oc_chunk, oh) + s[C].vectorize(oc_block) + if C == O: + s[C].parallel(parallel_axis) + + s[CC].compute_at(s[C], ow_chunk) + _, oc_chunk, oh, ow, oc_block = s[CC].op.axis + kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis + + ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n) + + # Skylake and future processors have 16 vector lanes + assert sch.oc_bn % int32_lanes == 0 + + oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) + + if sch.unroll_kw: + s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw, + ow_block, oc_f_inner, oc_s_inner, ic_s_inner) + s[CC].unroll(kw) + else: + s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, kw, ic_f_inner, + ow_block, oc_f_inner, oc_s_inner, ic_s_inner) + + + pc = dot_16x1x16_int8_int8_int32() + s[CC].tensorize(oc_s_inner, pc) + s[CC].unroll(ow_block) + s[CC].unroll(oc_f_inner) + + if C != O: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n) + s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[O].fuse(oc_chunk, oh) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + + return s diff --git a/topi/python/topi/x86/tensor_intrin.py b/topi/python/topi/x86/tensor_intrin.py new file mode 100644 index 0000000000000000000000000000000000000000..28e57f1c10f81e0136346aab60b1c17670d55b98 --- /dev/null +++ b/topi/python/topi/x86/tensor_intrin.py @@ -0,0 +1,84 @@ +"""Core kernel of dot product of 4 Int8 operations""" +#pylint: disable=invalid-name +import tvm + + +def dot_16x1x16_int8_int8_int32(): + """ + Int8 dot product by every 4 elements using AVX2 Skylake instructions. + This function takes two arrays of int8 datatype -- data[4] and + kernel[16][4] -- and computes a dot product of data[4] with every + 4 elements of kernels, resulting in output[16] of int32 datatype. + The pseudo code is as follows. + .. code-block:: c + void dot_16x1x16_int8_int8_int32(int8 data[4], int8 kernel[16][4], + int32 output[16]){ + for (int i = 0; i < 16; i++){ + out[i] = 0; + for (int k = 0; k < 4; k++){ + out[i] += data[k] * kernel[i][k] + } + } + } + + Physically, the kernel array sits in an AVX512 vector register and + the data[4] is broadcasted to another AVX512 vector register. This + function returns a TensorIntrin that can be used to tensorize + a schedule. + + Returns + ------- + intrin : TensorIntrin + The Skylake int8 TensorIntrin that can be used in tensorizing schedule + """ + + int32_lanes = 16 # 16 int32 lanes in AVX512 + num_int8_elements = 4 # 4 int8 elements in int32 + data = tvm.placeholder((num_int8_elements,), dtype='uint8', name='data') + kernel = tvm.placeholder((int32_lanes, num_int8_elements), dtype='int8', name='kernel') + k = tvm.reduce_axis((0, num_int8_elements), name='k') + C = tvm.compute((int32_lanes,), + lambda i: tvm.sum(data[k].astype('int32') * + kernel[i, k].astype('int32'), + axis=k), + name="C") + + a_buffer = tvm.decl_buffer(data.shape, dtype='uint8', name="a_buffer", + offset_factor=1, + strides=[1]) + b_buffer = tvm.decl_buffer(kernel.shape, dtype='int8', name="b_buffer", + offset_factor=1, + strides=[tvm.var('ldw'), 1]) + + def _intrin_func(ins, outs): + def _instr(index): + ib = tvm.ir_builder.create() + if index == 1: + ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16'))) + return ib.get() + + a_int8 = ins[0].vload([0], "uint8x4") + re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8) + vec_ai32 = re_int32.astype('int32x16') + vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai32) + vec_b = ins[1].vload([0, 0], "int8x64") + vec_one = tvm.const(1, "int16x32") + pair_reduction = tvm.call_llvm_intrin('int16x32', + 'llvm.x86.avx512.pmaddubs.w.512', + tvm.const(0, 'uint32'), + vec_a, vec_b) + quad_reduction = tvm.call_llvm_intrin('int32x16', + 'llvm.x86.avx512.pmaddw.d.512', + tvm.const(0, 'uint32'), + pair_reduction, vec_one) + if index == 0: + ib.emit(outs[0].vstore(0, quad_reduction)) + else: + ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], 'int32x16'))) + return ib.get() + + # body, reset, update + return _instr(0), _instr(1), _instr(2) + + with tvm.build_config(offset_factor=1, partition_const_loop=True): + return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}) diff --git a/topi/recipe/conv/test_conv_int8_intel.py b/topi/recipe/conv/test_conv_int8_intel.py new file mode 100644 index 0000000000000000000000000000000000000000..863b3a6a41ab9e8e8c686f09a9cbfb969e865956 --- /dev/null +++ b/topi/recipe/conv/test_conv_int8_intel.py @@ -0,0 +1,149 @@ +#pylint: disable-msg=too-many-arguments, too-many-locals, assignment-from-no-return +""" Conv Int8 functional and performance testing""" +import sys +import logging +import numpy as np +import tvm +import topi + +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +LOGGER = logging.getLogger('test_conv_int8_intel') +LOGGER.disabled = False + +# All the WORKLOADS from Resnet except first layer +# Workload is ['height', 'width', 'in_filter', 'out_filter', +# 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) +WORKLOADS = [(56, 56, 64, 64, 3, 3, 1, 1, 1, 1), + (56, 56, 64, 64, 1, 1, 0, 0, 1, 1), + (56, 56, 64, 128, 3, 3, 1, 1, 2, 2), + (56, 56, 64, 128, 1, 1, 0, 0, 2, 2), + (28, 28, 128, 128, 3, 3, 1, 1, 1, 1), + (28, 28, 128, 256, 3, 3, 1, 1, 2, 2), + (28, 28, 128, 256, 1, 1, 0, 0, 2, 2), + (14, 14, 256, 256, 3, 3, 1, 1, 1, 1), + (14, 14, 256, 512, 3, 3, 1, 1, 2, 2), + (14, 14, 256, 512, 1, 1, 0, 0, 2, 2), + (7, 7, 512, 512, 3, 3, 1, 1, 1, 1), + (56, 56, 64, 256, 1, 1, 0, 0, 1, 1), + (56, 56, 256, 64, 1, 1, 0, 0, 1, 1), + (56, 56, 256, 128, 1, 1, 0, 0, 2, 2), + (28, 28, 128, 512, 1, 1, 0, 0, 1, 1), + (56, 56, 256, 512, 1, 1, 0, 0, 2, 2), + (28, 28, 512, 128, 1, 1, 0, 0, 1, 1), + (28, 28, 512, 256, 1, 1, 0, 0, 2, 2), + (14, 14, 256, 1024, 1, 1, 0, 0, 1, 1), + (28, 28, 512, 1024, 1, 1, 0, 0, 2, 2), + (14, 14, 1024, 256, 1, 1, 0, 0, 1, 1), + (14, 14, 1024, 512, 1, 1, 0, 0, 2, 2), + (7, 7, 512, 2048, 1, 1, 0, 0, 1, 1), + (14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2), + (7, 7, 2048, 512, 1, 1, 0, 0, 1, 1) + ] + + +TARGET_NAME = 'llvm -mcpu=skylake-avx512' +NUM_VEC_LANES = 16 +CTX = tvm.context(TARGET_NAME, 0) + +def get_shape(im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad, + hstride, wstride, out_dtype): + """ + Finds out the shape of all data structures + """ + ## Find shapes + data_shape = (1, in_filter//NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES) + + if out_dtype == 'int32': + if k_h != 1: + kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w, + NUM_VEC_LANES//4, NUM_VEC_LANES, 4) + else: + kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, NUM_VEC_LANES//4, + NUM_VEC_LANES, 4, k_h, k_w) + elif out_dtype == 'float32': + if k_h != 1: + kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w, + NUM_VEC_LANES, NUM_VEC_LANES) + else: + kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, NUM_VEC_LANES, + NUM_VEC_LANES, k_h, k_w) + out_height = (im_height + 2 * hpad - k_h) // hstride + 1 + out_width = (im_width + 2 * wpad - k_w) // wstride + 1 + o_shape = (1, out_filter//NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES) + return (data_shape, kernel_shape, o_shape) + + + +def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_filter, + out_filter, k_h, k_w, hpad, wpad, hstride, wstride): + """ + Runs the inference and checks the functional correctness between + compute and schedule outputs + """ + (data_shape, kernel_shape, o_shape) = get_shape(im_height, im_width, in_filter, + out_filter, k_h, k_w, hpad, wpad, + hstride, wstride, out_dtype) + + # Create TVM placeholders + data = tvm.placeholder(data_shape, name='data', dtype=data_dtype) + kernel = tvm.placeholder(kernel_shape, name='kernel', dtype=kernel_dtype) + + # Create the numpy arrays to be used for executing conv models + if data_dtype == 'float32': + data_array = tvm.nd.array(np.random.rand(*data_shape).astype(dtype=data_dtype), CTX) + kernel_array = tvm.nd.array(np.random.rand(*kernel_shape).astype(dtype=kernel_dtype), CTX) + else: + data_array = tvm.nd.array(np.random.randint(100, size=data_shape).astype(data_dtype)) + kernel_array = tvm.nd.array(np.random.randint(100, size=kernel_shape).astype(kernel_dtype)) + + # c_orig will be used for declaration ouptut + # c_sch will be used for scheduled computation output + c_orig = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX) + c_sch = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX) + + + with tvm.target.create(TARGET_NAME): + conv = topi.nn.conv2d_NCHWc(data, kernel, num_filter=out_filter, + kernel_size=(k_h, k_w), stride=hstride, + padding=hpad, layout='NCHWc', + out_layout='NCHWc', out_dtype=out_dtype) + out = topi.nn.relu(conv) + sch = tvm.create_schedule(out.op) + func = tvm.build(sch, [data, kernel, out], target=TARGET_NAME, name='out') + func(data_array, kernel_array, c_orig) + LOGGER.debug(tvm.lower(sch, [data, kernel], simple_mode=True)) + + # Generate and run the optimized schedule + sconv = topi.generic.nn.schedule_conv2d_NCHWc(num_filter=out_filter, + kernel_size=(k_h, k_w), + strides=hstride, + padding=hpad, + layout='NCHWc', + out_layout='NCHWc', + outs=[out]) + func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name='conv') + func(data_array, kernel_array, c_sch) + + # Functional check + if data_dtype == 'uint8': + np.testing.assert_equal(c_orig.asnumpy(), c_sch.asnumpy()) + else: + assert np.allclose(c_orig.asnumpy(), c_sch.asnumpy()) + + evaluator = func.time_evaluator(func.entry_name, CTX, number=1000) + LOGGER.debug(tvm.lower(sconv, [data, kernel], simple_mode=True)) + return evaluator(data_array, kernel_array, c_sch).mean + +if __name__ == "__main__": + LOGGER.info("Workload, Kernel_size, FP32_time, INT8_time, Speedup") + SPEEDUP_ARRAY = [] + for i, wkl in enumerate(WORKLOADS): + fp32_time = run_inference('float32', 'float32', 'float32', *wkl) + int8_time = run_inference('uint8', 'int8', 'int32', *wkl) + kernel_h = wkl[4] + kernel_w = wkl[5] + LOGGER.info("Workload#" + str(i) + ", " + str(kernel_h) + "x" + str(kernel_w) + ", " + + str(fp32_time) + ", " + str(int8_time) + ", " + str(fp32_time/int8_time)) + + SPEEDUP_ARRAY.append(fp32_time/int8_time) + LOGGER.info("Average speedup --> %s" % str(sum(SPEEDUP_ARRAY)/float(len(SPEEDUP_ARRAY))))