Commit b0ddcff6 by Animesh Jain Committed by Yizhi Liu

[Relay] Legalize and AlterOpLayout for Int8 Intel. (#3961)

parent 92439166
...@@ -541,18 +541,35 @@ def test_upsampling(): ...@@ -541,18 +541,35 @@ def test_upsampling():
def test_conv2d_int8_intrinsics(): def test_conv2d_int8_intrinsics():
def _compile(input_dtype, weight_dtype, output_dtype, target): def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
n, ic, h, w, oc, ch, cw = 1, 16, 224, 224, 32, 3, 3 input_dtype, weight_dtype, output_dtype = dtypes
x = relay.var("x", relay.TensorType((n, ic, h, w), input_dtype))
w = relay.var("w", relay.TensorType((oc, ic, ch, cw), weight_dtype)) n, h, w, ch, cw = 1, 64, 64, 3, 3
if data_layout == 'NCHW':
x = relay.var("x", relay.TensorType((n, ic, h, w), input_dtype))
elif data_layout == 'NHWC':
x = relay.var("x", relay.TensorType((n, h, w, ic), input_dtype))
else:
raise ValueError('Not supported')
if kernel_layout == 'OIHW':
kernel_shape = (oc, ic, ch, cw)
elif kernel_layout == 'HWIO':
kernel_shape = (ch, cw, ic, oc)
else:
raise ValueError('Not supported')
w = relay.var("w", relay.TensorType(kernel_shape, weight_dtype))
y = relay.nn.conv2d(x, w, y = relay.nn.conv2d(x, w,
kernel_size=(ch, cw), kernel_size=(ch, cw),
channels=oc, channels=oc,
padding=(1, 1), padding=(1, 1),
dilation=(1, 1), dilation=(1, 1),
data_layout=data_layout,
kernel_layout=kernel_layout,
out_dtype=output_dtype) out_dtype=output_dtype)
func = relay.Function([x, w], y) func = relay.Function([x, w], y)
wdata = np.random.rand(oc, ic, ch, cw) * 10 wdata = np.random.rand(*kernel_shape) * 10
parameters = {"w": tvm.nd.array(wdata.astype(weight_dtype))} parameters = {"w": tvm.nd.array(wdata.astype(weight_dtype))}
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=parameters) graph, lib, params = relay.build(func, target, params=parameters)
...@@ -564,37 +581,59 @@ def test_conv2d_int8_intrinsics(): ...@@ -564,37 +581,59 @@ def test_conv2d_int8_intrinsics():
name = "llvm.x86.avx512.pmaddubs.w.512" name = "llvm.x86.avx512.pmaddubs.w.512"
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name) llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name)
if llvm_id != 0: if llvm_id != 0:
# Intel Int8 instruction need uint8 data and int8 kernel fast_int8_dtypes = ('uint8', 'int8', 'int32')
asm = _compile(input_dtype="uint8", # Sweep the input channels to check int8 robustness
weight_dtype="int8", for ic in range(1, 24):
output_dtype="int32", asm = _compile(ic=ic, oc=32, target=target, data_layout="NCHW", kernel_layout='OIHW',
target=target) dtypes=fast_int8_dtypes)
# Check that intrinisic is present in the assembly. assert "pmaddubs" in asm
for ic in range(1, 24):
asm = _compile(ic=ic, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=fast_int8_dtypes)
assert "pmaddubs" in asm
# Sweep the output channels to check int8 robustness
for oc in range(2, 24):
asm = _compile(ic=16, oc=oc, target=target, data_layout="NCHW", kernel_layout='OIHW',
dtypes=fast_int8_dtypes)
assert "pmaddubs" in asm
for oc in range(2, 24):
asm = _compile(ic=16, oc=oc, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=fast_int8_dtypes)
assert "pmaddubs" in asm
# Check that both non-divisible oc and ic work
asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
dtypes=fast_int8_dtypes)
assert "pmaddubs" in asm
asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
dtypes=fast_int8_dtypes)
assert "pmaddubs" in asm assert "pmaddubs" in asm
# Ensure that code is generated when datatypes are not HW supported. # Ensure that code is generated when datatypes are not HW supported.
asm = _compile(input_dtype="int8", dtypes = ('int8', 'int8', 'int32')
weight_dtype="int8", asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
output_dtype="int32", dtypes=dtypes)
target=target)
# Check that intrinisic is not present in the assembly. # Check that intrinisic is not present in the assembly.
assert "pmaddubs" not in asm assert "pmaddubs" not in asm
# Ensure that code is generated when datatypes are not HW supported. # Ensure that code is generated when datatypes are not HW supported.
asm = _compile(input_dtype="uint8", dtypes = ('uint8', 'uint8', 'int32')
weight_dtype="uint8", asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
output_dtype="int32", dtypes=dtypes)
target=target)
# Check that intrinisic is not present in the assembly. # Check that intrinisic is not present in the assembly.
assert "pmaddubs" not in asm assert "pmaddubs" not in asm
# Check that a vectorized instruction is generated for older Intel # Check that a vectorized instruction is generated for older Intel
# generations, because we default to NCHWc layout. # generations, because we default to NCHWc layout.
target = "llvm -mcpu=core-avx2" target = "llvm -mcpu=core-avx2"
asm = _compile(input_dtype="int8", fast_int8_dtypes = ('uint8', 'int8', 'int32')
weight_dtype="int8", asm = _compile(ic=16, oc=32, target=target, data_layout="NCHW", kernel_layout='OIHW',
output_dtype="int32", dtypes=fast_int8_dtypes)
target=target)
# Check that vector int mult and add instructions are generated. # Check that vector int mult and add instructions are generated.
assert "vpmulld" in asm and "vpadd" in asm assert "vpmulld" in asm and "vpadd" in asm
......
...@@ -151,7 +151,7 @@ def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'): ...@@ -151,7 +151,7 @@ def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'):
if data_layout == 'NCHW': if data_layout == 'NCHW':
CO, CIG, KH, KW = [x.value for x in kernel.shape] CO, CIG, KH, KW = [x.value for x in kernel.shape]
else: else:
KH, KW, CO, CIG = [x.value for x in kernel.shape] KH, KW, CIG, CO = [x.value for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
GRPS = CI // CIG GRPS = CI // CIG
......
...@@ -17,3 +17,4 @@ from .batch_matmul import schedule_batch_matmul ...@@ -17,3 +17,4 @@ from .batch_matmul import schedule_batch_matmul
from .roi_align import roi_align_nchw from .roi_align import roi_align_nchw
from .conv2d_transpose import _schedule_conv2d_transpose_nchw from .conv2d_transpose import _schedule_conv2d_transpose_nchw
from .sparse import * from .sparse import *
from .conv2d_alter_op import *
...@@ -26,40 +26,16 @@ from tvm.autotvm.task.topi_integration import deserialize_args ...@@ -26,40 +26,16 @@ from tvm.autotvm.task.topi_integration import deserialize_args
from tvm.autotvm.task import get_config from tvm.autotvm.task import get_config
from .. import generic, tag from .. import generic, tag
from .. import nn from .. import nn
from ..util import get_const_tuple, get_shape from ..nn.conv2d import conv2d, conv2d_NCHWc, \
from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_NCHWc_int8, \ conv2d_infer_layout, _get_workload as _get_conv2d_workload
conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw
from ..nn.pad import pad from ..nn.pad import pad
from ..util import get_const_tuple
from . import conv2d_avx_1x1, conv2d_avx_common from . import conv2d_avx_1x1, conv2d_avx_common
logger = logging.getLogger('topi') logger = logging.getLogger('topi')
def _is_int8_hw_support(data_dtype, kernel_dtype, target):
"""
Checks to ensure that we can use Intel DLBoost instructions
1) The datatypes are correct.
2) LLVM version has support for the instructions.
3) Target is skylake and above.
"""
# 1) Check datatypes
is_dtype_support = data_dtype == 'uint8' and kernel_dtype == 'int8'
# 2) Check LLVM support
llvm_intrin_fast_int8 = "llvm.x86.avx512.pmaddubs.w.512"
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(llvm_intrin_fast_int8)
is_llvm_support = llvm_id != 0
# 3) Check target
is_target_support = False
for opt in target.options:
if opt == '-mcpu=skylake-avx512':
is_target_support = True
return is_dtype_support and is_llvm_support and is_target_support
def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False,
layout='NCHW'): layout='NCHW'):
""" """
...@@ -353,133 +329,6 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs): ...@@ -353,133 +329,6 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs):
return s, [new_data, new_kernel, C] return s, [new_data, new_kernel, C]
@conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfo, F):
copy_inputs = [s for s in inputs]
new_attrs = {k : attrs[k] for k in attrs.keys()}
if F.__name__ == 'tvm.relay.op':
# Derive channels for frontends (e.g ONNX) that miss "channel" field.
new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')]
data, kernel = tinfo[0], tinfo[1]
batch_size, in_channel, height, width = get_const_tuple(data.shape)
groups = attrs.get_int("groups")
out_channel = attrs.get_int("channels") \
if F.__name__ == 'nnvm.symbol' else new_attrs["channels"]
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
out_dtype = attrs["out_dtype"]
layout_name = 'layout' if F.__name__ == 'nnvm.symbol' else 'data_layout'
layout = attrs[layout_name]
kh, kw = attrs.get_int_tuple("kernel_size")
dtype = data.dtype
out_dtype = dtype if out_dtype in ("same", "") else out_dtype
kshape = get_shape(kernel.shape, attrs["kernel_layout"], "OIHW")
is_depthwise = groups == kshape[0] and kshape[1] == 1
# only optimize for NCHW
if layout != 'NCHW' or attrs["kernel_layout"] != "OIHW":
return None
if groups != 1 and not is_depthwise:
return None
dispatch_ctx = autotvm.task.DispatchContext.current
target = tvm.target.current_target()
# query schedule and fallback if necessary
workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) \
if is_depthwise else \
autotvm.task.args_to_workload(
[data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback:
_get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise)
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
new_attrs[layout_name] = 'NCHW%dc' % ic_bn
new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
dtype=data.dtype)
if is_depthwise:
new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn
# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc)
dispatch_ctx.update(target, new_workload, cfg)
if F.__name__ == 'nnvm.symbol':
logging.warning("Use native layout for depthwise convolution on NNVM.")
return None
return F.nn.contrib_depthwise_conv2d_nchwc(*copy_inputs, **new_attrs)
if _is_int8_hw_support(data.dtype, kernel.dtype, target):
# Convert kernel data layout from 4D to 7D
n_elems = 4
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
data_expr, kernel_expr = inputs
kernel_IHWO = F.transpose(kernel_expr, axes=(1, 2, 3, 0))
kernel_IHWOo = F.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel//oc_bn, oc_bn))
kernel_OHWoI = F.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0))
kernel_OHWoIi = F.reshape(kernel_OHWoI, (out_channel//oc_bn, kh, kw, oc_bn,
in_channel//ic_bn, ic_bn))
kernel_OHWoIie = F.reshape(kernel_OHWoIi, (out_channel//oc_bn, kh, kw, oc_bn,
in_channel//ic_bn, ic_bn//n_elems, n_elems))
kernel_OIHWioe = F.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6))
copy_inputs = [data_expr, kernel_OIHWioe]
# Store altered operator's config. New kernel layout OIHWio4
new_kernel = tvm.placeholder((out_channel // oc_bn,
in_channel // ic_bn,
kh,
kw,
ic_bn // n_elems,
oc_bn,
n_elems), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload([new_data,
new_kernel,
strides,
padding,
dilation,
new_attrs[layout_name],
new_attrs['out_layout'],
out_dtype],
conv2d_NCHWc_int8)
dispatch_ctx.update(target, new_workload, cfg)
if F.__name__ == 'nnvm.symbol':
logging.warning("Use native layout for int8 convolution on NNVM.")
return None
return F.nn.contrib_conv2d_nchwc_int8(*copy_inputs, **new_attrs)
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn,
kh, kw, ic_bn, oc_bn), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
dispatch_ctx.update(target, new_workload, cfg)
if F.__name__ == 'nnvm.symbol':
return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
@conv2d_infer_layout.register("cpu") @conv2d_infer_layout.register("cpu")
def _conv2d_infer_layout(workload, cfg): def _conv2d_infer_layout(workload, cfg):
_, data, kernel, strides, padding, dilation, layout, dtype = workload _, data, kernel, strides, padding, dilation, layout, dtype = workload
......
...@@ -57,6 +57,36 @@ def _fallback_schedule(cfg, wkl): ...@@ -57,6 +57,36 @@ def _fallback_schedule(cfg, wkl):
raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) raise ValueError("cannot decide default schedule for workload: {}".format(wkl))
def _fallback_schedule_int8(cfg, wkl):
simd_width = get_fp32_len()
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
oc_bn = 16
assert wkl.out_filter % oc_bn == 0
ic_bn = 1
for bn in range(oc_bn, 0, -4):
if wkl.in_filter % bn == 0:
ic_bn = bn
break
assert wkl.in_filter % 4 == 0
for ow_factor in range(out_width, 0, -1):
if out_width % ow_factor == 0:
for oh_factor in range(out_height, 0, -1):
if out_height % oh_factor == 0 and ow_factor * oh_factor < 32:
cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
cfg["tile_oh"] = OtherOptionEntity(oh_factor)
cfg["tile_ow"] = SplitEntity([out_width // ow_factor, ow_factor])
return
raise ValueError("cannot decide default schedule for workload: {}".format(wkl))
def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last): def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
# fetch schedule # fetch schedule
ic_bn, oc_bn, oh_factor, ow_factor = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], ic_bn, oc_bn, oh_factor, ow_factor = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
......
...@@ -55,6 +55,34 @@ def _fallback_schedule(cfg, wkl): ...@@ -55,6 +55,34 @@ def _fallback_schedule(cfg, wkl):
cfg["unroll_kw"] = OtherOptionEntity(False) cfg["unroll_kw"] = OtherOptionEntity(False)
def _fallback_schedule_int8(cfg, wkl):
simd_width = get_fp32_len()
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
oc_bn = 16
assert wkl.out_filter % oc_bn == 0
ic_bn = 1
for bn in range(oc_bn, 0, -4):
if wkl.in_filter % bn == 0:
ic_bn = bn
break
assert wkl.in_filter % 4 == 0
reg_n = 1
for n in range(31, 0, -1):
if out_width % n == 0:
reg_n = n
break
cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
cfg["unroll_kw"] = OtherOptionEntity(False)
def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last): def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
# fetch schedule # fetch schedule
ic_bn, oc_bn, reg_n, unroll_kw = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], ic_bn, oc_bn, reg_n, unroll_kw = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
......
...@@ -22,12 +22,52 @@ import tvm ...@@ -22,12 +22,52 @@ import tvm
from tvm import autotvm from tvm import autotvm
from tvm.autotvm.task import get_config from tvm.autotvm.task import get_config
from tvm.autotvm.task.topi_integration import deserialize_args from tvm.autotvm.task.topi_integration import deserialize_args
from ..nn.conv2d import _get_workload as _get_conv2d_workload
from .. import generic, tag from .. import generic, tag
from ..util import get_const_tuple from ..util import get_const_tuple
from ..nn.conv2d import conv2d_NCHWc_int8 from ..nn.conv2d import conv2d_NCHWc_int8
from .. import nn from .. import nn
from . import conv2d_avx_1x1, conv2d_avx_common from . import conv2d_avx_1x1, conv2d_avx_common
def _get_default_config_int8(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False,
layout='NCHW'):
"""
Get default schedule config for the workload
"""
assert not is_depthwise, "Depthwise Int8 not supported"
wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
if is_kernel_1x1:
conv2d_avx_1x1._fallback_schedule_int8(cfg, wkl)
else:
conv2d_avx_common._fallback_schedule_int8(cfg, wkl)
def _is_int8_hw_support(data_dtype, kernel_dtype):
"""
Checks to ensure that we can use Intel DLBoost instructions
1) The datatypes are correct.
2) LLVM version has support for the instructions.
3) Target is skylake and above.
"""
# 1) Check datatypes
is_dtype_support = data_dtype == 'uint8' and kernel_dtype == 'int8'
# 2) Check LLVM support
llvm_intrin_fast_int8 = "llvm.x86.avx512.pmaddubs.w.512"
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(llvm_intrin_fast_int8)
is_llvm_support = llvm_id != 0
# 3) Check target
target = tvm.target.current_target()
is_target_support = False
for opt in target.options:
if opt == '-mcpu=skylake-avx512':
is_target_support = True
return is_dtype_support and is_llvm_support and is_target_support
def _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, layout): def _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, layout):
"""Create schedule configuration from input arguments""" """Create schedule configuration from input arguments"""
dshape = get_const_tuple(data.shape) dshape = get_const_tuple(data.shape)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment