Commit 3edf5260 by Animesh Jain Committed by Yizhi Liu

[TOPI] Setting up AutoTVM template for Intel Int8 conv2D (#3955)

parent c846d17c
...@@ -85,6 +85,7 @@ class TaskExtractEnv: ...@@ -85,6 +85,7 @@ class TaskExtractEnv:
topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw", topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw", topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
topi.nn.conv2d_NCHWc: "topi_x86_conv2d_NCHWc", topi.nn.conv2d_NCHWc: "topi_x86_conv2d_NCHWc",
topi.nn.conv2d_NCHWc_int8: "topi_x86_conv2d_NCHWc_int8",
topi.nn.dense: "topi_nn_dense", topi.nn.dense: "topi_nn_dense",
topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw", topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc", topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
...@@ -100,6 +101,7 @@ class TaskExtractEnv: ...@@ -100,6 +101,7 @@ class TaskExtractEnv:
topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw], topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw], topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
topi.nn.conv2d_NCHWc: [topi.generic.schedule_conv2d_NCHWc], topi.nn.conv2d_NCHWc: [topi.generic.schedule_conv2d_NCHWc],
topi.nn.conv2d_NCHWc_int8: [topi.generic.schedule_conv2d_NCHWc_int8],
topi.nn.dense: [topi.generic.schedule_dense], topi.nn.dense: [topi.generic.schedule_dense],
topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw], topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc], topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
...@@ -111,6 +113,7 @@ class TaskExtractEnv: ...@@ -111,6 +113,7 @@ class TaskExtractEnv:
self.func_to_reflection = { self.func_to_reflection = {
topi.nn.conv2d: lambda x: setattr(topi.nn, 'conv2d', x), topi.nn.conv2d: lambda x: setattr(topi.nn, 'conv2d', x),
topi.nn.conv2d_NCHWc: lambda x: setattr(topi.nn, 'conv2d_NCHWc', x), topi.nn.conv2d_NCHWc: lambda x: setattr(topi.nn, 'conv2d_NCHWc', x),
topi.nn.conv2d_NCHWc_int8: lambda x: setattr(topi.nn, 'conv2d_NCHWc_int8', x),
topi.nn.depthwise_conv2d_nchw: lambda x: setattr(topi.nn, 'depthwise_conv2d_nchw', x), topi.nn.depthwise_conv2d_nchw: lambda x: setattr(topi.nn, 'depthwise_conv2d_nchw', x),
topi.nn.group_conv2d_nchw: lambda x: setattr(topi.nn, 'group_conv2d_nchw', x), topi.nn.group_conv2d_nchw: lambda x: setattr(topi.nn, 'group_conv2d_nchw', x),
topi.nn.conv2d_transpose_nchw: lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x), topi.nn.conv2d_transpose_nchw: lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x),
......
...@@ -669,7 +669,6 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout, ...@@ -669,7 +669,6 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout,
name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
def conv2d_winograd_weight_transform(kernel, tile_size): def conv2d_winograd_weight_transform(kernel, tile_size):
"""Weight transformation for winograd """Weight transformation for winograd
......
...@@ -27,7 +27,7 @@ from tvm.autotvm.task import get_config ...@@ -27,7 +27,7 @@ from tvm.autotvm.task import get_config
from .. import generic, tag from .. import generic, tag
from .. import nn from .. import nn
from ..util import get_const_tuple, get_shape from ..util import get_const_tuple, get_shape
from ..nn.conv2d import conv2d, conv2d_NCHWc, \ from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_NCHWc_int8, \
conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw
...@@ -77,7 +77,6 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth ...@@ -77,7 +77,6 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth
else: else:
conv2d_avx_common._fallback_schedule(cfg, wkl) conv2d_avx_common._fallback_schedule(cfg, wkl)
def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
"""Create schedule configuration from input arguments""" """Create schedule configuration from input arguments"""
dshape = get_const_tuple(data.shape) dshape = get_const_tuple(data.shape)
...@@ -92,11 +91,6 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): ...@@ -92,11 +91,6 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
elif pat.match(layout) is not None: elif pat.match(layout) is not None:
n, ic_chunk, h, w, ic_bn = dshape n, ic_chunk, h, w, ic_bn = dshape
target = tvm.target.current_target(allow_none=False) target = tvm.target.current_target(allow_none=False)
if _is_int8_hw_support(data.dtype, kernel.dtype, target):
oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape
ic = ic_chunk*ic_bn
assert ic == k_ic*k_ic_f*k_ic_s
else:
oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape
assert ic_chunk == k_ic_chunk assert ic_chunk == k_ic_chunk
assert ic_bn == k_ic_bn assert ic_bn == k_ic_bn
...@@ -105,6 +99,7 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): ...@@ -105,6 +99,7 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
else: else:
raise ValueError("Not support this layout {} with " raise ValueError("Not support this layout {} with "
"schedule template.".format(layout)) "schedule template.".format(layout))
is_kernel_1x1 = kh == 1 and kw == 1 is_kernel_1x1 = kh == 1 and kw == 1
ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding)
sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
...@@ -444,14 +439,25 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): ...@@ -444,14 +439,25 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
in_channel//ic_bn, ic_bn//n_elems, n_elems)) in_channel//ic_bn, ic_bn//n_elems, n_elems))
kernel_OIHWioe = F.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6)) kernel_OIHWioe = F.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6))
copy_inputs = [data_expr, kernel_OIHWioe] copy_inputs = [data_expr, kernel_OIHWioe]
# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, kh, kw, oc_bn, # Store altered operator's config. New kernel layout OIHWio4
in_channel//ic_bn, ic_bn//n_elems, new_kernel = tvm.placeholder((out_channel // oc_bn,
n_elems)) in_channel // ic_bn,
new_workload = autotvm.task.args_to_workload( kh,
[new_data, new_kernel, strides, padding, dilation, kw,
new_attrs[layout_name], new_attrs['out_layout'], out_dtype], ic_bn // n_elems,
conv2d_NCHWc) oc_bn,
n_elems), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload([new_data,
new_kernel,
strides,
padding,
dilation,
new_attrs[layout_name],
new_attrs['out_layout'],
out_dtype],
conv2d_NCHWc_int8)
dispatch_ctx.update(target, new_workload, cfg) dispatch_ctx.update(target, new_workload, cfg)
if F.__name__ == 'nnvm.symbol': if F.__name__ == 'nnvm.symbol':
logging.warning("Use native layout for int8 convolution on NNVM.") logging.warning("Use native layout for int8 convolution on NNVM.")
......
...@@ -17,30 +17,108 @@ ...@@ -17,30 +17,108 @@
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member # pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Conv2D int8 schedule on x86""" """Conv2D int8 schedule on x86"""
import re
import tvm import tvm
from tvm import autotvm from tvm import autotvm
from tvm.autotvm.task import get_config
from tvm.autotvm.task.topi_integration import deserialize_args
from .. import generic, tag from .. import generic, tag
from ..util import get_const_tuple from ..util import get_const_tuple
from ..nn.conv2d import conv2d_NCHWc_int8 from ..nn.conv2d import conv2d_NCHWc_int8
from .. import nn from .. import nn
from .conv2d import _get_default_config
from . import conv2d_avx_1x1, conv2d_avx_common from . import conv2d_avx_1x1, conv2d_avx_common
def _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, layout):
"""Create schedule configuration from input arguments"""
dshape = get_const_tuple(data.shape)
kshape = get_const_tuple(kernel.shape)
pat = re.compile(r'NCHW.+(\d+)c')
if layout == 'NCHW':
n, ic, h, w = dshape
oc, _, kh, kw = kshape
elif layout == 'NHWC':
n, h, w, ic = dshape
kh, kw, oc, _ = kshape
elif pat.match(layout) is not None:
n, ic_chunk, h, w, ic_bn = dshape
target = tvm.target.current_target(allow_none=False)
oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape
ic = ic_chunk * ic_bn
assert ic == k_ic * k_ic_f * k_ic_s
oc = oc_chunk*oc_bn
else:
raise ValueError("Not support this layout {} with "
"schedule template.".format(layout))
is_kernel_1x1 = kh == 1 and kw == 1
ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding)
sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
oh = (h - kh + 2 * ph) // sh + 1
ow = (w - kw + 2 * pw) // sw + 1
# Create schedule config
cfg.define_split('tile_ic', ic, num_outputs=2, filter=lambda y: y.size[-1] % 4 == 0)
cfg.define_split('tile_oc', oc, num_outputs=2, filter=lambda y: y.size[-1] % 16 == 0)
cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64)
if is_kernel_1x1:
cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1])
else:
cfg.define_knob("unroll_kw", [True, False])
# Define template function for autotvm task
# We define schedule template in this function instead of
# declaration function since actual input arguments need
# to be altered by the schedule selected.
@autotvm.task.register("topi_x86_conv2d_NCHWc_int8")
def _topi_nn_conv2d_NCHWc_int8(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
if len(args) == 7:
data, kernel, strides, padding, dilation, origin_layout, dtype = args
else:
assert len(args) == 8
data, kernel, strides, padding, dilation, origin_layout, out_layout, dtype = args
raw_data_shape = get_const_tuple(data.shape)
raw_kernel_shape = get_const_tuple(kernel.shape)
# get config here
cfg = get_config()
_create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, origin_layout)
# change shape with the value in config
ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
cfg["tile_ow"].size[-1])
data_layout = "NCHW%dc" % ic_bn
out_layout = "NCHW%dc" % oc_bn
# Set up the new shape for data and kernel
new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn,
raw_data_shape[2], raw_data_shape[3], ic_bn)
n_elems = 4
new_kernel_shape = (raw_kernel_shape[0] // oc_bn,
raw_kernel_shape[1] // ic_bn,
raw_kernel_shape[2],
raw_kernel_shape[3],
ic_bn // n_elems,
oc_bn,
n_elems)
new_data = tvm.placeholder(new_data_shape, data.dtype)
new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
C = _declaration_conv_NCHWc_int8(cfg, new_data, new_kernel, strides, padding, dilation,
data_layout, out_layout, dtype)
s = _schedule_conv2d_NCHWc_int8(cfg, [C])
return s, [new_data, new_kernel, C]
@autotvm.register_topi_compute(conv2d_NCHWc_int8, 'cpu', 'direct') @autotvm.register_topi_compute(conv2d_NCHWc_int8, 'cpu', 'direct')
def _declaration_conv_NCHWc_int8(cfg, data, kernel, strides, def _declaration_conv_NCHWc_int8(cfg, data, kernel, strides,
padding, dilation, layout, out_layout, out_dtype): padding, dilation, layout, out_layout, out_dtype):
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn
oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = \
get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn
# If config is not set, we can reuse the default config for NCHW.
if cfg.is_fallback:
_get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width),
dtype=kernel.dtype),
strides, padding, out_dtype)
return nn.conv2d_NCHWc_int8_compute(data, return nn.conv2d_NCHWc_int8_compute(data,
kernel, kernel,
strides, strides,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment