Unverified Commit 3e3ccce1 by Ramana Radhakrishnan Committed by GitHub

Factor out import of common tflite.Operator in tflite frontend. (#5355)

* Restructure imports in tflite frontend.

These python modules are needed for every tflite file parsed.
Factorize out imports of the common most ones.

Now that the import of operator is common, asserts can be commonized.

Loses 473 lines of duplication.

* Only restrict to tflite.Operator
parent 24f68653
...@@ -159,7 +159,12 @@ class OperatorConverter(object): ...@@ -159,7 +159,12 @@ class OperatorConverter(object):
op = self.subgraph.Operators(op_idx) op = self.subgraph.Operators(op_idx)
op_code_str = self.get_op_code_str(op) op_code_str = self.get_op_code_str(op)
output_tensors = self.get_output_tensors(op) output_tensors = self.get_output_tensors(op)
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
ret = self.convert_map[op_code_str](op) ret = self.convert_map[op_code_str](op)
if len(output_tensors) == 1: if len(output_tensors) == 1:
...@@ -288,12 +293,6 @@ class OperatorConverter(object): ...@@ -288,12 +293,6 @@ class OperatorConverter(object):
def is_quantized(self, op): def is_quantized(self, op):
"""Check if an input tensor is quantized.""" """Check if an input tensor is quantized."""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
first_tensor = input_tensors[0] first_tensor = input_tensors[0]
return first_tensor.qnn_params is not None return first_tensor.qnn_params is not None
...@@ -335,12 +334,10 @@ class OperatorConverter(object): ...@@ -335,12 +334,10 @@ class OperatorConverter(object):
"""Convert TFLite reshape""" """Convert TFLite reshape"""
try: try:
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.ReshapeOptions import ReshapeOptions from tflite.ReshapeOptions import ReshapeOptions
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert input_tensors, "input tensors should not be empty" assert input_tensors, "input tensors should not be empty"
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
...@@ -368,7 +365,6 @@ class OperatorConverter(object): ...@@ -368,7 +365,6 @@ class OperatorConverter(object):
"""Generic method to Convert TFLite RESIZE operators""" """Generic method to Convert TFLite RESIZE operators"""
try: try:
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.ResizeBilinearOptions import ResizeBilinearOptions from tflite.ResizeBilinearOptions import ResizeBilinearOptions
# ResizeNearestNeighborOptions was added in tflite v1.13 # ResizeNearestNeighborOptions was added in tflite v1.13
tflite_ver = 1120 tflite_ver = 1120
...@@ -378,7 +374,6 @@ class OperatorConverter(object): ...@@ -378,7 +374,6 @@ class OperatorConverter(object):
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2" assert len(input_tensors) == 2, "input tensors length should be 2"
...@@ -421,14 +416,12 @@ class OperatorConverter(object): ...@@ -421,14 +416,12 @@ class OperatorConverter(object):
def convert_l2_normalization(self, op): def convert_l2_normalization(self, op):
"""Convert TFLite L2_NORMALIZATION """ """Convert TFLite L2_NORMALIZATION """
try: try:
from tflite.Operator import Operator
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.L2NormOptions import L2NormOptions from tflite.L2NormOptions import L2NormOptions
from tflite.ActivationFunctionType import ActivationFunctionType from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1" assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
...@@ -467,13 +460,11 @@ class OperatorConverter(object): ...@@ -467,13 +460,11 @@ class OperatorConverter(object):
def convert_lrn(self, op): def convert_lrn(self, op):
"""Convert TFLite LOCAL_RESPONSE_NORMALIZATION """ """Convert TFLite LOCAL_RESPONSE_NORMALIZATION """
try: try:
from tflite.Operator import Operator
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.LocalResponseNormalizationOptions import LocalResponseNormalizationOptions from tflite.LocalResponseNormalizationOptions import LocalResponseNormalizationOptions
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
if self.is_quantized(op): if self.is_quantized(op):
raise tvm.error.OpNotImplemented( raise tvm.error.OpNotImplemented(
'TFlite quantized LRN operator is not supported yet.') 'TFlite quantized LRN operator is not supported yet.')
...@@ -503,12 +494,6 @@ class OperatorConverter(object): ...@@ -503,12 +494,6 @@ class OperatorConverter(object):
def convert_logistic(self, op): def convert_logistic(self, op):
"""Convert TFLite LOGISTIC""" """Convert TFLite LOGISTIC"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1" assert len(input_tensors) == 1, "input tensors length should be 1"
...@@ -529,12 +514,6 @@ class OperatorConverter(object): ...@@ -529,12 +514,6 @@ class OperatorConverter(object):
def convert_softmax(self, op): def convert_softmax(self, op):
"""Convert TFLite softmax""" """Convert TFLite softmax"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1" assert len(input_tensors) == 1, "input tensors length should be 1"
...@@ -564,12 +543,6 @@ class OperatorConverter(object): ...@@ -564,12 +543,6 @@ class OperatorConverter(object):
def convert_tanh(self, op): def convert_tanh(self, op):
"""Convert TFLite TANH""" """Convert TFLite TANH"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1" assert len(input_tensors) == 1, "input tensors length should be 1"
...@@ -581,12 +554,6 @@ class OperatorConverter(object): ...@@ -581,12 +554,6 @@ class OperatorConverter(object):
def convert_relu(self, op): def convert_relu(self, op):
"""Convert TFLite ReLU""" """Convert TFLite ReLU"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1" assert len(input_tensors) == 1, "input tensors length should be 1"
...@@ -598,12 +565,6 @@ class OperatorConverter(object): ...@@ -598,12 +565,6 @@ class OperatorConverter(object):
def convert_hard_swish(self, op): def convert_hard_swish(self, op):
"""Convert TFLite Hard swish""" """Convert TFLite Hard swish"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1" assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
...@@ -635,14 +596,12 @@ class OperatorConverter(object): ...@@ -635,14 +596,12 @@ class OperatorConverter(object):
def convert_concatenation(self, op): def convert_concatenation(self, op):
"""Convert TFLite concatenation""" """Convert TFLite concatenation"""
try: try:
from tflite.Operator import Operator
from tflite.ConcatenationOptions import ConcatenationOptions from tflite.ConcatenationOptions import ConcatenationOptions
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.ActivationFunctionType import ActivationFunctionType from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) >= 1, "input tensors should greater than 1" assert len(input_tensors) >= 1, "input tensors should greater than 1"
in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors] in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors]
...@@ -683,12 +642,6 @@ class OperatorConverter(object): ...@@ -683,12 +642,6 @@ class OperatorConverter(object):
def _convert_unary_elemwise(self, relay_op, op): def _convert_unary_elemwise(self, relay_op, op):
"""Generic method to convert TFLite unary elemwise functions""" """Generic method to convert TFLite unary elemwise functions"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1" assert len(input_tensors) == 1, "input tensors length should be 1"
...@@ -784,12 +737,6 @@ class OperatorConverter(object): ...@@ -784,12 +737,6 @@ class OperatorConverter(object):
def convert_elu(self, op): def convert_elu(self, op):
"""Convert TFLite ELU""" """Convert TFLite ELU"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
if self.is_quantized(op): if self.is_quantized(op):
raise tvm.error.OpNotImplemented( raise tvm.error.OpNotImplemented(
'TFlite quantized ELU operator is not supported yet.') 'TFlite quantized ELU operator is not supported yet.')
...@@ -807,12 +754,6 @@ class OperatorConverter(object): ...@@ -807,12 +754,6 @@ class OperatorConverter(object):
def convert_square(self, op): def convert_square(self, op):
"""Convert TFLite SQUARE""" """Convert TFLite SQUARE"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1" assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
...@@ -834,7 +775,6 @@ class OperatorConverter(object): ...@@ -834,7 +775,6 @@ class OperatorConverter(object):
def _convert_elemwise(self, relay_op, op): def _convert_elemwise(self, relay_op, op):
"""Generic method to Convert TFLite elemwise""" """Generic method to Convert TFLite elemwise"""
try: try:
from tflite.Operator import Operator
from tflite.AddOptions import AddOptions from tflite.AddOptions import AddOptions
from tflite.SubOptions import SubOptions from tflite.SubOptions import SubOptions
from tflite.MulOptions import MulOptions from tflite.MulOptions import MulOptions
...@@ -844,7 +784,6 @@ class OperatorConverter(object): ...@@ -844,7 +784,6 @@ class OperatorConverter(object):
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2" assert len(input_tensors) == 2, "input tensors length should be 2"
...@@ -1025,12 +964,6 @@ class OperatorConverter(object): ...@@ -1025,12 +964,6 @@ class OperatorConverter(object):
def _convert_logical_binary(self, relay_op, op): def _convert_logical_binary(self, relay_op, op):
"""Generic method to convert logical binary ops""" """Generic method to convert logical binary ops"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2" assert len(input_tensors) == 2, "input tensors length should be 2"
...@@ -1052,12 +985,6 @@ class OperatorConverter(object): ...@@ -1052,12 +985,6 @@ class OperatorConverter(object):
def convert_zeros_like(self, op): def convert_zeros_like(self, op):
"""Convert TFLite ZEROS LIKE""" """Convert TFLite ZEROS LIKE"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1" assert len(input_tensors) == 1, "input tensors length should be 1"
...@@ -1071,12 +998,10 @@ class OperatorConverter(object): ...@@ -1071,12 +998,10 @@ class OperatorConverter(object):
"""Generic method to Convert TFLite MEAN operators""" """Generic method to Convert TFLite MEAN operators"""
try: try:
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.ReducerOptions import ReducerOptions from tflite.ReducerOptions import ReducerOptions
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2" assert len(input_tensors) == 2, "input tensors length should be 2"
...@@ -1135,7 +1060,6 @@ class OperatorConverter(object): ...@@ -1135,7 +1060,6 @@ class OperatorConverter(object):
def convert_fully_connected(self, op): def convert_fully_connected(self, op):
"""Convert TFLite fully connected""" """Convert TFLite fully connected"""
try: try:
from tflite.Operator import Operator
from tflite.FullyConnectedOptions import FullyConnectedOptions from tflite.FullyConnectedOptions import FullyConnectedOptions
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.TensorType import TensorType from tflite.TensorType import TensorType
...@@ -1143,7 +1067,6 @@ class OperatorConverter(object): ...@@ -1143,7 +1067,6 @@ class OperatorConverter(object):
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) >= 2, "input tensors length should be >= 2" assert len(input_tensors) >= 2, "input tensors length should be >= 2"
...@@ -1238,12 +1161,10 @@ class OperatorConverter(object): ...@@ -1238,12 +1161,10 @@ class OperatorConverter(object):
"""Convert TFLite squeeze""" """Convert TFLite squeeze"""
try: try:
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.SqueezeOptions import SqueezeOptions from tflite.SqueezeOptions import SqueezeOptions
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
output_tensors = self.get_output_tensors(op) output_tensors = self.get_output_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1" assert len(input_tensors) == 1, "input tensors length should be 1"
...@@ -1287,14 +1208,12 @@ class OperatorConverter(object): ...@@ -1287,14 +1208,12 @@ class OperatorConverter(object):
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.ActivationFunctionType import ActivationFunctionType from tflite.ActivationFunctionType import ActivationFunctionType
from tflite.TensorType import TensorType from tflite.TensorType import TensorType
from tflite.Operator import Operator
from tflite.Conv2DOptions import Conv2DOptions from tflite.Conv2DOptions import Conv2DOptions
from tflite.DepthwiseConv2DOptions import DepthwiseConv2DOptions from tflite.DepthwiseConv2DOptions import DepthwiseConv2DOptions
from tflite.Padding import Padding from tflite.Padding import Padding
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) >= 2, "input tensors length should be >= 2" assert len(input_tensors) >= 2, "input tensors length should be >= 2"
...@@ -1455,12 +1374,10 @@ class OperatorConverter(object): ...@@ -1455,12 +1374,10 @@ class OperatorConverter(object):
"""split implementation.""" """split implementation."""
try: try:
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.SplitOptions import SplitOptions from tflite.SplitOptions import SplitOptions
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be == 2" assert len(input_tensors) == 2, "input tensors length should be == 2"
...@@ -1490,12 +1407,6 @@ class OperatorConverter(object): ...@@ -1490,12 +1407,6 @@ class OperatorConverter(object):
def convert_slice(self, op): def convert_slice(self, op):
"""Convert TFLite SLICE""" """Convert TFLite SLICE"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 3, "input tensors length should be == 3" assert len(input_tensors) == 3, "input tensors length should be == 3"
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
...@@ -1519,12 +1430,6 @@ class OperatorConverter(object): ...@@ -1519,12 +1430,6 @@ class OperatorConverter(object):
def convert_transpose(self, op): def convert_transpose(self, op):
"""transpose implementation.""" """transpose implementation."""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2" assert len(input_tensors) == 2, "input tensors length should be 2"
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
...@@ -1545,13 +1450,11 @@ class OperatorConverter(object): ...@@ -1545,13 +1450,11 @@ class OperatorConverter(object):
def convert_cast(self, op): def convert_cast(self, op):
"""Convert TFLite CAST""" """Convert TFLite CAST"""
try: try:
from tflite.Operator import Operator
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.CastOptions import CastOptions from tflite.CastOptions import CastOptions
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1" assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
...@@ -1569,12 +1472,6 @@ class OperatorConverter(object): ...@@ -1569,12 +1472,6 @@ class OperatorConverter(object):
def convert_tile(self, op): def convert_tile(self, op):
"""tile implementation.""" """tile implementation."""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2" assert len(input_tensors) == 2, "input tensors length should be 2"
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
...@@ -1591,12 +1488,6 @@ class OperatorConverter(object): ...@@ -1591,12 +1488,6 @@ class OperatorConverter(object):
def convert_topk_v2(self, op): def convert_topk_v2(self, op):
""" Convert TFLite TOPK_v2 """ """ Convert TFLite TOPK_v2 """
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2" assert len(input_tensors) == 2, "input tensors length should be 2"
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
...@@ -1612,13 +1503,11 @@ class OperatorConverter(object): ...@@ -1612,13 +1503,11 @@ class OperatorConverter(object):
try: try:
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.ActivationFunctionType import ActivationFunctionType from tflite.ActivationFunctionType import ActivationFunctionType
from tflite.Operator import Operator
from tflite.Pool2DOptions import Pool2DOptions from tflite.Pool2DOptions import Pool2DOptions
from tflite.Padding import Padding from tflite.Padding import Padding
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1" assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
...@@ -1689,12 +1578,6 @@ class OperatorConverter(object): ...@@ -1689,12 +1578,6 @@ class OperatorConverter(object):
def convert_pad(self, op): def convert_pad(self, op):
"""Convert TFLite PAD""" """Convert TFLite PAD"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2" assert len(input_tensors) == 2, "input tensors length should be 2"
...@@ -1740,7 +1623,6 @@ class OperatorConverter(object): ...@@ -1740,7 +1623,6 @@ class OperatorConverter(object):
def convert_mirror_pad(self, op): def convert_mirror_pad(self, op):
"""Convert TFLite MIRROR_PAD""" """Convert TFLite MIRROR_PAD"""
try: try:
from tflite.Operator import Operator
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.MirrorPadOptions import MirrorPadOptions from tflite.MirrorPadOptions import MirrorPadOptions
except ImportError: except ImportError:
...@@ -1751,7 +1633,6 @@ class OperatorConverter(object): ...@@ -1751,7 +1633,6 @@ class OperatorConverter(object):
raise tvm.error.OpNotImplemented( raise tvm.error.OpNotImplemented(
'TFlite quantized MIRROR_PAD operator is not supported yet.') 'TFlite quantized MIRROR_PAD operator is not supported yet.')
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2" assert len(input_tensors) == 2, "input tensors length should be 2"
...@@ -1779,12 +1660,10 @@ class OperatorConverter(object): ...@@ -1779,12 +1660,10 @@ class OperatorConverter(object):
"""Convert TFLite pack""" """Convert TFLite pack"""
try: try:
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.PackOptions import PackOptions from tflite.PackOptions import PackOptions
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) >= 1, "input tensors should greater than 1" assert len(input_tensors) >= 1, "input tensors should greater than 1"
in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors] in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors]
...@@ -1806,12 +1685,10 @@ class OperatorConverter(object): ...@@ -1806,12 +1685,10 @@ class OperatorConverter(object):
"""Convert TFLite unpack""" """Convert TFLite unpack"""
try: try:
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.UnpackOptions import UnpackOptions from tflite.UnpackOptions import UnpackOptions
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1" assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
...@@ -1848,12 +1725,7 @@ class OperatorConverter(object): ...@@ -1848,12 +1725,7 @@ class OperatorConverter(object):
def convert_batch_to_space_nd(self, op): def convert_batch_to_space_nd(self, op):
"""batch_to_space_nd implementation.""" """batch_to_space_nd implementation."""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 3, "input tensors length should be 3" assert len(input_tensors) == 3, "input tensors length should be 3"
...@@ -1901,12 +1773,6 @@ class OperatorConverter(object): ...@@ -1901,12 +1773,6 @@ class OperatorConverter(object):
def convert_space_to_batch_nd(self, op): def convert_space_to_batch_nd(self, op):
"""space_to_batch_nd implementation.""" """space_to_batch_nd implementation."""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 3, "input tensors length should be 3" assert len(input_tensors) == 3, "input tensors length should be 3"
...@@ -1960,12 +1826,10 @@ class OperatorConverter(object): ...@@ -1960,12 +1826,10 @@ class OperatorConverter(object):
"""Convert TFLite DEPTH_TO_SPACE""" """Convert TFLite DEPTH_TO_SPACE"""
try: try:
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.DepthToSpaceOptions import DepthToSpaceOptions from tflite.DepthToSpaceOptions import DepthToSpaceOptions
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1" assert len(input_tensors) == 1, "input tensors length should be 1"
...@@ -1985,12 +1849,10 @@ class OperatorConverter(object): ...@@ -1985,12 +1849,10 @@ class OperatorConverter(object):
"""Convert TFLite SPACE_TO_DEPTH""" """Convert TFLite SPACE_TO_DEPTH"""
try: try:
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.SpaceToDepthOptions import SpaceToDepthOptions from tflite.SpaceToDepthOptions import SpaceToDepthOptions
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1" assert len(input_tensors) == 1, "input tensors length should be 1"
...@@ -2008,12 +1870,6 @@ class OperatorConverter(object): ...@@ -2008,12 +1870,6 @@ class OperatorConverter(object):
def convert_prelu(self, op): def convert_prelu(self, op):
"""Convert TFLite PReLU""" """Convert TFLite PReLU"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2" assert len(input_tensors) == 2, "input tensors length should be 2"
...@@ -2033,13 +1889,11 @@ class OperatorConverter(object): ...@@ -2033,13 +1889,11 @@ class OperatorConverter(object):
try: try:
from tflite.BuiltinOptions import BuiltinOptions from tflite.BuiltinOptions import BuiltinOptions
from tflite.TensorType import TensorType from tflite.TensorType import TensorType
from tflite.Operator import Operator
from tflite.TransposeConvOptions import TransposeConvOptions from tflite.TransposeConvOptions import TransposeConvOptions
from tflite.Padding import Padding from tflite.Padding import Padding
except ImportError: except ImportError:
raise ImportError("The tflite package must be installed") raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op) input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 3, "input tensors length should be 3" assert len(input_tensors) == 3, "input tensors length should be 3"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment