Commit 19eb829e by Ramana Radhakrishnan Committed by Yao Wang

Add support for Tflite operator SPLIT (#3520)

* [RFC] Initial support for Tflite operator SPLIT

This patch adds initial support for the tflite operator split. However
I am not yet sure how to handle the axis parameter for the split
operator and support it in the test infrastructure. Putting this up for
an initial review and comment.

The split operator in tflite according to
https://www.tensorflow.org/lite/guide/ops_compatibility

appears to take num_or_size_split as a 0D tensor.

I also note that tflite.split is one of the few operators that returns
multiple outputs and thus the helper routines in the tests needed some
massaging to make this work.

@apivarov , could you please review this ?

Thanks,
Ramana

* Fix the axis parameter

Add more tests

* Address review comments

* Try out frozen_gene's suggestion

* Handle split of 1 element

* int32 is only supported in tflite 1.14, let's check that version here.

* Keep this at python3.5

* Add packaging as a python package to be installed
parent 443b5b46
......@@ -21,5 +21,5 @@ set -u
set -o pipefail
# install libraries for python package on ubuntu
pip2 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typing antlr4-python2-runtime attrs
pip3 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typed_ast pytest mypy orderedset antlr4-python3-runtime attrs requests Pillow
pip2 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typing antlr4-python2-runtime attrs packaging
pip3 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typed_ast pytest mypy orderedset antlr4-python3-runtime attrs requests Pillow packaging
......@@ -81,6 +81,7 @@ class OperatorConverter(object):
'PAD': self.convert_pad,
'PACK': self.convert_pack,
'LOGISTIC': self.convert_logistic,
'SPLIT': self.convert_split
}
def check_unsupported_ops(self):
......@@ -705,6 +706,43 @@ class OperatorConverter(object):
return out
def convert_split(self, op):
"""split implementation."""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.SplitOptions import SplitOptions
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be == 2"
axis_tensor = input_tensors[0]
split_axis = self.get_tensor_value(axis_tensor)
input_tensor = input_tensors[1]
input_tensor_idx = input_tensor.tensor_idx
assert op.BuiltinOptionsType() == BuiltinOptions.SplitOptions
op_options = op.BuiltinOptions()
split_options = SplitOptions()
split_options.Init(op_options.Bytes, op_options.Pos)
num_splits = split_options.NumSplits()
in_expr = self.get_expr(input_tensor_idx)
out = _op.split(in_expr, num_splits, axis=int(split_axis))
# Relay does not like a TupleWrapper of 1 element, further this
# only shows up with tf1.13 if we use a split with num_splits==1.
# In tf 1.14 this doesn't appear as it is automatically a reshape
# operation.
if isinstance(out, _expr.TupleWrapper):
if out.size == 1:
out = out[0]
return out
def convert_pool2d(self, op, pool_type):
"""pool2d implementation."""
try:
......
......@@ -38,6 +38,7 @@ except ImportError:
from tensorflow.contrib import lite as interpreter_wrapper
import tvm.relay.testing.tf as tf_testing
from packaging import version as package_version
#######################################################################
# Generic run functions for TVM & TFLite
......@@ -120,10 +121,11 @@ def run_tflite_graph(tflite_model_buf, input_data):
def compare_tflite_with_tvm(in_data, in_name, input_tensors,
output_tensors, init_global_variables=False):
output_tensors, init_global_variables=False, out_names=None):
"""Generic function to generate and compare TFLite and TVM output"""
in_data = convert_to_list(in_data)
in_name = convert_to_list(in_name)
out_names = convert_to_list(out_names)
in_node = [0] * len(in_name)
for i in range(len(in_name)):
in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i]
......@@ -143,7 +145,8 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
print("Skip because %s is not enabled" % device)
continue
tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device)
tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device,
num_output=len(out_names), out_names=out_names)
for i in range(len(tflite_output)):
tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
......@@ -161,6 +164,42 @@ def with_fused_activation_function(input_tensor, fn_name):
return math_ops.tanh(input_tensor)
raise AssertionError("Unknown fused_activation_function {}".format(fn_name))
def _test_split(in_shape, axis, num_Splits, dtype):
'''internal split tester taking as parameters in_shape, number of tensors to split into
and dtype (data type)'''
np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=in_shape, dtype=dtype)
out = array_ops.split(in_data, num_Splits, axis=axis)
out_names = ['out_' + str(n) + ':0' for n in range(num_Splits)]
compare_tflite_with_tvm([np_data], ['Placeholder:0'], [in_data], out,
out_names=out_names)
def test_forward_split():
'''test split layer'''
# rank 1
_test_split((3,), 0, 1, 'float32')
_test_split((3,), 0, 3, 'float32')
_test_split((6,), 0, 3, 'float32')
# rank 2
_test_split((6, 2), 0, 3, 'float32')
_test_split((2, 6), 1, 6, 'float32')
# rank 3
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
_test_split((6, 2, 4), 0, 2, 'int32')
_test_split((2, 6, 4), 1, 3, 'float32')
_test_split((2, 4, 6), 2, 1, 'float32')
# rank 4
_test_split((6, 1, 3, 5), 0, 3, 'float32')
_test_split((1, 6, 3, 5), 1, 3, 'float32')
_test_split((1, 3, 6, 5), 2, 3, 'float32')
_test_split((1, 3, 5, 6), 3, 3, 'float32')
# split along negative axis
_test_split((6, 1, 3, 5), -4, 3, 'float32')
_test_split((1, 6, 3, 5), -3, 3, 'float32')
_test_split((1, 3, 6, 5), -2, 3, 'float32')
_test_split((1, 3, 5, 6), -1, 3, 'float32')
#######################################################################
# Pooling
......@@ -782,6 +821,8 @@ def test_forward_ssd_mobilenet_v1():
# Main
# ----
if __name__ == '__main__':
# Split
test_forward_split()
# Transforms
test_forward_concatenation()
test_forward_pad()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment