Unverified Commit 1c8e5b93 by Samuel Committed by GitHub

[TFLITE]FLOOR_MOD & FLOOR_DIV support (#4971)

* TFLite Floor_div & floor_mod parsing code

* Review comment updated
parent 51af454a
......@@ -123,6 +123,8 @@ class OperatorConverter(object):
'DETECTION_POSTPROCESS': self.convert_detection_postprocess,
'SQUARE': self.convert_square,
'L2_NORMALIZATION': self.convert_l2_normalization,
'FLOOR_DIV': self.convert_floor_div,
'FLOOR_MOD': self.convert_floor_mod,
}
def check_unsupported_ops(self):
......@@ -1579,6 +1581,20 @@ class OperatorConverter(object):
out = _op.nn.pad(in_expr, pad_width=paddings, pad_value=pad_value)
return out
def convert_floor_div(self, op):
"""Convert TFLite FLOOR_DIV"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized FLOOR DIV operator is not supported yet.')
return self._convert_elemwise(_op.floor_divide, op)
def convert_floor_mod(self, op):
"""Convert TFLite FLOOR_MOD"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized FLOOR MOD operator is not supported yet.')
return self._convert_elemwise(_op.floor_mod, op)
def convert_mirror_pad(self, op):
"""Convert TFLite MIRROR_PAD"""
try:
......
......@@ -943,6 +943,22 @@ def _test_squared_difference(data):
""" One iteration of squared difference """
return _test_elemwise(math_ops.squared_difference, data)
#######################################################################
# Floor_divide
# ------------
def _test_floor_divide(data):
""" One iteration of floor_div"""
return _test_elemwise(math_ops.floordiv, data)
#######################################################################
# Floor_mod
# ---------
def _test_floor_mod(data):
""" One iteration of floor_mod"""
return _test_elemwise(math_ops.floormod, data)
def _test_forward_elemwise(testop):
""" Elewise"""
testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
......@@ -991,6 +1007,9 @@ def test_all_elemwise():
_test_forward_elemwise(_test_less_equal)
_test_forward_elemwise(_test_equal)
_test_forward_elemwise(_test_not_equal)
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
_test_forward_elemwise(_test_floor_divide)
_test_forward_elemwise(_test_floor_mod)
#######################################################################
# Logical operators
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment