Commit 4f712c79 by Ina Dobreva Committed by Yao Wang

Add parser support for ReLU tflite operator (#4022)

parent 9151d435
...@@ -84,6 +84,7 @@ class OperatorConverter(object): ...@@ -84,6 +84,7 @@ class OperatorConverter(object):
'PACK': self.convert_pack, 'PACK': self.convert_pack,
'LOGISTIC': self.convert_logistic, 'LOGISTIC': self.convert_logistic,
'TANH':self.convert_tanh, 'TANH':self.convert_tanh,
'RELU':self.convert_relu,
'SPLIT': self.convert_split, 'SPLIT': self.convert_split,
'TRANSPOSE': self.convert_transpose, 'TRANSPOSE': self.convert_transpose,
'TILE': self.convert_tile, 'TILE': self.convert_tile,
...@@ -345,6 +346,23 @@ class OperatorConverter(object): ...@@ -345,6 +346,23 @@ class OperatorConverter(object):
return out return out
def convert_relu(self, op):
"""Convert TFLite ReLU"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
out = _op.nn.relu(in_expr)
return out
def convert_concatenation(self, op): def convert_concatenation(self, op):
"""Convert TFLite concatenation""" """Convert TFLite concatenation"""
try: try:
......
...@@ -837,6 +837,21 @@ def test_forward_tanh(): ...@@ -837,6 +837,21 @@ def test_forward_tanh():
_test_tanh(np.arange(6.0, dtype=np.float32).reshape((1, 6))) _test_tanh(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
####################################################################### #######################################################################
# ReLu
# --------
def _test_relu(data):
""" One iteration of ReLU """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = nn_ops.relu(in_data)
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
def test_forward_relu():
""" ReLU """
_test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
#######################################################################
# Fully Connected # Fully Connected
# ------- # -------
...@@ -999,6 +1014,7 @@ if __name__ == '__main__': ...@@ -999,6 +1014,7 @@ if __name__ == '__main__':
test_forward_pooling() test_forward_pooling()
test_forward_softmax() test_forward_softmax()
test_forward_tanh() test_forward_tanh()
test_forward_relu()
test_forward_fully_connected() test_forward_fully_connected()
# Elemwise # Elemwise
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment