Unverified Commit 8d945872 by Alex Gladkov Committed by GitHub

Optimize x86 conv3d_ndhwc using data packing approach. (#4866)

Add tuneable conv3d_ndhwc schedule
parent 70c63829
......@@ -133,6 +133,7 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul],
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
tvm.relay.op.nn.conv1d_transpose: [topi.nn.conv1d_transpose_ncw],
tvm.relay.op.nn.conv3d: [topi.nn.conv3d],
}
topi_funcs = []
......
......@@ -94,6 +94,7 @@ class TaskExtractEnv:
topi.nn.bitserial_dense: "topi_nn_bitserial_dense",
topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw",
topi.nn.conv1d_transpose_ncw: "topi_nn_conv1d_transpose_ncw",
topi.nn.conv3d: "topi_nn_conv3d",
}
self.topi_to_schedule = {
......@@ -112,6 +113,7 @@ class TaskExtractEnv:
topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense],
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
topi.nn.conv1d_transpose_ncw: [topi.generic.schedule_conv1d_transpose_ncw],
topi.nn.conv3d: [topi.generic.schedule_conv3d_ndhwc],
}
# function reflection for tracing
......@@ -129,6 +131,7 @@ class TaskExtractEnv:
topi.nn.bitserial_dense: lambda x: setattr(topi.nn, 'bitserial_dense', x),
topi.nn.deformable_conv2d_nchw: lambda x: setattr(topi.nn, 'deformable_conv2d_nchw', x),
topi.nn.conv1d_transpose_ncw: lambda x: setattr(topi.nn, 'conv1d_transpose_ncw', x),
topi.nn.conv3d: lambda x: setattr(topi.nn, 'conv3d', x),
}
self.allow_duplicate = allow_duplicate
......@@ -231,6 +234,15 @@ class TaskExtractEnv:
s = topi.generic.schedule_conv1d_transpose_ncw([C])
return s, [A, W, C]
@register("topi_nn_conv3d")
def _topi_nn_conv3d(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.conv3d(*args, **kwargs)
s = topi.generic.schedule_conv3d_ndhwc([C])
return s, [A, W, C]
@register("topi_nn_dense")
def _topi_nn_dense(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
......
......@@ -47,6 +47,42 @@ def infer_pad(data, data_pad):
wpad = (TW - IW) // 2
return get_const_int(hpad), get_const_int(wpad)
def infer_pad3d(data, data_pad, layout):
"""Infer the padding from stages in reverse.
Parameters
----------
data : Tensor
data stage.
data_pad : Tensor
pad stage.
Returns
-------
dpad : int
padding depth
hpad : int
padding height
wpad : int
padding width
"""
if data_pad is None:
return 0, 0, 0
if layout == "NDHWC":
_, ID, IH, IW, _ = data.shape
_, TD, TH, TW, _ = data_pad.shape
elif layout == "NCDHW":
_, _, ID, IH, IW = data.shape
_, _, TD, TH, TW = data_pad.shape
else:
raise ValueError("Layout {} is not supported".format(layout))
dpad = (TD - ID)
hpad = (TH - IH)
wpad = (TW - IW)
return get_const_int(dpad), get_const_int(hpad), get_const_int(wpad)
def infer_stride(data, kernel, out):
"""Infer the stride from stages in reverse.
......
......@@ -17,66 +17,359 @@
# pylint: disable=invalid-name, unused-variable, too-many-locals
# pylint: disable=unused-argument, redefined-builtin, no-else-return
"""Conv3D operators"""
from collections import namedtuple
import tvm
from .. import generic, tag
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from .. import generic
from ..util import traverse_inline
from ..nn.conv3d import conv3d, conv3d_ncdhw
from ..nn.util import get_pad_tuple3d, infer_pad3d
from ..nn.pad import pad
from ..util import get_const_tuple, simplify, get_const_int
from .util import get_fp32_len
@generic.schedule_conv3d_ndhwc.register("cpu")
def schedule_conv3d_ndhwc(outs):
"""TOPI schedule callback for conv3d
Workload3D = namedtuple('Workload',
['in_dtype', 'out_dtype', 'depth', 'height', 'width',
'in_filter', 'groups', 'out_filter', 'dkernel',
'hkernel', 'wkernel', 'dpad', 'hpad', 'wpad',
'dstride', 'hstride', 'wstride'])
@autotvm.register_topi_compute(conv3d, 'cpu', ['direct'])
def _declaration_conv3d(cfg, data, kernel, strides, padding, dilation,
layout, out_dtype):
"""3D convolution forward operator.
Parameters
----------
input : tvm.Tensor
5-D input data with shapes:
[batch, in_channel, in_depth, in_height, in_width] for NCDHW layout
[batch, in_depth, in_height, in_width, in_channel] for NDHWC layout
filter : tvm.Tensor
5-D filter with shape [kernel_depth, kernel_height, kernel_width, in_channels, out_channels]
strides : int or a list/tuple of three ints
stride size, or [stride_depth, stride_height, stride_width]
padding : int or a list/tuple of three ints
padding size, or [pad_depth, pad_height, pad_width]
dilation: int or a list/tuple of three ints
dilation size, or [dilation_depth, dilation_height, dilation_width]
layout : str
layout of data
Returns
-------
output : tvm.Tensor
5-D with shape [batch, out_depth, out_height, out_width, out_channel] for NDHWC layout
5-D with shape [batch, out_channel, out_depth, out_height, out_width] for NCDHW layout
"""
out_dtype = data.dtype if out_dtype is None else out_dtype
strides = strides if isinstance(strides, (tuple, list)) else (strides, strides, strides)
dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation, dilation)
if layout == 'NDHWC':
_create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout)
if cfg.is_fallback:
_get_default_config(cfg, data, kernel, strides, padding, out_dtype, layout)
return _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
elif layout == 'NCDHW':
return conv3d_ncdhw(data, kernel, strides, padding, dilation, out_dtype)
raise ValueError("Layout {} is not supported".format(layout))
@autotvm.register_topi_schedule(generic.schedule_conv3d_ndhwc, 'cpu', ['direct'])
def schedule_conv3d_ndhwc(cfg, outs):
"""TOPI schedule callback for conv3d
Parameters
----------
outs: Array of Tensor
The computation graph description of conv3d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv3d.
"""
s = tvm.create_schedule([x.op for x in outs])
output_op = outs[0].op
def _traverse(op):
"""Traverse operators from computation graph"""
if op in s.outputs and tag.is_broadcast(op.tag) and len(op.axis) == 5:
# schedule bias + bn + relu
n, d, h, w, c = op.axis
fused = s[op].fuse(n, d, h, w)
s[op].parallel(fused)
s[op].vectorize(c)
if 'conv3d_ndhwc' in op.tag:
conv = op.output(0)
kernel = op.input_tensors[1]
# dilation stage
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
# padding stage
data = op.input_tensors[0]
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
# fuse pad h and w
data_pad = data
data = data_pad.op.input_tensors[0]
_, _, h_pad, w_pad, _ = data_pad.op.axis
pad_fused = s[data_pad].fuse(h_pad, w_pad)
s[data_pad].parallel(pad_fused)
# compute conv
C = conv
n, d, h, w, c = s[C].op.axis
s[C].vectorize(c)
if op != output_op: # fuse bias + bn + activation
_, _, _, _, c_out = output_op.axis
s[C].compute_at(s[output_op], c_out)
kd, kh, kw, i, o = get_const_tuple(kernel.shape)
args = [s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, outs[0]]
_schedule_conv3d_ndhwc(*args)
traverse_inline(s, outs[0].op, _traverse)
return s
def _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
out_dtype = data.dtype if out_dtype is None else out_dtype
assert isinstance(dilation, int) or len(dilation) == 3
if isinstance(dilation, int):
dilation_d, dilation_h, dilation_w = (dilation, dilation, dilation)
else:
dilation_d, dilation_h, dilation_w = dilation
DSTR, HSTR, WSTR = strides
batch_size, in_depth, in_height, in_width, in_channel = get_const_tuple(data.shape)
kernel_depth, kernel_height, kernel_width, _, num_filter = get_const_tuple(kernel.shape)
dilated_kernel_d = (kernel_depth - 1) * dilation_d + 1
dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w))
pad_d = pad_front + pad_back
pad_h = pad_top + pad_down
pad_w = pad_left + pad_right
pad_depth = in_depth + pad_d
pad_height = in_height + pad_h
pad_width = in_width + pad_w
out_depth = simplify((in_depth + pad_d - dilated_kernel_d) // DSTR + 1)
out_height = simplify((in_height + pad_h - dilated_kernel_h) // HSTR + 1)
out_width = simplify((in_width + pad_w - dilated_kernel_w) // WSTR + 1)
# pack data
DOPAD = (pad_d != 0 or pad_h != 0 or pad_w != 0)
if DOPAD:
data_pad = pad(data, (0, pad_front, pad_top, pad_left, 0),
(0, pad_back, pad_down, pad_right, 0), name="data_pad")
else:
# fuse batch, depth, height axes
fused = s[C].fuse(n, d, h)
s[C].parallel(fused)
data_pad = data
# fetch schedule
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
shape = (batch_size, in_channel // ic_bn, pad_depth, pad_height, ic_bn, pad_width)
data_vec = tvm.compute(shape,
lambda n, C, d, h, c, w: data_pad[n, d, h, w, C * ic_bn + c],
name='data_vec')
# pack kernel
shape = (num_filter//oc_bn, in_channel//ic_bn,
kernel_depth, kernel_height, kernel_width, ic_bn, oc_bn)
kernel_vec = tvm.compute(shape,
lambda CO, CI, d, h, w, ci, co:
kernel[d, h, w, CI * ic_bn + ci, CO * oc_bn + co],
name='kernel_vec')
# convolution
oshape = (batch_size, num_filter//oc_bn, out_depth, out_height, out_width, oc_bn)
unpack_shape = (batch_size, out_depth, out_height, out_width, num_filter)
ic = tvm.reduce_axis((0, in_channel), name='ic')
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')
kd = tvm.reduce_axis((0, kernel_depth), name='kd')
idxmod = tvm.indexmod
idxdiv = tvm.indexdiv
conv = tvm.compute(oshape, lambda n, oc_chunk, od, oh, ow, oc_block:
tvm.sum(data_vec[n,
idxdiv(ic, ic_bn),
od*DSTR+kd*dilation_d,
oh*HSTR+kh*dilation_h,
idxmod(ic, ic_bn),
ow*WSTR+kw*dilation_w].astype(out_dtype) *
kernel_vec[oc_chunk, idxdiv(ic, ic_bn), kd, kh, kw,
idxmod(ic, ic_bn),
oc_block].astype(out_dtype),
axis=[kd, kh, kw, ic]), name='conv')
conv_unpacked = tvm.compute(unpack_shape,
lambda n, d, h, w, c: conv[n, idxdiv(c, oc_bn),
d, h, w,
idxmod(c, oc_bn)]
.astype(out_dtype),
name='output_unpack',
tag='conv3d_ndhwc')
return conv_unpacked
def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
"""Create schedule configuration from input arguments"""
dshape = get_const_tuple(data.shape)
kshape = get_const_tuple(kernel.shape)
if layout == 'NDHWC':
n, d, h, w, ic = dshape
kd, kh, kw, _, oc = kshape
else:
raise ValueError("Not support this layout {} with "
"schedule template.".format(layout))
# pad_front, pad_top, pad_left, pad_back, pad_down(bottom), pad_right
pf, pt, pl, pb, pd, pr = get_pad_tuple3d(padding, (kd, kh, kw))
sd, sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides, strides)
od = (d - kd + pf + pb) // sd + 1
oh = (h - kh + pt + pd) // sh + 1
ow = (w - kw + pl + pr) // sw + 1
# Create schedule config
cfg.define_split("tile_ic", ic, num_outputs=2)
cfg.define_split("tile_oc", oc, num_outputs=2)
cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 8)
cfg.define_knob("unroll_kw", [True, False])
def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, layout):
"""
Get default schedule config for the workload
"""
if layout != 'NDHWC':
raise ValueError("Layout {} is not supported".format(layout))
static_data_shape = []
for dim in get_const_tuple(data.shape):
if isinstance(dim, tvm.expr.Var):
static_data_shape.append(1)
else:
static_data_shape.append(dim)
data = tvm.placeholder(static_data_shape, dtype=data.dtype)
wkl = _get_conv3d_workload(data, kernel, strides, padding, out_dtype, layout)
_fallback_schedule(cfg, wkl)
def _get_conv3d_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'):
""" Get the workload structure. """
if data_layout == 'NCDHW':
_, CI, ID, IH, IW = get_const_tuple(data.shape)
CIG, CO, KD, KH, KW = get_const_tuple(kernel.shape)
elif data_layout == 'NDHWC':
_, ID, IH, IW, CI = get_const_tuple(data.shape)
KD, KH, KW, CIG, CO = get_const_tuple(kernel.shape)
else:
raise ValueError("not support this layout {} yet".format(data_layout))
pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
padding, (get_const_int(KD), get_const_int(KH), get_const_int(KW)))
DPAD = pad_front + pad_back
HPAD = pad_top + pad_down
WPAD = pad_left + pad_right
GRPS = CI // CIG
if isinstance(stride, (tuple, list)):
DSTR, HSTR, WSTR = stride
else:
DSTR, HSTR, WSTR = stride, stride, stride
assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
"Do not support inputs with different data types now. ' \
'{} vs. {}".format(data.dtype, kernel.dtype)
return Workload3D(data.dtype, out_dtype, ID, IH, IW, CI, GRPS, CO, KD, KH, KW,
DPAD, HPAD, WPAD, DSTR, HSTR, WSTR)
def _fallback_schedule(cfg, wkl):
simd_width = get_fp32_len()
DPAD, HPAD, WPAD = wkl.dpad, wkl.hpad, wkl.wpad
DSTR, HSTR, WSTR = wkl.dstride, wkl.hstride, wkl.wstride
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
oc_bn = 1
for bn in range(simd_width, 0, -1):
if wkl.out_filter % bn == 0:
oc_bn = bn
break
ic_bn = 1
for bn in range(oc_bn, 0, -1):
if wkl.in_filter % bn == 0:
ic_bn = bn
break
reg_n = 1
for n in range(7, 0, -1):
if out_width % n == 0:
reg_n = n
break
cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
cfg["unroll_kw"] = OtherOptionEntity(False)
def _schedule_conv3d_ndhwc(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
# fetch schedule
ic_bn, oc_bn, reg_n, unroll_kw = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
cfg["tile_ow"].size[-1], cfg["unroll_kw"].val)
# get padding size
padding = infer_pad3d(data, data_pad, "NDHWC")
DPAD, HPAD, WPAD = padding
DOPAD = (DPAD != 0 or HPAD != 0 or WPAD != 0)
A, W = data, kernel_vec
A0, A1 = data_pad, data_vec
# schedule data
if DOPAD:
s[A0].compute_inline()
batch, ic_chunk, idd, ih, ic_block, iw = s[A1].op.axis
parallel_axis = s[A1].fuse(batch, ic_chunk, idd, ih)
s[A1].parallel(parallel_axis)
# schedule kernel pack
oc_chunk, ic_chunk, od, oh, ow, ic_block, oc_block = s[W].op.axis
s[W].reorder(oc_chunk, od, oh, ic_chunk, ow, ic_block, oc_block)
if oc_bn > 1:
s[W].vectorize(oc_block)
parallel_axis = s[W].fuse(oc_chunk, od, oh)
s[W].parallel(parallel_axis)
# schedule conv
C, O0, O = conv_out, output, last
CC = s.cache_write(C, 'global')
_, oc_chunk, od, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
s[C].reorder(oc_chunk, od, oh, ow_chunk, ow_block, oc_block)
s[C].fuse(oc_chunk, od, oh)
s[C].vectorize(oc_block)
s[CC].compute_at(s[C], ow_chunk)
_, oc_chunk, od, oh, ow, oc_block = s[CC].op.axis
kd, kh, kw, ic = s[CC].op.reduce_axis
ow_chunk, ow_block = s[CC].split(ow, factor=reg_n)
ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn)
if unroll_kw:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kd, kh, ic_block, kw, ow_block, oc_block)
s[CC].unroll(kw)
else:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kd, kh, kw, ic_block, ow_block, oc_block)
s[CC].fuse(oc_chunk, od, oh)
s[CC].vectorize(oc_block)
s[CC].unroll(ow_block)
if O0 != O:
s[O0].compute_inline()
traverse_inline(s, output_op, _traverse)
# unpacking
batch, od, oh, ow, oc = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
s[O].reorder(oc_chunk, od, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, od, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
return s
......@@ -36,7 +36,6 @@ def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride,
A = tvm.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A')
W = tvm.placeholder((kernel_depth, kernel_height, kernel_width, in_channel, num_filter), name='W')
B = topi.nn.conv3d_ndhwc(A, W, stride, padding, dilation)
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
......@@ -57,6 +56,7 @@ def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride,
return
print("Running on target: %s" % device)
with tvm.target.create(device):
B = topi.nn.conv3d(A, W, stride, padding, dilation, layout="NDHWC")
s = topi.generic.schedule_conv3d_ndhwc([B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment