Commit ab890d6e by Alexey Romanov Committed by Siva

Support SpaceToBatchND/BatchToSpaceND in Tensorflow frontend (#2943)

Thanks @alexeyr . This is now merged.
parent 5a27632e
...@@ -984,6 +984,91 @@ def _logical(name): ...@@ -984,6 +984,91 @@ def _logical(name):
return AttrCvt(op_name=name)(inputs, attr) return AttrCvt(op_name=name)(inputs, attr)
return _impl return _impl
def _space_to_batch_nd():
def _impl(inputs, attr, params):
input_node = inputs[0]
input_shape = attr['_input_shapes'][input_node]
block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist()
paddings = params.pop(inputs[2].name_hint).asnumpy().tolist()
N = len(input_shape)
M = len(block_shape)
batch = input_shape[0]
remaining_shape_length = N - M - 1
paddings = [(0, 0)] + paddings + [(0, 0)] * remaining_shape_length
# From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch-n-d:
# Zero-pad the start and end of dimensions [1, ..., M] of the input according to paddings
# to produce padded of shape padded_shape.
padded = tvm.relay.nn.pad(input_node, pad_width=paddings)
# Reshape padded to reshaped_padded of shape:
# [batch] + [padded_shape[1] / block_shape[0], block_shape[0], ...,
# padded_shape[M] / block_shape[M-1], block_shape[M-1]] + remaining_shape
shape1 = [batch] + [item for i in range(M) for item in [-4, -1, block_shape[i]]] + [-2]
reshaped_padded = tvm.relay.reshape(padded, newshape=shape1)
# Permute dimensions of reshaped_padded to produce permuted_reshaped_padded of shape:
# block_shape + [batch] + [padded_shape[1] / block_shape[0], ...,
# padded_shape[M] / block_shape[M-1]] + remaining_shape
axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \
list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length))
permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes)
permuted_reshaped_padded_shape = _infer_out_shapes(permuted_reshaped_padded, params)[0]
# Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
# producing an output tensor of shape:
# [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ...,
# padded_shape[M] / block_shape[M-1]] + remaining_shape
shape2 = [batch * np.prod(block_shape)] + list(permuted_reshaped_padded_shape)[M + 1:]
reshaped_permuted_reshaped_padded = tvm.relay.reshape(permuted_reshaped_padded,
newshape=shape2)
return reshaped_permuted_reshaped_padded
return _impl
def _batch_to_space_nd():
def _impl(inputs, attr, params):
input_node = inputs[0]
input_shape = attr['_input_shapes'][input_node]
block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist()
crops = params.pop(inputs[2].name_hint).asnumpy().tolist()
M = len(block_shape)
batch = input_shape[0]
# From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d:
# Reshape input to reshaped of shape:
# [block_shape[0], ..., block_shape[M-1], batch / prod(block_shape),
# input_shape[1], ..., input_shape[N-1]]
shape1 = block_shape + [batch // np.prod(block_shape)] + input_shape[1:]
reshaped = tvm.relay.reshape(input_node, newshape=shape1)
# Permute dimensions of reshaped to produce permuted of shape
# [batch / prod(block_shape), input_shape[1], block_shape[0], ...,
# input_shape[M], block_shape[M-1], input_shape[M+1], ..., input_shape[N-1]]
axes = [M] + [axis for i in range(M) for axis in [M + i + 1, i]] + \
list(range(2 * M + 1, len(shape1)))
permuted = tvm.relay.transpose(reshaped, axes=axes)
# Reshape permuted to produce reshaped_permuted of shape
# [batch / prod(block_shape), input_shape[1] * block_shape[0], ...,
# input_shape[M] * block_shape[M-1], input_shape[M+1], ..., input_shape[N-1]]
shape2 = [0] + [-3] * M + [-2]
reshaped_permuted = tvm.relay.reshape(permuted, newshape=shape2)
# Crop the start and end of dimensions [1, ..., M] of reshaped_permuted according to crops
# to produce the output of shape:
# [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
# ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
# input_shape[M+1], ..., input_shape[N-1]]
reshaped_permuted_shape = _infer_out_shapes(reshaped_permuted, params)[0]
cropped = reshaped_permuted
for axis in range(1, M+1):
crop = crops[axis - 1]
if crop != [0, 0]:
indices = tvm.relay.arange(
crop[0],
reshaped_permuted_shape[axis] - crop[1],
dtype='int32'
)
cropped = tvm.relay.take(cropped, indices=indices, axis=axis)
return cropped
return _impl
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -1060,6 +1145,8 @@ _convert_map = { ...@@ -1060,6 +1145,8 @@ _convert_map = {
'Split' : _split(False), 'Split' : _split(False),
'SplitV' : _split(True), 'SplitV' : _split(True),
'Unpack' : _unpack(), 'Unpack' : _unpack(),
'SpaceToBatchND' : _space_to_batch_nd(),
'BatchToSpaceND' : _batch_to_space_nd(),
} }
def _LSTMBlockCell(): def _LSTMBlockCell():
......
...@@ -161,6 +161,7 @@ def is_gpu_available(): ...@@ -161,6 +161,7 @@ def is_gpu_available():
else: else:
return False return False
####################################################################### #######################################################################
# Pooling # Pooling
# ------- # -------
...@@ -221,6 +222,19 @@ def test_forward_pooling(): ...@@ -221,6 +222,19 @@ def test_forward_pooling():
dilation_rate=[1, 1], dilation_rate=[1, 1],
strides=[2, 1]) strides=[2, 1])
# Tests involving SpaceToBatchND
_test_pooling(input_shape=[1, 1, 2, 1],
window_shape=[1, 1],
padding='VALID',
pooling_type=pool_type,
dilation_rate=[1, 2])
_test_pooling(input_shape=[1, 2, 1],
window_shape=[1],
padding='VALID',
pooling_type=pool_type,
dilation_rate=[2])
####################################################################### #######################################################################
# Convolution # Convolution
# ----------- # -----------
...@@ -229,12 +243,8 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, ...@@ -229,12 +243,8 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
dilations, strides, padding, data_format): dilations, strides, padding, data_format):
""" One iteration of convolution with given shapes and attributes """ """ One iteration of convolution with given shapes and attributes """
total_size_1 = 1 total_size_1 = np.prod(tensor_in_sizes)
total_size_2 = 1 total_size_2 = np.prod(filter_in_sizes)
for s in tensor_in_sizes:
total_size_1 *= s
for s in filter_in_sizes:
total_size_2 *= s
# Initializes the input tensor with array containing incrementing # Initializes the input tensor with array containing incrementing
# numbers from 1. # numbers from 1.
data_array = [f * 1.0 for f in range(1, total_size_1 + 1)] data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
...@@ -253,6 +263,7 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, ...@@ -253,6 +263,7 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
nn_ops.conv2d(in_data, nn_ops.conv2d(in_data,
in_filter, in_filter,
strides=strides, strides=strides,
dilations=dilations,
padding=padding, padding=padding,
data_format=data_format) data_format=data_format)
...@@ -272,6 +283,116 @@ def test_forward_convolution(): ...@@ -272,6 +283,116 @@ def test_forward_convolution():
_test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC') _test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC')
####################################################################### #######################################################################
# SpaceToBatchND
# --------------
def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype='int32'):
data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
with tf.Graph().as_default():
in_data = tf.placeholder(shape=input_shape, dtype=dtype)
out = tf.space_to_batch_nd(in_data, block_shape, paddings)
compare_tf_with_tvm(data, in_data.name, out.name)
def test_forward_space_to_batch_nd():
# test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch-n-d
_test_space_to_batch_nd(
input_shape=[1, 2, 2, 1],
block_shape=[2, 2],
paddings=[[0, 0], [0, 0]]
)
_test_space_to_batch_nd(
input_shape=[1, 2, 2, 3],
block_shape=[2, 2],
paddings=[[0, 0], [0, 0]]
)
_test_space_to_batch_nd(
input_shape=[1, 4, 4, 1],
block_shape=[2, 2],
paddings=[[0, 0], [0, 0]]
)
_test_space_to_batch_nd(
input_shape=[2, 2, 4, 1],
block_shape=[2, 2],
paddings=[[0, 0], [2, 0]],
dtype='int64'
)
# pylint: disable=line-too-long
# https://github.com/tensorflow/tensorflow/blob/24f578/tensorflow/python/kernel_tests/spacetobatch_op_test.py
_test_space_to_batch_nd(
input_shape=[2, 3],
block_shape=[2],
paddings=[[1, 0]],
dtype='float32'
)
_test_space_to_batch_nd(
input_shape=[2, 3, 2],
block_shape=[2],
paddings=[[1, 0]],
dtype='float64'
)
#######################################################################
# BatchToSpaceND
# --------------
def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype='int32'):
data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
with tf.Graph().as_default():
in_data = tf.placeholder(shape=input_shape, dtype=dtype)
out = tf.batch_to_space_nd(in_data, block_shape, crops)
compare_tf_with_tvm(data, in_data.name, out.name)
def test_forward_batch_to_space_nd():
# test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d
_test_batch_to_space_nd(
input_shape=[4, 1, 1, 1],
block_shape=[2, 2],
crops=[[0, 0], [0, 0]]
)
_test_batch_to_space_nd(
input_shape=[4, 1, 1, 3],
block_shape=[2, 2],
crops=[[0, 0], [0, 0]]
)
_test_batch_to_space_nd(
input_shape=[4, 2, 2, 1],
block_shape=[2, 2],
crops=[[0, 0], [0, 0]]
)
_test_batch_to_space_nd(
input_shape=[8, 1, 3, 1],
block_shape=[2, 2],
crops=[[0, 0], [2, 0]],
dtype='int64'
)
# pylint: disable=line-too-long
# https://github.com/tensorflow/tensorflow/blob/24f578/tensorflow/python/kernel_tests/batchtospace_op_test.py
_test_batch_to_space_nd(
input_shape=[18, 2, 1, 2],
block_shape=[2, 3],
crops=[[1, 1], [0, 0]],
dtype='float32'
)
_test_batch_to_space_nd(
input_shape=[20, 5, 8, 7],
block_shape=[2, 2],
crops=[[1, 1], [1, 1]],
dtype='float64'
)
#######################################################################
# Reshape # Reshape
# ------- # -------
...@@ -1312,6 +1433,8 @@ if __name__ == '__main__': ...@@ -1312,6 +1433,8 @@ if __name__ == '__main__':
_test_forward_concat_v2() _test_forward_concat_v2()
test_forward_lrn() test_forward_lrn()
test_forward_l2_normalize() test_forward_l2_normalize()
test_forward_space_to_batch_nd()
test_forward_batch_to_space_nd()
# End to End # End to End
test_forward_inception_v3() test_forward_inception_v3()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment