Commit 084e338e by Alexander Pivovarov Committed by Yao Wang

Add MUL operator to relay tflite frontend (#3304)

parent 98a91af9
......@@ -64,6 +64,7 @@ class OperatorConverter(object):
'MAX_POOL_2D': self.convert_max_pool2d,
'CONCATENATION': self.convert_concatenation,
'ADD': self.convert_add,
'MUL': self.convert_mul,
'FULLY_CONNECTED': self.convert_fully_connected,
}
......@@ -267,8 +268,8 @@ class OperatorConverter(object):
out = self.convert_fused_activation_function(out, fused_activation_fn)
return out
def convert_add(self, op):
"""Convert TFLite add"""
def _convert_elemwise(self, relay_op, op):
"""Generic method to Convert TFLite elemwise"""
try:
from tflite.Operator import Operator
except ImportError:
......@@ -283,19 +284,26 @@ class OperatorConverter(object):
rhs_tensor = input_tensors[1]
if self.has_expr(rhs_tensor.tensor_idx):
# In most cases, we can assume that TOCO fuses ADD operators
# In most cases, we can assume that TOCO fuses elemwise operators
# with constants - it means both will be tensors.
rhs_expr = self.get_expr(rhs_tensor.tensor_idx)
else:
# However, in some corner cases, the ADD operator is not fused,
# However, in some corner cases, the elemwise operator is not fused,
# we can receive as constant.
rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type())
rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
dtype=rhs_type_str)
out = _op.add(lhs_expr, rhs_expr)
out = relay_op(lhs_expr, rhs_expr)
return out
def convert_add(self, op):
"""Convert TFLite ADD"""
return self._convert_elemwise(_op.add, op)
def convert_mul(self, op):
"""Convert TFLite MUL"""
return self._convert_elemwise(_op.multiply, op)
def convert_fully_connected(self, op):
"""Convert TFLite fully connected"""
try:
......
......@@ -24,7 +24,6 @@ from __future__ import print_function
import numpy as np
import tvm
from tvm import relay
from tvm.contrib import util
import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
......@@ -144,8 +143,6 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
for i in range(len(tflite_output)):
tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
sess.close()
#######################################################################
# Pooling
......@@ -311,10 +308,10 @@ def test_forward_concatenation():
#######################################################################
# Add
# Element-wise
# ---
def _test_add(data):
def _test_elemwise(math_op, data):
""" One iteration of add """
assert len(data) == 2
......@@ -329,10 +326,19 @@ def _test_add(data):
# Test with tensor and constant
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
out = math_ops.add(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
#######################################################################
# Add
# ---
def _test_add(data):
""" One iteration of add """
return _test_elemwise(math_ops.add, data)
def test_forward_add():
""" Add """
_test_add([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
......@@ -344,6 +350,25 @@ def test_forward_add():
#######################################################################
# Mul
# ---
def _test_mul(data):
""" One iteration of mul """
return _test_elemwise(math_ops.multiply, data)
def test_forward_mul():
""" Mul """
_test_mul([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3))])
_test_mul([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))])
_test_mul([np.arange(3.0, dtype=np.float32).reshape((1, 3)),
np.arange(3.0, dtype=np.float32).reshape((1, 3))])
#######################################################################
# Squeeze
# -------
......@@ -514,6 +539,7 @@ if __name__ == '__main__':
# Math
test_forward_add()
test_forward_mul()
# End to End
test_forward_mobilenet_v1()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment