Commit 18592c8d by Ina Dobreva Committed by Zhi

[Relay][Frontend][TFlite] Add parses support for UNPACK tflite operator (#4447)

* use SPLIT & SQUEEZE = UNPACK as implemented in tensorflow parser
  Relay doesn't support UNPACK
* tflite 1.13: UNPACK doesn't work as exepcted -> copies the values from
  1st unpacked tensor to the other unpacks
* tflite 1.13: doesn't accept negative axis
parent 6ab15806
......@@ -86,6 +86,7 @@ class OperatorConverter(object):
'FULLY_CONNECTED': self.convert_fully_connected,
'PAD': self.convert_pad,
'PACK': self.convert_pack,
'UNPACK': self.convert_unpack,
'LOGISTIC': self.convert_logistic,
'TANH':self.convert_tanh,
'RELU':self.convert_relu,
......@@ -1239,6 +1240,50 @@ class OperatorConverter(object):
out = _op.concatenate(in_exprs_reshaped, pack_axis)
return out
def convert_unpack(self, op):
"""Convert TFLite unpack"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.UnpackOptions import UnpackOptions
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
assert op.BuiltinOptionsType() == BuiltinOptions.UnpackOptions
op_options = op.BuiltinOptions()
unpack_options = UnpackOptions()
unpack_options.Init(op_options.Bytes, op_options.Pos)
num_unpacks = unpack_options.Num()
unpack_axis = unpack_options.Axis()
# Relay doesn't support 'unpack' operator so we use 'split' & 'squeeze' instead.
# We have to do 'squeeze' along the split axis but Relay expects
# squeeze_axis to be either None or List.
squeeze_axis = None if unpack_axis == 0 else [unpack_axis]
# Relay doesn't like TupleWrapper of 1 element so we isolate the case of unpacking
# a tensor by an axis with len(axis) == 1. For reference see convert_split().
# Such unpacking will result in the same tensor so we omit 'split' and only squeeze
# along the axis of dim == 1.
if num_unpacks == 1:
squeezed = _op.squeeze(in_expr, axis=squeeze_axis)
if isinstance(squeezed, _expr.TupleWrapper):
squeezed = squeezed[0]
else:
splitted = _op.split(in_expr,
indices_or_sections=num_unpacks,
axis=unpack_axis)
squeezed = _expr.TupleWrapper(
_expr.Tuple([_op.squeeze(split_item, axis=squeeze_axis) \
for split_item in splitted]), len(splitted))
return squeezed
def convert_batch_to_space_nd(self, op):
"""batch_to_space_nd implementation."""
try:
......
......@@ -971,6 +971,27 @@ def test_forward_pack():
#######################################################################
# Unpack
# ------
def _test_unpack(data, axis, num_unpacks):
""" One iteration of UNPACK """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = gen_array_ops.unpack(in_data, num=num_unpacks, axis=axis, name='unpack')
out_names = ['out_' + str(n) + ':0' for n in range(num_unpacks)]
compare_tflite_with_tvm([data], 'Placeholder:0', [in_data], out, out_names=out_names)
def test_forward_unpack():
""" UNPACK """
_test_unpack(np.array(np.random.uniform(0, 5, (3, 1)), dtype=np.int32), axis=1, num_unpacks=1)
_test_unpack(np.array(np.random.uniform(0, 5, (3, 4)), dtype=np.float32), axis=0, num_unpacks=3)
# tflite 1.13 doesn't accept negative axis
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
_test_unpack(np.array(np.random.uniform(0, 5, (3, 6)), dtype=np.int32), axis=-2, num_unpacks=3)
_test_unpack(np.array(np.random.uniform(0, 5, (2, 3, 4)), dtype=np.int32), axis=-3, num_unpacks=2)
#######################################################################
# Logistic
# --------
......@@ -1280,6 +1301,7 @@ if __name__ == '__main__':
test_forward_concatenation()
test_forward_pad()
test_forward_pack()
test_forward_unpack()
test_forward_reshape()
test_all_resize()
test_forward_squeeze()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment