Commit 227c7af4 by optima2005 Committed by Tianqi Chen

[FRONTEND][TF] conv2d_transpose 'SAME' support kernel more than 1x1 (#4484)

* [FRONTEND][TF] conv3d_transpose 'SAME' support kernel more than 1x1

* revised per as review comments

* add more fallback wolkaround to make all tests pass
parent 32276146
...@@ -269,6 +269,12 @@ def _conv(opname): ...@@ -269,6 +269,12 @@ def _conv(opname):
attr['strides'][1], attr['strides'][2], attr['strides'][3] = \ attr['strides'][1], attr['strides'][2], attr['strides'][3] = \
attr['strides'][3], attr['strides'][1], attr['strides'][2] attr['strides'][3], attr['strides'][1], attr['strides'][2]
attr['data_format'] = 'NCHW' attr['data_format'] = 'NCHW'
if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0:
tmp_shape = attr['_output_shapes'][0]
tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
attr['_output_shapes'][0] = tmp_shape
flip_layout = True flip_layout = True
inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2] inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2]
...@@ -345,12 +351,17 @@ def _conv(opname): ...@@ -345,12 +351,17 @@ def _conv(opname):
elif attr['padding'] == 'SAME': elif attr['padding'] == 'SAME':
stride_h, stride_w = attr['strides'] stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape'] kernel_h, kernel_w = attr['kernel_shape']
pdata_shape = input_shape
if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0:
pdata_shape = attr['_output_shapes'][0]
if attr['data_format'] == 'NHWC': if attr['data_format'] == 'NHWC':
in_h = input_shape[1] in_h = pdata_shape[1]
in_w = input_shape[2] in_w = pdata_shape[2]
else: else:
in_h = input_shape[2] in_h = pdata_shape[2]
in_w = input_shape[3] in_w = pdata_shape[3]
dilation_h = attr['dilations'][0] dilation_h = attr['dilations'][0]
dilation_w = attr['dilations'][1] dilation_w = attr['dilations'][1]
...@@ -359,7 +370,7 @@ def _conv(opname): ...@@ -359,7 +370,7 @@ def _conv(opname):
pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
if opname != 'conv_transpose':
if attr['data_format'] == 'NHWC': if attr['data_format'] == 'NHWC':
inputs_data = _op.nn.pad(data=inputs_data, inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0), pad_width=((0, 0),
...@@ -374,6 +385,8 @@ def _conv(opname): ...@@ -374,6 +385,8 @@ def _conv(opname):
(pad_h[0], pad_h[1]))) (pad_h[0], pad_h[1])))
attr['padding'] = [0, 0] attr['padding'] = [0, 0]
else:
attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
else: else:
msg = 'Value {} in attribute "padding" of operator Conv is not ' \ msg = 'Value {} in attribute "padding" of operator Conv is not ' \
......
...@@ -249,10 +249,22 @@ bool Conv2DTransposeRel(const Array<Type>& types, ...@@ -249,10 +249,22 @@ bool Conv2DTransposeRel(const Array<Type>& types,
} }
// dilation // dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0}); Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
auto pad_h = param->padding[0];
auto pad_w = param->padding[1];
if (param->padding.size() == 2) {
pad_h *= 2;
pad_w *= 2;
} else if (param->padding.size() == 4) {
pad_h += param->padding[2];
pad_w += param->padding[3];
} else {
CHECK_EQ(param->padding.size(), 4) << " Padding should be 2 or 4, but got "
<< param->padding.size();
}
oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
2 * param->padding[0] + param->output_padding[0])); pad_h + param->output_padding[0]));
oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
2 * param->padding[1] + param->output_padding[1])); pad_w + param->output_padding[1]));
DataType out_dtype = param->out_dtype; DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) { if (out_dtype.bits() == 0) {
......
...@@ -403,10 +403,22 @@ def test_forward_convolution(): ...@@ -403,10 +403,22 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW') _test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution('conv_transpose', [4, 32, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', _test_convolution('conv_transpose', [4, 32, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 176, 8, 8]) 'NCHW', [4, 176, 8, 8])
_test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 176, 32], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 176, 8, 8])
_test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 176, 32], [1, 1], [2, 2], 'SAME',
'NCHW', [4, 176, 15, 15])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 176, 8, 8])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
'NCHW', [4, 176, 15, 15])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
'NCHW', [4, 176, 16, 16])
_test_convolution('conv_transpose', [4, 19, 8, 8], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', _test_convolution('conv_transpose', [4, 19, 8, 8], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 19, 17, 17]) 'NCHW', [4, 19, 17, 17])
_test_convolution('conv_transpose', [4, 19, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', _test_convolution('conv_transpose', [4, 19, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 124, 17, 17]) 'NCHW', [4, 124, 17, 17])
_test_convolution('conv_transpose', [4, 19, 17, 17], [3, 3, 124, 19], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 124, 17, 17])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 12, 17, 17]) 'NCHW', [4, 12, 17, 17])
# kernel 2x2, strides (2,2) # kernel 2x2, strides (2,2)
...@@ -429,10 +441,22 @@ def test_forward_convolution(): ...@@ -429,10 +441,22 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC') _test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution('conv_transpose', [4, 8, 8, 32], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', _test_convolution('conv_transpose', [4, 8, 8, 32], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 8, 8, 176]) 'NHWC', [4, 8, 8, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 176, 32], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 8, 8, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 176, 32], [1, 1], [2, 2], 'SAME',
'NHWC', [4, 15, 15, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 8, 8, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
'NHWC', [4, 15, 15, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
'NHWC', [4, 16, 16, 176])
_test_convolution('conv_transpose', [4, 8, 8, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', _test_convolution('conv_transpose', [4, 8, 8, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 17, 17, 19]) 'NHWC', [4, 17, 17, 19])
_test_convolution('conv_transpose', [4, 17, 17, 19], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', _test_convolution('conv_transpose', [4, 17, 17, 19], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 17, 17, 124]) 'NHWC', [4, 17, 17, 124])
_test_convolution('conv_transpose', [4, 17, 17, 19], [3, 3, 124, 19], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 17, 17, 124])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', _test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 17, 17, 12]) 'NHWC', [4, 17, 17, 12])
# kernel 2x2, strides (2,2) # kernel 2x2, strides (2,2)
......
...@@ -197,6 +197,8 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs): ...@@ -197,6 +197,8 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
do_fallback = False do_fallback = False
elif (kh, kw) == (1, 1): elif (kh, kw) == (1, 1):
do_fallback = True do_fallback = True
elif (stride_h, stride_w) == (2, 2):
do_fallback = False
elif (kh, kw) == (stride_h, stride_w): elif (kh, kw) == (stride_h, stride_w):
do_fallback = False do_fallback = False
......
...@@ -103,8 +103,13 @@ def get_pad_tuple(padding, kernel): ...@@ -103,8 +103,13 @@ def get_pad_tuple(padding, kernel):
""" """
# compute the padding size # compute the padding size
if isinstance(padding, (tuple, list)): if isinstance(padding, (tuple, list)):
if len(padding) == 2:
pad_h = padding[0] * 2 pad_h = padding[0] * 2
pad_w = padding[1] * 2 pad_w = padding[1] * 2
elif len(padding) == 4:
return padding[0], padding[1], padding[2], padding[3]
else:
raise ValueError("Size of padding can only be 2 or 4")
elif isinstance(padding, int): elif isinstance(padding, int):
pad_h = pad_w = padding * 2 pad_h = pad_w = padding * 2
elif padding == "VALID": elif padding == "VALID":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment