Commit dee52466 by Neo Chien Committed by Jared Roesch

Implementation of tile for TFLite (#3814)

parent eef35a57
...@@ -82,7 +82,8 @@ class OperatorConverter(object): ...@@ -82,7 +82,8 @@ class OperatorConverter(object):
'PACK': self.convert_pack, 'PACK': self.convert_pack,
'LOGISTIC': self.convert_logistic, 'LOGISTIC': self.convert_logistic,
'SPLIT': self.convert_split, 'SPLIT': self.convert_split,
'TRANSPOSE': self.convert_transpose 'TRANSPOSE': self.convert_transpose,
'TILE': self.convert_tile
} }
def check_unsupported_ops(self): def check_unsupported_ops(self):
...@@ -769,6 +770,28 @@ class OperatorConverter(object): ...@@ -769,6 +770,28 @@ class OperatorConverter(object):
return out return out
def convert_tile(self, op):
"""tile implementation."""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"
input_tensor = input_tensors[0]
input_tensor_idx = input_tensor.tensor_idx
in_expr = self.get_expr(input_tensor_idx)
# reps (tuple of int) – The number of times repeating the tensor data.
reps = tuple(self.get_tensor_value(input_tensors[1]))
out = _op.tile(in_expr, reps)
return out
def convert_pool2d(self, op, pool_type): def convert_pool2d(self, op, pool_type):
"""pool2d implementation.""" """pool2d implementation."""
try: try:
......
...@@ -229,6 +229,26 @@ def test_forward_transpose(): ...@@ -229,6 +229,26 @@ def test_forward_transpose():
_test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2)) _test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2))
_test_forward_transpose((2, 3, 4, 5), ()) _test_forward_transpose((2, 3, 4, 5), ())
#######################################################################
# tile
# ---------
def _test_forward_tile(in_shape, reps, dtype):
data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = array_ops.tile(in_data, reps)
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
def test_forward_tile():
_test_forward_tile((2, ), (3, ), "int32")
_test_forward_tile((2, 2), (2, 3), "float32")
####################################################################### #######################################################################
# Pooling # Pooling
...@@ -856,6 +876,9 @@ if __name__ == '__main__': ...@@ -856,6 +876,9 @@ if __name__ == '__main__':
# Transpose # Transpose
test_forward_transpose() test_forward_transpose()
# Tile
test_forward_tile()
# Transforms # Transforms
test_forward_concatenation() test_forward_concatenation()
test_forward_pad() test_forward_pad()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment