Commit eb220d92 by Animesh Jain Committed by Yao Wang

Refactoring x86 conv2d_NCHWc (#3944)

parent 195973c0
...@@ -548,6 +548,34 @@ reg.register_pattern("nn.contrib_conv2d_NCHWc", ...@@ -548,6 +548,34 @@ reg.register_pattern("nn.contrib_conv2d_NCHWc",
OpPattern.OUT_ELEMWISE_FUSABLE) OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("nn.contrib_conv2d_NCHWc_int8")
def compute_contrib_conv2d_NCHWc_int8(attrs, inputs, out_dtype, target):
"""Compute definition of conv2d NCHWc"""
# pylint: disable=assignment-from-no-return
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
data_layout = attrs.get_str("data_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.get_str("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
out = topi.nn.conv2d_NCHWc_int8(inputs[0], inputs[1], strides, padding, dilation,
data_layout, out_layout, out_dtype)
return [out]
@reg.register_schedule("nn.contrib_conv2d_NCHWc_int8")
def schedule_contrib_conv2d_NCHWc_int8(attrs, outs, target):
"""Schedule definition of contrib_conv2d_NCHWc_int8"""
with target:
return topi.generic.schedule_conv2d_NCHWc_int8(outs)
reg.register_pattern("nn.contrib_conv2d_NCHWc_int8",
OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc") @reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc")
def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target): def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target):
"""Compute definition of depthwise conv2d NCHWc""" """Compute definition of depthwise conv2d NCHWc"""
......
...@@ -1340,6 +1340,72 @@ def contrib_depthwise_conv2d_nchwc(data, ...@@ -1340,6 +1340,72 @@ def contrib_depthwise_conv2d_nchwc(data,
groups, channels, kernel_size, data_layout, groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype) kernel_layout, out_layout, out_dtype)
def contrib_conv2d_nchwc_int8(data,
kernel,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW8c",
kernel_layout="OIHW",
out_layout="",
out_dtype=""):
r"""Variant of 2D convolution. It deals with only int8 inputs.
This operator takes the weight as the convolution kernel
and convolves it with data to produce an output, following a specialized
NCHWc data layout.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
kernel : tvm.relay.Expr
The kernel expressions.
strides : tuple of int, optional
The strides of convolution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_conv2d_NCHWc_int8(data, kernel, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
def contrib_conv2d_winograd_weight_transform(weight, def contrib_conv2d_winograd_weight_transform(weight,
tile_size): tile_size):
r"""Weight Transformation part for 2D convolution with winograd algorithm. r"""Weight Transformation part for 2D convolution with winograd algorithm.
......
...@@ -570,6 +570,54 @@ weight transformation in advance. ...@@ -570,6 +570,54 @@ weight transformation in advance.
.set_support_level(10) .set_support_level(10)
.add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel); .add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel);
// Positional relay function to create conv2d NCHWc operator
// used by frontend FFI.
Expr MakeConv2DNCHWcInt8(Expr data,
Expr kernel,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_node<Conv2DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_conv2d_NCHWc_int8");
return CallNode::make(op, {data, kernel}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_NCHWc_int8")
.set_body_typed(MakeConv2DNCHWcInt8);
RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc_int8")
.describe(R"code(Compute conv2d with NCHWc data layout with int8 inputs.
- **data**: Input is 5D packed tensor.
- **weight**: 7D packed tensor.
- **out**: Output is 5D packed tensor
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.Conv2D")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2DNCHWcInt8", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DAttrs>);
// Positional relay function to create conv2d NCHWc operator // Positional relay function to create conv2d NCHWc operator
// used by frontend FFI. // used by frontend FFI.
......
...@@ -108,6 +108,25 @@ def schedule_conv2d_NCHWc(outs): ...@@ -108,6 +108,25 @@ def schedule_conv2d_NCHWc(outs):
@tvm.target.generic_func @tvm.target.generic_func
def schedule_conv2d_NCHWc_int8(outs):
"""Schedule for conv2d_NCHW[x]c_int8
Parameters
----------
outs : Array of Tensor
The computation graph description of conv2d_NCHWc_int8
in the format of an array of tensors.
The number of filter, i.e., the output channel.
Returns
-------
sch : Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_conv2d_winograd_weight_transform(outs): def schedule_conv2d_winograd_weight_transform(outs):
"""Schedule for weight transformation of winograd """Schedule for weight transformation of winograd
......
...@@ -398,27 +398,75 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou ...@@ -398,27 +398,75 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
output : tvm.Tensor output : tvm.Tensor
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
""" """
# search platform specific declaration first
# default declaration return conv2d_NCHWc_compute(data,
kernel,
stride,
padding,
dilation,
layout,
out_layout,
out_dtype)
def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_layout, out_dtype):
"""Conv2D operator compute for nChw[x]c layout.
Parameters
----------
data : tvm.Tensor
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
kernel : tvm.Tensor
6-D with shape
[num_filter_chunk, in_channel_chunk, filter_height, filter_width,
in_channel_block, num_filter_block]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
layout : str
Input data layout
out_layout : str
Output data layout
out_dtype : str
output data type
Returns
-------
output : tvm.Tensor
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
# layout and out_layout are not used here, # layout and out_layout are not used here,
# we keep them for debug convenience when dumping autotvm workload # we keep them for debug convenience when dumping autotvm workload
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding)
(dilated_kernel_h, HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
dilated_kernel_w)) dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \
HPAD = pad_top + pad_down else (dilation, dilation)
WPAD = pad_left + pad_right
HSTR, WSTR = stride if isinstance(stride, (tuple, list)) else (stride, stride)
dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
assert (dh, dw) == (1, 1), "Does not support dilation"
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn in_channel = ic_chunk * ic_bn
oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) target = tvm.target.current_target(allow_none=False)
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \
get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn num_filter = oc_chunk * oc_bn
groups = ic_chunk // ic_chunk_group
dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
# output shape # output shape
out_height = (ih + 2 * HPAD - kernel_height) // HSTR + 1 out_height = (ih + 2 * HPAD - dilated_kernel_h) // HSTR + 1
out_width = (iw + 2 * WPAD - kernel_width) // WSTR + 1 out_width = (iw + 2 * WPAD - dilated_kernel_w) // WSTR + 1
oshape = (n, oc_chunk, out_height, out_width, oc_bn) oshape = (n, oc_chunk, out_height, out_width, oc_bn)
# DOPAD # DOPAD
...@@ -433,13 +481,194 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou ...@@ -433,13 +481,194 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
kw = tvm.reduce_axis((0, kernel_width), name='kw') kw = tvm.reduce_axis((0, kernel_width), name='kw')
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw, tvm.sum(data_pad[n,
ic%ic_bn].astype(out_dtype) * ic // ic_bn,
kernel[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, oc_block], oh * HSTR + kh * dilation_h,
ow * WSTR + kw * dilation_w,
ic % ic_bn].astype(out_dtype)
* kernel[oc_chunk,
ic // ic_bn,
kh,
kw,
ic % ic_bn,
oc_block],
axis=[ic, kh, kw]), axis=[ic, kh, kw]),
name='conv2d_NCHWc', tag="conv2d_NCHWc") name='conv2d_NCHWc', tag="conv2d_NCHWc")
@tvm.target.generic_func
def conv2d_NCHWc_int8(data, kernel, strides, padding, dilation, layout, out_layout,
out_dtype='int32'):
"""Conv2D operator for nChw[x]c layout.
Parameters
----------
data : tvm.Tensor
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
kernel : tvm.Tensor
7-D with shape
[num_filter_chunk, in_channel_chunk, filter_height, filter_width, in_channel_block/4,
num_filter_block, 4]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
layout : str
Input data layout
out_layout : str
Output data layout
out_dtype : str
output data type
Returns
-------
output : tvm.Tensor
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
return conv2d_NCHWc_int8_compute(data,
kernel,
strides,
padding,
dilation,
layout,
out_layout,
out_dtype)
def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout, out_layout,
out_dtype='int32'):
"""Conv2D operator for nChw[x]c layout.
Parameters
----------
data : tvm.Tensor
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
kernel : tvm.Tensor
7-D with shape
[num_filter_chunk, in_channel_chunk, filter_height, filter_width, in_channel_block/4,
num_filter_block, 4]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
layout : str
Input data layout
out_layout : str
Output data layout
out_dtype : str
output data type
Returns
-------
output : tvm.Tensor
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
# layout and out_layout are not used here,
# we keep them for debug convenience when dumping autotvm workload
HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding)
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \
else (dilation, dilation)
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn
target = tvm.target.current_target(allow_none=False)
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \
get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn
groups = ic_chunk // ic_chunk_group
# Since the weight is 7-D and the last element size is 4, we have to
# check ic_bn should be a multiple of 4.
# Similary, oc_bn has to be a multiple of 4.
assert ic_bn % 4 == 0
assert oc_bn % 16 == 0
dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
# output shape
out_height = (ih + 2 * HPAD - dilated_kernel_h) // HSTR + 1
out_width = (iw + 2 * WPAD - dilated_kernel_w) // WSTR + 1
oshape = (n, oc_chunk, out_height, out_width, oc_bn)
# DOPAD
DOPAD = (HPAD != 0 or WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
else:
data_pad = data
ic = tvm.reduce_axis((0, in_channel), name='ic')
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')
if groups == 1:
n_elems = 4
ic_outer = tvm.reduce_axis((0, in_channel//ic_bn), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n,
ic_outer,
oh * HSTR + kh * dilation_h,
ow * WSTR + kw * dilation_w,
ic_f_inner * n_elems + ic_s_inner].astype(out_dtype)
* kernel[oc_chunk,
ic_outer,
kh,
kw,
ic_f_inner,
oc_block,
ic_s_inner].astype(out_dtype),
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
# for int8 group conv support
n_elems = 4
ic_chunk = in_channel//ic_bn
ic_outer = tvm.reduce_axis((0, ic_chunk//groups), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
oshape = (n, oc_chunk, out_height, out_width, oc_bn)
return tvm.compute(oshape, lambda n, occ, oh, ow, oc_block:
tvm.sum(data_pad[n,
(occ * oc_bn // (oc_chunk * oc_bn // groups))
* (ic_chunk // groups) + ic_outer,
oh * HSTR + kh,
ow * WSTR + kw,
ic_f_inner * n_elems + ic_s_inner].astype(out_dtype)
* kernel[occ,
ic_outer,
kh,
kw,
ic_f_inner,
oc_block,
ic_s_inner].astype(out_dtype),
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
def conv2d_winograd_weight_transform(kernel, tile_size): def conv2d_winograd_weight_transform(kernel, tile_size):
"""Weight transformation for winograd """Weight transformation for winograd
......
...@@ -6,6 +6,7 @@ from .conv2d import schedule_conv2d, schedule_conv2d_nhwc ...@@ -6,6 +6,7 @@ from .conv2d import schedule_conv2d, schedule_conv2d_nhwc
from .binarize_pack import schedule_binarize_pack from .binarize_pack import schedule_binarize_pack
from .binary_dense import schedule_binary_dense from .binary_dense import schedule_binary_dense
from .nn import * from .nn import *
from .conv2d_int8 import *
from .injective import * from .injective import *
from .pooling import schedule_pool, schedule_adaptive_pool from .pooling import schedule_pool, schedule_adaptive_pool
from .bitserial_conv2d import schedule_bitserial_conv2d from .bitserial_conv2d import schedule_bitserial_conv2d
......
...@@ -263,58 +263,6 @@ def schedule_conv2d(cfg, outs): ...@@ -263,58 +263,6 @@ def schedule_conv2d(cfg, outs):
traverse(outs[0].op) traverse(outs[0].op)
return s return s
@autotvm.register_topi_schedule(generic.schedule_conv2d_nhwc_pack, 'cpu', ['direct'])
def schedule_conv2d_nhwc_pack(cfg, outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
output_op = outs[0].op
scheduled_ops = []
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
else: # inject custom schedule
if len(op.axis) == 4: # schedule bias + bn + relu
n, h, w, c = op.axis
fused = s[op].fuse(n, h, w)
s[op].parallel(fused)
s[op].vectorize(c)
for tensor in op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d_nhwc_pack_int8' in op.tag:
conv_out = op.output(0)
kernel = conv_out.op.input_tensors[1]
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0] \
if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \
else data_vec
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
args = [s, cfg, data_vec, conv_out, outs[0]]
if data.dtype == 'uint8':
kh, kw, _, _, _ = get_const_tuple(kernel.shape)
if kh == 1 and kw == 1:
conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args)
else:
raise ValueError("Only support 1x1 kernel with "
"schedule_conv2d_nhwc_pack.")
else:
raise ValueError("Not support this data type {} with "
"schedule_conv2d_nhwc_pack. Only support int8".format(data.dtype))
scheduled_ops.append(op)
traverse(output_op)
return s
@generic.schedule_conv2d_nhwc.register("cpu") @generic.schedule_conv2d_nhwc.register("cpu")
def schedule_conv2d_nhwc(outs): def schedule_conv2d_nhwc(outs):
"""Create schedule for tensors""" """Create schedule for tensors"""
...@@ -477,7 +425,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): ...@@ -477,7 +425,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc) new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc)
dispatch_ctx.update(target, new_workload, cfg) dispatch_ctx.update(target, new_workload, cfg)
else: if F.__name__ == 'nnvm.symbol':
logging.warning("Use native layout for depthwise convolution on NNVM.")
return None
return F.nn.contrib_depthwise_conv2d_nchwc(*copy_inputs, **new_attrs)
if _is_int8_hw_support(data.dtype, kernel.dtype, target): if _is_int8_hw_support(data.dtype, kernel.dtype, target):
# Convert kernel data layout from 4D to 7D # Convert kernel data layout from 4D to 7D
n_elems = 4 n_elems = 4
...@@ -501,7 +453,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): ...@@ -501,7 +453,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
new_attrs[layout_name], new_attrs['out_layout'], out_dtype], new_attrs[layout_name], new_attrs['out_layout'], out_dtype],
conv2d_NCHWc) conv2d_NCHWc)
dispatch_ctx.update(target, new_workload, cfg) dispatch_ctx.update(target, new_workload, cfg)
else: if F.__name__ == 'nnvm.symbol':
logging.warning("Use native layout for int8 convolution on NNVM.")
return None
return F.nn.contrib_conv2d_nchwc_int8(*copy_inputs, **new_attrs)
out_channel, _, kh, kw = get_const_tuple(kernel.shape) out_channel, _, kh, kw = get_const_tuple(kernel.shape)
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
...@@ -513,12 +469,6 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): ...@@ -513,12 +469,6 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
new_attrs['out_layout'], out_dtype], conv2d_NCHWc) new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
dispatch_ctx.update(target, new_workload, cfg) dispatch_ctx.update(target, new_workload, cfg)
if is_depthwise:
if F.__name__ == 'nnvm.symbol':
logging.warning("Use native layout for depthwise convolution on NNVM.")
return None
return F.nn.contrib_depthwise_conv2d_nchwc(*copy_inputs, **new_attrs)
else:
if F.__name__ == 'nnvm.symbol': if F.__name__ == 'nnvm.symbol':
return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
...@@ -544,95 +494,27 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, ...@@ -544,95 +494,27 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
padding, dilation, layout, out_layout, out_dtype): padding, dilation, layout, out_layout, out_dtype):
# layout and out_layout are not used here, # layout and out_layout are not used here,
# we keep them for debug convenience when dumping autotvm workload # we keep them for debug convenience when dumping autotvm workload
HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding)
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \
else (dilation, dilation)
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn in_channel = ic_chunk * ic_bn
target = tvm.target.current_target(allow_none=False)
if _is_int8_hw_support(data.dtype, kernel.dtype, target):
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \
get_const_tuple(kernel.shape)
else:
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \ oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \
get_const_tuple(kernel.shape) get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn num_filter = oc_chunk * oc_bn
groups = ic_chunk // ic_chunk_group
dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
# If no config was set, we can fallback to NCHW config.
if cfg.is_fallback: if cfg.is_fallback:
_get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width), tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width),
dtype=kernel.dtype), dtype=kernel.dtype),
strides, padding, out_dtype) strides, padding, out_dtype)
# output shape return nn.conv2d_NCHWc_compute(data,
out_height = (ih + 2 * HPAD - dilated_kernel_h) // HSTR + 1 kernel,
out_width = (iw + 2 * WPAD - dilated_kernel_w) // WSTR + 1 strides,
oshape = (n, oc_chunk, out_height, out_width, oc_bn) padding,
dilation,
# DOPAD layout,
DOPAD = (HPAD != 0 or WPAD != 0) out_layout,
if DOPAD: out_dtype)
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
else:
data_pad = data
ic = tvm.reduce_axis((0, in_channel), name='ic')
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')
if _is_int8_hw_support(data.dtype, kernel.dtype, target) and groups == 1:
assert out_dtype == "int32", \
"INT8 convolution requires input dtype = uint8 and output dtype=int32"
# Intel performs dot product of 2 "4" Int8 values
# Current implementation requires ic_bn to be a multiple of 4
n_elems = 4
assert ic_bn % n_elems == 0
ic_outer = tvm.reduce_axis((0, in_channel//ic_bn), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh*dilation_h,
ow*WSTR+kw*dilation_w,
ic_f_inner * n_elems + ic_s_inner]
.astype(out_dtype) *
kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
if _is_int8_hw_support(data.dtype, kernel.dtype, target):
# for int8 group conv support
n_elems = 4
ic_chunk = in_channel//ic_bn
ic_outer = tvm.reduce_axis((0, ic_chunk//groups), name='ic_outer')
ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner')
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
oshape = (n, oc_chunk, out_height, out_width, oc_bn)
return tvm.compute(oshape, lambda n, occ, oh, ow, oc_block:
tvm.sum(data_pad[n, (occ*oc_bn//(oc_chunk*oc_bn//groups))*\
(ic_chunk//groups)+ic_outer,
oh*HSTR+kh, ow*WSTR+kw,
ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) *
kernel[occ, ic_outer, kh, kw, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
# else: fp implementation
return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh*dilation_h,
ow*WSTR+kw*dilation_w,
ic%ic_bn].astype(out_dtype) *
kernel[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, oc_block],
axis=[ic, kh, kw]),
name='conv2d_NCHWc', tag="conv2d_NCHWc")
@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc, 'cpu', ['direct']) @autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc, 'cpu', ['direct'])
...@@ -664,14 +546,6 @@ def _schedule_conv2d_NCHWc(cfg, outs): ...@@ -664,14 +546,6 @@ def _schedule_conv2d_NCHWc(cfg, outs):
args = [s, cfg, data_vec, conv_out, outs[0]] args = [s, cfg, data_vec, conv_out, outs[0]]
target = tvm.target.current_target(allow_none=False) target = tvm.target.current_target(allow_none=False)
if _is_int8_hw_support(data.dtype, kernel.dtype, target):
# int8 conv kernel is 7-dim
_, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape)
if kh == 1 and kw == 1:
conv2d_avx_1x1._schedule_conv_NCHWc_int8(*args)
else:
conv2d_avx_common._schedule_conv_NCHWc_int8(*args)
else:
_, _, kh, kw, _, _, = get_const_tuple(kernel.shape) _, _, kh, kw, _, _, = get_const_tuple(kernel.shape)
if kh == 1 and kw == 1: if kh == 1 and kw == 1:
conv2d_avx_1x1._schedule_conv_NCHWc(*args) conv2d_avx_1x1._schedule_conv_NCHWc(*args)
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Conv2D int8 schedule on x86"""
import tvm
from tvm import autotvm
from .. import generic, tag
from ..util import get_const_tuple
from ..nn.conv2d import conv2d_NCHWc_int8
from .. import nn
from .conv2d import _get_default_config
from . import conv2d_avx_1x1, conv2d_avx_common
@autotvm.register_topi_compute(conv2d_NCHWc_int8, 'cpu', 'direct')
def _declaration_conv_NCHWc_int8(cfg, data, kernel, strides,
padding, dilation, layout, out_layout, out_dtype):
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn
oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = \
get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn
# If config is not set, we can reuse the default config for NCHW.
if cfg.is_fallback:
_get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width),
dtype=kernel.dtype),
strides, padding, out_dtype)
return nn.conv2d_NCHWc_int8_compute(data,
kernel,
strides,
padding,
dilation,
layout,
out_layout,
out_dtype)
@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc_int8, 'cpu', ['direct'])
def _schedule_conv2d_NCHWc_int8(cfg, outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d_NCHWc_int8' in op.tag:
conv_out = op.output(0)
kernel = conv_out.op.input_tensors[1]
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0] \
if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \
else data_vec
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
args = [s, cfg, data_vec, conv_out, outs[0]]
target = tvm.target.current_target(allow_none=False)
# int8 conv kernel is 7-dim
_, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape)
if kh == 1 and kw == 1:
conv2d_avx_1x1._schedule_conv_NCHWc_int8(*args)
else:
conv2d_avx_common._schedule_conv_NCHWc_int8(*args)
scheduled_ops.append(op)
traverse(outs[0].op)
return s
@autotvm.register_topi_schedule(generic.schedule_conv2d_nhwc_pack, 'cpu', ['direct'])
def schedule_conv2d_nhwc_pack(cfg, outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
output_op = outs[0].op
scheduled_ops = []
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
else: # inject custom schedule
if len(op.axis) == 4: # schedule bias + bn + relu
n, h, w, c = op.axis
fused = s[op].fuse(n, h, w)
s[op].parallel(fused)
s[op].vectorize(c)
for tensor in op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d_nhwc_pack_int8' in op.tag:
conv_out = op.output(0)
kernel = conv_out.op.input_tensors[1]
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0] \
if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \
else data_vec
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
args = [s, cfg, data_vec, conv_out, outs[0]]
if data.dtype == 'uint8':
kh, kw, _, _, _ = get_const_tuple(kernel.shape)
if kh == 1 and kw == 1:
conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args)
else:
raise ValueError("Only support 1x1 kernel with "
"schedule_conv2d_nhwc_pack.")
else:
raise ValueError("Not support this data type {} with "
"schedule_conv2d_nhwc_pack. Only support int8".format(data.dtype))
scheduled_ops.append(op)
traverse(output_op)
return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment