Unverified Commit 79ce87f8 by Ina Dobreva Committed by GitHub

[Relay][Frontend][TFLite] Add parser support for logical operators (#4642)

* [Relay][Frontend][TFLite] Add parser support for logical operators

* Add parser support for logical_and, logical_or
* Add boolean dtype as a valid tensor type
* BOOLEAN dtype is supported only from tf 1.15
  so logical ops work only in that and newer versions
* Logical_not is ommited since tflite can't convert it -->
  throws errors for addv2

* Add TFLite vesion check in tests for logical ops

* Check is added because of boolean dtype lack of support
parent 23f3988b
......@@ -117,6 +117,8 @@ class OperatorConverter(object):
'PRELU': self.convert_prelu,
'TRANSPOSE_CONV': self.convert_transpose_conv,
'SQUARED_DIFFERENCE': self.convert_squared_difference,
'LOGICAL_AND': self.convert_logical_and,
'LOGICAL_OR': self.convert_logical_or,
}
def check_unsupported_ops(self):
......@@ -222,6 +224,9 @@ class OperatorConverter(object):
if tensor_wrapper.tensor.Type() == TensorType.INT64:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int64).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
if tensor_wrapper.tensor.Type() == TensorType.BOOL:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.bool_).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
raise NotImplementedError("Tensor type {} is currently not supported"
.format(str(tensor_wrapper.tensor.Type())))
......@@ -240,6 +245,8 @@ class OperatorConverter(object):
return "int32"
if tensor_type == TensorType.INT64:
return "int64"
if tensor_type == TensorType.BOOL:
return "bool"
raise NotImplementedError("Tensor type {} is currently not supported"
.format(str(tensor_type)))
......@@ -792,6 +799,33 @@ class OperatorConverter(object):
'TFlite quantized NOT_EQUAL operator is not supported yet.')
return self._convert_elemwise(_op.not_equal, op)
def _convert_logical_binary(self, relay_op, op):
"""Generic method to convert logical binary ops"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"
lhs_tensor = input_tensors[0]
lhs_expr = self.get_expr(lhs_tensor.tensor_idx)
rhs_tensor = input_tensors[1]
rhs_expr = self.get_expr(rhs_tensor.tensor_idx)
out = relay_op(lhs_expr, rhs_expr)
return out
def convert_logical_and(self, op):
"""Convert tflite LOGICAL_AND"""
return self._convert_logical_binary(_op.logical_and, op)
def convert_logical_or(self, op):
"""Convert tflite LOGICAL_OR"""
return self._convert_logical_binary(_op.logical_or, op)
def convert_zeros_like(self, op):
"""Convert TFLite ZEROS LIKE"""
try:
......
......@@ -966,6 +966,34 @@ def test_all_elemwise():
_test_forward_elemwise(_test_not_equal)
#######################################################################
# Logical operators
# -----------------
def _test_logical_binary(logical_bin_op, data):
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype='bool', name='in_0'),
array_ops.placeholder(shape=data[1].shape, dtype='bool', name='in_1')]
out = logical_bin_op(in_data[0], in_data[1], name='out')
compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])
def _test_forward_logical_and(data):
""" One iteration of logical and """
return _test_logical_binary(math_ops.logical_and, data)
def _test_forward_logical_or(data):
""" One iteration of logical or """
return _test_logical_binary(math_ops.logical_or, data)
def test_all_logical():
data = [np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool'),
np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool')]
# boolean dtype is not supported by older versions than TFLite 1.15.0
if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
_test_forward_logical_and(data)
_test_forward_logical_or(data)
#######################################################################
# Zeros like
# --------
......@@ -1530,6 +1558,9 @@ if __name__ == '__main__':
# Reduce
test_all_reduce()
# Logical
test_all_logical()
# End to End
test_forward_mobilenet_v1()
test_forward_mobilenet_v2()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment