Commit 2327bb9f by Ina Dobreva Committed by Tianqi Chen

[Relay][Frontend][TFlite] Add parses support for SLICE (#4502)

* [Relay][Frontend][TFlite] Add parses support for SLICE

* TFlite 1.13: convertor gives nonsense output when size[i]==-1
* TF parser: SLICE need fixing for size[i]==-1 -> gives wrong output
  bcs of indices

* Set end[i] = input_tensor_shape[i] as suggested in PR review

* Add another test to cover size=-1 case
parent 74d5cf46
......@@ -103,6 +103,7 @@ class OperatorConverter(object):
'TANH':self.convert_tanh,
'RELU':self.convert_relu,
'SPLIT': self.convert_split,
'SLICE': self.convert_slice,
'TRANSPOSE': self.convert_transpose,
'CAST': self.convert_cast,
'TILE': self.convert_tile,
......@@ -1152,6 +1153,35 @@ class OperatorConverter(object):
return out
def convert_slice(self, op):
"""Convert TFLite SLICE"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 3, "input tensors length should be == 3"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
begin = list(self.get_tensor_value(input_tensors[1]))
size = list(self.get_tensor_value(input_tensors[2]))
# strided_slice(Relay) needs the slice's end indices, not the size
end = size
input_tensor_shape = input_tensor.tensor.ShapeAsNumpy()
input_tensor_rank = len(input_tensor_shape)
for i in range(input_tensor_rank):
if size[i] == -1:
end[i] = input_tensor_shape[i]
else:
end[i] += begin[i]
out = _op.strided_slice(in_expr, begin, end)
return out
def convert_transpose(self, op):
"""transpose implementation."""
try:
......
......@@ -225,6 +225,26 @@ def test_forward_split():
_test_split((1, 3, 5, 6), -1, 3, 'float32')
#######################################################################
# slice
# -----
def _test_slice(data, begin, size):
""" One iteration of SLICE """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = array_ops.slice(in_data, begin, size)
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
def test_forward_slice():
""" SLICE """
_test_slice(np.arange(4, dtype=np.float32).reshape((4, )), begin=[0], size=[2])
_test_slice(np.arange(18, dtype=np.int32).reshape((3, 2, 3)), begin=[1, 0, 0], size=[1, 1, 3])
# tflite 1.13 outputs nonsense values if size[i] == -1
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
_test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 1], size=[-1, -1])
_test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4], size=[-1])
#######################################################################
# transpose
# ---------
......@@ -1408,6 +1428,7 @@ if __name__ == '__main__':
test_forward_reshape()
test_all_resize()
test_forward_squeeze()
test_forward_slice()
# NN
test_forward_convolution()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment