Commit 72ad9a38 by Animesh Jain Committed by Yizhi Liu

INT8 conv operator implementation with NCHWc data layout for Intel machines (#1680)

* Int8 implementation for convolution operator on Intel Skylake

* Int8 implementation for convolution operator on Intel Skylake

* PR changes

* PR changes

* PR changes

* Fixing an error

* Fixing an error

* Minor typos fix

* Minor typos fix

* Removing the broadcast16 CPP code. Using astype feature instead

* Replacing constant by variable name num_elements_intel

* Name fixes and tensorize update rule updated

* Fixing the bug about checking skylake

* Replacing bitcast with reinterpret

* Isolating INT8 and FP32 schedules to ease out future AutoTVM PR merge

* Putting check_skylake function in the x86 directory

* Added documentation and organizing files to better locations

* Tensor intrin renaming. Avoid code duplication for intrin by kernel reshaping
parent bde53033
......@@ -79,12 +79,27 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
assert data.dtype == kernel.dtype, \
assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
"Do not support inputs with different data types now. ' \
'{} vs. {}".format(data.dtype, kernel.dtype)
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
def _get_workload_int8(data, kernel, stride, padding, out_dtype):
""" Get the workload structure. """
_, CI, IH, IW = [x.value for x in data.shape]
CO, _, KH, KW = [x.value for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
"Do not support inputs with different data types now. ' \
'{} vs. {}".format(data.dtype, kernel.dtype)
return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
@tvm.target.generic_func
def _get_alter_layout_schedule(wkl):
# pylint: disable=unreachable
......@@ -118,6 +133,17 @@ def _get_schedule_NCHWc(wkl, layout, out_layout):
return wkl
@tvm.target.generic_func
def _get_schedule_NCHWc_int8(wkl, layout, out_layout):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None):
"""Convolution operator in NCHW layout.
......
# pylint: disable=invalid-name,unused-variable,invalid-name,unused-argument
"""Checks different x86 targets for target specific schedules"""
def check_skylake(target):
"""
Checks if the target is skylake
"""
for opt in target.options:
if opt == '-mcpu=skylake-avx512':
return True
return False
......@@ -5,12 +5,13 @@ from .. import generic, tag
from .. import nn
from ..nn.util import infer_pad, infer_stride
from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, \
_get_workload, _get_schedule, _get_schedule_NCHWc, \
_get_alter_layout_schedule, Workload
_get_workload, _get_workload_int8, _get_schedule, _get_schedule_NCHWc, \
_get_schedule_NCHWc_int8, _get_alter_layout_schedule, Workload
from . import conv2d_avx_1x1, conv2d_avx_common
from .conv2d_avx_common import AVXConvCommonFwd
from .conv2d_avx_1x1 import AVXConv1x1Fwd
from .check_targets import check_skylake
@_get_schedule.register("cpu")
def _get_schedule_conv(wkl):
......@@ -100,10 +101,95 @@ def _get_schedule_conv(wkl):
sch = _SCHEDULES_AVX[idx]
return sch
def _get_schedule_conv_int8(wkl):
_WORKLOADS_AVX = [
## Following are for INT8 kernels
Workload('uint8', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload('uint8', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload('uint8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload('uint8', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload('uint8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload('uint8', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload('uint8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
# workloads of resnet34_v1 on imagenet, no extra workload required
# workloads of resnet50_v1 on imagenet
Workload('uint8', 'int32', 56, 56, 64, 256, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 56, 56, 256, 64, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 56, 56, 256, 128, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 28, 28, 128, 512, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 56, 56, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 28, 28, 512, 128, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 28, 28, 512, 256, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 14, 14, 256, 1024, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 28, 28, 512, 1024, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 14, 14, 1024, 256, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 14, 14, 1024, 512, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 7, 7, 512, 2048, 1, 1, 0, 0, 1, 1),
Workload('uint8', 'int32', 14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2),
Workload('uint8', 'int32', 7, 7, 2048, 512, 1, 1, 0, 0, 1, 1),
]
fp32_vec_len = 8
target = tvm.target.current_target(allow_none=False)
if check_skylake(target):
fp32_vec_len = 16
_SCHEDULES_AVX = [
# Following are for INT8 operations
# workloads of resnet18_v1 on imagenet
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, False),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, True),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 7),
AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True),
# workloads of resnet34_v1 on imagenet, no extra workload required
# workloads of resnet50_v1 on imagenet
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
# workloads of resnet101_v1 on imagenet, no extra workload required
# workloads of resnet152_v1 on imagenet, no extra workload required
# workloads of resnet18_v2 on imagenet, no extra workload required
# workloads of resnet34_v2 on imagenet, no extra workload required
]
if wkl not in _WORKLOADS_AVX:
if wkl.hkernel == 1 and wkl.wkernel == 1:
return conv2d_avx_1x1._get_default_schedule(wkl, fp32_vec_len)
return conv2d_avx_common._get_default_schedule(wkl, fp32_vec_len)
idx = _WORKLOADS_AVX.index(wkl)
sch = _SCHEDULES_AVX[idx]
return sch
@_get_schedule_NCHWc.register("cpu")
def _get_schedule_NCHWc_x86(wkl, layout, out_layout):
return _get_schedule_conv(wkl)
@_get_schedule_NCHWc_int8.register("cpu")
def _get_schedule_NCHWc_x86_int8(wkl, layout, out_layout):
return _get_schedule_conv_int8(wkl)
@_get_alter_layout_schedule.register("cpu")
def _get_alter_layout_schedule_x86(wkl):
return _get_schedule_conv(wkl)
......@@ -162,6 +248,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
@conv2d_NCHWc.register("cpu")
def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride,
padding, layout, out_layout, out_dtype):
......@@ -169,11 +256,27 @@ def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride,
AVXConvCommonFwd: conv2d_avx_common._declaration_conv_NCHWc,
AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv_NCHWc
}
# Use int8 schedules if the input data is of int8 dtype
if data.dtype == 'uint8':
_AVX_SCH_TO_DECL_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._declaration_conv_NCHWc_int8,
AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv_NCHWc_int8
}
n, ic_chunk, h, w, ic_block = [x.value for x in data.shape]
ic = ic_chunk * ic_block
kh, kw = kernel_size
wkl = _get_workload(tvm.placeholder((n, ic, h, w), dtype=out_dtype),
tvm.placeholder((num_filter, ic, kh, kw), dtype=out_dtype),
if data.dtype == 'uint8':
wkl = _get_workload_int8(tvm.placeholder((n, ic, h, w), dtype=data.dtype),
tvm.placeholder((num_filter, ic, kh, kw),
dtype=kernel.dtype),
stride, padding, out_dtype)
sch = _get_schedule_NCHWc_int8(wkl, layout, out_layout)
else:
wkl = _get_workload(tvm.placeholder((n, ic, h, w), dtype=data.dtype),
tvm.placeholder((num_filter, ic, kh, kw),
dtype=kernel.dtype),
stride, padding, out_dtype)
sch = _get_schedule_NCHWc(wkl, layout, out_layout)
return _AVX_SCH_TO_DECL_FUNC[type(sch)](wkl, sch, data, kernel)
......@@ -289,10 +392,6 @@ def schedule_conv2d_nhwc(outs):
def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding,
layout, out_layout, outs):
"""Create schedule for tensors"""
_AVX_SCH_TO_SCH_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc,
AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc
}
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
......@@ -317,13 +416,31 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding,
data_pad = data
data = data_pad.op.input_tensors[0]
_AVX_SCH_TO_SCH_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc,
AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc
}
# Use int8 schedules if the input data is of int8 dtype
if data.dtype == 'uint8':
_AVX_SCH_TO_SCH_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc_int8,
AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc_int8
}
n, ic_chunk, h, w, ic_block = [x.value for x in data.shape]
ic = ic_chunk * ic_block
original_data = tvm.placeholder((n, ic, h, w), dtype=conv_out.dtype)
original_data = tvm.placeholder((n, ic, h, w), dtype=data.dtype)
kh, kw = kernel_size
original_kernel = tvm.placeholder((num_filter, ic, kh, kw), dtype=conv_out.dtype)
original_kernel = tvm.placeholder((num_filter, ic, kh, kw),
dtype=kernel.dtype)
if data.dtype == 'uint8':
wkl = _get_workload_int8(original_data, original_kernel,
stride, padding, conv_out.dtype)
sch = _get_schedule_NCHWc_int8(wkl, layout, out_layout)
else:
wkl = _get_workload(original_data, original_kernel, stride, padding, conv_out.dtype)
sch = _get_schedule_NCHWc(wkl, layout, out_layout)
_AVX_SCH_TO_SCH_FUNC[type(sch)](s, wkl, sch, data_vec,
......
......@@ -3,11 +3,14 @@
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
import topi
from ..util import get_const_tuple
from ..nn.conv2d import _get_schedule, _get_workload
from ..nn.util import infer_pad, infer_stride
from ..nn.pad import pad
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake
AVXConv1x1Fwd = namedtuple('AVXConv1x1Fwd', ['ic_bn', 'oc_bn', 'oh_factor', 'ow_factor'])
......@@ -229,3 +232,117 @@ def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last):
s[O].parallel(parallel_axis)
return s
def _declaration_conv_NCHWc_int8(wkl, sch, data, kernel):
""" Declaration for int8 conv"""
out_dtype = wkl.out_dtype
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
batch_size = data.shape[0]
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
DOPAD = (HPAD != 0 or WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
else:
data_pad = data
oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn)
# Intel performs dot product of 2 "4" Int8 values
n_elems = 4
assert sch.ic_bn%n_elems == 0
ic_outer = tvm.reduce_axis((0, wkl.in_filter//(sch.ic_bn)), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, sch.ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
# Reshaping kernel as the last 2 dimensions are 1x1 (k_h x k_w)
k_shape = kernel.shape
kernel = topi.reshape(kernel, (k_shape[0], k_shape[1], k_shape[2], k_shape[3],
k_shape[4] * k_shape[5] * k_shape[6]))
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic_outer, oh*HSTR, ow*WSTR,
ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) *
kernel[oc_chunk, ic_outer, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
axis=[ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8',
tag="conv2d_NCHWc_int8")
return conv
def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
"""
Defines the schedule for INT8 for intel machines
Uses the Intel intrinsics to use INT8 operations
More details - https://software.intel.com/en-us/articles/
lower-numerical-precision-deep-learning-inference-and-training
"""
target = tvm.target.current_target(allow_none=False)
int32_lanes = -1
if check_skylake(target):
int32_lanes = 16
else:
return s
assert int32_lanes != -1
# schedule data
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
parallel_axis = s[A].fuse(ic_chunk, ih)
s[A].parallel(parallel_axis)
C, O = conv_out, last
CC = s.cache_write(C, 'global')
batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[C].split(ow, factor=sch.ow_factor)
s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
s[C].vectorize(oc_block)
parallel_axis = s[C].fuse(oc_chunk, oh_outer)
s[CC].compute_at(s[C], parallel_axis)
if C == O:
s[C].parallel(parallel_axis)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
# Skylake and future processors have 16 vector lanes
assert sch.oc_bn % int32_lanes == 0
oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor)
s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_outer, ic_f_inner, oh_inner,
ow_inner, oc_f_inner, oc_s_inner, ic_s_inner)
s[CC].fuse(oc_chunk, oh_outer)
pc = dot_16x1x16_int8_int8_int32()
s[CC].tensorize(oc_s_inner, pc)
s[CC].unroll(ow_inner)
s[CC].unroll(oh_inner)
if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
return s
......@@ -8,6 +8,8 @@ from ..util import get_const_tuple
from ..nn.conv2d import _get_schedule, _get_workload
from ..nn.util import infer_pad, infer_stride
from ..nn.pad import pad
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake
AVXConvCommonFwd = namedtuple('AVXConvCommonFwd', ['ic_bn', 'oc_bn', 'reg_n', 'unroll_kw'])
......@@ -252,3 +254,124 @@ def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last):
s[O].parallel(parallel_axis)
return s
def _declaration_conv_NCHWc_int8(wkl, sch, data, kernel):
"""
This function sets up the compute for INT8 conv 2d
Inputs are in INT8 datatype
Output is in INT32 datatype
"""
out_dtype = wkl.out_dtype
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
batch_size = data.shape[0]
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
# pack data
DOPAD = (HPAD != 0 or WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
else:
data_pad = data
# convolution
oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn)
kh = tvm.reduce_axis((0, wkl.hkernel), name='kh')
kw = tvm.reduce_axis((0, wkl.wkernel), name='kw')
# Intel performs dot product of 2 "4" Int8 values
# Current implementation requires ic_bn to be a multiple of 4
n_elems = 4
assert sch.ic_bn%n_elems == 0
ic_outer = tvm.reduce_axis((0, wkl.in_filter//(sch.ic_bn)), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, sch.ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh, ow*WSTR+kw,
ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) *
kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8',
tag="conv2d_NCHWc_int8")
return conv
def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last):
"""
Defines the schedule for INT8 for intel machines
Uses the Intel intrinsics to use INT8 operations
More details - https://software.intel.com/en-us/articles/
lower-numerical-precision-deep-learning-inference-and-training
"""
# Currently INT8 operations are supported for only Skylake
# In future the _intrin_reduce4int8 will be updated for VNNI instructions
# In case of unsupported target, the schedule will go to the original
# compute
target = tvm.target.current_target(allow_none=False)
int32_lanes = -1
if check_skylake(target):
int32_lanes = 16
else:
return s
assert int32_lanes != -1
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, _ = s[A].op.axis
parallel_axis = s[A].fuse(ic_chunk, ih)
s[A].parallel(parallel_axis)
# schedule 5-D NCHW[x]c conv
C, O = conv_out, last
CC = s.cache_write(C, 'global')
_, oc_chunk, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=sch.reg_n)
s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[C].fuse(oc_chunk, oh)
s[C].vectorize(oc_block)
if C == O:
s[C].parallel(parallel_axis)
s[CC].compute_at(s[C], ow_chunk)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis
ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n)
# Skylake and future processors have 16 vector lanes
assert sch.oc_bn % int32_lanes == 0
oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)
if sch.unroll_kw:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw,
ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
s[CC].unroll(kw)
else:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, kw, ic_f_inner,
ow_block, oc_f_inner, oc_s_inner, ic_s_inner)
pc = dot_16x1x16_int8_int8_int32()
s[CC].tensorize(oc_s_inner, pc)
s[CC].unroll(ow_block)
s[CC].unroll(oc_f_inner)
if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
return s
"""Core kernel of dot product of 4 Int8 operations"""
#pylint: disable=invalid-name
import tvm
def dot_16x1x16_int8_int8_int32():
"""
Int8 dot product by every 4 elements using AVX2 Skylake instructions.
This function takes two arrays of int8 datatype -- data[4] and
kernel[16][4] -- and computes a dot product of data[4] with every
4 elements of kernels, resulting in output[16] of int32 datatype.
The pseudo code is as follows.
.. code-block:: c
void dot_16x1x16_int8_int8_int32(int8 data[4], int8 kernel[16][4],
int32 output[16]){
for (int i = 0; i < 16; i++){
out[i] = 0;
for (int k = 0; k < 4; k++){
out[i] += data[k] * kernel[i][k]
}
}
}
Physically, the kernel array sits in an AVX512 vector register and
the data[4] is broadcasted to another AVX512 vector register. This
function returns a TensorIntrin that can be used to tensorize
a schedule.
Returns
-------
intrin : TensorIntrin
The Skylake int8 TensorIntrin that can be used in tensorizing schedule
"""
int32_lanes = 16 # 16 int32 lanes in AVX512
num_int8_elements = 4 # 4 int8 elements in int32
data = tvm.placeholder((num_int8_elements,), dtype='uint8', name='data')
kernel = tvm.placeholder((int32_lanes, num_int8_elements), dtype='int8', name='kernel')
k = tvm.reduce_axis((0, num_int8_elements), name='k')
C = tvm.compute((int32_lanes,),
lambda i: tvm.sum(data[k].astype('int32') *
kernel[i, k].astype('int32'),
axis=k),
name="C")
a_buffer = tvm.decl_buffer(data.shape, dtype='uint8', name="a_buffer",
offset_factor=1,
strides=[1])
b_buffer = tvm.decl_buffer(kernel.shape, dtype='int8', name="b_buffer",
offset_factor=1,
strides=[tvm.var('ldw'), 1])
def _intrin_func(ins, outs):
def _instr(index):
ib = tvm.ir_builder.create()
if index == 1:
ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16')))
return ib.get()
a_int8 = ins[0].vload([0], "uint8x4")
re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8)
vec_ai32 = re_int32.astype('int32x16')
vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai32)
vec_b = ins[1].vload([0, 0], "int8x64")
vec_one = tvm.const(1, "int16x32")
pair_reduction = tvm.call_llvm_intrin('int16x32',
'llvm.x86.avx512.pmaddubs.w.512',
tvm.const(0, 'uint32'),
vec_a, vec_b)
quad_reduction = tvm.call_llvm_intrin('int32x16',
'llvm.x86.avx512.pmaddw.d.512',
tvm.const(0, 'uint32'),
pair_reduction, vec_one)
if index == 0:
ib.emit(outs[0].vstore(0, quad_reduction))
else:
ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], 'int32x16')))
return ib.get()
# body, reset, update
return _instr(0), _instr(1), _instr(2)
with tvm.build_config(offset_factor=1, partition_const_loop=True):
return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})
#pylint: disable-msg=too-many-arguments, too-many-locals, assignment-from-no-return
""" Conv Int8 functional and performance testing"""
import sys
import logging
import numpy as np
import tvm
import topi
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
LOGGER = logging.getLogger('test_conv_int8_intel')
LOGGER.disabled = False
# All the WORKLOADS from Resnet except first layer
# Workload is ['height', 'width', 'in_filter', 'out_filter',
# 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
WORKLOADS = [(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
(56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
(56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
(56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
(28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
(28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
(28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
(14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
(56, 56, 64, 256, 1, 1, 0, 0, 1, 1),
(56, 56, 256, 64, 1, 1, 0, 0, 1, 1),
(56, 56, 256, 128, 1, 1, 0, 0, 2, 2),
(28, 28, 128, 512, 1, 1, 0, 0, 1, 1),
(56, 56, 256, 512, 1, 1, 0, 0, 2, 2),
(28, 28, 512, 128, 1, 1, 0, 0, 1, 1),
(28, 28, 512, 256, 1, 1, 0, 0, 2, 2),
(14, 14, 256, 1024, 1, 1, 0, 0, 1, 1),
(28, 28, 512, 1024, 1, 1, 0, 0, 2, 2),
(14, 14, 1024, 256, 1, 1, 0, 0, 1, 1),
(14, 14, 1024, 512, 1, 1, 0, 0, 2, 2),
(7, 7, 512, 2048, 1, 1, 0, 0, 1, 1),
(14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2),
(7, 7, 2048, 512, 1, 1, 0, 0, 1, 1)
]
TARGET_NAME = 'llvm -mcpu=skylake-avx512'
NUM_VEC_LANES = 16
CTX = tvm.context(TARGET_NAME, 0)
def get_shape(im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad,
hstride, wstride, out_dtype):
"""
Finds out the shape of all data structures
"""
## Find shapes
data_shape = (1, in_filter//NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES)
if out_dtype == 'int32':
if k_h != 1:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
NUM_VEC_LANES//4, NUM_VEC_LANES, 4)
else:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, NUM_VEC_LANES//4,
NUM_VEC_LANES, 4, k_h, k_w)
elif out_dtype == 'float32':
if k_h != 1:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w,
NUM_VEC_LANES, NUM_VEC_LANES)
else:
kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, NUM_VEC_LANES,
NUM_VEC_LANES, k_h, k_w)
out_height = (im_height + 2 * hpad - k_h) // hstride + 1
out_width = (im_width + 2 * wpad - k_w) // wstride + 1
o_shape = (1, out_filter//NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES)
return (data_shape, kernel_shape, o_shape)
def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_filter,
out_filter, k_h, k_w, hpad, wpad, hstride, wstride):
"""
Runs the inference and checks the functional correctness between
compute and schedule outputs
"""
(data_shape, kernel_shape, o_shape) = get_shape(im_height, im_width, in_filter,
out_filter, k_h, k_w, hpad, wpad,
hstride, wstride, out_dtype)
# Create TVM placeholders
data = tvm.placeholder(data_shape, name='data', dtype=data_dtype)
kernel = tvm.placeholder(kernel_shape, name='kernel', dtype=kernel_dtype)
# Create the numpy arrays to be used for executing conv models
if data_dtype == 'float32':
data_array = tvm.nd.array(np.random.rand(*data_shape).astype(dtype=data_dtype), CTX)
kernel_array = tvm.nd.array(np.random.rand(*kernel_shape).astype(dtype=kernel_dtype), CTX)
else:
data_array = tvm.nd.array(np.random.randint(100, size=data_shape).astype(data_dtype))
kernel_array = tvm.nd.array(np.random.randint(100, size=kernel_shape).astype(kernel_dtype))
# c_orig will be used for declaration ouptut
# c_sch will be used for scheduled computation output
c_orig = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX)
c_sch = tvm.nd.array(np.zeros(o_shape, dtype=out_dtype), CTX)
with tvm.target.create(TARGET_NAME):
conv = topi.nn.conv2d_NCHWc(data, kernel, num_filter=out_filter,
kernel_size=(k_h, k_w), stride=hstride,
padding=hpad, layout='NCHWc',
out_layout='NCHWc', out_dtype=out_dtype)
out = topi.nn.relu(conv)
sch = tvm.create_schedule(out.op)
func = tvm.build(sch, [data, kernel, out], target=TARGET_NAME, name='out')
func(data_array, kernel_array, c_orig)
LOGGER.debug(tvm.lower(sch, [data, kernel], simple_mode=True))
# Generate and run the optimized schedule
sconv = topi.generic.nn.schedule_conv2d_NCHWc(num_filter=out_filter,
kernel_size=(k_h, k_w),
strides=hstride,
padding=hpad,
layout='NCHWc',
out_layout='NCHWc',
outs=[out])
func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name='conv')
func(data_array, kernel_array, c_sch)
# Functional check
if data_dtype == 'uint8':
np.testing.assert_equal(c_orig.asnumpy(), c_sch.asnumpy())
else:
assert np.allclose(c_orig.asnumpy(), c_sch.asnumpy())
evaluator = func.time_evaluator(func.entry_name, CTX, number=1000)
LOGGER.debug(tvm.lower(sconv, [data, kernel], simple_mode=True))
return evaluator(data_array, kernel_array, c_sch).mean
if __name__ == "__main__":
LOGGER.info("Workload, Kernel_size, FP32_time, INT8_time, Speedup")
SPEEDUP_ARRAY = []
for i, wkl in enumerate(WORKLOADS):
fp32_time = run_inference('float32', 'float32', 'float32', *wkl)
int8_time = run_inference('uint8', 'int8', 'int32', *wkl)
kernel_h = wkl[4]
kernel_w = wkl[5]
LOGGER.info("Workload#" + str(i) + ", " + str(kernel_h) + "x" + str(kernel_w) + ", "
+ str(fp32_time) + ", " + str(int8_time) + ", " + str(fp32_time/int8_time))
SPEEDUP_ARRAY.append(fp32_time/int8_time)
LOGGER.info("Average speedup --> %s" % str(sum(SPEEDUP_ARRAY)/float(len(SPEEDUP_ARRAY))))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment