Commit 084e338e by Alexander Pivovarov Committed by Yao Wang

Add MUL operator to relay tflite frontend (#3304)

parent 98a91af9
...@@ -64,6 +64,7 @@ class OperatorConverter(object): ...@@ -64,6 +64,7 @@ class OperatorConverter(object):
'MAX_POOL_2D': self.convert_max_pool2d, 'MAX_POOL_2D': self.convert_max_pool2d,
'CONCATENATION': self.convert_concatenation, 'CONCATENATION': self.convert_concatenation,
'ADD': self.convert_add, 'ADD': self.convert_add,
'MUL': self.convert_mul,
'FULLY_CONNECTED': self.convert_fully_connected, 'FULLY_CONNECTED': self.convert_fully_connected,
} }
...@@ -267,8 +268,8 @@ class OperatorConverter(object): ...@@ -267,8 +268,8 @@ class OperatorConverter(object):
out = self.convert_fused_activation_function(out, fused_activation_fn) out = self.convert_fused_activation_function(out, fused_activation_fn)
return out return out
def convert_add(self, op): def _convert_elemwise(self, relay_op, op):
"""Convert TFLite add""" """Generic method to Convert TFLite elemwise"""
try: try:
from tflite.Operator import Operator from tflite.Operator import Operator
except ImportError: except ImportError:
...@@ -283,19 +284,26 @@ class OperatorConverter(object): ...@@ -283,19 +284,26 @@ class OperatorConverter(object):
rhs_tensor = input_tensors[1] rhs_tensor = input_tensors[1]
if self.has_expr(rhs_tensor.tensor_idx): if self.has_expr(rhs_tensor.tensor_idx):
# In most cases, we can assume that TOCO fuses ADD operators # In most cases, we can assume that TOCO fuses elemwise operators
# with constants - it means both will be tensors. # with constants - it means both will be tensors.
rhs_expr = self.get_expr(rhs_tensor.tensor_idx) rhs_expr = self.get_expr(rhs_tensor.tensor_idx)
else: else:
# However, in some corner cases, the ADD operator is not fused, # However, in some corner cases, the elemwise operator is not fused,
# we can receive as constant. # we can receive as constant.
rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type()) rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type())
rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor), rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
dtype=rhs_type_str) dtype=rhs_type_str)
out = relay_op(lhs_expr, rhs_expr)
out = _op.add(lhs_expr, rhs_expr)
return out return out
def convert_add(self, op):
"""Convert TFLite ADD"""
return self._convert_elemwise(_op.add, op)
def convert_mul(self, op):
"""Convert TFLite MUL"""
return self._convert_elemwise(_op.multiply, op)
def convert_fully_connected(self, op): def convert_fully_connected(self, op):
"""Convert TFLite fully connected""" """Convert TFLite fully connected"""
try: try:
......
...@@ -24,7 +24,6 @@ from __future__ import print_function ...@@ -24,7 +24,6 @@ from __future__ import print_function
import numpy as np import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.contrib import util
import tensorflow as tf import tensorflow as tf
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
...@@ -144,8 +143,6 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors, ...@@ -144,8 +143,6 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
for i in range(len(tflite_output)): for i in range(len(tflite_output)):
tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
sess.close()
####################################################################### #######################################################################
# Pooling # Pooling
...@@ -311,10 +308,10 @@ def test_forward_concatenation(): ...@@ -311,10 +308,10 @@ def test_forward_concatenation():
####################################################################### #######################################################################
# Add # Element-wise
# --- # ---
def _test_add(data): def _test_elemwise(math_op, data):
""" One iteration of add """ """ One iteration of add """
assert len(data) == 2 assert len(data) == 2
...@@ -329,10 +326,19 @@ def _test_add(data): ...@@ -329,10 +326,19 @@ def _test_add(data):
# Test with tensor and constant # Test with tensor and constant
with tf.Graph().as_default(): with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')] in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
out = math_ops.add(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype)) out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out]) compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
#######################################################################
# Add
# ---
def _test_add(data):
""" One iteration of add """
return _test_elemwise(math_ops.add, data)
def test_forward_add(): def test_forward_add():
""" Add """ """ Add """
_test_add([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)), _test_add([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
...@@ -344,6 +350,25 @@ def test_forward_add(): ...@@ -344,6 +350,25 @@ def test_forward_add():
####################################################################### #######################################################################
# Mul
# ---
def _test_mul(data):
""" One iteration of mul """
return _test_elemwise(math_ops.multiply, data)
def test_forward_mul():
""" Mul """
_test_mul([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3))])
_test_mul([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))])
_test_mul([np.arange(3.0, dtype=np.float32).reshape((1, 3)),
np.arange(3.0, dtype=np.float32).reshape((1, 3))])
#######################################################################
# Squeeze # Squeeze
# ------- # -------
...@@ -514,6 +539,7 @@ if __name__ == '__main__': ...@@ -514,6 +539,7 @@ if __name__ == '__main__':
# Math # Math
test_forward_add() test_forward_add()
test_forward_mul()
# End to End # End to End
test_forward_mobilenet_v1() test_forward_mobilenet_v1()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment