Unverified Commit a9505365 by Samuel Committed by GitHub

[TFLITE][FRONTEND]Reduce_any op parsing support (#4926)

* [TFLITE][FRONTEND]Reduce_any op parsing support

* Testcase check added to run in tf version above 1.14.0 & review comments

* Review comment, checked updated to 1.15
parent 5e7fbaef
......@@ -102,6 +102,7 @@ class OperatorConverter(object):
'PAD': self.convert_pad,
'POW': self.convert_pow,
'PRELU': self.convert_prelu,
'REDUCE_ANY': self._convert_reduce_any,
'REDUCE_MAX': self._convert_reduce_max,
'REDUCE_MIN': self._convert_reduce_min,
'REDUCE_PROD': self._convert_reduce_prod,
......@@ -1088,6 +1089,9 @@ class OperatorConverter(object):
def _convert_reduce_sum(self, op):
return self._convert_reduce(_op.reduce.sum, op)
def _convert_reduce_any(self, op):
return self._convert_reduce(_op.reduce.any, op)
def convert_fully_connected(self, op):
"""Convert TFLite fully connected"""
try:
......
......@@ -1154,11 +1154,24 @@ def _test_reduce_sum(data, keep_dims=None):
""" One iteration of reduce_sum """
return _test_reduce(math_ops.reduce_sum, data, keep_dims)
#######################################################################
# Reduce_any
# ----------
def _test_reduce_any(data, keep_dims=None):
""" One iteration of reduce_any """
return _test_reduce(math_ops.reduce_any, data, keep_dims)
def _test_forward_reduce(testop):
def _test_forward_reduce(testop, dtype="float32"):
""" Reduce """
data0 = [np.random.rand(16, 16, 16, 16).astype("float32"), None]
data1 = [np.random.rand(16, 16, 16, 16).astype("float32"), np.array([1, 2], dtype=np.int32)]
if dtype == 'bool':
data0 = [np.random.choice(a=[False, True], size=(16, 16, 16, 16)).astype(dtype),
None]
data1 = [np.random.choice(a=[False, True], size=(16, 16, 16, 16)).astype(dtype),
np.array([1, 2], dtype=np.int32)]
else:
data0 = [np.random.rand(16, 16, 16, 16).astype(dtype), None]
data1 = [np.random.rand(16, 16, 16, 16).astype(dtype), np.array([1, 2], dtype=np.int32)]
testop(data0)
testop(data0, keep_dims=False)
testop(data0, keep_dims=True)
......@@ -1179,6 +1192,8 @@ def test_all_reduce():
_test_forward_reduce_quantized(_test_reduce_mean)
_test_forward_reduce(_test_reduce_prod)
_test_forward_reduce(_test_reduce_sum)
if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
_test_forward_reduce(_test_reduce_any, dtype="bool")
#######################################################################
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment