Commit b074314b by tristan-arm Committed by Tianqi Chen

Add Pack operator to TFLite (#3521)

parent 273c0280
......@@ -79,6 +79,7 @@ class OperatorConverter(object):
'REDUCE_PROD': self._convert_reduce_prod,
'FULLY_CONNECTED': self.convert_fully_connected,
'PAD': self.convert_pad,
'PACK': self.convert_pack,
'LOGISTIC': self.convert_logistic,
}
......@@ -789,6 +790,33 @@ class OperatorConverter(object):
out = _op.nn.pad(in_expr, paddings)
return out
def convert_pack(self, op):
"""Convert TFLite pack"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.PackOptions import PackOptions
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) >= 1, "input tensors should greater than 1"
in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors]
output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors should be 1"
assert op.BuiltinOptionsType() == BuiltinOptions.PackOptions
op_options = op.BuiltinOptions()
pack_options = PackOptions()
pack_options.Init(op_options.Bytes, op_options.Pos)
pack_axis = pack_options.Axis()
in_exprs_reshaped = [_op.expand_dims(i, axis=pack_axis, num_newaxis=1) for i in in_exprs]
out = _op.concatenate(in_exprs_reshaped, pack_axis)
return out
def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))
......
......@@ -582,6 +582,41 @@ def test_forward_pad():
#######################################################################
# Pack
# -------------
def _test_pack(data, axis):
""" One iteration of pack """
assert len(data) >= 1
with tf.Graph().as_default():
in_data = [
array_ops.placeholder(shape=tensor.shape, dtype=tensor.dtype, name="in_{}".format(idx))
for idx, tensor in enumerate(data)]
out = array_ops.pack(in_data, axis=axis)
name = ["in_{}:0".format(idx) for idx in range(len(data))]
compare_tflite_with_tvm(data, name, in_data, [out])
def test_forward_pack():
""" Pack """
_test_pack(
[np.arange(6).reshape((1, 2, 1, 3)),
np.arange(6).reshape((1, 2, 1, 3))], 1)
_test_pack(
[np.arange(6).reshape((3, 2)),
np.arange(6).reshape((3, 2))], 1)
_test_pack(
[np.arange(6).reshape((2, 1, 1, 3)),
np.arange(6).reshape((2, 1, 1, 3)),
np.arange(6).reshape((2, 1, 1, 3))], 1)
#######################################################################
# Logistic
# --------
......@@ -750,6 +785,7 @@ if __name__ == '__main__':
# Transforms
test_forward_concatenation()
test_forward_pad()
test_forward_pack()
test_forward_reshape()
test_all_resize()
test_forward_squeeze()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment