Unverified Commit 095f565f by Samuel Committed by GitHub

[FRONTEND][TFLITE]Logical not op support (#5475)

parent 90b08f5e
......@@ -94,6 +94,7 @@ class OperatorConverter(object):
'LOCAL_RESPONSE_NORMALIZATION': self.convert_lrn,
'LOG': self.convert_log,
'LOGICAL_AND': self.convert_logical_and,
'LOGICAL_NOT': self.convert_logical_not,
'LOGICAL_OR': self.convert_logical_or,
'LOGISTIC': self.convert_logistic,
'MAX_POOL_2D': self.convert_max_pool2d,
......@@ -992,6 +993,16 @@ class OperatorConverter(object):
"""Convert tflite LOGICAL_OR"""
return self._convert_logical_binary(_op.logical_or, op)
def convert_logical_not(self, op):
"""Convert tflite LOGICAL_NOT"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
data = self.get_expr(input_tensors[0].tensor_idx)
out = _op.logical_not(data)
return out
def convert_gather(self, op):
"""Method to Convert TFLite GATHER operator"""
try:
......
......@@ -1183,7 +1183,12 @@ def _test_logical_binary(logical_bin_op, data):
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype='bool', name='in_0'),
array_ops.placeholder(shape=data[1].shape, dtype='bool', name='in_1')]
out = logical_bin_op(in_data[0], in_data[1], name='out')
if logical_bin_op == math_ops.logical_not:
out = math_ops.logical_or(in_data[0], in_data[1], name='out1')
out = logical_bin_op(out, name='out')
else:
out = logical_bin_op(in_data[0], in_data[1], name='out')
compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])
def _test_forward_logical_and(data):
......@@ -1194,6 +1199,10 @@ def _test_forward_logical_or(data):
""" One iteration of logical or """
return _test_logical_binary(math_ops.logical_or, data)
def _test_forward_logical_not(data):
""" One iteration of logical not """
return _test_logical_binary(math_ops.logical_not, data)
def test_all_logical():
data = [np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool'),
np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool')]
......@@ -1201,6 +1210,7 @@ def test_all_logical():
if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
_test_forward_logical_and(data)
_test_forward_logical_or(data)
_test_forward_logical_not(data)
#######################################################################
# Zeros like
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment