Commit b1d93ecc by Wang Yucheng Committed by Zhao Wu

[Relay][Frontend][TFLite] Add parser support for squared difference (#4652)

* [Relay][Frontend][TFLite] Add parser support for squared difference

* fix some error

* fix exp_type

* add comment
parent e3016371
...@@ -111,6 +111,7 @@ class OperatorConverter(object): ...@@ -111,6 +111,7 @@ class OperatorConverter(object):
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd, 'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
'PRELU': self.convert_prelu, 'PRELU': self.convert_prelu,
'TRANSPOSE_CONV': self.convert_transpose_conv, 'TRANSPOSE_CONV': self.convert_transpose_conv,
'SQUARED_DIFFERENCE': self.convert_squared_difference,
} }
def check_unsupported_ops(self): def check_unsupported_ops(self):
...@@ -735,6 +736,17 @@ class OperatorConverter(object): ...@@ -735,6 +736,17 @@ class OperatorConverter(object):
'TFlite quantized greater operator is not supported yet.') 'TFlite quantized greater operator is not supported yet.')
return self._convert_elemwise(_op.greater, op) return self._convert_elemwise(_op.greater, op)
def convert_squared_difference(self, op):
# Check if the input tensor is quantized, call QNN op
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized squared difference operator is not supported yet.')
difference = self._convert_elemwise(_op.subtract, op)
# _convert_elemwise has guaranteed only have one output tensor
exp_type = self.get_tensor_type_str(self.get_output_tensors(op)[0].tensor.Type())
out = _op.power(difference, relay.const(2, exp_type))
return out
def convert_zeros_like(self, op): def convert_zeros_like(self, op):
"""Convert TFLite ZEROS LIKE""" """Convert TFLite ZEROS LIKE"""
try: try:
......
...@@ -864,6 +864,14 @@ def _test_greater(data): ...@@ -864,6 +864,14 @@ def _test_greater(data):
""" One iteration of greater """ """ One iteration of greater """
return _test_elemwise(math_ops.greater, data) return _test_elemwise(math_ops.greater, data)
#######################################################################
# Squared_difference
# ------------------
def _test_squared_difference(data):
""" One iteration of squared difference """
return _test_elemwise(math_ops.squared_difference, data)
def _test_forward_elemwise(testop): def _test_forward_elemwise(testop):
""" Elewise""" """ Elewise"""
testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)), testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
...@@ -906,6 +914,7 @@ def test_all_elemwise(): ...@@ -906,6 +914,7 @@ def test_all_elemwise():
_test_forward_elemwise(_test_maximum) _test_forward_elemwise(_test_maximum)
_test_forward_elemwise(_test_minimum) _test_forward_elemwise(_test_minimum)
_test_forward_elemwise(_test_greater) _test_forward_elemwise(_test_greater)
_test_forward_elemwise(_test_squared_difference)
####################################################################### #######################################################################
# Zeros like # Zeros like
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment