Unverified Commit d56829ea by Matthew Brookhart Committed by GitHub

Conv3D ONNX support and conv3D_ncdhw x86 schedules (#4949)

* Support 3d Convolution with the ONNX frontend

* add unit tests for conv3d in onnx frontend

respond to PR formatting requests

add x86 schedules to conv3d ncdhw test

fix a doc string format issue

refactor for changed upsream API

* first attempt at conv3d autotuning

add default schedule for conv3d_ncdhw

fill in autotvm integration

add a fallback for invalid schedules

fix fallback

fix reduction order to get simd working correctly
parent a6cb4b8d
...@@ -91,16 +91,18 @@ def get_numpy(tensor_proto): ...@@ -91,16 +91,18 @@ def get_numpy(tensor_proto):
return to_array(tensor_proto) return to_array(tensor_proto)
def dimension_picker(prefix, surfix=''): def dimension_picker(prefix, suffix=''):
"""Check that dimensions are supported.""" """Check that dimensions are supported."""
def _impl(attr): def _impl(attr):
kernel = attr['kernel_shape'] kernel = attr['kernel_shape']
if len(kernel) == 1: if len(kernel) == 1:
return prefix + '1d' + surfix return prefix + '1d' + suffix
if len(kernel) == 2: if len(kernel) == 2:
return prefix + '2d' + surfix return prefix + '2d' + suffix
msg = 'Only 1D and 2D kernels are supported for operator {}.' if len(kernel) == 3:
op_name = prefix + '1d/2d' return prefix + '3d' + suffix
msg = 'Only 1D, 2D, and 3D kernels are supported for operator {}.'
op_name = prefix + '1d/2d/3d'
raise tvm.error.OpAttributeInvalid(msg.format(op_name)) raise tvm.error.OpAttributeInvalid(msg.format(op_name))
return _impl return _impl
...@@ -155,11 +157,11 @@ def onnx_storage_order2layout(storage_order, dims=2): ...@@ -155,11 +157,11 @@ def onnx_storage_order2layout(storage_order, dims=2):
def dimension_constraint(): def dimension_constraint():
def _dim_check(attrs): def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2 or len(attrs['kernel_shape']) == 1: if len(attrs['kernel_shape']) in [1, 2, 3]:
return True return True
return False return False
return _dim_check, "Only 1d and 2d kernel supported." return _dim_check, "Only 1d, 2d and 3d kernel supported."
class OnnxOpConverter(object): class OnnxOpConverter(object):
......
...@@ -188,10 +188,9 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target): ...@@ -188,10 +188,9 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target):
strategy = _op.OpStrategy() strategy = _op.OpStrategy()
layout = attrs.data_layout layout = attrs.data_layout
if layout == "NCDHW": if layout == "NCDHW":
logger.warning("conv3d with layout NCDHW is not optimized for x86.") strategy.add_implementation(wrap_compute_conv3d(topi.x86.conv3d_ncdhw),
strategy.add_implementation(wrap_compute_conv3d(topi.nn.conv3d_ncdhw), wrap_topi_schedule(topi.x86.schedule_conv3d_ncdhw),
wrap_topi_schedule(topi.generic.schedule_conv3d_ncdhw), name="conv3d_ncdhw.x86")
name="conv3d_ncdhw.generic")
elif layout == "NDHWC": elif layout == "NDHWC":
strategy.add_implementation(wrap_compute_conv3d(topi.x86.conv3d_ndhwc), strategy.add_implementation(wrap_compute_conv3d(topi.x86.conv3d_ndhwc),
wrap_topi_schedule(topi.x86.schedule_conv3d_ndhwc), wrap_topi_schedule(topi.x86.schedule_conv3d_ndhwc),
......
...@@ -1794,37 +1794,51 @@ def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilat ...@@ -1794,37 +1794,51 @@ def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilat
def test_conv(): def test_conv():
def repeat(N, D):
return tuple([N for _ in range(D)])
for D in [1, 2, 3]:
# Convolution with padding # Convolution with padding
# Conv2D verify_conv((1, 1) + repeat(5, D),
verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 5, 5), [1, 1, 1, 1], [3, 3], [1, 1], [1, 1]) (1, 1) + repeat(3, D),
# Conv1D (1, 1) + repeat(5, D),
verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 5), [1, 1], [3], [1], [1]) 2 * repeat(1, D),
repeat(3, D),
repeat(1, D),
repeat(1, D))
# Convolution without padding # Convolution without padding
# Conv2D verify_conv((1, 1) + repeat(5, D),
verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 3, 3), [0, 0, 0, 0], [3, 3], [1, 1], [1, 1]) (1, 1) + repeat(3, D),
# Conv1D (1, 1) + repeat(3, D),
verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 3), [0, 0], [3], [1], [1]) 2 * repeat(0, D),
repeat(3, D),
repeat(1, D),
repeat(1, D))
# Convolution with autopadding # Convolution with autopadding
verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 5, 5), verify_conv((1, 1) + repeat(5, D),
None, [3, 3], [1, 1], [1, 1], (1, 1) + repeat(3, D),
(1, 1) + repeat(5, D),
None,
repeat(3, D),
repeat(1, D),
repeat(1, D),
auto_pad="SAME_UPPER") auto_pad="SAME_UPPER")
# Conv1D
verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 5), None, [3], [1], [1], auto_pad="SAME_UPPER")
# Convolution with non uniform stride # Convolution with non uniform stride
verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 3, 3), verify_conv((1, 1) + repeat(5, D),
None, [3, 3], [2, 2], [1, 1], (1, 1) + repeat(3, D),
(1, 1) + repeat(3, D),
None,
repeat(3, D),
repeat(2, D),
repeat(1, D),
auto_pad="SAME_UPPER") auto_pad="SAME_UPPER")
# Conv1D
verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 3), None, [3], [2], [1], auto_pad="SAME_UPPER")
# Convolution with dilation # Convolution with dilation
verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 5, 5), [2, 2, 2, 2], [3, 3], [1, 1], [2, 2]) verify_conv((1, 1) + repeat(5, D),
# Conv1D (1, 1) + repeat(3, D),
verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 5), [2, 2], [3], [1], [2]) (1, 1) + repeat(5, D),
2 * repeat(2, D),
repeat(3, D),
repeat(1, D),
repeat(2, D))
def verify_convtranspose(x_shape, w_shape, y_shape, p): def verify_convtranspose(x_shape, w_shape, y_shape, p):
node = onnx.helper.make_node("ConvTranspose", node = onnx.helper.make_node("ConvTranspose",
......
...@@ -42,7 +42,6 @@ def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype): ...@@ -42,7 +42,6 @@ def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
---------- ----------
input : tvm.te.Tensor input : tvm.te.Tensor
5-D input data with shapes: 5-D input data with shapes:
[batch, in_channel, in_depth, in_height, in_width] for NCDHW layout
[batch, in_depth, in_height, in_width, in_channel] for NDHWC layout [batch, in_depth, in_height, in_width, in_channel] for NDHWC layout
filter : tvm.te.Tensor filter : tvm.te.Tensor
...@@ -61,7 +60,6 @@ def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype): ...@@ -61,7 +60,6 @@ def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
------- -------
output : tvm.te.Tensor output : tvm.te.Tensor
5-D with shape [batch, out_depth, out_height, out_width, out_channel] for NDHWC layout 5-D with shape [batch, out_depth, out_height, out_width, out_channel] for NDHWC layout
5-D with shape [batch, out_channel, out_depth, out_height, out_width] for NCDHW layout
""" """
layout = "NDHWC" layout = "NDHWC"
out_dtype = data.dtype if out_dtype is None else out_dtype out_dtype = data.dtype if out_dtype is None else out_dtype
...@@ -74,14 +72,53 @@ def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype): ...@@ -74,14 +72,53 @@ def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
return _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype) return _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype)
@autotvm.register_topi_compute("conv3d_ncdhw.x86")
def conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, out_dtype):
"""3D convolution forward operator.
Parameters
----------
input : tvm.Tensor
5-D input data with shapes:
[batch, in_channel, in_depth, in_height, in_width] for NCDHW layout
filter : tvm.Tensor
5-D filter with shape [out_channels, in_channels, kernel_depth, kernel_height, kernel_width]
strides : int or a list/tuple of three ints
stride size, or [stride_depth, stride_height, stride_width]
padding : int or a list/tuple of three ints
padding size, or [pad_depth, pad_height, pad_width]
dilation: int or a list/tuple of three ints
dilation size, or [dilation_depth, dilation_height, dilation_width]
Returns
-------
output : tvm.Tensor
5-D with shape [batch, out_channel, out_depth, out_height, out_width] for NCDHW layout
"""
layout = "NCDHW"
out_dtype = data.dtype if out_dtype is None else out_dtype
strides = strides if isinstance(strides, (tuple, list)) else (strides, strides, strides)
dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation, dilation)
_create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout)
if cfg.is_fallback:
_get_default_config(cfg, data, kernel, strides, padding, out_dtype, layout)
return _conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
@autotvm.register_topi_schedule("conv3d_ndhwc.x86") @autotvm.register_topi_schedule("conv3d_ndhwc.x86")
def schedule_conv3d_ndhwc(cfg, outs): def schedule_conv3d_ndhwc(cfg, outs):
"""TOPI schedule callback for conv3d """TOPI schedule callback for conv3d
Parameters Parameters
---------- ----------
outs: Array of Tensor outs: Array of Tensor
The computation graph description of conv3d The computation graph description of conv3d
in the format of an array of tensors. in the format of an array of tensors.
Returns Returns
------- -------
s: Schedule s: Schedule
...@@ -111,6 +148,45 @@ def schedule_conv3d_ndhwc(cfg, outs): ...@@ -111,6 +148,45 @@ def schedule_conv3d_ndhwc(cfg, outs):
traverse_inline(s, outs[0].op, _traverse) traverse_inline(s, outs[0].op, _traverse)
return s return s
@autotvm.register_topi_schedule("conv3d_ncdhw.x86")
def schedule_conv3d_ncdhw(cfg, outs):
"""TOPI schedule callback for conv3d
Parameters
----------
outs: Array of Tensor
The computation graph description of conv3d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv3d.
"""
s = te.create_schedule([x.op for x in outs])
def _traverse(op):
if 'conv3d_ncdhw' in op.tag:
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.te.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
kd, kh, kw, i, o = get_const_tuple(kernel.shape)
args = [s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, outs[0]]
_schedule_conv3d_ncdhw(*args)
traverse_inline(s, outs[0].op, _traverse)
return s
def _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype): def _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
out_dtype = data.dtype if out_dtype is None else out_dtype out_dtype = data.dtype if out_dtype is None else out_dtype
...@@ -198,6 +274,93 @@ def _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype): ...@@ -198,6 +274,93 @@ def _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
tag='conv3d_ndhwc') tag='conv3d_ndhwc')
return conv_unpacked return conv_unpacked
def _conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
out_dtype = data.dtype if out_dtype is None else out_dtype
assert isinstance(dilation, int) or len(dilation) == 3
if isinstance(dilation, int):
dilation_d, dilation_h, dilation_w = (dilation, dilation, dilation)
else:
dilation_d, dilation_h, dilation_w = dilation
DSTR, HSTR, WSTR = strides
batch_size, in_channel, in_depth, in_height, in_width = get_const_tuple(data.shape)
num_filter, _, kernel_depth, kernel_height, kernel_width = get_const_tuple(kernel.shape)
dilated_kernel_d = (kernel_depth - 1) * dilation_d + 1
dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w))
pad_d = pad_front + pad_back
pad_h = pad_top + pad_down
pad_w = pad_left + pad_right
pad_depth = in_depth + pad_d
pad_height = in_height + pad_h
pad_width = in_width + pad_w
out_depth = simplify((in_depth + pad_d - dilated_kernel_d) // DSTR + 1)
out_height = simplify((in_height + pad_h - dilated_kernel_h) // HSTR + 1)
out_width = simplify((in_width + pad_w - dilated_kernel_w) // WSTR + 1)
# pack data
DOPAD = (pad_d != 0 or pad_h != 0 or pad_w != 0)
if DOPAD:
data_pad = pad(data, (0, 0, pad_front, pad_top, pad_left),
(0, 0, pad_back, pad_down, pad_right), name="data_pad")
else:
data_pad = data
# fetch schedule
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
shape = (batch_size, in_channel // ic_bn, pad_depth, pad_height, ic_bn, pad_width)
data_vec = te.compute(shape,
lambda n, C, d, h, c, w: data_pad[n, C * ic_bn + c, d, h, w],
name='data_vec')
# pack kernel
shape = (num_filter//oc_bn, in_channel//ic_bn,
kernel_depth, kernel_height, kernel_width, ic_bn, oc_bn)
kernel_vec = te.compute(shape,
lambda CO, CI, d, h, w, ci, co:
kernel[CO * oc_bn + co, CI * ic_bn + ci, d, h, w],
name='kernel_vec')
# convolution
oshape = (batch_size, num_filter//oc_bn,
out_depth, out_height, out_width, oc_bn)
unpack_shape = (batch_size, num_filter, out_depth, out_height, out_width)
ic = te.reduce_axis((0, in_channel), name='ic')
kh = te.reduce_axis((0, kernel_height), name='kh')
kw = te.reduce_axis((0, kernel_width), name='kw')
kd = te.reduce_axis((0, kernel_depth), name='kd')
idxmod = tvm.tir.indexmod
idxdiv = tvm.tir.indexdiv
conv = te.compute(oshape, lambda n, oc_chunk, od, oh, ow, oc_block:
te.sum(data_vec[n,
idxdiv(ic, ic_bn),
od*DSTR+kd*dilation_d,
oh*HSTR+kh*dilation_h,
idxmod(ic, ic_bn),
ow*WSTR+kw*dilation_w].astype(out_dtype) *
kernel_vec[oc_chunk, idxdiv(ic, ic_bn), kd, kh, kw,
idxmod(ic, ic_bn),
oc_block].astype(out_dtype),
axis=[ic, kd, kh, kw]), name='conv')
conv_unpacked = te.compute(unpack_shape,
lambda n, c, d, h, w: conv[n, idxdiv(c, oc_bn),
d, h, w,
idxmod(c, oc_bn)]
.astype(out_dtype),
name='output_unpack',
tag='conv3d_ncdhw')
return conv_unpacked
def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
"""Create schedule configuration from input arguments""" """Create schedule configuration from input arguments"""
...@@ -206,6 +369,9 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): ...@@ -206,6 +369,9 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
if layout == 'NDHWC': if layout == 'NDHWC':
n, d, h, w, ic = dshape n, d, h, w, ic = dshape
kd, kh, kw, _, oc = kshape kd, kh, kw, _, oc = kshape
elif layout == 'NCDHW':
n, ic, d, h, w = dshape
oc, _, kd, kh, kw = kshape
else: else:
raise ValueError("Not support this layout {} with " raise ValueError("Not support this layout {} with "
"schedule template.".format(layout)) "schedule template.".format(layout))
...@@ -227,7 +393,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, layout): ...@@ -227,7 +393,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, layout):
""" """
Get default schedule config for the workload Get default schedule config for the workload
""" """
if layout != 'NDHWC': if layout not in ['NDHWC', 'NCDHW']:
raise ValueError("Layout {} is not supported".format(layout)) raise ValueError("Layout {} is not supported".format(layout))
static_data_shape = [] static_data_shape = []
...@@ -244,7 +410,7 @@ def _get_conv3d_workload(data, kernel, stride, padding, out_dtype, data_layout=' ...@@ -244,7 +410,7 @@ def _get_conv3d_workload(data, kernel, stride, padding, out_dtype, data_layout='
""" Get the workload structure. """ """ Get the workload structure. """
if data_layout == 'NCDHW': if data_layout == 'NCDHW':
_, CI, ID, IH, IW = get_const_tuple(data.shape) _, CI, ID, IH, IW = get_const_tuple(data.shape)
CIG, CO, KD, KH, KW = get_const_tuple(kernel.shape) CO, CIG, KD, KH, KW = get_const_tuple(kernel.shape)
elif data_layout == 'NDHWC': elif data_layout == 'NDHWC':
_, ID, IH, IW, CI = get_const_tuple(data.shape) _, ID, IH, IW, CI = get_const_tuple(data.shape)
KD, KH, KW, CIG, CO = get_const_tuple(kernel.shape) KD, KH, KW, CIG, CO = get_const_tuple(kernel.shape)
...@@ -365,3 +531,73 @@ def _schedule_conv3d_ndhwc(s, cfg, data, data_pad, data_vec, kernel_vec, conv_ou ...@@ -365,3 +531,73 @@ def _schedule_conv3d_ndhwc(s, cfg, data, data_pad, data_vec, kernel_vec, conv_ou
s[O].vectorize(oc_block) s[O].vectorize(oc_block)
s[O].parallel(parallel_axis) s[O].parallel(parallel_axis)
return s return s
def _schedule_conv3d_ncdhw(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
# fetch schedule
ic_bn, oc_bn, reg_n, unroll_kw = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
cfg["tile_ow"].size[-1], cfg["unroll_kw"].val)
# get padding size
padding = infer_pad3d(data, data_pad, "NCDHW")
DPAD, HPAD, WPAD = padding
DOPAD = (DPAD != 0 or HPAD != 0 or WPAD != 0)
A, W = data, kernel_vec
A0, A1 = data_pad, data_vec
# schedule data
if DOPAD:
s[A0].compute_inline()
batch, ic_chunk, idd, ih, ic_block, iw = s[A1].op.axis
parallel_axis = s[A1].fuse(batch, ic_chunk, idd, ih)
s[A1].parallel(parallel_axis)
# schedule kernel pack
oc_chunk, ic_chunk, od, oh, ow, ic_block, oc_block = s[W].op.axis
s[W].reorder(oc_chunk, od, oh, ic_chunk, ow, ic_block, oc_block)
if oc_bn > 1:
s[W].vectorize(oc_block)
parallel_axis = s[W].fuse(oc_chunk, od, oh)
s[W].parallel(parallel_axis)
# schedule conv
C, O0, O = conv_out, output, last
CC = s.cache_write(C, 'global')
_, oc_chunk, od, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
s[C].reorder(oc_chunk, od, oh, ow_chunk, ow_block, oc_block)
s[C].fuse(oc_chunk, od, oh)
s[C].vectorize(oc_block)
s[CC].compute_at(s[C], ow_chunk)
_, oc_chunk, od, oh, ow, oc_block = s[CC].op.axis
ic, kd, kh, kw = s[CC].op.reduce_axis
ow_chunk, ow_block = s[CC].split(ow, factor=reg_n)
ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn)
if unroll_kw:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kd, kh, ic_block, kw, ow_block, oc_block)
s[CC].unroll(kw)
else:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kd, kh, kw, ic_block, ow_block, oc_block)
s[CC].fuse(oc_chunk, od, oh)
s[CC].vectorize(oc_block)
s[CC].unroll(ow_block)
if O0 != O:
s[O0].compute_inline()
# unpacking
batch, oc, od, oh, ow = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
s[O].reorder(oc_chunk, od, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, od, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
return s
...@@ -30,6 +30,7 @@ from common import get_all_backend ...@@ -30,6 +30,7 @@ from common import get_all_backend
_conv3d_ncdhw_implement = { _conv3d_ncdhw_implement = {
"generic": (topi.nn.conv3d_ncdhw, topi.generic.schedule_conv3d_ncdhw), "generic": (topi.nn.conv3d_ncdhw, topi.generic.schedule_conv3d_ncdhw),
"cpu": (topi.x86.conv3d_ncdhw, topi.x86.schedule_conv3d_ncdhw),
"gpu": (topi.cuda.conv3d_ncdhw, topi.cuda.schedule_conv3d_ncdhw), "gpu": (topi.cuda.conv3d_ncdhw, topi.cuda.schedule_conv3d_ncdhw),
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment