Commit 3ada7c0e by Animesh Jain Committed by Tianqi Chen

Checking the correct dtypes for choosing the Intel int8 instructions. (#3516)

parent 9e6a8c0d
......@@ -517,6 +517,65 @@ def test_upsampling():
_test_upsampling("NHWC", "BILINEAR")
def test_conv2d_int8_intrinsics():
def _compile(input_dtype, weight_dtype, output_dtype, target):
n, ic, h, w, oc, ch, cw = 1, 16, 224, 224, 32, 3, 3
x = relay.var("x", relay.TensorType((n, ic, h, w), input_dtype))
w = relay.var("w", relay.TensorType((oc, ic, ch, cw), weight_dtype))
y = relay.nn.conv2d(x, w,
kernel_size=(ch, cw),
channels=oc,
padding=(1, 1),
dilation=(1, 1),
out_dtype=output_dtype)
func = relay.Function([x, w], y)
wdata = np.random.rand(oc, ic, ch, cw) * 10
parameters = {"w": tvm.nd.array(wdata.astype(weight_dtype))}
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=parameters)
assembly = lib.get_source("asm")
return assembly
# compile conv2d for x86 (skylake) and test assembly contains *pmadd* instructions
target = "llvm -mcpu=skylake-avx512"
name = "llvm.x86.avx512.pmaddubs.w.512"
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name)
if llvm_id != 0:
# Intel Int8 instruction need uint8 data and int8 kernel
asm = _compile(input_dtype="uint8",
weight_dtype="int8",
output_dtype="int32",
target=target)
# Check that intrinisic is present in the assembly.
assert "pmaddubs" in asm
# Ensure that code is generated when datatypes are not HW supported.
asm = _compile(input_dtype="int8",
weight_dtype="int8",
output_dtype="int32",
target=target)
# Check that intrinisic is not present in the assembly.
assert "pmaddubs" not in asm
# Ensure that code is generated when datatypes are not HW supported.
asm = _compile(input_dtype="uint8",
weight_dtype="uint8",
output_dtype="int32",
target=target)
# Check that intrinisic is not present in the assembly.
assert "pmaddubs" not in asm
# Check that a vectorized instruction is generated for older Intel
# generations, because we default to NCHWc layout.
target = "llvm -mcpu=core-avx2"
asm = _compile(input_dtype="int8",
weight_dtype="int8",
output_dtype="int32",
target=target)
# Check that vector int mult and add instructions are generated.
assert "vpmulld" in asm and "vpadd" in asm
if __name__ == "__main__":
test_pool2d()
test_avg_pool2d_no_count_pad()
......@@ -532,3 +591,4 @@ if __name__ == "__main__":
test_conv2d_run()
test_batch_flatten()
test_upsampling()
test_conv2d_int8_intrinsics()
......@@ -391,10 +391,7 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn
if data.dtype == 'uint8':
oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape)
else:
oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape)
oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn
# output shape
......@@ -413,26 +410,6 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')
if data.dtype == 'uint8':
assert out_dtype == "int32", \
"INT8 convolution requires input dtype = uint8 and output dtype=int32"
# Intel performs dot product of 2 "4" Int8 values
# Current implementation requires ic_bn to be a multiple of 4
n_elems = 4
assert ic_bn % n_elems == 0
ic_outer = tvm.reduce_axis((0, in_channel//ic_bn), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh, ow*WSTR+kw,
ic_f_inner * n_elems + ic_s_inner]
.astype(out_dtype) *
kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
# else: fp implementation
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw,
ic%ic_bn].astype(out_dtype) *
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,invalid-name,unused-argument
"""Checks different x86 targets for target specific schedules"""
def check_skylake(target):
"""
Checks if the target is skylake
"""
for opt in target.options:
if opt == '-mcpu=skylake-avx512':
return True
return False
......@@ -37,6 +37,29 @@ from . import conv2d_avx_1x1, conv2d_avx_common
logger = logging.getLogger('topi')
def _is_int8_hw_support(data_dtype, kernel_dtype, target):
"""
Checks to ensure that we can use Intel DLBoost instructions
1) The datatypes are correct.
2) LLVM version has support for the instructions.
3) Target is skylake and above.
"""
# 1) Check datatypes
is_dtype_support = data_dtype == 'uint8' and kernel_dtype == 'int8'
# 2) Check LLVM support
llvm_intrin_fast_int8 = "llvm.x86.avx512.pmaddubs.w.512"
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(llvm_intrin_fast_int8)
is_llvm_support = llvm_id != 0
# 3) Check target
is_target_support = False
for opt in target.options:
if opt == '-mcpu=skylake-avx512':
is_target_support = True
return is_dtype_support and is_llvm_support and is_target_support
def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False,
layout='NCHW'):
"""
......@@ -68,7 +91,8 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
kh, kw, oc, _ = kshape
elif pat.match(layout) is not None:
n, ic_chunk, h, w, ic_bn = dshape
if data.dtype == 'uint8':
target = tvm.target.current_target(allow_none=False)
if _is_int8_hw_support(data.dtype, kernel.dtype, target):
oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape
ic = ic_chunk*ic_bn
assert ic == k_ic*k_ic_f*kic_s
......@@ -276,7 +300,6 @@ def schedule_conv2d_nhwc_pack(cfg, outs):
args = [s, cfg, data_vec, conv_out, outs[0]]
if data.dtype == 'uint8':
# int8 conv kernel is 7-dim
kh, kw, _, _, _ = get_const_tuple(kernel.shape)
if kh == 1 and kw == 1:
conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args)
......@@ -453,19 +476,42 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc)
dispatch_ctx.update(target, new_workload, cfg)
else:
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
dispatch_ctx.update(target, new_workload, cfg)
if _is_int8_hw_support(data.dtype, kernel.dtype, target):
# Convert kernel data layout from 4D to 7D
n_elems = 4
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
data_expr, kernel_expr = inputs
kernel_IHWO = F.transpose(kernel_expr, axes=(1, 2, 3, 0))
kernel_IHWOo = F.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel//oc_bn, oc_bn))
kernel_OHWoI = F.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0))
kernel_OHWoIi = F.reshape(kernel_OHWoI, (out_channel//oc_bn, kh, kw, oc_bn,
in_channel//ic_bn, ic_bn))
kernel_OHWoIie = F.reshape(kernel_OHWoIi, (out_channel//oc_bn, kh, kw, oc_bn,
in_channel//ic_bn, ic_bn//n_elems, n_elems))
kernel_OIHWioe = F.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6))
copy_inputs = [data_expr, kernel_OIHWioe]
# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, kh, kw, oc_bn,
in_channel//ic_bn, ic_bn//n_elems,
n_elems))
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation,
new_attrs[layout_name], new_attrs['out_layout'], out_dtype],
conv2d_NCHWc)
dispatch_ctx.update(target, new_workload, cfg)
else:
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn,
kh, kw, ic_bn, oc_bn), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
dispatch_ctx.update(target, new_workload, cfg)
if is_depthwise:
if F.__name__ == 'nnvm.symbol':
......@@ -505,7 +551,8 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn
if data.dtype == 'uint8':
target = tvm.target.current_target(allow_none=False)
if _is_int8_hw_support(data.dtype, kernel.dtype, target):
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \
get_const_tuple(kernel.shape)
else:
......@@ -539,7 +586,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')
if data.dtype == 'uint8' and groups == 1:
if _is_int8_hw_support(data.dtype, kernel.dtype, target) and groups == 1:
assert out_dtype == "int32", \
"INT8 convolution requires input dtype = uint8 and output dtype=int32"
# Intel performs dot product of 2 "4" Int8 values
......@@ -559,7 +606,8 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
oc_block, ic_s_inner].astype(out_dtype),
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
if data.dtype == 'uint8':
if _is_int8_hw_support(data.dtype, kernel.dtype, target):
# for int8 group conv support
n_elems = 4
ic_chunk = in_channel//ic_bn
......@@ -615,7 +663,8 @@ def _schedule_conv2d_NCHWc(cfg, outs):
data = data_pad.op.input_tensors[0]
args = [s, cfg, data_vec, conv_out, outs[0]]
if data.dtype == 'uint8':
target = tvm.target.current_target(allow_none=False)
if _is_int8_hw_support(data.dtype, kernel.dtype, target):
# int8 conv kernel is 7-dim
_, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape)
if kh == 1 and kw == 1:
......
......@@ -24,7 +24,6 @@ from ..nn.pad import pad
from ..nn.util import infer_pad, get_pad_tuple
from ..util import get_const_tuple, simplify
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake
from .util import get_fp32_len
def _fallback_schedule(cfg, wkl):
......@@ -187,13 +186,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
More details - https://software.intel.com/en-us/articles/
lower-numerical-precision-deep-learning-inference-and-training
"""
target = tvm.target.current_target(allow_none=False)
int32_lanes = -1
if check_skylake(target):
int32_lanes = 16
else:
return s
assert int32_lanes != -1
int32_lanes = 16
oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1]
_, _, _, _, ic_bn = get_const_tuple(data.shape)
......@@ -310,13 +303,11 @@ def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last):
packing of weight to make the address access be friendly to int8
intrinsic
"""
target = tvm.target.current_target(allow_none=False)
int32_lanes = -1
if check_skylake(target):
int32_lanes = 16
else:
return s
assert int32_lanes != -1
# FIXME - https://github.com/dmlc/tvm/issues/3598
# pylint: disable=unreachable
return s
int32_lanes = 16
# assertion to fail the unhandled case
_, _, _, ic_num = get_const_tuple(data.shape)
......
......@@ -23,7 +23,6 @@ from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from ..nn.util import infer_pad
from ..util import get_const_tuple
from .tensor_intrin import dot_16x1x16_int8_int8_int32
from .check_targets import check_skylake
from .util import get_fp32_len
def _fallback_schedule(cfg, wkl):
......@@ -186,19 +185,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
More details - https://software.intel.com/en-us/articles/
lower-numerical-precision-deep-learning-inference-and-training
"""
# Currently INT8 operations are supported for only Skylake
# In future the _intrin_reduce4int8 will be updated for VNNI instructions
# In case of unsupported target, the schedule will go to the original
# compute
target = tvm.target.current_target(allow_none=False)
int32_lanes = -1
if check_skylake(target):
int32_lanes = 16
else:
return s
assert int32_lanes != -1
int32_lanes = 16
reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val
_, _, _, _, ic_bn = get_const_tuple(data.shape)
......
......@@ -24,6 +24,7 @@ import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from nose.tools import nottest
from common import get_all_backend
......@@ -97,15 +98,22 @@ def verify_group_conv2d_NCHWc_int8(batch, in_channel, groups, in_size, num_filte
func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3)
# for device in ["llvm -mcpu=skylake-avx512"]:
for device in ["llvm"]:
# for device in ["llvm"]:
for device in ["llvm -mcpu=skylake-avx512"]:
with autotvm.tophub.context(device): # load tophub pre-tuned parameters
check_device(device)
@nottest
def test_conv2d_NCHWc():
# ResNet50 workloads
verify_group_conv2d_NCHWc_int8(1, 256, 32, 224, 64, 7, 2, 3)
if __name__ == "__main__":
test_conv2d_NCHWc()
# The test requires Skylake and newer Intel machines to generate the correct
# instruction. This test directly calls the topi operator, requiring correct
# kernel shape. For older generation of Intel machines, the kernel needs to
# be 6D. This test tests 7D kernel, that can only work on Skylake+ machines.
# So, disabling the test.
# test_conv2d_NCHWc()
pass
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment