Commit 227c7af4 by optima2005 Committed by Tianqi Chen

[FRONTEND][TF] conv2d_transpose 'SAME' support kernel more than 1x1 (#4484)

* [FRONTEND][TF] conv3d_transpose 'SAME' support kernel more than 1x1

* revised per as review comments

* add more fallback wolkaround to make all tests pass
parent 32276146
......@@ -269,6 +269,12 @@ def _conv(opname):
attr['strides'][1], attr['strides'][2], attr['strides'][3] = \
attr['strides'][3], attr['strides'][1], attr['strides'][2]
attr['data_format'] = 'NCHW'
if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0:
tmp_shape = attr['_output_shapes'][0]
tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
attr['_output_shapes'][0] = tmp_shape
flip_layout = True
inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2]
......@@ -345,12 +351,17 @@ def _conv(opname):
elif attr['padding'] == 'SAME':
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
pdata_shape = input_shape
if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0:
pdata_shape = attr['_output_shapes'][0]
if attr['data_format'] == 'NHWC':
in_h = input_shape[1]
in_w = input_shape[2]
in_h = pdata_shape[1]
in_w = pdata_shape[2]
else:
in_h = input_shape[2]
in_w = input_shape[3]
in_h = pdata_shape[2]
in_w = pdata_shape[3]
dilation_h = attr['dilations'][0]
dilation_w = attr['dilations'][1]
......@@ -359,21 +370,23 @@ def _conv(opname):
pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
if opname != 'conv_transpose':
if attr['data_format'] == 'NHWC':
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
else:
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))
if attr['data_format'] == 'NHWC':
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
attr['padding'] = [0, 0]
else:
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))
attr['padding'] = [0, 0]
attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
else:
msg = 'Value {} in attribute "padding" of operator Conv is not ' \
......
......@@ -249,10 +249,22 @@ bool Conv2DTransposeRel(const Array<Type>& types,
}
// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
auto pad_h = param->padding[0];
auto pad_w = param->padding[1];
if (param->padding.size() == 2) {
pad_h *= 2;
pad_w *= 2;
} else if (param->padding.size() == 4) {
pad_h += param->padding[2];
pad_w += param->padding[3];
} else {
CHECK_EQ(param->padding.size(), 4) << " Padding should be 2 or 4, but got "
<< param->padding.size();
}
oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
2 * param->padding[0] + param->output_padding[0]));
pad_h + param->output_padding[0]));
oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
2 * param->padding[1] + param->output_padding[1]));
pad_w + param->output_padding[1]));
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
......
......@@ -403,10 +403,22 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution('conv_transpose', [4, 32, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 176, 8, 8])
_test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 176, 32], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 176, 8, 8])
_test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 176, 32], [1, 1], [2, 2], 'SAME',
'NCHW', [4, 176, 15, 15])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 176, 8, 8])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
'NCHW', [4, 176, 15, 15])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
'NCHW', [4, 176, 16, 16])
_test_convolution('conv_transpose', [4, 19, 8, 8], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 19, 17, 17])
_test_convolution('conv_transpose', [4, 19, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 124, 17, 17])
_test_convolution('conv_transpose', [4, 19, 17, 17], [3, 3, 124, 19], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 124, 17, 17])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 12, 17, 17])
# kernel 2x2, strides (2,2)
......@@ -429,10 +441,22 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution('conv_transpose', [4, 8, 8, 32], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 8, 8, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 176, 32], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 8, 8, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 176, 32], [1, 1], [2, 2], 'SAME',
'NHWC', [4, 15, 15, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 8, 8, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
'NHWC', [4, 15, 15, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
'NHWC', [4, 16, 16, 176])
_test_convolution('conv_transpose', [4, 8, 8, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 17, 17, 19])
_test_convolution('conv_transpose', [4, 17, 17, 19], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 17, 17, 124])
_test_convolution('conv_transpose', [4, 17, 17, 19], [3, 3, 124, 19], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 17, 17, 124])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 17, 17, 12])
# kernel 2x2, strides (2,2)
......
......@@ -197,6 +197,8 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
do_fallback = False
elif (kh, kw) == (1, 1):
do_fallback = True
elif (stride_h, stride_w) == (2, 2):
do_fallback = False
elif (kh, kw) == (stride_h, stride_w):
do_fallback = False
......
......@@ -103,8 +103,13 @@ def get_pad_tuple(padding, kernel):
"""
# compute the padding size
if isinstance(padding, (tuple, list)):
pad_h = padding[0] * 2
pad_w = padding[1] * 2
if len(padding) == 2:
pad_h = padding[0] * 2
pad_w = padding[1] * 2
elif len(padding) == 4:
return padding[0], padding[1], padding[2], padding[3]
else:
raise ValueError("Size of padding can only be 2 or 4")
elif isinstance(padding, int):
pad_h = pad_w = padding * 2
elif padding == "VALID":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment