Commit 40d56b5d by Alexander Pivovarov Committed by Yao Wang

Add RESIZE operators to realy TFLite frontend (#3370)

parent 8703d9fb
...@@ -60,6 +60,8 @@ class OperatorConverter(object): ...@@ -60,6 +60,8 @@ class OperatorConverter(object):
'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d, 'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
'AVERAGE_POOL_2D': self.convert_average_pool2d, 'AVERAGE_POOL_2D': self.convert_average_pool2d,
'RESHAPE': self.convert_reshape, 'RESHAPE': self.convert_reshape,
'RESIZE_BILINEAR': self.convert_resize_bilinear,
'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor,
'SOFTMAX': self.convert_softmax, 'SOFTMAX': self.convert_softmax,
'SQUEEZE': self.convert_squeeze, 'SQUEEZE': self.convert_squeeze,
'MAX_POOL_2D': self.convert_max_pool2d, 'MAX_POOL_2D': self.convert_max_pool2d,
...@@ -225,6 +227,58 @@ class OperatorConverter(object): ...@@ -225,6 +227,58 @@ class OperatorConverter(object):
return out return out
def _convert_resize(self, method, op):
"""Generic method to Convert TFLite RESIZE operators"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.ResizeBilinearOptions import ResizeBilinearOptions
# ResizeNearestNeighborOptions was added in tflite v1.13
tflite_ver = 1120
if 'ResizeNearestNeighborOptions' in dir(BuiltinOptions):
from tflite.ResizeNearestNeighborOptions import ResizeNearestNeighborOptions
tflite_ver = 1130
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"
# images, 4-D Tensor with shape NHWC.
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
# size - 1-D int32 Tensor of 2 elements: new_height, new_width
target_size = tuple(self.get_tensor_value(input_tensors[1]))
# Options - align_corners (bool)
resize_options = None
align_corners = False
if method == "BILINEAR":
assert op.BuiltinOptionsType() == BuiltinOptions.ResizeBilinearOptions
resize_options = ResizeBilinearOptions()
elif tflite_ver >= 1130:
assert op.BuiltinOptionsType() == BuiltinOptions.ResizeNearestNeighborOptions
resize_options = ResizeNearestNeighborOptions()
if resize_options is not None:
op_options = op.BuiltinOptions()
resize_options.Init(op_options.Bytes, op_options.Pos)
align_corners = resize_options.AlignCorners()
# Use layout NHWC
out = _op.image.resize(in_expr, target_size, "NHWC", method, align_corners)
return out
def convert_resize_bilinear(self, op):
"""Convert TFLite RESIZE_BILINEAR"""
return self._convert_resize("BILINEAR", op)
def convert_resize_nearest_neighbor(self, op):
"""Convert TFLite RESIZE_NEAREST_NEIGHBOR"""
return self._convert_resize("NEAREST_NEIGHBOR", op)
def convert_logistic(self, op): def convert_logistic(self, op):
"""Convert TFLite LOGISTIC""" """Convert TFLite LOGISTIC"""
try: try:
......
...@@ -290,6 +290,37 @@ def test_forward_reshape(): ...@@ -290,6 +290,37 @@ def test_forward_reshape():
####################################################################### #######################################################################
# Resize
# ------
def _test_resize(tf_resize_op, data, align_corners):
""" One iteration of Resize """
assert len(data) == 2
# Test with tensor and constant
with tf.Graph().as_default():
images_tensor = array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')
size = ops.convert_to_tensor(data[1], dtype=data[1].dtype)
out_tensor = tf_resize_op(images=images_tensor, size=size, align_corners=align_corners)
compare_tflite_with_tvm([data[0]], ['in:0'], [images_tensor], [out_tensor])
def test_all_resize():
""" Resize """
data = [np.random.rand(1, 16, 16, 3).astype("float32"), np.array([8, 8], dtype=np.int32)]
### RESIZE_BILINEAR
_test_resize(tf.image.resize_bilinear, data, align_corners=False)
_test_resize(tf.image.resize_bilinear, data, align_corners=True)
### RESIZE_NEAREST_NEIGHBOR (was added in v1.13)
# According to topi resize.h
# Align corners not supported for nearest neighbour
from tflite.BuiltinOperator import BuiltinOperator
if 'RESIZE_NEAREST_NEIGHBOR' in dir(BuiltinOperator()):
_test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False)
#######################################################################
# Concatenation # Concatenation
# ------------- # -------------
...@@ -651,6 +682,7 @@ if __name__ == '__main__': ...@@ -651,6 +682,7 @@ if __name__ == '__main__':
test_forward_concatenation() test_forward_concatenation()
test_forward_pad() test_forward_pad()
test_forward_reshape() test_forward_reshape()
test_all_resize()
test_forward_squeeze() test_forward_squeeze()
# NN # NN
......
...@@ -384,7 +384,7 @@ inline Tensor resize_bilinear(const Tensor& input, ...@@ -384,7 +384,7 @@ inline Tensor resize_bilinear(const Tensor& input,
* \param shape Output shape to resize to. * \param shape Output shape to resize to.
* \param layout input layout * \param layout input layout
* \param align_corners To preserve centers of 4 corner pixels * \param align_corners To preserve centers of 4 corner pixels
* \param mode Angorithm to use (NEAREST_NEIGHBOR / BILINEAR) * \param mode Algorithm to use (NEAREST_NEIGHBOR / BILINEAR)
* \param name Name of the operation * \param name Name of the operation
* \param tag The tag to mark the operation * \param tag The tag to mark the operation
* *
......
...@@ -43,7 +43,7 @@ using namespace topi::image; ...@@ -43,7 +43,7 @@ using namespace topi::image;
* \param input The input tensor. * \param input The input tensor.
* \param shape Output shape to upsample. * \param shape Output shape to upsample.
* \param layout input layout * \param layout input layout
* \param mode Angorithm to use (NEAREST_NEIGHBOR / BILINEAR) * \param mode Algorithm to use (NEAREST_NEIGHBOR / BILINEAR)
* \param name Name of the operation * \param name Name of the operation
* \param tag The tag to mark the operation * \param tag The tag to mark the operation
* *
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment