tflite.py 81.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18
# pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel

19 20 21
"""Tensorflow lite frontend."""
import math
import numpy as np
22
import tvm
23 24
from tvm.ir import IRModule

25
from tvm import relay
Zhi committed
26
from .. import analysis
27 28
from .. import expr as _expr
from .. import op as _op
29
from .. import qnn as _qnn
30
from ..util import get_scalar_from_constant
31 32
from ... import nd as _nd
from .common import ExprTable
33
from .common import infer_shape as _infer_shape
34 35 36 37 38

__all__ = ['from_tflite']

class TensorWrapper(object):
    """Tensor wrapper for TFLite Tensor"""
39
    def __init__(self, tensor_idx, tensor, buffer, qnn_params=None):
40 41 42
        self.tensor_idx = tensor_idx
        self.tensor = tensor
        self.buffer = buffer
43
        self.qnn_params = qnn_params
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

class OperatorConverter(object):
    """Operator Converted for converting TFLite ops to Relay ops"""
    def __init__(self, model, subgraph, exp_tab):

        try:
            from tflite.BuiltinOperator import BuiltinOperator
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.ActivationFunctionType import ActivationFunctionType
        except ImportError:
            raise ImportError("The tflite package must be installed")

        self.model = model
        self.subgraph = subgraph
        self.exp_tab = exp_tab
        self.builtin_op_code = build_str_map(BuiltinOperator())
        self.activation_fn_type = build_str_map(ActivationFunctionType())
        self.builtin_options = build_str_map(BuiltinOptions())
62 63

        # Add more operators
64
        self.convert_map = {
65 66 67 68 69 70 71 72 73 74
            'ABS': self.convert_abs,
            'EXP': self.convert_exp,
            'FLOOR': self.convert_floor,
            'CEIL': self.convert_ceil,
            'LOG': self.convert_log,
            'SIN': self.convert_sin,
            'COS': self.convert_cos,
            'SQRT': self.convert_sqrt,
            'RSQRT': self.convert_rsqrt,
            'NEG': self.convert_neg,
75 76 77 78
            'CONV_2D': self.convert_conv2d,
            'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
            'AVERAGE_POOL_2D': self.convert_average_pool2d,
            'RESHAPE': self.convert_reshape,
79 80
            'RESIZE_BILINEAR': self.convert_resize_bilinear,
            'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor,
81 82 83
            'SOFTMAX': self.convert_softmax,
            'SQUEEZE': self.convert_squeeze,
            'MAX_POOL_2D': self.convert_max_pool2d,
84
            'CONCATENATION': self.convert_concatenation,
85
            'ADD': self.convert_add,
86
            'SUB': self.convert_sub,
87
            'MUL': self.convert_mul,
88 89 90 91
            'DIV': self.convert_div,
            'POW': self.convert_pow,
            'MAXIMUM': self.convert_maximum,
            'MINIMUM': self.convert_minimum,
92
            'GREATER': self.convert_greater,
93 94 95 96 97
            'GREATER_EQUAL': self.convert_greater_equal,
            'LESS': self.convert_less,
            'LESS_EQUAL': self.convert_less_equal,
            'EQUAL': self.convert_equal,
            'NOT_EQUAL': self.convert_not_equal,
98
            'ZEROS_LIKE': self.convert_zeros_like,
99 100 101 102
            'REDUCE_MIN': self._convert_reduce_min,
            'REDUCE_MAX': self._convert_reduce_max,
            'MEAN': self._convert_reduce_mean,
            'REDUCE_PROD': self._convert_reduce_prod,
103
            'SUM': self._convert_reduce_sum,
104
            'FULLY_CONNECTED': self.convert_fully_connected,
105
            'PAD': self.convert_pad,
106
            'MIRROR_PAD': self.convert_mirror_pad,
107
            'PACK': self.convert_pack,
108
            'UNPACK': self.convert_unpack,
109
            'LOGISTIC': self.convert_logistic,
110
            'TANH':self.convert_tanh,
111
            'RELU':self.convert_relu,
112
            'SPLIT': self.convert_split,
113
            'SLICE': self.convert_slice,
114
            'TRANSPOSE': self.convert_transpose,
115
            'CAST': self.convert_cast,
116 117
            'TILE': self.convert_tile,
            'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
118 119
            'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
            'PRELU': self.convert_prelu,
120
            'TRANSPOSE_CONV': self.convert_transpose_conv,
121
            'SQUARED_DIFFERENCE': self.convert_squared_difference,
122 123
            'LOGICAL_AND': self.convert_logical_and,
            'LOGICAL_OR': self.convert_logical_or,
124 125 126 127 128 129 130 131 132 133 134 135 136
        }

    def check_unsupported_ops(self):
        """Check unsupported TFLite ops in our converter."""
        unsupported_ops_set = set()

        for op_idx in range(self.subgraph.OperatorsLength()):
            op = self.subgraph.Operators(op_idx)
            op_code_str = self.get_op_code_str(op)
            if op_code_str not in self.convert_map:
                unsupported_ops_set.add(op_code_str)

        if unsupported_ops_set:
137 138 139 140
            msg = 'The following operators are not supported in frontend ' \
                  'TFLite: {}'
            ops = str(list(unsupported_ops_set)).strip('[,]')
            raise tvm.error.OpNotImplemented(msg.format(ops))
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170

    def convert_op_to_relay(self):
        """Convert TFLite ops to relay ops"""
        for op_idx in range(self.subgraph.OperatorsLength()):
            op = self.subgraph.Operators(op_idx)
            op_code_str = self.get_op_code_str(op)
            output_tensors = self.get_output_tensors(op)

            ret = self.convert_map[op_code_str](op)

            if len(output_tensors) == 1:
                tensor_idx = output_tensors[0].tensor_idx
                self.exp_tab.set_expr(get_tensor_name(self.subgraph, tensor_idx), ret)
            else:
                for idx, output_tensor in enumerate(output_tensors):
                    self.exp_tab.set_expr(get_tensor_name(self.subgraph, output_tensor.tensor_idx),
                                          ret[idx])

    def get_op_code_str(self, op):
        """Get TFLite ops string representation"""
        try:
            from tflite.BuiltinOperator import BuiltinOperator
        except ImportError:
            raise ImportError("The tflite package must be installed")

        op_code_list_idx = op.OpcodeIndex()
        op_code_id = self.model.OperatorCodes(op_code_list_idx).BuiltinCode()
        op_code_str = self.builtin_op_code[op_code_id]
        if op_code_id == BuiltinOperator.CUSTOM:
            # Custom operator
171
            raise NotImplementedError("Custom operators are currently not supported")
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
        return op_code_str

    def get_input_tensors(self, op):
        operator_inputs = op.InputsAsNumpy()
        return self.get_tensors(operator_inputs)

    def get_output_tensors(self, op):
        operator_outputs = op.OutputsAsNumpy()
        return self.get_tensors(operator_outputs)

    def get_tensors(self, tensors_idx_list):
        """Get tensor wrapper list from given TFLite tensor index list"""
        return_list = list()
        for tensor_idx in tensors_idx_list:
            if tensor_idx < 0:
                return_list.append(TensorWrapper(tensor_idx, 0, 0))
                continue

            tensor = self.subgraph.Tensors(tensor_idx)
            buffer_idx = tensor.Buffer()
            buffer = self.model.Buffers(buffer_idx)
193 194 195 196 197 198 199 200 201 202

            # Check if the tensors are quantized. Parse if yes.
            qnn_params = None
            tflite_qnn_params = tensor.Quantization()
            if tflite_qnn_params is not None:
                scale = float(tflite_qnn_params.ScaleAsNumpy())
                zero_point = int(tflite_qnn_params.ZeroPointAsNumpy())
                # Check that the scale and zero points are valid.
                if scale != 0 or zero_point != 0:
                    qnn_params = dict()
203 204
                    qnn_params['scale'] = relay.const(scale, 'float32')
                    qnn_params['zero_point'] = relay.const(zero_point, 'int32')
205
            return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params))
206 207 208 209 210 211 212 213 214 215 216 217 218 219
        return return_list

    def get_tensor_value(self, tensor_wrapper):
        """Get tensor buffer value from given tensor wrapper"""
        assert isinstance(tensor_wrapper, TensorWrapper)

        try:
            from tflite.TensorType import TensorType
        except ImportError:
            raise ImportError("The tflite package must be installed")

        if tensor_wrapper.tensor.Type() == TensorType.UINT8:
            return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.uint8).reshape(
                tensor_wrapper.tensor.ShapeAsNumpy())
220
        if tensor_wrapper.tensor.Type() == TensorType.FLOAT32:
221 222
            return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.float32).reshape(
                tensor_wrapper.tensor.ShapeAsNumpy())
223
        if tensor_wrapper.tensor.Type() == TensorType.INT32:
224 225
            return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape(
                tensor_wrapper.tensor.ShapeAsNumpy())
226 227 228
        if tensor_wrapper.tensor.Type() == TensorType.INT64:
            return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int64).reshape(
                tensor_wrapper.tensor.ShapeAsNumpy())
229 230 231
        if tensor_wrapper.tensor.Type() == TensorType.BOOL:
            return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.bool_).reshape(
                tensor_wrapper.tensor.ShapeAsNumpy())
232
        raise NotImplementedError("Tensor type {} is currently not supported"
233
                                  .format(str(tensor_wrapper.tensor.Type())))
234 235 236 237 238 239 240 241 242 243

    def get_tensor_type_str(self, tensor_type):
        """Get tensor type string representation when given TFLite tensor type"""
        try:
            from tflite.TensorType import TensorType
        except ImportError:
            raise ImportError("The tflite package must be installed")

        if tensor_type == TensorType.UINT8:
            return "uint8"
244
        if tensor_type == TensorType.FLOAT32:
245
            return "float32"
246
        if tensor_type == TensorType.INT32:
247
            return "int32"
248 249
        if tensor_type == TensorType.INT64:
            return "int64"
250 251
        if tensor_type == TensorType.BOOL:
            return "bool"
252 253
        raise NotImplementedError("Tensor type {} is currently not supported"
                                  .format(str(tensor_type)))
254

255
    def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
256 257 258 259 260 261 262 263 264 265
        lhs_scale = lhs_tensor.qnn_params['scale']
        rhs_scale = rhs_tensor.qnn_params['scale']
        lhs_zero_point = lhs_tensor.qnn_params['zero_point']
        rhs_zero_point = rhs_tensor.qnn_params['zero_point']
        lhs_scale_value = get_scalar_from_constant(lhs_scale)
        rhs_scale_value = get_scalar_from_constant(rhs_scale)
        lhs_zero_point_value = get_scalar_from_constant(lhs_zero_point)
        rhs_zero_point_value = get_scalar_from_constant(rhs_zero_point)
        return lhs_scale_value == rhs_scale_value and \
                lhs_zero_point_value == rhs_zero_point_value
266

267 268 269 270 271 272 273 274 275 276 277 278
    def is_quantized(self, op):
        """Check if an input tensor is quantized."""
        try:
            from tflite.Operator import Operator
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        first_tensor = input_tensors[0]
        return first_tensor.qnn_params is not None

279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
    def quantize(self, expr, tensor_to_quantize):
        """ Helper function to quantize a tensor with Relay """
        tensor_type = tensor_to_quantize.tensor.Type()
        tensor_type_str = self.get_tensor_type_str(tensor_type)
        quantized = _qnn.op.quantize(data=expr,
                                     output_scale=tensor_to_quantize.qnn_params['scale'],
                                     output_zero_point=tensor_to_quantize.qnn_params['zero_point'],
                                     out_dtype=tensor_type_str)
        return quantized

    def dequantize(self, expr, tensor):
        """ Helper function to dequantize a tensor with Relay """
        dequantized = _qnn.op.dequantize(data=expr,
                                         input_scale=tensor.qnn_params['scale'],
                                         input_zero_point=tensor.qnn_params['zero_point'])
        return dequantized

296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
    def convert_conv2d(self, op):
        """Convert TFLite conv2d"""
        return self.convert_conv(op, "conv2d")

    def convert_depthwise_conv2d(self, op):
        """Convert TFLite depthwise conv2d"""
        return self.convert_conv(op, "depthwise")

    def convert_average_pool2d(self, op):
        """Convert TFLite average pool2d"""
        return self.convert_pool2d(op, "average")

    def convert_max_pool2d(self, op):
        """Convert TFLite max pool2d"""
        return self.convert_pool2d(op, "max")

    def convert_reshape(self, op):
        """Convert TFLite reshape"""
        try:
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.Operator import Operator
            from tflite.ReshapeOptions import ReshapeOptions
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
323
        assert input_tensors, "input tensors should not be empty"
324 325 326 327 328 329 330 331 332 333 334
        input_tensor = input_tensors[0]
        input_tensor_idx = input_tensor.tensor_idx

        assert op.BuiltinOptionsType() == BuiltinOptions.ReshapeOptions
        op_options = op.BuiltinOptions()
        reshape_options = ReshapeOptions()
        reshape_options.Init(op_options.Bytes, op_options.Pos)
        target_shape = reshape_options.NewShapeAsNumpy()

        in_expr = self.get_expr(input_tensor_idx)

335 336 337 338 339 340 341 342
        # If the tensors are quantized, ensure that input/output qnn params are same.
        if input_tensor.qnn_params:
            output_tensors = self.get_output_tensors(op)
            assert len(output_tensors) == 1, "There should be only 1 output tensor"
            output_tensor = output_tensors[0]
            assert self.has_same_qnn_params(input_tensor, output_tensor), \
                    "TFLite reshape requires input and output scale and zero points to be equal"
        out = _op.reshape(in_expr, newshape=tuple(target_shape))
343 344
        return out

345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
    def _convert_resize(self, method, op):
        """Generic method to Convert TFLite RESIZE operators"""
        try:
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.Operator import Operator
            from tflite.ResizeBilinearOptions import ResizeBilinearOptions
            # ResizeNearestNeighborOptions was added in tflite v1.13
            tflite_ver = 1120
            if 'ResizeNearestNeighborOptions' in dir(BuiltinOptions):
                from tflite.ResizeNearestNeighborOptions import ResizeNearestNeighborOptions
                tflite_ver = 1130
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 2, "input tensors length should be 2"

        # images, 4-D Tensor with shape NHWC.
        input_tensor = input_tensors[0]
        in_expr = self.get_expr(input_tensor.tensor_idx)

        # size - 1-D int32 Tensor of 2 elements: new_height, new_width
        target_size = tuple(self.get_tensor_value(input_tensors[1]))

        # Options - align_corners (bool)
        resize_options = None
        align_corners = False
373
        if method == "bilinear":
374 375 376 377 378 379 380 381 382 383 384 385
            assert op.BuiltinOptionsType() == BuiltinOptions.ResizeBilinearOptions
            resize_options = ResizeBilinearOptions()
        elif tflite_ver >= 1130:
            assert op.BuiltinOptionsType() == BuiltinOptions.ResizeNearestNeighborOptions
            resize_options = ResizeNearestNeighborOptions()

        if resize_options is not None:
            op_options = op.BuiltinOptions()
            resize_options.Init(op_options.Bytes, op_options.Pos)
            align_corners = resize_options.AlignCorners()

        # Use layout NHWC
386 387 388
        coord_trans = "align_corners" if align_corners else "asymmetric"
        out = _op.image.resize(in_expr, target_size, "NHWC", method,
                               coordinate_transformation_mode=coord_trans)
389 390 391 392
        return out

    def convert_resize_bilinear(self, op):
        """Convert TFLite RESIZE_BILINEAR"""
393
        return self._convert_resize("bilinear", op)
394 395 396

    def convert_resize_nearest_neighbor(self, op):
        """Convert TFLite RESIZE_NEAREST_NEIGHBOR"""
397
        return self._convert_resize("nearest_neighbor", op)
398

399 400 401 402 403 404 405 406 407 408 409 410 411 412
    def convert_logistic(self, op):
        """Convert TFLite LOGISTIC"""
        try:
            from tflite.Operator import Operator
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 1, "input tensors length should be 1"

        input_tensor = input_tensors[0]
        in_expr = self.get_expr(input_tensor.tensor_idx)

413 414 415 416 417 418
        output_tensors = self.get_output_tensors(op)
        assert len(output_tensors) == 1, "output tensors length should be 1"
        output_tensor = output_tensors[0]

        if input_tensor.qnn_params:
            in_expr = self.dequantize(in_expr, input_tensor)
419
        out = _op.sigmoid(in_expr)
420 421 422
        if output_tensor.qnn_params:
            out = self.quantize(out, output_tensor)

423 424
        return out

425 426 427 428 429 430 431 432 433 434 435 436 437
    def convert_softmax(self, op):
        """Convert TFLite softmax"""
        try:
            from tflite.Operator import Operator
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 1, "input tensors length should be 1"

        input_tensor = input_tensors[0]
        input_tensor_idx = input_tensor.tensor_idx
438 439 440 441 442

        output_tensors = self.get_output_tensors(op)
        assert len(output_tensors) == 1, "output tensors length should be 1"
        output_tensor = output_tensors[0]

443 444
        params = {'axis': 1}  # 1 is channel
        in_expr = self.get_expr(input_tensor_idx)
445 446 447 448 449

        # TODO - Naive softmax int8 implementation leads to bad accuracy. Currently, we can
        # dequantize to FP32 and perform softmax on FP32. We can investigate an integer only softmax
        # implementation in future.
        if input_tensor.qnn_params:
450
            in_expr = self.dequantize(in_expr, input_tensor)
451

452 453
        out = _op.nn.softmax(in_expr, **params)

454 455
        # Go back to integer dataype if the original operator was quantized.
        if output_tensor.qnn_params:
456
            out = self.quantize(out, output_tensor)
457

458 459
        return out

460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476
    def convert_tanh(self, op):
        """Convert TFLite TANH"""
        try:
            from tflite.Operator import Operator
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 1, "input tensors length should be 1"

        input_tensor = input_tensors[0]
        in_expr = self.get_expr(input_tensor.tensor_idx)
        out = _op.tanh(in_expr)

        return out

477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493
    def convert_relu(self, op):
        """Convert TFLite ReLU"""
        try:
            from tflite.Operator import Operator
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 1, "input tensors length should be 1"

        input_tensor = input_tensors[0]
        in_expr = self.get_expr(input_tensor.tensor_idx)
        out = _op.nn.relu(in_expr)

        return out

494
    def convert_concatenation(self, op):
495
        """Convert TFLite concatenation"""
496 497 498 499 500 501 502 503 504 505 506 507 508 509
        try:
            from tflite.Operator import Operator
            from tflite.ConcatenationOptions import ConcatenationOptions
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.ActivationFunctionType import ActivationFunctionType
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) >= 1, "input tensors should greater than 1"
        in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors]

        output_tensors = self.get_output_tensors(op)
510 511
        assert len(output_tensors) == 1, "output tensors length should be 1"
        output_tensor = output_tensors[0]
512 513 514 515 516 517 518 519

        assert op.BuiltinOptionsType() == BuiltinOptions.ConcatenationOptions
        op_options = op.BuiltinOptions()
        concatenation_options = ConcatenationOptions()
        concatenation_options.Init(op_options.Bytes, op_options.Pos)
        concatenation_axis = concatenation_options.Axis()
        fused_activation_fn = concatenation_options.FusedActivationFunction()

520 521 522 523 524 525 526 527 528 529 530 531
        if not input_tensors[0].qnn_params:
            out = _op.concatenate(in_exprs, axis=concatenation_axis)
        else:
            input_scales = [input_tensor.qnn_params['scale'] for input_tensor in input_tensors]
            input_zero_points = \
                    [input_tensor.qnn_params['zero_point'] for input_tensor in input_tensors]
            out = _qnn.op.concatenate(in_exprs,
                                      input_scales=input_scales,
                                      input_zero_points=input_zero_points,
                                      output_scale=output_tensor.qnn_params['scale'],
                                      output_zero_point=output_tensor.qnn_params['zero_point'],
                                      axis=concatenation_axis)
532 533 534

        # if we have activation fn
        if fused_activation_fn != ActivationFunctionType.NONE:
535 536 537 538 539 540
            if not output_tensor.qnn_params:
                out = self.convert_fused_activation_function(out, fused_activation_fn)
            else:
                raise tvm.error.OpNotImplemented(
                    'Operator {} with fused activation is not supported yet.'
                    .format('qnn.op.concatenate'))
541 542
        return out

543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629
    def _convert_unary_elemwise(self, relay_op, op):
        """Generic method to convert TFLite unary elemwise functions"""
        try:
            from tflite.Operator import Operator
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 1, "input tensors length should be 1"

        input_tensor = input_tensors[0]
        in_expr = self.get_expr(input_tensor.tensor_idx)
        out = relay_op(in_expr)

        return out

    def convert_abs(self, op):
        """Convert TFLite ABS"""
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
                'TFlite quantized ABS operator is not supported yet.')
        return self._convert_unary_elemwise(_op.abs, op)

    def convert_ceil(self, op):
        """Convert TFLite CEIL"""
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
                'TFlite quantized CEIL operator is not supported yet.')
        return self._convert_unary_elemwise(_op.ceil, op)

    def convert_floor(self, op):
        """Convert TFLite FLOOR"""
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
                'TFlite quantized FLOOR operator is not supported yet.')
        return self._convert_unary_elemwise(_op.floor, op)

    def convert_exp(self, op):
        """Convert TFLite EXP"""
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
                'TFlite quantized EXP operator is not supported yet.')
        return self._convert_unary_elemwise(_op.exp, op)

    def convert_log(self, op):
        """Convert TFLite LOG"""
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
                'TFlite quantized LOG operator is not supported yet.')
        return self._convert_unary_elemwise(_op.log, op)

    def convert_sin(self, op):
        """Convert TFLite SIN"""
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
                'TFlite quantized SIN operator is not supported yet.')
        return self._convert_unary_elemwise(_op.sin, op)

    def convert_cos(self, op):
        """Convert TFLite COS"""
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
                'TFlite quantized COS operator is not supported yet.')
        return self._convert_unary_elemwise(_op.cos, op)

    def convert_sqrt(self, op):
        """Convert TFLite SQRT"""
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
                'TFlite quantized SQRT operator is not supported yet.')
        return self._convert_unary_elemwise(_op.sqrt, op)

    def convert_rsqrt(self, op):
        """Convert TFLite RSQRT"""
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
                'TFlite quantized RSQRT operator is not supported yet.')
        return self._convert_unary_elemwise(_op.rsqrt, op)

    def convert_neg(self, op):
        """Convert TFLite NEG"""
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
                'TFlite quantized NEG operator is not supported yet.')
        return self._convert_unary_elemwise(_op.negative, op)

630 631
    def _convert_elemwise(self, relay_op, op):
        """Generic method to Convert TFLite elemwise"""
632 633
        try:
            from tflite.Operator import Operator
634 635 636 637 638 639
            from tflite.AddOptions import AddOptions
            from tflite.SubOptions import SubOptions
            from tflite.MulOptions import MulOptions
            from tflite.DivOptions import DivOptions
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.ActivationFunctionType import ActivationFunctionType
640 641 642 643 644 645 646 647
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 2, "input tensors length should be 2"

        lhs_tensor = input_tensors[0]
648 649 650 651 652 653 654 655 656 657
        if self.has_expr(lhs_tensor.tensor_idx):
            # In most cases, we can assume that TOCO fuses elemwise operators
            # with constants - it means both will be tensors.
            lhs_expr = self.get_expr(lhs_tensor.tensor_idx)
        else:
            # However, in some corner cases, the elemwise operator is not fused,
            # we can receive as constant.
            lhs_type_str = self.get_tensor_type_str(lhs_tensor.tensor.Type())
            lhs_expr = self.exp_tab.new_const(self.get_tensor_value(lhs_tensor),
                                              dtype=lhs_type_str)
658 659 660

        rhs_tensor = input_tensors[1]
        if self.has_expr(rhs_tensor.tensor_idx):
661
            # In most cases, we can assume that TOCO fuses elemwise operators
662 663 664
            # with constants - it means both will be tensors.
            rhs_expr = self.get_expr(rhs_tensor.tensor_idx)
        else:
665
            # However, in some corner cases, the elemwise operator is not fused,
666 667 668 669
            # we can receive as constant.
            rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type())
            rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
                                              dtype=rhs_type_str)
670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688

        output_tensors = self.get_output_tensors(op)
        assert len(output_tensors) == 1, "output tensors length should be 1"
        output_tensor = output_tensors[0]

        # If quantized, extracts qnn params and call QNN add operator.
        if lhs_tensor.qnn_params:
            assert rhs_tensor.qnn_params, "Both tensors should be quantized."
            assert output_tensor.qnn_params, "Output tensor should be quantized."
            out = relay_op(lhs=lhs_expr,
                           rhs=rhs_expr,
                           lhs_scale=lhs_tensor.qnn_params['scale'],
                           lhs_zero_point=lhs_tensor.qnn_params['zero_point'],
                           rhs_scale=rhs_tensor.qnn_params['scale'],
                           rhs_zero_point=rhs_tensor.qnn_params['zero_point'],
                           output_scale=output_tensor.qnn_params['scale'],
                           output_zero_point=output_tensor.qnn_params['zero_point'])
        else:
            out = relay_op(lhs_expr, rhs_expr)
689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706

        # Options (fused_activation_function)
        options = None
        if op.BuiltinOptionsType() == BuiltinOptions.AddOptions:
            options = AddOptions()
        elif op.BuiltinOptionsType() == BuiltinOptions.SubOptions:
            options = SubOptions()
        elif op.BuiltinOptionsType() == BuiltinOptions.MulOptions:
            options = MulOptions()
        elif op.BuiltinOptionsType() == BuiltinOptions.DivOptions:
            options = DivOptions()

        if options is not None:
            op_options = op.BuiltinOptions()
            options.Init(op_options.Bytes, op_options.Pos)
            fused_activation_fn = options.FusedActivationFunction()
            # if we have activation fn
            if fused_activation_fn != ActivationFunctionType.NONE:
707 708 709
                if output_tensor.qnn_params:
                    raise tvm.error.OpNotImplemented(
                        'Elemwise operators with fused activation are not supported yet.')
710 711
                out = self.convert_fused_activation_function(out, fused_activation_fn)

712 713
        return out

714 715
    def convert_add(self, op):
        """Convert TFLite ADD"""
716 717 718
        # Check if the input tensor is quantized, call QNN op
        if self.is_quantized(op):
            return self._convert_elemwise(_qnn.op.add, op)
719 720
        return self._convert_elemwise(_op.add, op)

721 722
    def convert_sub(self, op):
        """Convert TFLite SUB"""
723 724 725
        # Check if the input tensor is quantized, call QNN op
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
726
                'TFlite quantized SUB operator is not supported yet.')
727 728
        return self._convert_elemwise(_op.subtract, op)

729 730
    def convert_mul(self, op):
        """Convert TFLite MUL"""
731 732
        # Check if the input tensor is quantized, call QNN op
        if self.is_quantized(op):
733
            return self._convert_elemwise(_qnn.op.mul, op)
734 735
        return self._convert_elemwise(_op.multiply, op)

736 737
    def convert_div(self, op):
        """Convert TFLite DIV"""
738 739 740
        # Check if the input tensor is quantized, call QNN op
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
741
                'TFlite quantized DIV operator is not supported yet.')
742 743 744
        return self._convert_elemwise(_op.divide, op)

    def convert_pow(self, op):
745
        """Convert TFLite POW"""
746 747 748
        # Check if the input tensor is quantized, call QNN op
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
749
                'TFlite quantized POW operator is not supported yet.')
750 751 752
        return self._convert_elemwise(_op.power, op)

    def convert_maximum(self, op):
753
        """Convert TFLite MAXIMUM"""
754 755 756
        # Check if the input tensor is quantized, call QNN op
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
757
                'TFlite quantized MAXIMUM operator is not supported yet.')
758 759 760
        return self._convert_elemwise(_op.maximum, op)

    def convert_minimum(self, op):
761
        """Convert TFLite MINIMUM"""
762 763 764
        # Check if the input tensor is quantized, call QNN op
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
765
                'TFlite quantized MINIMUM operator is not supported yet.')
766 767
        return self._convert_elemwise(_op.minimum, op)

768
    def convert_greater(self, op):
769
        """Convert TFLite GREATER"""
770 771 772
        # Check if the input tensor is quantized, call QNN op
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
773
                'TFlite quantized GREATER operator is not supported yet.')
774 775
        return self._convert_elemwise(_op.greater, op)

776
    def convert_squared_difference(self, op):
777
        """Convert TFLite SQUARED DIFFERENCE"""
778 779 780 781 782 783 784 785 786 787
        # Check if the input tensor is quantized, call QNN op
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
                'TFlite quantized squared difference operator is not supported yet.')
        difference = self._convert_elemwise(_op.subtract, op)
        # _convert_elemwise has guaranteed only have one output tensor
        exp_type = self.get_tensor_type_str(self.get_output_tensors(op)[0].tensor.Type())
        out = _op.power(difference, relay.const(2, exp_type))
        return out

788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822
    def convert_greater_equal(self, op):
        """Convert TFLite GREATER_EQUAL"""
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
                'TFlite quantized GREATER_EQUAL operator is not supported yet.')
        return self._convert_elemwise(_op.greater_equal, op)

    def convert_less(self, op):
        """Convert TFLite LESS"""
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
                'TFlite quantized LESS operator is not supported yet.')
        return self._convert_elemwise(_op.less, op)

    def convert_less_equal(self, op):
        """Convert TFLite LESS_EQUAL"""
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
                'TFlite quantized LESS_EQUAL operator is not supported yet.')
        return self._convert_elemwise(_op.less_equal, op)

    def convert_equal(self, op):
        """Convert TFLite EQUAL"""
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
                'TFlite quantized EQUAL operator is not supported yet.')
        return self._convert_elemwise(_op.equal, op)

    def convert_not_equal(self, op):
        """Convert TFLite NOT_EQUAL"""
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
                'TFlite quantized NOT_EQUAL operator is not supported yet.')
        return self._convert_elemwise(_op.not_equal, op)

823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849
    def _convert_logical_binary(self, relay_op, op):
        """Generic method to convert logical binary ops"""
        try:
            from tflite.Operator import Operator
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 2, "input tensors length should be 2"

        lhs_tensor = input_tensors[0]
        lhs_expr = self.get_expr(lhs_tensor.tensor_idx)
        rhs_tensor = input_tensors[1]
        rhs_expr = self.get_expr(rhs_tensor.tensor_idx)
        out = relay_op(lhs_expr, rhs_expr)

        return out

    def convert_logical_and(self, op):
        """Convert tflite LOGICAL_AND"""
        return self._convert_logical_binary(_op.logical_and, op)

    def convert_logical_or(self, op):
        """Convert tflite LOGICAL_OR"""
        return self._convert_logical_binary(_op.logical_or, op)

850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866
    def convert_zeros_like(self, op):
        """Convert TFLite ZEROS LIKE"""
        try:
            from tflite.Operator import Operator
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 1, "input tensors length should be 1"

        input_tensor = input_tensors[0]
        in_expr = self.get_expr(input_tensor.tensor_idx)
        out = _op.zeros_like(in_expr)

        return out

867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893
    def _convert_reduce(self, relay_op, op):
        """Generic method to Convert TFLite MEAN operators"""
        try:
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.Operator import Operator
            from tflite.ReducerOptions import ReducerOptions
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 2, "input tensors length should be 2"

        # input_tensor
        input_tensor = input_tensors[0]
        in_expr = self.get_expr(input_tensor.tensor_idx)

        # axis
        axis = tuple(self.get_tensor_value(input_tensors[1]))

        # Options - keep_dims (bool)
        assert op.BuiltinOptionsType() == BuiltinOptions.ReducerOptions
        reduce_options = ReducerOptions()
        op_options = op.BuiltinOptions()
        reduce_options.Init(op_options.Bytes, op_options.Pos)
        keep_dims = reduce_options.KeepDims()

894 895 896
        if input_tensor.qnn_params:
            in_expr = _op.cast(in_expr, "int32")

897
        out = relay_op(in_expr, axis, keep_dims)
898 899 900 901 902 903 904 905 906 907 908 909 910 911

        # Finally if the reduce is quantized. Add a requantize at the end.
        output_tensors = self.get_output_tensors(op)
        assert len(output_tensors) == 1, "output tensors length should be 1"
        output_tensor = output_tensors[0]
        output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
        if output_tensor.qnn_params:
            out = _qnn.op.requantize(out,
                                     input_scale=input_tensor.qnn_params['scale'],
                                     input_zero_point=input_tensor.qnn_params['zero_point'],
                                     output_scale=output_tensor.qnn_params['scale'],
                                     output_zero_point=output_tensor.qnn_params['zero_point'],
                                     out_dtype=output_tensor_type_str)

912 913 914 915 916 917 918 919 920 921 922 923 924 925
        return out

    def _convert_reduce_min(self, op):
        return self._convert_reduce(_op.reduce.min, op)

    def _convert_reduce_max(self, op):
        return self._convert_reduce(_op.reduce.max, op)

    def _convert_reduce_mean(self, op):
        return self._convert_reduce(_op.reduce.mean, op)

    def _convert_reduce_prod(self, op):
        return self._convert_reduce(_op.reduce.prod, op)

926 927 928
    def _convert_reduce_sum(self, op):
        return self._convert_reduce(_op.reduce.sum, op)

929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947
    def convert_fully_connected(self, op):
        """Convert TFLite fully connected"""
        try:
            from tflite.Operator import Operator
            from tflite.FullyConnectedOptions import FullyConnectedOptions
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.TensorType import TensorType
            from tflite.ActivationFunctionType import ActivationFunctionType
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) >= 2, "input tensors length should be >= 2"

        input_tensor = input_tensors[0]
        input_tensor_idx = input_tensor.tensor_idx
        weight_tensor = input_tensors[1]

948 949 950 951 952 953
        output_tensors = self.get_output_tensors(op)
        assert len(output_tensors) == 1, "output tensors length should be 1"
        output_tensor = output_tensors[0]
        output_tensor_type = output_tensor.tensor.Type()
        output_tensor_type_str = self.get_tensor_type_str(output_tensor_type)

954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980
        input_tensor_shape = input_tensor.tensor.ShapeAsNumpy()
        weight_tensor_shape = weight_tensor.tensor.ShapeAsNumpy()

        # reshape input tensor from N H W C to N H*W*C
        input_size_per_batch = 1
        for s in range(1, len(input_tensor_shape)):
            input_size_per_batch *= input_tensor_shape[s]
        assert input_size_per_batch == weight_tensor_shape[1], \
            "input size and weight size are mismatched"
        target_shape = tuple((input_tensor_shape[0], input_size_per_batch))
        in_expr = self.get_expr(input_tensor_idx)
        in_expr = _op.reshape(in_expr, target_shape)

        assert op.BuiltinOptionsType() == BuiltinOptions.FullyConnectedOptions
        op_options = op.BuiltinOptions()
        fully_connected_options = FullyConnectedOptions()
        fully_connected_options.Init(op_options.Bytes, op_options.Pos)
        fused_activation_fn = fully_connected_options.FusedActivationFunction()

        # weight tensor type should be UINT8 (quantization) or FLOAT32
        weight_tensor_type = weight_tensor.tensor.Type()
        assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
        weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)

        weight_value = self.get_tensor_value(weight_tensor)
        weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)

981 982 983 984
        if input_tensor.qnn_params:
            out = _qnn.op.dense(in_expr, weight_expr,
                                input_zero_point=input_tensor.qnn_params['zero_point'],
                                kernel_zero_point=weight_tensor.qnn_params['zero_point'],
985 986
                                input_scale=input_tensor.qnn_params['scale'],
                                kernel_scale=weight_tensor.qnn_params['scale'],
987 988 989
                                out_dtype='int32')
        else:
            out = _op.nn.dense(in_expr, weight_expr)
990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003

        # if we have bias
        if len(input_tensors) == 3:
            bias_tensor = input_tensors[2]
            bias_tensor_type = bias_tensor.tensor.Type()
            # bias tensor type should be INT32 (quantization) or FLOAT32
            assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
            bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
            bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor),
                                               dtype=bias_tensor_type_str)
            out = _op.nn.bias_add(out, bias_expr)

        # If we have fused activations
        if fused_activation_fn != ActivationFunctionType.NONE:
1004 1005 1006 1007 1008 1009 1010 1011 1012
            if not output_tensor.qnn_params:
                out = self.convert_fused_activation_function(out, fused_activation_fn)
            else:
                raise tvm.error.OpNotImplemented(
                    'Operator {} with fused activation is not supported yet.'
                    .format('qnn.op.dense'))

        # Finally if the dense is quantized. Add a requantize at the end.
        if output_tensor.qnn_params:
1013 1014 1015 1016 1017 1018 1019
            data_scale = input_tensor.qnn_params['scale']
            weight_scale = weight_tensor.qnn_params['scale']
            data_scale_val = get_scalar_from_constant(data_scale)
            weight_scale_val = get_scalar_from_constant(weight_scale)
            new_input_scale_val = data_scale_val * weight_scale_val
            new_input_scale = relay.const(new_input_scale_val, 'float32')
            new_input_zero_point = relay.const(0, 'int32')
1020
            out = _qnn.op.requantize(out,
1021 1022
                                     input_scale=new_input_scale,
                                     input_zero_point=new_input_zero_point,
1023 1024 1025
                                     output_scale=output_tensor.qnn_params['scale'],
                                     output_zero_point=output_tensor.qnn_params['zero_point'],
                                     out_dtype=output_tensor_type_str)
1026 1027 1028

        return out

1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065
    def convert_squeeze(self, op):
        """Convert TFLite squeeze"""
        try:
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.Operator import Operator
            from tflite.SqueezeOptions import SqueezeOptions
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        output_tensors = self.get_output_tensors(op)
        assert len(input_tensors) == 1, "input tensors length should be 1"
        assert len(output_tensors) == 1, "output tensors length should be 1"
        input_tensor = input_tensors[0]
        input_tensor_idx = input_tensor.tensor_idx

        assert op.BuiltinOptionsType() == BuiltinOptions.SqueezeOptions
        op_options = op.BuiltinOptions()
        squeeze_options = SqueezeOptions()
        squeeze_options.Init(op_options.Bytes, op_options.Pos)
        squeeze_axis = squeeze_options.SqueezeDimsAsNumpy()

        in_expr = self.get_expr(input_tensor_idx)
        out = _op.squeeze(in_expr, axis=tuple(squeeze_axis))

        return out

    def convert_fused_activation_function(self, in_expr, fused_activation_fn):
        """Convert TFLite fused activation function"""
        try:
            from tflite.ActivationFunctionType import ActivationFunctionType
        except ImportError:
            raise ImportError("The tflite package must be installed")
        assert fused_activation_fn != ActivationFunctionType.NONE
        if fused_activation_fn == ActivationFunctionType.RELU6:
            return _op.clip(in_expr, a_min=0, a_max=6)
1066
        if fused_activation_fn == ActivationFunctionType.RELU:
1067
            return _op.nn.relu(in_expr)
1068
        if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1:
1069
            return _op.clip(in_expr, a_min=-1, a_max=1)
1070
        if fused_activation_fn == ActivationFunctionType.TANH:
1071
            return _op.tanh(in_expr)
1072
        fused_activation_fn_str = self.activation_fn_type[fused_activation_fn]
1073 1074
        raise tvm.error.OpNotImplemented(
            'Operator {} is not supported for frontend TFLite.'.format(fused_activation_fn_str))
1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096

    def convert_conv(self, op, conv_type):
        """convolution implementation."""
        try:
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.ActivationFunctionType import ActivationFunctionType
            from tflite.TensorType import TensorType
            from tflite.Operator import Operator
            from tflite.Conv2DOptions import Conv2DOptions
            from tflite.DepthwiseConv2DOptions import DepthwiseConv2DOptions
            from tflite.Padding import Padding
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) >= 2, "input tensors length should be >= 2"

        input_tensor = input_tensors[0]
        input_tensor_idx = input_tensor.tensor_idx
        weight_tensor = input_tensors[1]

1097 1098 1099 1100 1101 1102
        output_tensors = self.get_output_tensors(op)
        assert len(output_tensors) == 1, "output tensors length should be 1"
        output_tensor = output_tensors[0]
        output_tensor_type = output_tensor.tensor.Type()
        output_tensor_type_str = self.get_tensor_type_str(output_tensor_type)

1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116
        is_depthwise_conv = False
        if conv_type == 'conv2d':
            assert op.BuiltinOptionsType() == BuiltinOptions.Conv2DOptions
            op_options = op.BuiltinOptions()
            conv_options = Conv2DOptions()
            conv_options.Init(op_options.Bytes, op_options.Pos)
        elif conv_type == 'depthwise':
            is_depthwise_conv = True
            assert op.BuiltinOptionsType() == BuiltinOptions.DepthwiseConv2DOptions
            op_options = op.BuiltinOptions()
            conv_options = DepthwiseConv2DOptions()
            conv_options.Init(op_options.Bytes, op_options.Pos)
            depth_multiplier = conv_options.DepthMultiplier()
        else:
1117 1118
            raise tvm.error.OpNotImplemented(
                'Operator {} is not supported for frontend TFLite.'.format(conv_type))
1119 1120 1121 1122 1123 1124 1125 1126

        stride_h = conv_options.StrideH()
        stride_w = conv_options.StrideW()
        dilation_h = conv_options.DilationHFactor()
        dilation_w = conv_options.DilationWFactor()
        padding = conv_options.Padding()
        fused_activation_fn = conv_options.FusedActivationFunction()

1127
        _, input_h, input_w, input_c = input_tensor.tensor.ShapeAsNumpy()
1128 1129

        if is_depthwise_conv:
1130 1131 1132 1133
            # TFLite depthwise convolution kernel layout is:
            # 1 KH KW C(input_c * depth_multiplier)
            _, kernel_h, kernel_w, in_channels = weight_tensor.tensor.ShapeAsNumpy()
            assert in_channels == input_c * depth_multiplier
1134 1135 1136 1137 1138 1139 1140 1141 1142
        else:
            output_channels, kernel_h, kernel_w, _ = weight_tensor.tensor.ShapeAsNumpy()

        dilated_kernel_h = dilation_h * (kernel_h - 1) + 1
        dilated_kernel_w = dilation_w * (kernel_w - 1) + 1

        params = {'kernel_size': [kernel_h, kernel_w],
                  'strides': [stride_h, stride_w],
                  'dilation': [dilation_h, dilation_w],
1143 1144
                  'padding': [0, 0],
                  'data_layout': 'NHWC'}
1145 1146

        if is_depthwise_conv:
1147
            params['channels'] = int(in_channels)
1148
            params['groups'] = int(in_channels)
1149
            params['kernel_layout'] = 'HWOI'
1150 1151
        else:
            params['channels'] = int(output_channels)
1152
            params['kernel_layout'] = 'HWIO'
1153 1154 1155

        # weight tensor type should be UINT8 (quantization) or FLOAT32
        weight_tensor_type = weight_tensor.tensor.Type()
1156
        assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
1157 1158 1159 1160 1161
        weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)

        in_expr = self.get_expr(input_tensor_idx)
        weight_value = self.get_tensor_value(weight_tensor)

1162 1163 1164 1165 1166 1167 1168 1169 1170 1171
        # TFLite kernel layout:
        # convolution:
        # OC KH KW IC, we require KH KW IC OC (HWIO)
        # depthwise convolution:
        # 1 KH KW C(input_c * depth_multiplier), we require
        # KH KW IC M (depth_multiplier) (HWOI)
        if is_depthwise_conv:
            weight_value = weight_value.reshape(kernel_h, kernel_w, input_c, depth_multiplier)
        else:
            weight_value = weight_value.transpose((1, 2, 3, 0))
1172 1173 1174 1175 1176 1177 1178 1179

        weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)

        if padding == Padding.VALID:
            pass
        elif padding == Padding.SAME:
            pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h)
            pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w)
1180 1181
            do_pad = not (pad_top == 0 and pad_bottom == 0 and pad_left == 0 and pad_right == 0)
            if do_pad:
1182 1183 1184
                pad_value = 0
                if input_tensor.qnn_params:
                    pad_value = get_scalar_from_constant(input_tensor.qnn_params['zero_point'])
1185 1186 1187
                in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0),
                                                              (pad_top, pad_bottom),
                                                              (pad_left, pad_right),
1188 1189
                                                              (0, 0)), pad_value=float(pad_value))

1190
        else:
1191
            raise tvm.error.OpAttributeUnImplemented(
1192
                'Padding format {} is not supported for operator Conv.'.format(padding))
1193

1194 1195 1196 1197 1198
        if input_tensor.qnn_params:
            qnn_conv2d_params = dict(params)
            qnn_conv2d_params['input_zero_point'] = input_tensor.qnn_params['zero_point']
            qnn_conv2d_params['kernel_zero_point'] = weight_tensor.qnn_params['zero_point']
            qnn_conv2d_params['out_dtype'] = 'int32'
1199 1200
            qnn_conv2d_params['input_scale'] = input_tensor.qnn_params['scale']
            qnn_conv2d_params['kernel_scale'] = weight_tensor.qnn_params['scale']
1201 1202 1203
            out = _qnn.op.conv2d(in_expr, weight_expr, **qnn_conv2d_params)
        else:
            out = _op.nn.conv2d(in_expr, weight_expr, **params)
1204 1205 1206 1207 1208 1209

        # if we have bias
        if len(input_tensors) == 3:
            bias_tensor = input_tensors[2]
            bias_tensor_type = bias_tensor.tensor.Type()
            # bias tensor type should be INT32 (quantization) or FLOAT32
1210
            assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
1211 1212 1213
            bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
            bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor),
                                               dtype=bias_tensor_type_str)
1214 1215
            channel_axis = 3
            out = _op.nn.bias_add(out, bias_expr, axis=channel_axis)
1216 1217 1218

        # If we have fused activations
        if fused_activation_fn != ActivationFunctionType.NONE:
1219 1220 1221 1222 1223 1224 1225 1226 1227
            if not output_tensor.qnn_params:
                out = self.convert_fused_activation_function(out, fused_activation_fn)
            else:
                raise tvm.error.OpNotImplemented(
                    'Operator {} with fused activation is not supported yet.'
                    .format('qnn.op.conv2d'))

        # Finally if the conv is quantized. Add a requantize at the end.
        if output_tensor.qnn_params:
1228 1229 1230 1231 1232 1233 1234
            data_scale = input_tensor.qnn_params['scale']
            weight_scale = weight_tensor.qnn_params['scale']
            data_scale_val = get_scalar_from_constant(data_scale)
            weight_scale_val = get_scalar_from_constant(weight_scale)
            new_input_scale_val = data_scale_val * weight_scale_val
            new_input_scale = relay.const(new_input_scale_val, 'float32')
            new_input_zero_point = relay.const(0, 'int32')
1235
            out = _qnn.op.requantize(out,
1236 1237
                                     input_scale=new_input_scale,
                                     input_zero_point=new_input_zero_point,
1238 1239 1240
                                     output_scale=output_tensor.qnn_params['scale'],
                                     output_zero_point=output_tensor.qnn_params['zero_point'],
                                     out_dtype=output_tensor_type_str)
1241 1242 1243

        return out

1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280
    def convert_split(self, op):
        """split implementation."""
        try:
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.Operator import Operator
            from tflite.SplitOptions import SplitOptions
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)

        assert len(input_tensors) == 2, "input tensors length should be == 2"

        axis_tensor = input_tensors[0]
        split_axis = self.get_tensor_value(axis_tensor)
        input_tensor = input_tensors[1]
        input_tensor_idx = input_tensor.tensor_idx

        assert op.BuiltinOptionsType() == BuiltinOptions.SplitOptions
        op_options = op.BuiltinOptions()
        split_options = SplitOptions()
        split_options.Init(op_options.Bytes, op_options.Pos)
        num_splits = split_options.NumSplits()

        in_expr = self.get_expr(input_tensor_idx)
        out = _op.split(in_expr, num_splits, axis=int(split_axis))
        # Relay does not like a TupleWrapper of 1 element, further this
        # only shows up with tf1.13 if we use a split with num_splits==1.
        # In tf 1.14 this doesn't appear as it is automatically a reshape
        # operation.
        if isinstance(out, _expr.TupleWrapper):
            if out.size == 1:
                out = out[0]

        return out

1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309
    def convert_slice(self, op):
        """Convert TFLite SLICE"""
        try:
            from tflite.Operator import Operator
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 3, "input tensors length should be == 3"
        input_tensor = input_tensors[0]
        in_expr = self.get_expr(input_tensor.tensor_idx)

        begin = list(self.get_tensor_value(input_tensors[1]))
        size = list(self.get_tensor_value(input_tensors[2]))
        # strided_slice(Relay) needs the slice's end indices, not the size
        end = size
        input_tensor_shape = input_tensor.tensor.ShapeAsNumpy()
        input_tensor_rank = len(input_tensor_shape)
        for i in range(input_tensor_rank):
            if size[i] == -1:
                end[i] = input_tensor_shape[i]
            else:
                end[i] += begin[i]

        out = _op.strided_slice(in_expr, begin, end)

        return out

1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334
    def convert_transpose(self, op):
        """transpose implementation."""
        try:
            from tflite.Operator import Operator
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 2, "input tensors length should be 2"
        input_tensor = input_tensors[0]
        input_tensor_idx = input_tensor.tensor_idx

        in_expr = self.get_expr(input_tensor_idx)

        # axis
        in_axis = tuple(self.get_tensor_value(input_tensors[1]))

        if not in_axis:
            out = _op.transpose(in_expr)
        else:
            out = _op.transpose(in_expr, in_axis)

        return out

1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359
    def convert_cast(self, op):
        """Convert TFLite CAST"""
        try:
            from tflite.Operator import Operator
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.CastOptions import CastOptions
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 1, "input tensors length should be 1"
        input_tensor = input_tensors[0]
        in_expr = self.get_expr(input_tensor.tensor_idx)

        assert op.BuiltinOptionsType() == BuiltinOptions.CastOptions
        op_options = op.BuiltinOptions()
        cast_options = CastOptions()
        cast_options.Init(op_options.Bytes, op_options.Pos)
        cast_dtype = cast_options.OutDataType()

        out = _op.cast(in_expr, self.get_tensor_type_str(cast_dtype))

        return out

1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381
    def convert_tile(self, op):
        """tile implementation."""
        try:
            from tflite.Operator import Operator
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 2, "input tensors length should be 2"
        input_tensor = input_tensors[0]
        input_tensor_idx = input_tensor.tensor_idx

        in_expr = self.get_expr(input_tensor_idx)

        # reps (tuple of int) – The number of times repeating the tensor data.
        reps = tuple(self.get_tensor_value(input_tensors[1]))

        out = _op.tile(in_expr, reps)

        return out

1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398
    def convert_pool2d(self, op, pool_type):
        """pool2d implementation."""
        try:
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.ActivationFunctionType import ActivationFunctionType
            from tflite.Operator import Operator
            from tflite.Pool2DOptions import Pool2DOptions
            from tflite.Padding import Padding
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 1, "input tensors length should be 1"
        input_tensor = input_tensors[0]
        input_tensor_idx = input_tensor.tensor_idx

1399 1400 1401 1402 1403 1404
        output_tensors = self.get_output_tensors(op)
        assert len(output_tensors) == 1, "output tensors should be 1"
        output_tensor = output_tensors[0]
        output_tensor_type = output_tensor.tensor.Type()
        output_tensor_type_str = self.get_tensor_type_str(output_tensor_type)

1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417
        assert op.BuiltinOptionsType() == BuiltinOptions.Pool2DOptions
        op_options = op.BuiltinOptions()
        pool2d_options = Pool2DOptions()
        pool2d_options.Init(op_options.Bytes, op_options.Pos)
        stride_h = pool2d_options.StrideH()
        stride_w = pool2d_options.StrideW()
        padding = pool2d_options.Padding()
        filter_h = pool2d_options.FilterHeight()
        filter_w = pool2d_options.FilterWidth()
        fused_activation_fn = pool2d_options.FusedActivationFunction()

        params = {'pool_size': (filter_h, filter_w),
                  'strides': (stride_h, stride_w),
1418 1419
                  'padding': [0, 0],
                  'layout': 'NHWC'}
1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430

        in_expr = self.get_expr(input_tensor_idx)

        _, input_h, input_w, _ = input_tensor.tensor.ShapeAsNumpy()
        if padding == Padding.VALID:
            pass
        elif padding == Padding.SAME:
            pad_top, pad_bottom = get_pad_value(input_h, filter_h, stride_h)
            pad_left, pad_right = get_pad_value(input_w, filter_w, stride_w)
            params['padding'] = [pad_top, pad_left, pad_bottom, pad_right]
        else:
1431
            raise tvm.error.OpAttributeUnImplemented(
1432
                'Padding format {} for operator Pool2D is not supported.'.format(padding))
1433 1434

        if pool_type == "average":
1435 1436 1437 1438 1439 1440 1441 1442 1443
            if input_tensor.qnn_params:
                assert self.has_same_qnn_params(input_tensor, output_tensor), \
                        'TFLite avg_pool2dreshape requires input and output scale' \
                        'and zero points to be equal'
                out = _op.cast(in_expr, dtype="int32")
                out = _op.nn.avg_pool2d(out, **params)
                out = _op.cast(out, dtype=output_tensor_type_str)
            else:
                out = _op.nn.avg_pool2d(in_expr, **params)
1444
        elif pool_type == "max":
1445 1446 1447
            if input_tensor.qnn_params:
                assert self.has_same_qnn_params(input_tensor, output_tensor), \
                        "qnn.op.max_pool2d requires input and output qnn params to be same"
1448 1449
            out = _op.nn.max_pool2d(in_expr, **params)
        else:
1450 1451
            raise tvm.error.OpNotImplemented(
                'Operator {} is not supported for frontend TFLite.'.format(pool_type + ' pool'))
1452 1453 1454

        # If we have fused activations
        if fused_activation_fn != ActivationFunctionType.NONE:
1455 1456 1457 1458
            if input_tensor.qnn_params:
                raise tvm.error.OpNotImplemented(
                    'Operator {} with fused activation is not supported yet.'
                    .format('qnn.op.pool2d'))
1459
            out = self.convert_fused_activation_function(out, fused_activation_fn)
1460 1461
        return out

1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472
    def convert_pad(self, op):
        """Convert TFLite PAD"""
        try:
            from tflite.Operator import Operator
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 2, "input tensors length should be 2"

1473
        # TFLite PAD only support CONSTANT mode and does not support constant_values parameter.
1474 1475 1476 1477 1478 1479 1480 1481 1482
        # tensor
        input_tensor = input_tensors[0]
        in_expr = self.get_expr(input_tensor.tensor_idx)

        # paddings
        pad_list = self.get_tensor_value(input_tensors[1])
        # convert list of lists to tuple of tuples
        paddings = tuple(tuple(l) for l in pad_list)

1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495
        # Set the pad value
        pad_value = 0
        if input_tensor.qnn_params:
            # Check that input and output tensor have same qnn params.
            output_tensors = self.get_output_tensors(op)
            output_tensor = output_tensors[0]
            assert self.has_same_qnn_params(input_tensor, output_tensor), \
                    "TFLite reshape requires input and output scale and zero points to be equal"

            # The pad value for quantized pad is the input zero point.
            pad_value = float(input_tensor.qnn_params['zero_point'].data.asnumpy())

        out = _op.nn.pad(in_expr, pad_width=paddings, pad_value=pad_value)
1496 1497
        return out

1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535
    def convert_mirror_pad(self, op):
        """Convert TFLite MIRROR_PAD"""
        try:
            from tflite.Operator import Operator
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.MirrorPadOptions import MirrorPadOptions
        except ImportError:
            raise ImportError("The tflite package must be installed")

        # the quantized form MirrorPad is not yet implemented in TFLite.
        if self.is_quantized(op):
            raise tvm.error.OpNotImplemented(
                'TFlite quantized MIRROR_PAD operator is not supported yet.')

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 2, "input tensors length should be 2"

        # tensor
        input_tensor = input_tensors[0]
        in_expr = self.get_expr(input_tensor.tensor_idx)

        # paddings
        pad_list = self.get_tensor_value(input_tensors[1])
        # convert list of lists to tuple of tuples
        paddings = tuple(tuple(l) for l in pad_list)

        assert op.BuiltinOptionsType() == BuiltinOptions.MirrorPadOptions
        op_options = op.BuiltinOptions()
        mirror_pad_options = MirrorPadOptions()
        mirror_pad_options.Init(op_options.Bytes, op_options.Pos)
        mode_byte = mirror_pad_options.Mode()

        mode = "REFLECT" if mode_byte == 0 else "SYMMETRIC"
        out = _op.nn.mirror_pad(in_expr, paddings, mode)

        return out

1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550
    def convert_pack(self, op):
        """Convert TFLite pack"""
        try:
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.Operator import Operator
            from tflite.PackOptions import PackOptions
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) >= 1, "input tensors should greater than 1"
        in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors]

        output_tensors = self.get_output_tensors(op)
1551
        assert len(output_tensors) == 1, "output tensors length should be 1"
1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562

        assert op.BuiltinOptionsType() == BuiltinOptions.PackOptions
        op_options = op.BuiltinOptions()
        pack_options = PackOptions()
        pack_options.Init(op_options.Bytes, op_options.Pos)
        pack_axis = pack_options.Axis()

        in_exprs_reshaped = [_op.expand_dims(i, axis=pack_axis, num_newaxis=1) for i in in_exprs]
        out = _op.concatenate(in_exprs_reshaped, pack_axis)
        return out

1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606
    def convert_unpack(self, op):
        """Convert TFLite unpack"""
        try:
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.Operator import Operator
            from tflite.UnpackOptions import UnpackOptions
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 1, "input tensors length should be 1"
        input_tensor = input_tensors[0]
        in_expr = self.get_expr(input_tensor.tensor_idx)
        assert op.BuiltinOptionsType() == BuiltinOptions.UnpackOptions
        op_options = op.BuiltinOptions()
        unpack_options = UnpackOptions()
        unpack_options.Init(op_options.Bytes, op_options.Pos)
        num_unpacks = unpack_options.Num()
        unpack_axis = unpack_options.Axis()

        # Relay doesn't support 'unpack' operator so we use 'split' & 'squeeze' instead.
        # We have to do 'squeeze' along the split axis but Relay expects
        # squeeze_axis to be either None or List.
        squeeze_axis = None if unpack_axis == 0 else [unpack_axis]

        # Relay doesn't like TupleWrapper of 1 element so we isolate the case of unpacking
        # a tensor by an axis with len(axis) == 1. For reference see convert_split().
        # Such unpacking will result in the same tensor so we omit 'split' and only squeeze
        # along the axis of dim == 1.
        if num_unpacks == 1:
            squeezed = _op.squeeze(in_expr, axis=squeeze_axis)
            if isinstance(squeezed, _expr.TupleWrapper):
                squeezed = squeezed[0]
        else:
            splitted = _op.split(in_expr,
                                 indices_or_sections=num_unpacks,
                                 axis=unpack_axis)
            squeezed = _expr.TupleWrapper(
                _expr.Tuple([_op.squeeze(split_item, axis=squeeze_axis) \
                             for split_item in splitted]), len(splitted))

        return squeezed

1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716
    def convert_batch_to_space_nd(self, op):
        """batch_to_space_nd implementation."""
        try:
            from tflite.Operator import Operator
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 3, "input tensors length should be 3"

        input_tensor = input_tensors[0]
        input_tensor_idx = input_tensor.tensor_idx
        in_expr = self.get_expr(input_tensor_idx)

        input_shape = list(input_tensor.tensor.ShapeAsNumpy())
        batch = input_shape[0]

        block_shape = list(self.get_tensor_value(input_tensors[1]))
        M = len(block_shape)

        crops = list(self.get_tensor_value(input_tensors[2]))

        # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d:
        # Reshape input to reshaped of shape
        shape1 = block_shape + [batch // np.prod(block_shape)] + input_shape[1:]
        reshaped = _op.reshape(in_expr, newshape=shape1)

        # Permute dimensions of reshaped to produce permuted of shape
        axes = [M] + [axis for i in range(M) for axis in [M + i + 1, i]] + \
            list(range(2 * M + 1, len(shape1)))
        permuted = _op.transpose(reshaped, axes=axes)

        # Reshape permuted to produce reshaped_permuted of shape
        shape2 = [0] + [-3] * M + [-2]
        reshaped_permuted = _op.reshape(permuted, newshape=shape2)

        # Crop the start and end of dimensions [1, ..., M] of reshaped_permuted according to crops
        # to produce the output of shape:
        reshaped_permuted_shape = _infer_shape(reshaped_permuted)
        cropped = reshaped_permuted
        for axis in range(1, M + 1):
            crop = crops[axis - 1]
            if (crop != [0, 0]).all():
                indices = _op.arange(
                    _expr.const(crop[0]),
                    _expr.const(reshaped_permuted_shape[axis] - crop[1]),
                    dtype='int32'
                )
                cropped = _op.take(cropped, indices=indices, axis=axis)

        return cropped

    def convert_space_to_batch_nd(self, op):
        """space_to_batch_nd implementation."""
        try:
            from tflite.Operator import Operator
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 3, "input tensors length should be 3"

        input_tensor = input_tensors[0]
        input_tensor_idx = input_tensor.tensor_idx
        in_expr = self.get_expr(input_tensor_idx)

        input_shape = list(input_tensor.tensor.ShapeAsNumpy())
        batch = input_shape[0]
        N = len(input_shape)

        block_shape = list(self.get_tensor_value(input_tensors[1]))
        M = len(block_shape)

        paddings = list(self.get_tensor_value(input_tensors[2]))

        # From https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd:
        # Zero-pad the start and end of dimensions [1, ..., M] of the input according to paddings
        # to produce padded of shape padded_shape.
        remaining_shape_length = N - M - 1
        padded_list = [(0, 0)] + paddings + [(0, 0)] * remaining_shape_length

        padded_shape = []
        for element in padded_list:
            if isinstance(element, np.ndarray):
                element = element.tolist()

            padded_shape.append(element)

        padded_shape = tuple(padded_shape)
        padded = _op.nn.pad(in_expr, pad_width=tuple(padded_shape))

        # Reshape padded to reshaped_padded of shape:
        shape1 = [batch] + [item for i in range(M) for item in [-4, -1, block_shape[i]]] + [-2]
        reshaped_padded = _op.reshape(padded, newshape=shape1)

        # Permute dimensions of reshaped_padded to produce permuted_reshaped_padded of shape:
        axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \
            list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length))
        permuted_reshaped_padded = _op.transpose(reshaped_padded, axes=axes)
        permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded)

        # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
        # producing an output tensor of shape:
        shape2 = [batch * np.prod(block_shape)] + list(permuted_reshaped_padded_shape)[M + 1:]
        reshaped_permuted_reshaped_padded = _op.reshape(permuted_reshaped_padded, newshape=shape2)

        return reshaped_permuted_reshaped_padded

1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731
    def convert_prelu(self, op):
        """Convert TFLite PReLU"""
        try:
            from tflite.Operator import Operator
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 2, "input tensors length should be 2"

        input_tensor = input_tensors[0]
        alpha_tensor = input_tensors[1]
        alpha_tensor_type = alpha_tensor.tensor.Type()
        alpha_tensor_type_str = self.get_tensor_type_str(alpha_tensor_type)
1732
        alpha_expr = self.exp_tab.new_const(self.get_tensor_value(alpha_tensor).flatten(),
1733 1734 1735 1736 1737 1738
                                            dtype=alpha_tensor_type_str)
        in_expr = self.get_expr(input_tensor.tensor_idx)
        out = _op.nn.prelu(in_expr, alpha_expr, axis=3)

        return out

1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816
    def convert_transpose_conv(self, op):
        """Convert TFLite TRANSPOSE_CONV"""
        try:
            from tflite.BuiltinOptions import BuiltinOptions
            from tflite.TensorType import TensorType
            from tflite.Operator import Operator
            from tflite.TransposeConvOptions import TransposeConvOptions
            from tflite.Padding import Padding
        except ImportError:
            raise ImportError("The tflite package must be installed")

        assert isinstance(op, Operator)
        input_tensors = self.get_input_tensors(op)
        assert len(input_tensors) == 3, "input tensors length should be 3"

        # Input (data) Tensor. NHWC layout
        input_tensor = input_tensors[2]
        _, _, _, input_c = input_tensor.tensor.ShapeAsNumpy()
        # Weights tensor. TFLite uses OHWI layout
        weights_tensor = input_tensors[1]
        out_channels, kernel_h, kernel_w, in_channels = weights_tensor.tensor.ShapeAsNumpy()
        assert input_c == in_channels, \
            "Input channel in the filter should match to channel in the input"
        # output_shape Tensor. NHWC layout
        output_shape_tensor = input_tensors[0]

        output_tensors = self.get_output_tensors(op)
        assert len(output_tensors) == 1, "output tensors length should be 1"
        output_tensor = output_tensors[0]
        output_tensor_type = output_tensor.tensor.Type()
        output_tensor_type_str = self.get_tensor_type_str(output_tensor_type)

        assert op.BuiltinOptionsType() == BuiltinOptions.TransposeConvOptions
        op_options = op.BuiltinOptions()
        deconv_options = TransposeConvOptions()
        deconv_options.Init(op_options.Bytes, op_options.Pos)

        padding = deconv_options.Padding()
        stride_h = deconv_options.StrideH()
        stride_w = deconv_options.StrideW()
        assert padding in (Padding.VALID, Padding.SAME), \
            'Padding format {} is not supported for operator TRANSPOSE_CONV'.format(padding)

        # Data
        in_expr = self.get_expr(input_tensor.tensor_idx)

        # Weights
        weights_tensor_type = weights_tensor.tensor.Type()
        # weights tensor type should be UINT8 (quantization) or FLOAT32
        assert weights_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
        weight_tensor_type_str = self.get_tensor_type_str(weights_tensor_type)
        weight_value_ohwi = self.get_tensor_value(weights_tensor)
        # Relay kernel_layout should be OIHW
        # Relay weights layout should be different from kernel_layout - it should be IOHW
        weight_value_iohw = np.transpose(weight_value_ohwi, (3, 0, 1, 2))
        weight_expr_iohw = self.exp_tab.new_const(weight_value_iohw, dtype=weight_tensor_type_str)

        # Output shape value
        output_shape_value = self.get_tensor_value(output_shape_tensor)
        # Relay expects filter output channel to match to output tensor channel.
        assert out_channels == output_shape_value[3], \
            "Output channel in the filter should match to channel in the output_shape"

        # TF frontend supports 'SAME' padding for kernel 1x1 only. Lets do the same here
        if padding == Padding.SAME:
            assert (kernel_h, kernel_w) == (1, 1), \
                "SAME padding is supported for kernel (1,1) only"

        out = _op.nn.conv2d_transpose(in_expr, weight_expr_iohw,
                                      strides=(stride_h, stride_w),
                                      channels=int(out_channels),
                                      kernel_size=(int(kernel_h), int(kernel_w)),
                                      data_layout="NHWC",
                                      kernel_layout="OIHW",
                                      out_dtype=output_tensor_type_str)

        return out

1817 1818 1819
    def get_expr(self, input_tensor_idx):
        return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))

1820 1821 1822
    def has_expr(self, input_tensor_idx):
        return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx))

1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862
def build_str_map(obj):
    """Build string map of TFLite enum int value

    Parameters
    ----------
    obj:
        TFLite class which contains enum int value, such as BuiltInOptions

    Returns
    -------
        String representation map of TFLite class enum int value
    """
    ret = {}
    for field_name in dir(obj):
        if not field_name.startswith('_'):
            field_value = getattr(obj, field_name)
            if isinstance(field_value, int):
                ret[field_value] = field_name
    return ret

# SAME padding: https://www.tensorflow.org/api_guides/python/nn
def get_pad_value(data, kernel, stride):
    """Get the pad tuple of value for SAME padding

    Parameters
    ----------
    data:
        1D input data

    kernel:
        1D input kernel

    stride:
        1D input stride

    Returns
    -------
        pad tuple of value
    """

1863
    out = int(math.ceil(float(data) / float(stride)))
1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903
    pad = max(0, (out - 1) * stride + kernel - data)
    pad_before = pad // 2
    pad_after = pad - pad_before
    return pad_before, pad_after


def get_tensor_name(subgraph, tensor_idx):
    """Get the tensor name.

    Parameters
    ----------
    subgraph:
        tflite.Subgraph.Subgraph

    tensor:
        tensor index in subgraph

    Returns
    -------
        tensor name in UTF-8 encoding
    """
    return subgraph.Tensors(tensor_idx).Name().decode("utf-8")


def from_tflite(model, shape_dict, dtype_dict):
    """Convert from tflite model into compatible relay Function.

    Parameters
    ----------
    model:
        tflite.Model.Model

    shape_dict : dict of str to int list/tuple
        Input shapes of the model.

    dtype_dict : dict of str to str
        Input types of the model.

    Returns
    -------
1904
    mod : tvm.IRModule
1905
        The relay module for compilation.
1906

1907
    params : dict of str to tvm.nd.NDArray
1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941
        The parameter dict to be used by relay
    """
    try:
        import tflite.Model
        import tflite.SubGraph
        import tflite.BuiltinOperator
    except ImportError:
        raise ImportError("The tflite package must be installed")
    assert isinstance(model, tflite.Model.Model)

    # keep the same as tflite
    assert model.SubgraphsLength() == 1, "only support one subgraph (main subgraph)"
    subgraph = model.Subgraphs(0)

    # model inputs / outputs
    model_inputs = subgraph.InputsAsNumpy()
    model_outputs = subgraph.OutputsAsNumpy()

    exp_tab = ExprTable()
    for model_input in model_inputs:
        model_input_name = get_tensor_name(subgraph, model_input)
        shape = shape_dict[model_input_name] if model_input_name in shape_dict else None
        dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict else "float32"
        exp_tab.set_expr(model_input_name, _expr.var(model_input_name, shape=shape, dtype=dtype))

    # op code in model
    op_converter = OperatorConverter(model, subgraph, exp_tab)
    op_converter.check_unsupported_ops()
    op_converter.convert_op_to_relay()

    # params and outputs
    params = {k:_nd.array(np.array(v)) for k, v in exp_tab.params.items()}
    outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs]
    outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
Zhi committed
1942
    func = _expr.Function(analysis.free_vars(outputs), outputs)
1943
    mod = IRModule.from_expr(func)
1944
    return mod, params