Unverified Commit 2355caa8 by Ina Dobreva Committed by GitHub

[Frontend][TFLite] Add parser support for l2_normalization (#4966)

* [Frontend][TFLite] Add parser support for l2_normalization

* TF doesn't provide uint8 support
* TFL does the normalization only if it's over the last axis
* TFL uses only the default value for expilon

* Change error message
parent a449d8b1
......@@ -122,6 +122,7 @@ class OperatorConverter(object):
'LOGICAL_OR': self.convert_logical_or,
'DETECTION_POSTPROCESS': self.convert_detection_postprocess,
'SQUARE': self.convert_square,
'L2_NORMALIZATION': self.convert_l2_normalization,
}
def check_unsupported_ops(self):
......@@ -405,6 +406,52 @@ class OperatorConverter(object):
"""Convert TFLite RESIZE_NEAREST_NEIGHBOR"""
return self._convert_resize("nearest_neighbor", op)
def convert_l2_normalization(self, op):
"""Convert TFLite L2_NORMALIZATION """
try:
from tflite.Operator import Operator
from tflite.BuiltinOptions import BuiltinOptions
from tflite.L2NormOptions import L2NormOptions
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]
assert op.BuiltinOptionsType() == BuiltinOptions.L2NormOptions
op_options = op.BuiltinOptions()
l2_norm_options = L2NormOptions()
l2_norm_options.Init(op_options.Bytes, op_options.Pos)
fused_activation_fn = l2_norm_options.FusedActivationFunction()
# TFLite supports normalization only over the last dim
input_tensor_rank = len(input_tensor.tensor.ShapeAsNumpy())
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFLite quantized L2_NORMALIZATION operator is not supported yet.')
# TFL uses only the default epsilon value
out = _op.nn.l2_normalize(in_expr, eps=1e-12, axis=[input_tensor_rank - 1])
# if we have fused activation fn
if fused_activation_fn != ActivationFunctionType.NONE:
if not output_tensor.qnn_params:
out = self.convert_fused_activation_function(out, fused_activation_fn)
else:
raise tvm.error.OpNotImplemented(
'TFLite quantized L2_NORMALIZATION operator\
with fused activation function is not supported yet.')
return out
def convert_logistic(self, op):
"""Convert TFLite LOGISTIC"""
try:
......
......@@ -33,6 +33,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import variables
try:
from tensorflow import lite as interpreter_wrapper
......@@ -1264,6 +1265,24 @@ def test_forward_unpack():
_test_unpack(np.array(np.random.uniform(0, 5, (2, 3, 4)), dtype=np.int32), axis=-3, num_unpacks=2)
#######################################################################
# L2 normalization
# ----------------
def _test_l2_normalization(data, axis, fused_activation_function=None):
""" One iteration of L2_NORMALIZATION """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = nn_impl.l2_normalize(in_data, axis)
out = with_fused_activation_function(out, fused_activation_function)
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
def test_forward_l2_normalization():
""" L2_NORMALIZATION """
data = np.random.uniform(size=(3, 6, 4)).astype('float32')
_test_l2_normalization(data, axis=2)
_test_l2_normalization(data, axis=2, fused_activation_function="RELU")
#######################################################################
# Logistic
# --------
......@@ -1649,6 +1668,7 @@ if __name__ == '__main__':
test_forward_relu()
test_forward_prelu()
test_forward_fully_connected()
test_forward_l2_normalization()
# Elemwise
test_all_elemwise()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment