Commit 7c1c97d2 by Alexander Pivovarov Committed by Yao Wang

Add LOGISTIC operator to relay tflite frontend (#3313)

parent c4245e3d
...@@ -67,6 +67,7 @@ class OperatorConverter(object): ...@@ -67,6 +67,7 @@ class OperatorConverter(object):
'MUL': self.convert_mul, 'MUL': self.convert_mul,
'FULLY_CONNECTED': self.convert_fully_connected, 'FULLY_CONNECTED': self.convert_fully_connected,
'PAD': self.convert_pad, 'PAD': self.convert_pad,
'LOGISTIC': self.convert_logistic,
} }
def check_unsupported_ops(self): def check_unsupported_ops(self):
...@@ -218,6 +219,23 @@ class OperatorConverter(object): ...@@ -218,6 +219,23 @@ class OperatorConverter(object):
return out return out
def convert_logistic(self, op):
"""Convert TFLite LOGISTIC"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
out = _op.sigmoid(in_expr)
return out
def convert_softmax(self, op): def convert_softmax(self, op):
"""Convert TFLite softmax""" """Convert TFLite softmax"""
try: try:
......
...@@ -424,6 +424,22 @@ def test_forward_pad(): ...@@ -424,6 +424,22 @@ def test_forward_pad():
####################################################################### #######################################################################
# Logistic
# --------
def _test_logistic(data):
""" One iteration of LOGISTIC """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = math_ops.sigmoid(in_data)
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
def test_forward_logistic():
""" LOGISTIC """
_test_logistic(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
#######################################################################
# Softmax # Softmax
# ------- # -------
...@@ -563,6 +579,7 @@ if __name__ == '__main__': ...@@ -563,6 +579,7 @@ if __name__ == '__main__':
# NN # NN
test_forward_convolution() test_forward_convolution()
test_forward_logistic()
test_forward_pooling() test_forward_pooling()
test_forward_softmax() test_forward_softmax()
test_forward_fully_connected() test_forward_fully_connected()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment