Commit c9e96d9f by Hua Committed by Tianqi Chen

[Relay] Add Elemwise operator Sub, Divide, Power, Max, Min to tflite frontend. (#3357)

parent a698ad7f
......@@ -64,7 +64,12 @@ class OperatorConverter(object):
'MAX_POOL_2D': self.convert_max_pool2d,
'CONCATENATION': self.convert_concatenation,
'ADD': self.convert_add,
'SUB': self.convert_sub,
'MUL': self.convert_mul,
'DIV': self.convert_div,
'POW': self.convert_pow,
'MAXIMUM': self.convert_maximum,
'MINIMUM': self.convert_minimum,
'FULLY_CONNECTED': self.convert_fully_connected,
'PAD': self.convert_pad,
'LOGISTIC': self.convert_logistic,
......@@ -320,10 +325,27 @@ class OperatorConverter(object):
"""Convert TFLite ADD"""
return self._convert_elemwise(_op.add, op)
def convert_sub(self, op):
"""Convert TFLite SUB"""
return self._convert_elemwise(_op.subtract, op)
def convert_mul(self, op):
"""Convert TFLite MUL"""
return self._convert_elemwise(_op.multiply, op)
def convert_div(self, op):
"""Convert TFLite DIV"""
return self._convert_elemwise(_op.divide, op)
def convert_pow(self, op):
return self._convert_elemwise(_op.power, op)
def convert_maximum(self, op):
return self._convert_elemwise(_op.maximum, op)
def convert_minimum(self, op):
return self._convert_elemwise(_op.minimum, op)
def convert_fully_connected(self, op):
"""Convert TFLite fully connected"""
try:
......
......@@ -320,7 +320,7 @@ def _test_elemwise(math_op, data):
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'),
array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')]
out = math_ops.add(in_data[0], in_data[1])
out = math_op(in_data[0], in_data[1])
compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])
# Test with tensor and constant
......@@ -338,35 +338,66 @@ def _test_add(data):
""" One iteration of add """
return _test_elemwise(math_ops.add, data)
#######################################################################
# Subtract
# ---
def test_forward_add():
""" Add """
_test_add([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3))])
_test_add([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))])
_test_add([np.arange(3.0, dtype=np.float32).reshape((1, 3)),
np.arange(3.0, dtype=np.float32).reshape((1, 3))])
def _test_sub(data):
""" One iteration of subtract """
return _test_elemwise(math_ops.subtract, data)
#######################################################################
# Mul
# ---
def _test_mul(data):
""" One iteration of mul """
return _test_elemwise(math_ops.multiply, data)
#######################################################################
# Divide
# ---
def _test_div(data):
""" One iteration of divide """
return _test_elemwise(math_ops.divide, data)
#######################################################################
# Power
# ---
def _test_pow(data):
""" One iteration of power """
return _test_elemwise(math_ops.pow, data)
#######################################################################
# Maximum
# ---
def _test_maximum(data):
""" One iteration of maximum """
return _test_elemwise(math_ops.maximum, data)
#######################################################################
# Minimum
# ---
def _test_minimum(data):
""" One iteration of minimum """
return _test_elemwise(math_ops.minimum, data)
def test_forward_mul():
""" Mul """
_test_mul([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
def _test_forward_elemwise(testop):
""" Elewise"""
testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3))])
_test_mul([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))])
_test_mul([np.arange(3.0, dtype=np.float32).reshape((1, 3)),
testop([np.arange(3.0, dtype=np.float32).reshape((1, 3)),
np.arange(3.0, dtype=np.float32).reshape((1, 3))])
def test_all_elemwise():
_test_forward_elemwise(_test_add)
_test_forward_elemwise(_test_sub)
_test_forward_elemwise(_test_mul)
_test_forward_elemwise(_test_div)
_test_forward_elemwise(_test_pow)
_test_forward_elemwise(_test_maximum)
_test_forward_elemwise(_test_minimum)
#######################################################################
# Squeeze
......@@ -584,9 +615,8 @@ if __name__ == '__main__':
test_forward_softmax()
test_forward_fully_connected()
# Math
test_forward_add()
test_forward_mul()
# Elemwise
test_all_elemwise()
# End to End
test_forward_mobilenet_v1()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment