Commit 493c98d3 by Animesh Jain Committed by Zhi

[TOPI][x86] Legalize - Support int8xint8 convolution to use VNNI instructions. (#4196)

parent decccd6c
......@@ -546,9 +546,11 @@ def test_conv2d_int8_intrinsics():
n, h, w, ch, cw = 1, 64, 64, 3, 3
if data_layout == 'NCHW':
x = relay.var("x", relay.TensorType((n, ic, h, w), input_dtype))
data_shape = (n, ic, h, w)
x = relay.var("x", relay.TensorType(data_shape, input_dtype))
elif data_layout == 'NHWC':
x = relay.var("x", relay.TensorType((n, h, w, ic), input_dtype))
data_shape = (n, h, w, ic)
x = relay.var("x", relay.TensorType(data_shape, input_dtype))
else:
raise ValueError('Not supported')
......@@ -559,8 +561,8 @@ def test_conv2d_int8_intrinsics():
else:
raise ValueError('Not supported')
w = relay.var("w", relay.TensorType(kernel_shape, weight_dtype))
y = relay.nn.conv2d(x, w,
weight = relay.var("weight", relay.TensorType(kernel_shape, weight_dtype))
y = relay.nn.conv2d(x, weight,
kernel_size=(ch, cw),
channels=oc,
padding=(1, 1),
......@@ -568,11 +570,13 @@ def test_conv2d_int8_intrinsics():
data_layout=data_layout,
kernel_layout=kernel_layout,
out_dtype=output_dtype)
func = relay.Function([x, w], y)
func = relay.Function([x, weight], y)
wdata = np.random.rand(*kernel_shape) * 10
parameters = {"w": tvm.nd.array(wdata.astype(weight_dtype))}
parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))}
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=parameters)
assembly = lib.get_source("asm")
return assembly
......@@ -589,58 +593,63 @@ def test_conv2d_int8_intrinsics():
llvm_version = tvm.codegen.llvm_version_major()
for target in targets:
if llvm_version >= 8:
fast_int8_dtypes = ('uint8', 'int8', 'int32')
dtypes = ('uint8', 'int8', 'int32')
# Sweep the input channels to check int8 robustness
# Input channels should be a multiple of 4 internally.
for ic in [1, 4, 6]:
asm = _compile(ic=ic, oc=32, target=target, data_layout="NCHW",
asm = _compile(ic=ic, oc=16, target=target, data_layout="NCHW",
kernel_layout='OIHW',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)
for ic in [1, 4, 6]:
asm = _compile(ic=ic, oc=32, target=target, data_layout="NHWC",
asm = _compile(ic=ic, oc=16, target=target, data_layout="NHWC",
kernel_layout='HWIO',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)
# Sweep the output channels to check int8 robustness
# Output channels should be a multiple of 16 internally.
for oc in [4, 16, 20]:
asm = _compile(ic=16, oc=oc, target=target, data_layout="NCHW",
asm = _compile(ic=8, oc=oc, target=target, data_layout="NCHW",
kernel_layout='OIHW',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)
for oc in [4, 16, 20]:
asm = _compile(ic=16, oc=oc, target=target, data_layout="NHWC",
asm = _compile(ic=8, oc=oc, target=target, data_layout="NHWC",
kernel_layout='HWIO',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)
# Check that both non-divisible oc and ic work
asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)
asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=fast_int8_dtypes)
dtypes=dtypes)
assert _has_fast_int8_instructions(asm, target)
# Ensure that code is generated when datatypes are not HW supported.
dtypes = ('int8', 'int8', 'int32')
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
# Check that int8 x int8 goes through legalization so that fast instructions can be picked up.
for target in targets:
if llvm_version >= 8:
dtypes = (('int8', 'int8', 'int32'))
# Check that both non-divisible oc and ic work
asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
dtypes=dtypes)
# Check that intrinisic is not present in the assembly.
assert not _has_fast_int8_instructions(asm, target)
assert _has_fast_int8_instructions(asm, target)
# Ensure that code is generated when datatypes are not HW supported.
dtypes = ('uint8', 'uint8', 'int32')
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=dtypes)
# Check that intrinisic is not present in the assembly.
assert not _has_fast_int8_instructions(asm, target)
assert _has_fast_int8_instructions(asm, target)
# Ensure that code is generated when datatypes are not HW supported.
dtypes = ('uint8', 'uint8', 'int32')
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=dtypes)
# Check that intrinisic is not present in the assembly.
assert not _has_fast_int8_instructions(asm, target)
# Check that a vectorized instruction is generated for older Intel
# generations, because we default to NCHWc layout.
......
......@@ -192,24 +192,72 @@ def _conv2d_legalize(attrs, inputs, arg_types):
The legalized expr
"""
# Dilation not supported yet. Return None if dilation is not (1, 1)
dilation = attrs.get_int_tuple("dilation")
if not (dilation[0] == 1 and dilation[1] == 1):
return None
# Collect the input tensors.
data_tensor, kernel_tensor = arg_types[0], arg_types[1]
data_dtype = data_tensor.dtype
kernel_dtype = kernel_tensor.dtype
# Collect the output tensor.
output_tensor = arg_types[2]
# Collect the input exprs.
data, kernel = inputs
# Get the conv attrs
new_attrs = {k: attrs[k] for k in attrs.keys()}
is_int8_inputs = False
# If both the inputs are int8, we can add 128 to make the input dtype uint8, and then adjust the
# output. This will help picking up Intel VNNI instructions.
# Original --> C = A (conv) B
# A and B are int8
# C = (A + 128 - 128) (conv) B
# C = (A' conv B) - 128 (conv) B
# where A' = A + 128
# and 128 (conv) B is basically a reduce on CRS axis for weights.
if data_tensor.dtype == 'int8' and kernel_tensor.dtype == 'int8':
is_int8_inputs = True
padding = attrs.get_int_tuple("padding")
if attrs['data_layout'] == 'NHWC' and attrs['kernel_layout'] == 'HWIO':
adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(0, 1, 2))
pad_width = ((0, 0), (padding[0], padding[0]), (padding[1], padding[1]), (0, 0))
elif attrs['data_layout'] == 'NCHW' and attrs['kernel_layout'] == 'OIHW':
pad_width = ((0, 0), (0, 0), (padding[0], padding[0]), (padding[1], padding[1]))
adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(1, 2, 3))
adjust_shift = relay.expand_dims(adjust_shift, axis=1, num_newaxis=2)
else:
return None
data = relay.cast(data, 'int32')
data = relay.add(data, relay.const(128, 'int32'))
data = relay.cast(data, 'uint8')
# Do external padding as pad value has to be 128.
if not (padding[0] == 0 and padding[1] == 0):
data = relay.nn.pad(data, pad_width=pad_width, pad_value=128)
new_attrs['padding'] = (0, 0)
# The data type is now shifted to uint8
data_dtype = 'uint8'
# Multiply 128 to adjust shift.
adjust_shift = relay.multiply(adjust_shift, relay.const(128, 'int32'))
# Legalize if the datatypes are suitable for fast Int8 instructions. Int8 instructions require
# input channel to be a multiple of 4 and output channels to be a multiple of 16. For input
# channels, we pad both the inputs and weights input channels. For output channels, we pad the
# weight and stride_slice the output.
if _is_int8_hw_support(data_tensor.dtype, kernel_tensor.dtype):
if _is_int8_hw_support(data_dtype, kernel_dtype):
# Flags to remember if the expr is modified
ic_modified = False
oc_modified = False
# Collect the input exprs.
data, kernel = inputs
# Find the value of input and output channel.
in_channel = -1
out_channel = -1
......@@ -250,16 +298,16 @@ def _conv2d_legalize(attrs, inputs, arg_types):
else:
return None
if not (ic_modified or oc_modified):
return None
if ic_modified and not oc_modified:
return relay.nn.conv2d(data, kernel, **attrs)
if oc_modified:
new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs['channels'] = new_out_channel
out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
original_out_shape = [x.value for x in output_tensor.shape]
return relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape)
out = relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape)
else:
out = relay.nn.conv2d(data, kernel, **new_attrs)
if is_int8_inputs:
out = relay.subtract(out, adjust_shift)
return out
return None
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment