Unverified Commit 5eab475d by Samuel Committed by GitHub

[TFLITE]DepthToSpace and SpaceToDepth support (#5041)

* [TFLITE]DepthToSpace and SpaceToDepth op parser support

* DepthToSpace and SpaceToDepth testcases

* Review comments fixed
parent 06bb17ec
......@@ -70,6 +70,7 @@ class OperatorConverter(object):
'CONCATENATION': self.convert_concatenation,
'CONV_2D': self.convert_conv2d,
'COS': self.convert_cos,
'DEPTH_TO_SPACE': self.convert_depth_to_space,
'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
'DETECTION_POSTPROCESS': self.convert_detection_postprocess,
'DIV': self.convert_div,
......@@ -116,6 +117,7 @@ class OperatorConverter(object):
'SLICE': self.convert_slice,
'SOFTMAX': self.convert_softmax,
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
'SPACE_TO_DEPTH': self.convert_space_to_depth,
'SPLIT': self.convert_split,
'SQRT': self.convert_sqrt,
'SQUARE': self.convert_square,
......@@ -1896,6 +1898,56 @@ class OperatorConverter(object):
return reshaped_permuted_reshaped_padded
def convert_depth_to_space(self, op):
"""Convert TFLite DEPTH_TO_SPACE"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.DepthToSpaceOptions import DepthToSpaceOptions
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
assert op.BuiltinOptionsType() == BuiltinOptions.DepthToSpaceOptions
op_options = op.BuiltinOptions()
depth_to_space_options = DepthToSpaceOptions()
depth_to_space_options.Init(op_options.Bytes, op_options.Pos)
block_size = depth_to_space_options.BlockSize()
out = _op.nn.depth_to_space(in_expr, block_size, layout='NHWC')
return out
def convert_space_to_depth(self, op):
"""Convert TFLite SPACE_TO_DEPTH"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.SpaceToDepthOptions import SpaceToDepthOptions
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
assert op.BuiltinOptionsType() == BuiltinOptions.SpaceToDepthOptions
op_options = op.BuiltinOptions()
space_to_depth_options = SpaceToDepthOptions()
space_to_depth_options.Init(op_options.Bytes, op_options.Pos)
block_size = space_to_depth_options.BlockSize()
out = _op.nn.space_to_depth(in_expr, block_size, layout='NHWC')
return out
def convert_prelu(self, op):
"""Convert TFLite PReLU"""
try:
......
......@@ -1449,6 +1449,40 @@ def test_forward_prelu():
_test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"), np.full((1, 1, 3), 0.2, dtype="float32"))
#######################################################################
# DepthToSpace
# ------------
def _test_depthtospace(data, block_size):
""" One iteration of depth_to_space operation with given data and block size """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = array_ops.depth_to_space(in_data, block_size)
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
def test_forward_depthtospace():
# DEPTH_TO_SPACE comes with TFLite >= 1.15.0 fbs schema
if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
_test_depthtospace(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2)
_test_depthtospace(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4)
#######################################################################
# SpaceToDepth
# ------------
def _test_spacetodepth(data, block_size):
""" One iteration of space_to_depth operation with given data and block size """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = array_ops.space_to_depth(in_data, block_size)
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
def test_forward_spacetodepth():
_test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2)
_test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4)
#######################################################################
# Fully Connected
# ---------------
......@@ -1741,6 +1775,8 @@ if __name__ == '__main__':
test_all_resize()
test_forward_squeeze()
test_forward_slice()
test_forward_depthtospace()
test_forward_spacetodepth()
# NN
test_forward_convolution()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment