Commit 2baf310e by optima2005 Committed by Yizhi Liu

[Relay][Frontend][Tensorflow]Add conv2d_transpose (#4300)

* [Relay][Frontend][Tensorflow]Add conv2d_transpose

* add transformation from NHWC to NCHW to compatible with TVM conv2d_transpose implementation

* remove 'dilations' paramater to compitable with TF1.3
parent 9955602d
...@@ -195,10 +195,24 @@ def _conv(opname): ...@@ -195,10 +195,24 @@ def _conv(opname):
attr['data_format'] = attr['data_format'].decode("utf-8") attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False flip_layout = False
if opname == 'conv_transpose' and attr['data_format'] == 'NHWC':
# transform to NCHW for TVM backend compatible and set 'flip_layout'
# to have output flip back to NHWC
tmp_shape = attr['_input_shapes'][inputs[2]]
tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2))
attr['_input_shapes'][inputs[2]] = tmp_shape
attr['strides'][1], attr['strides'][2], attr['strides'][3] = \
attr['strides'][3], attr['strides'][1], attr['strides'][2]
attr['data_format'] = 'NCHW'
flip_layout = True
inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2]
# NCHW Layout require weights transpose # NCHW Layout require weights transpose
if attr['data_format'] == 'NCHW': if attr['data_format'] == 'NCHW':
tmp_shape = attr['_input_shapes'][inputs[1]] tmp_shape = attr['_input_shapes'][inputs[1]]
if opname == 'conv': if opname in ['conv', 'conv_transpose']:
tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)] tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1))
else: else:
...@@ -206,13 +220,13 @@ def _conv(opname): ...@@ -206,13 +220,13 @@ def _conv(opname):
inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1)) inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1))
attr['_input_shapes'][inputs[1]] = tmp_shape attr['_input_shapes'][inputs[1]] = tmp_shape
input_shape = attr['_input_shapes'][inputs[0]] input_shape = attr['_input_shapes'][inputs_data]
weights_shape = attr['_input_shapes'][inputs[1]] weights_shape = attr['_input_shapes'][inputs[1]]
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2)) inputs_data = _op.transpose(inputs_data, axes=(0, 3, 1, 2))
if opname == 'conv': if opname in ['conv', 'conv_transpose']:
weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)] weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1))
else: else:
...@@ -228,6 +242,8 @@ def _conv(opname): ...@@ -228,6 +242,8 @@ def _conv(opname):
attr['kernel_shape'] = (weights_shape[0], weights_shape[1]) attr['kernel_shape'] = (weights_shape[0], weights_shape[1])
if opname == 'conv': if opname == 'conv':
attr['channels'] = weights_shape[3] attr['channels'] = weights_shape[3]
elif opname == 'conv_transpose':
attr['channels'] = weights_shape[2]
else: else:
attr['channels'] = input_shape[3] * depth_mult attr['channels'] = input_shape[3] * depth_mult
...@@ -239,6 +255,8 @@ def _conv(opname): ...@@ -239,6 +255,8 @@ def _conv(opname):
attr['kernel_shape'] = (weights_shape[2], weights_shape[3]) attr['kernel_shape'] = (weights_shape[2], weights_shape[3])
if opname == 'conv': if opname == 'conv':
attr['channels'] = weights_shape[0] attr['channels'] = weights_shape[0]
elif opname == 'conv_transpose':
attr['channels'] = weights_shape[1]
else: else:
attr['channels'] = input_shape[1] * depth_mult attr['channels'] = input_shape[1] * depth_mult
if attr['channels'] < 0: if attr['channels'] < 0:
...@@ -279,17 +297,17 @@ def _conv(opname): ...@@ -279,17 +297,17 @@ def _conv(opname):
if attr['data_format'] == 'NHWC': if attr['data_format'] == 'NHWC':
inputs[0] = _op.nn.pad(data=inputs[0], inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0), pad_width=((0, 0),
(pad_v[0], pad_v[1]), (pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]), (pad_h[0], pad_h[1]),
(0, 0))) (0, 0)))
else: else:
inputs[0] = _op.nn.pad(data=inputs[0], inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0), pad_width=((0, 0),
(0, 0), (0, 0),
(pad_v[0], pad_v[1]), (pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]))) (pad_h[0], pad_h[1])))
attr['padding'] = [0, 0] attr['padding'] = [0, 0]
...@@ -299,27 +317,30 @@ def _conv(opname): ...@@ -299,27 +317,30 @@ def _conv(opname):
raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))
if 'kernel_layout' not in attr: if 'kernel_layout' not in attr:
if opname == 'conv': if opname in ['conv', 'conv_transpose']:
attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW' attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'
else: else:
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW' attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'
use_bias = len(inputs) == 3 use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4)
channel_axis = 1 if attr['data_format'] == "NCHW" else 3 channel_axis = 1 if attr['data_format'] == "NCHW" else 3
# Ignore the new attributes from TF2.0, for now. # Ignore the new attributes from TF2.0, for now.
out = AttrCvt( out = AttrCvt(
op_name=_dimension_picker('conv'), op_name=_dimension_picker('conv', \
surfix="_transpose" if opname == 'conv_transpose' else ""),
ignores=['explicit_paddings'], ignores=['explicit_paddings'],
transforms={ transforms={
'kernel_shape': 'kernel_size', 'kernel_shape': 'kernel_size',
'data_format': 'data_layout', 'data_format': 'data_layout',
'dilations': ('dilation', (0, 0)), 'dilations': ('dilation', (0, 0)),
'group': ('groups', 1)}, 'group': ('groups', 1)},
custom_check=_dimension_constraint())([inputs[0], inputs[1]], attr) custom_check=_dimension_constraint())([inputs_data, inputs[1]], attr)
if use_bias: if use_bias:
out = _op.nn.bias_add(out, inputs[2], axis=channel_axis) out = _op.nn.bias_add(out,
inputs[2] if opname != 'conv_transpose' else inputs[3],
axis=channel_axis)
if flip_layout: if flip_layout:
out = _op.transpose(out, axes=(0, 2, 3, 1)) out = _op.transpose(out, axes=(0, 2, 3, 1))
...@@ -1403,6 +1424,7 @@ _convert_map = { ...@@ -1403,6 +1424,7 @@ _convert_map = {
'Concat' : _concat(), 'Concat' : _concat(),
'ConcatV2' : _concatV2(), 'ConcatV2' : _concatV2(),
'Conv2D' : _conv('conv'), 'Conv2D' : _conv('conv'),
'Conv2DBackpropInput' : _conv('conv_transpose'),
'CropAndResize' : _crop_and_resize(), 'CropAndResize' : _crop_and_resize(),
'DecodeJpeg' : _decode_image(), 'DecodeJpeg' : _decode_image(),
'DepthwiseConv2dNative' : _conv('depthwise'), 'DepthwiseConv2dNative' : _conv('depthwise'),
......
...@@ -295,7 +295,8 @@ def test_forward_pooling(): ...@@ -295,7 +295,8 @@ def test_forward_pooling():
def _test_convolution(opname, tensor_in_sizes, filter_in_sizes, def _test_convolution(opname, tensor_in_sizes, filter_in_sizes,
dilations, strides, padding, data_format): dilations, strides, padding, data_format,
deconv_output_shape=[]):
""" One iteration of convolution with given shapes and attributes """ """ One iteration of convolution with given shapes and attributes """
total_size_1 = np.prod(tensor_in_sizes) total_size_1 = np.prod(tensor_in_sizes)
...@@ -326,6 +327,16 @@ def _test_convolution(opname, tensor_in_sizes, filter_in_sizes, ...@@ -326,6 +327,16 @@ def _test_convolution(opname, tensor_in_sizes, filter_in_sizes,
compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'), compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
'Placeholder:0', 'Conv2D:0') 'Placeholder:0', 'Conv2D:0')
elif opname == 'conv_transpose':
nn_ops.conv2d_transpose(in_data,
in_filter,
output_shape=deconv_output_shape,
strides=strides,
padding=padding,
data_format=data_format)
compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
'Placeholder:0', 'conv2d_transpose:0')
else: else:
nn_ops.depthwise_conv2d_native(in_data, nn_ops.depthwise_conv2d_native(in_data,
in_filter, in_filter,
...@@ -349,6 +360,14 @@ def test_forward_convolution(): ...@@ -349,6 +360,14 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 124, 17, 17], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NCHW') _test_convolution('depthwise', [4, 124, 17, 17], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NCHW')
_test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NCHW') _test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW') _test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution('conv_transpose', [4, 32, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 176, 8, 8])
_test_convolution('conv_transpose', [4, 19, 8, 8], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 19, 17, 17])
_test_convolution('conv_transpose', [4, 19, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 124, 17, 17])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 12, 17, 17])
_test_convolution('conv', [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC') _test_convolution('conv', [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution('conv', [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC') _test_convolution('conv', [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
...@@ -359,6 +378,15 @@ def test_forward_convolution(): ...@@ -359,6 +378,15 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC') _test_convolution('depthwise', [4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC') _test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC') _test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution('conv_transpose', [4, 8, 8, 32], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 8, 8, 176])
_test_convolution('conv_transpose', [4, 8, 8, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 17, 17, 19])
_test_convolution('conv_transpose', [4, 17, 17, 19], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 17, 17, 124])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 17, 17, 12])
####################################################################### #######################################################################
# BiasAdd # BiasAdd
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment