Commit d53affdc by Yizhi Liu Committed by Tianqi Chen

AVX schedule for conv_NCHW[x]c (#1143)

* add conv2d_NCHWc compute template

* add conv_NCHWc compute decl and schedules

* allow default avx schedule

* fix lint

* remove unused schedule

* remove import nnvm.reg

* pass schedule object to compute and schedule
parent 3d74b48f
...@@ -112,6 +112,23 @@ def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype=None): ...@@ -112,6 +112,23 @@ def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype=None):
raise ValueError("not support this layout {} yet".format(layout)) raise ValueError("not support this layout {} yet".format(layout))
@tvm.target.generic_func
def conv2d_alter_layout(attrs, inputs, tinfos):
"""Change Conv2D layout.
Parameters
----------
attrs : nnvm.top.AttrDict
Attributes of current convolution
inputs : nnvm.symbol
Grouped input symbols
tinfos : list
Input shape and dtype
"""
# not to change by default
return None
def _get_workload(data, kernel, stride, padding, out_dtype): def _get_workload(data, kernel, stride, padding, out_dtype):
""" Get the workload structure. """ """ Get the workload structure. """
_, CI, IH, IW = [x.value for x in data.shape] _, CI, IH, IW = [x.value for x in data.shape]
...@@ -425,6 +442,44 @@ def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'): ...@@ -425,6 +442,44 @@ def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'):
name="Conv2dOutput", tag="conv2d_nhwc") name="Conv2dOutput", tag="conv2d_nhwc")
return Output return Output
@tvm.target.generic_func
def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride, padding, out_dtype='float32'):
"""Conv2D operator for nChw[x]c layout.
Parameters
----------
data : tvm.Tensor
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
kernel : tvm.Tensor
6-D with shape
[num_filter_chunk, in_channel_chunk, filter_height, filter_width,
in_channel_block, num_filter_block]
num_filter : int
number of filters, i.e., output channel size
kernel_size : tuple of two ints
[kernel_height, kernel_width]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
out_dtype : str
output data type
Returns
-------
output : tvm.Tensor
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
# search platform specific declaration first
# default declaration
raise ValueError("missing register for topi.nn.conv2d_NCHWc")
# map from schedule type to declaration function # map from schedule type to declaration function
_SCH_TO_DECL_FUNC = { _SCH_TO_DECL_FUNC = {
SpatialPack: _spatial_pack, SpatialPack: _spatial_pack,
......
...@@ -4,27 +4,50 @@ import tvm ...@@ -4,27 +4,50 @@ import tvm
from .. import generic, tag from .. import generic, tag
from .. import nn from .. import nn
from ..nn.util import infer_pad, infer_stride from ..nn.util import infer_pad, infer_stride
from ..nn.conv2d import conv2d, _get_workload, _get_schedule, _WORKLOADS from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, \
_get_workload, _get_schedule, Workload
from . import conv2d_avx_1x1, conv2d_avx_common from . import conv2d_avx_1x1, conv2d_avx_common
from .conv2d_avx_common import AVXConvCommonFwd from .conv2d_avx_common import AVXConvCommonFwd
from .conv2d_avx_1x1 import AVXConv1x1Fwd from .conv2d_avx_1x1 import AVXConv1x1Fwd
_AVX_SCH_TO_DECL_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._declaration_conv,
AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv
}
_AVX_SCH_TO_SCH_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._schedule_conv,
AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv
}
@_get_schedule.register("cpu") @_get_schedule.register("cpu")
def _get_schedule_conv(wkl): def _get_schedule_conv(wkl):
if wkl not in _WORKLOADS: _WORKLOADS_AVX = [
raise ValueError("no schedule for such workload: {}".format(wkl)) # workloads of resnet18_v1 on imagenet
idx = _WORKLOADS.index(wkl) Workload('float32', 'float32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
Workload('float32', 'float32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload('float32', 'float32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload('float32', 'float32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload('float32', 'float32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
# workloads of resnet34_v1 on imagenet, no extra workload required
# workloads of resnet50_v1 on imagenet
Workload('float32', 'float32', 56, 56, 64, 256, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 256, 64, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 256, 128, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 28, 28, 128, 512, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 28, 28, 512, 128, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 28, 28, 512, 256, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 14, 14, 256, 1024, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 28, 28, 512, 1024, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 14, 14, 1024, 256, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 14, 14, 1024, 512, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 7, 7, 512, 2048, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2),
Workload('float32', 'float32', 7, 7, 2048, 512, 1, 1, 0, 0, 1, 1),
# workloads of resnet101_v1 on imagenet, no extra workload required
# workloads of resnet152_v1 on imagenet, no extra workload required
# workloads of resnet18_v2 on imagenet, no extra workload required
# workloads of resnet34_v2 on imagenet, no extra workload required
]
fp32_vec_len = 8 fp32_vec_len = 8
target = tvm.target.current_target(allow_none=False) target = tvm.target.current_target(allow_none=False)
...@@ -32,43 +55,61 @@ def _get_schedule_conv(wkl): ...@@ -32,43 +55,61 @@ def _get_schedule_conv(wkl):
if opt == '-mcpu=skylake-avx512': if opt == '-mcpu=skylake-avx512':
fp32_vec_len = 16 fp32_vec_len = 16
_SCHEDULES_AVX_NCHW = [ _SCHEDULES_AVX = [
# float32 resnet-18 # workloads of resnet18_v1 on imagenet
AVXConvCommonFwd(3, fp32_vec_len, 28, False), AVXConvCommonFwd(3, fp32_vec_len, 28, False),
AVXConvCommonFwd(16, fp32_vec_len, 28, False), AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28), AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28),
AVXConvCommonFwd(16, fp32_vec_len, 28, False), AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28), AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28),
AVXConvCommonFwd(16, fp32_vec_len, 28, False), AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False),
AVXConvCommonFwd(16, fp32_vec_len, 14, False), AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, False),
AVXConv1x1Fwd(16, fp32_vec_len, 2, 14), AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConvCommonFwd(16, fp32_vec_len, 14, True), AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, True),
AVXConvCommonFwd(16, 32, 7, True), AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 7), AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 7),
AVXConvCommonFwd(16, fp32_vec_len, 7, True), AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True),
# float32 mobilenet # workloads of resnet34_v1 on imagenet, no extra workload required
AVXConvCommonFwd(3, fp32_vec_len, 28, False), # workloads of resnet50_v1 on imagenet
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28), AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28), AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28), AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28), AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 28), AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 2, 14), AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28),
AVXConv1x1Fwd(16, fp32_vec_len, 2, 14), AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 7), AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(16, fp32_vec_len, 1, 7), AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7),
# workloads of resnet101_v1 on imagenet, no extra workload required
# workloads of resnet152_v1 on imagenet, no extra workload required
# workloads of resnet18_v2 on imagenet, no extra workload required
# workloads of resnet34_v2 on imagenet, no extra workload required
] ]
sch = _SCHEDULES_AVX_NCHW[idx] if wkl not in _WORKLOADS_AVX:
if wkl.hkernel == 1 and wkl.wkernel == 1:
return conv2d_avx_1x1._get_default_schedule(wkl, fp32_vec_len)
return conv2d_avx_common._get_default_schedule(wkl, fp32_vec_len)
idx = _WORKLOADS_AVX.index(wkl)
sch = _SCHEDULES_AVX[idx]
return sch return sch
@conv2d.register("cpu") @conv2d.register("cpu")
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
_AVX_SCH_TO_DECL_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._declaration_conv,
AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv
}
out_dtype = data.dtype if out_dtype is None else out_dtype out_dtype = data.dtype if out_dtype is None else out_dtype
target = tvm.target.current_target(allow_none=False) target = tvm.target.current_target(allow_none=False)
wkl = _get_workload(data, kernel, stride, padding, out_dtype) wkl = _get_workload(data, kernel, stride, padding, out_dtype)
if wkl in _WORKLOADS and 'avx' in str(target) and layout == 'NCHW': if 'avx' in str(target) and layout == 'NCHW':
sch = _get_schedule(wkl) sch = _get_schedule(wkl)
return _AVX_SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, layout, out_dtype) return _AVX_SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, layout, out_dtype)
elif layout == 'NCHW': elif layout == 'NCHW':
...@@ -81,9 +122,63 @@ def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): ...@@ -81,9 +122,63 @@ def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
raise ValueError("not support this layout {} yet".format(layout)) raise ValueError("not support this layout {} yet".format(layout))
@conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfos):
import nnvm.symbol as sym
copy_inputs = [s for s in inputs]
new_attrs = {k : attrs[k] for k in attrs.keys()}
# only optimize for NCHW, groups=1 conv
if attrs['layout'] != 'NCHW' or attrs.get_int("groups") != 1:
return None
data = tinfos[0]
kernel = tinfos[1]
import ast
padding = ast.literal_eval(attrs['padding'])
stride = ast.literal_eval(attrs['strides'])
wkl = _get_workload(data, kernel, stride, padding, data.dtype)
sch = _get_schedule_conv(wkl)
is_kernel_1x1 = isinstance(sch, AVXConv1x1Fwd)
ic_bn, oc_bn = sch.ic_bn, sch.oc_bn
new_attrs['layout'] = 'NCHW%dc' % ic_bn
new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
if is_kernel_1x1:
# (oc, ic, h, w) -> (OC, IC, ic, oc, h, w)
new_attrs['kernel_layout'] = 'OI%di%doHW' % (ic_bn, oc_bn)
else:
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
@conv2d_NCHWc.register("cpu")
def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride, padding, out_dtype):
_AVX_SCH_TO_DECL_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._declaration_conv_NCHWc,
AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv_NCHWc
}
n, ic_chunk, h, w, ic_block = [x.value for x in data.shape]
ic = ic_chunk * ic_block
kh, kw = kernel_size
wkl = _get_workload(tvm.placeholder((n, ic, h, w), dtype=out_dtype),
tvm.placeholder((num_filter, ic, kh, kw), dtype=out_dtype),
stride, padding, out_dtype)
sch = _get_schedule(wkl)
return _AVX_SCH_TO_DECL_FUNC[type(sch)](wkl, sch, data, kernel)
@generic.schedule_conv2d_nchw.register(["cpu"]) @generic.schedule_conv2d_nchw.register(["cpu"])
def schedule_conv2d(outs): def schedule_conv2d(outs):
"""Create schedule for tensors""" """Create schedule for tensors"""
_AVX_SCH_TO_SCH_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._schedule_conv,
AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv
}
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
target = tvm.target.current_target(allow_none=False) target = tvm.target.current_target(allow_none=False)
...@@ -213,3 +308,49 @@ def schedule_conv2d_nhwc(outs): ...@@ -213,3 +308,49 @@ def schedule_conv2d_nhwc(outs):
traverse(output_op) traverse(output_op)
return s return s
@generic.schedule_conv2d_NCHWc.register(["cpu"])
def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, outs):
"""Create schedule for tensors"""
_AVX_SCH_TO_SCH_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc,
AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc
}
s = tvm.create_schedule([x.op for x in outs])
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
if 'conv2d_NCHWc' in op.tag:
conv_out = op.output(0)
kernel = conv_out.op.input_tensors[1]
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0] \
if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \
else data_vec
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
n, ic_chunk, h, w, ic_block = [x.value for x in data.shape]
ic = ic_chunk * ic_block
original_data = tvm.placeholder((n, ic, h, w), dtype=conv_out.dtype)
kh, kw = kernel_size
original_kernel = tvm.placeholder((num_filter, ic, kh, kw), dtype=conv_out.dtype)
wkl = _get_workload(original_data, original_kernel, stride, padding, conv_out.dtype)
sch = _get_schedule(wkl)
_AVX_SCH_TO_SCH_FUNC[type(sch)](s, wkl, sch, data_vec,
kernel, conv_out, outs[0])
traverse(outs[0].op)
return s
# pylint: disable=invalid-name,unused-variable,invalid-name # pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name
"""1x1 Conv2D schedule on for Intel CPU""" """1x1 Conv2D schedule on for Intel CPU"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from collections import namedtuple from collections import namedtuple
...@@ -11,6 +11,34 @@ from ..nn.pad import pad ...@@ -11,6 +11,34 @@ from ..nn.pad import pad
AVXConv1x1Fwd = namedtuple('AVXConv1x1Fwd', ['ic_bn', 'oc_bn', 'oh_factor', 'ow_factor']) AVXConv1x1Fwd = namedtuple('AVXConv1x1Fwd', ['ic_bn', 'oc_bn', 'oh_factor', 'ow_factor'])
def _get_default_schedule(wkl, simd_width):
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
oc_bn = 1
for bn in range(simd_width, 0, -1):
if wkl.out_filter % bn == 0:
oc_bn = bn
break
ic_bn = 1
for bn in range(oc_bn, 0, -1):
if wkl.in_filter % bn == 0:
ic_bn = bn
break
for ow_factor in range(out_width, 0, -1):
if out_width % ow_factor == 0:
for oh_factor in range(out_height, 0, -1):
if out_height % oh_factor == 0 and ow_factor * oh_factor < 32:
return AVXConv1x1Fwd(ic_bn, oc_bn, oh_factor, ow_factor)
raise ValueError("cannot decide default schedule for workload: {}".format(wkl))
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
assert layout == 'NCHW', "only support NCHW convolution for AVX" assert layout == 'NCHW', "only support NCHW convolution for AVX"
wkl = _get_workload(data, kernel, stride, padding, out_dtype) wkl = _get_workload(data, kernel, stride, padding, out_dtype)
...@@ -124,3 +152,80 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou ...@@ -124,3 +152,80 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou
s[O].parallel(parallel_axis) s[O].parallel(parallel_axis)
return s return s
def _declaration_conv_NCHWc(wkl, sch, data, kernel):
out_dtype = wkl.out_dtype
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
batch_size = data.shape[0]
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
DOPAD = (HPAD != 0 and WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
else:
data_pad = data
oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn)
ic = tvm.reduce_axis((0, wkl.in_filter), name='ic')
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic//sch.ic_bn, oh*HSTR, ow*WSTR, ic%sch.ic_bn]
.astype(out_dtype) *
kernel[oc_chunk, ic // sch.ic_bn, ic % sch.ic_bn, oc_block, 0, 0],
axis=[ic]), name='conv2d_NCHWc', tag='conv2d_NCHWc')
return conv
def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last):
# schedule data
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
parallel_axis = s[A].fuse(ic_chunk, ih)
s[A].parallel(parallel_axis)
C, O = conv_out, last
CC = s.cache_write(C, 'global')
batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[C].split(ow, factor=sch.ow_factor)
s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
s[C].vectorize(oc_block)
parallel_axis = s[C].fuse(oc_chunk, oh_outer)
s[CC].compute_at(s[C], parallel_axis)
if C == O:
s[C].parallel(parallel_axis)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
ic, = s[CC].op.reduce_axis
ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn)
oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor)
s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_chunk, ic_block, oh_inner, ow_inner, oc_block)
s[CC].fuse(oc_chunk, oh_outer)
s[CC].vectorize(oc_block)
s[CC].unroll(ow_inner)
s[CC].unroll(oh_inner)
if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
return s
# pylint: disable=invalid-name,unused-variable,invalid-name # pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name
"""Conv2D schedule on for Intel CPU""" """Conv2D schedule on for Intel CPU"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from collections import namedtuple from collections import namedtuple
...@@ -11,6 +11,34 @@ from ..nn.pad import pad ...@@ -11,6 +11,34 @@ from ..nn.pad import pad
AVXConvCommonFwd = namedtuple('AVXConvCommonFwd', ['ic_bn', 'oc_bn', 'reg_n', 'unroll_kw']) AVXConvCommonFwd = namedtuple('AVXConvCommonFwd', ['ic_bn', 'oc_bn', 'reg_n', 'unroll_kw'])
def _get_default_schedule(wkl, simd_width):
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
oc_bn = 1
for bn in range(simd_width, 0, -1):
if wkl.out_filter % bn == 0:
oc_bn = bn
break
ic_bn = 1
for bn in range(oc_bn, 0, -1):
if wkl.in_filter % bn == 0:
ic_bn = bn
break
reg_n = 1
for n in range(31, 0, -1):
if out_width % n == 0:
reg_n = n
break
return AVXConvCommonFwd(ic_bn, oc_bn, reg_n, False)
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
out_dtype = data.dtype if out_dtype is None else out_dtype out_dtype = data.dtype if out_dtype is None else out_dtype
assert layout == 'NCHW', "only support NCHW convolution for AVX" assert layout == 'NCHW', "only support NCHW convolution for AVX"
...@@ -141,3 +169,83 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou ...@@ -141,3 +169,83 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou
s[O].parallel(parallel_axis) s[O].parallel(parallel_axis)
return s return s
def _declaration_conv_NCHWc(wkl, sch, data, kernel):
out_dtype = wkl.out_dtype
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
batch_size = data.shape[0]
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
# pack data
DOPAD = (HPAD != 0 and WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
else:
data_pad = data
# convolution
oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn)
ic = tvm.reduce_axis((0, wkl.in_filter), name='ic')
kh = tvm.reduce_axis((0, wkl.hkernel), name='kh')
kw = tvm.reduce_axis((0, wkl.wkernel), name='kw')
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_pad[n, ic//sch.ic_bn, oh*HSTR+kh, ow*WSTR+kw, ic%sch.ic_bn]
.astype(out_dtype) *
kernel[oc_chunk, ic//sch.ic_bn, kh, kw, ic%sch.ic_bn, oc_block],
axis=[ic, kh, kw]), name='conv2d_NCHWc', tag="conv2d_NCHWc")
return conv
def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last):
# schedule data
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
parallel_axis = s[A].fuse(ic_chunk, ih)
s[A].parallel(parallel_axis)
# schedule 5-D NCHW[x]c conv
C, O = conv_out, last
CC = s.cache_write(C, 'global')
_, oc_chunk, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=sch.reg_n)
s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[C].fuse(oc_chunk, oh)
s[C].vectorize(oc_block)
if C == O:
s[C].parallel(parallel_axis)
s[CC].compute_at(s[C], ow_chunk)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
ic, kh, kw = s[CC].op.reduce_axis
ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n)
ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn)
if sch.unroll_kw:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, ic_block, kw, ow_block, oc_block)
s[CC].unroll(kw)
else:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, kw, ic_block, ow_block, oc_block)
s[CC].vectorize(oc_block)
s[CC].unroll(ow_block)
if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment