# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-self, invalid-name, unused-argument
"""
TFLite testcases
================
This article is a test script to test TFLite operator with Relay.
"""
from __future__ import print_function
from functools import partial
import numpy as np
import tvm
from tvm import te
from tvm import relay
try:
    import tensorflow.compat.v1 as tf
except ImportError:
    import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import variables
try:
    from tensorflow import lite as interpreter_wrapper
except ImportError:
    from tensorflow.contrib import lite as interpreter_wrapper

from tvm.contrib.download import download_testdata
import tvm.relay.testing.tf as tf_testing
from packaging import version as package_version

from PIL import Image
import os

#######################################################################
# Generic run functions for TVM & TFLite
# --------------------------------------
def convert_to_list(x):
    if not isinstance(x, list):
        x = [x]
    return x


#######################################################################
# Get a real image for e2e testing
# --------------------------------
def get_real_image(im_height, im_width):
    repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
    img_name = 'elephant-299.jpg'
    image_url = os.path.join(repo_base, img_name)
    img_path = download_testdata(image_url, img_name, module='data')
    image = Image.open(img_path).resize((im_height, im_width))
    x = np.array(image).astype('uint8')
    data = np.reshape(x, (1, im_height, im_width, 3))
    return data

def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm',
                  out_names=None):
    """ Generic function to compile on relay and execute on tvm """
    try:
        import tflite.Model
    except ImportError:
        raise ImportError("The tflite package must be installed")

    # get TFLite model from buffer
    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)

    input_data = convert_to_list(input_data)
    input_node = convert_to_list(input_node)

    shape_dict = {}
    dtype_dict = {}
    for i, e in enumerate(input_node):
        shape_dict[e] = input_data[i].shape
        dtype_dict[e] = input_data[i].dtype.name

    mod, params = relay.frontend.from_tflite(tflite_model,
                                             shape_dict=shape_dict,
                                             dtype_dict=dtype_dict)
    with relay.build_config(opt_level=3):
        graph, lib, params = relay.build(mod, target, params=params)

    ctx = tvm.context(target, 0)
    from tvm.contrib import graph_runtime
    m = graph_runtime.create(graph, lib, ctx)
    # set inputs
    for i, e in enumerate(input_node):
        m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))

    m.set_input(**params)
    # execute
    m.run()
    # get outputs
    assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format(
        out_names, num_output)
    tvm_output_list = []
    for i in range(0, num_output):
        tvm_output = m.get_output(i)
        tvm_output_list.append(tvm_output.asnumpy())
    return tvm_output_list


def run_tflite_graph(tflite_model_buf, input_data):
    """ Generic function to execute TFLite """
    input_data = convert_to_list(input_data)

    interpreter = interpreter_wrapper.Interpreter(model_content=tflite_model_buf)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # set input
    assert len(input_data) == len(input_details)
    for i in range(len(input_details)):
        interpreter.set_tensor(input_details[i]['index'], input_data[i])

    # Run
    interpreter.invoke()

    # get output
    tflite_output = list()
    for i in range(len(output_details)):
        tflite_output.append(interpreter.get_tensor(output_details[i]['index']))

    return tflite_output


def compare_tflite_with_tvm(in_data, in_name, input_tensors,
                            output_tensors, init_global_variables=False,
                            out_names=None, quantized=False, input_range=None):
    """Generic function to generate and compare TFLite and TVM output"""
    in_data = convert_to_list(in_data)
    in_name = convert_to_list(in_name)
    out_names = convert_to_list(out_names)
    in_node = [0] * len(in_name)
    for i in range(len(in_name)):
        in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i]

    with tf.Session() as sess:
        if init_global_variables:
            sess.run(variables.global_variables_initializer())
        # convert to tflite model
        converter = tf.lite.TFLiteConverter.from_session(
            sess, input_tensors, output_tensors)

        if quantized:
            converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
            input_arrays = converter.get_input_arrays()
            input_stats = {}
            # calculate the mean and quantization scale for every input tensor,
            # with respect to its fp32 input range, defined in fake_quant.
            # s = 255/(fmax-fmin);  m = -fmin*s (the zero point)
            for i in input_arrays:
                try:
                    quant_scale = 255 / (input_range[i][1] - input_range[i][0])
                except ZeroDivisionError:
                    raise ZeroDivisionError('Min and max of the input range for tensor ' + i + ' can\'t be equal')
                mean = - input_range[i][0] * quant_scale
                input_stats[i] = (mean, quant_scale)
            converter.quantized_input_stats = input_stats

        tflite_model_buffer = converter.convert()
        tflite_output = run_tflite_graph(tflite_model_buffer, in_data)

        for device in ["llvm"]:
            ctx = tvm.context(device, 0)
            if not ctx.exist:
                print("Skip because %s is not enabled" % device)
                continue

            tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device,
                                       num_output=len(out_names), out_names=out_names)

            # WARNING: the results could well be random values clipped to 0 or 255 because of badly tuned output
            # range for the specific operator. While adding test ensure that we aren't getting only clipped values
            # in output tensors that still pass the assertion. For reference see _test_elemwise_qnn_out_range()
            if quantized:
                for i in range(len(tflite_output)):
                    # allow absolute tolerance of 1 in the quantized results
                    tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1, rtol=1e-5)
            else:
                for i in range(len(tflite_output)):
                    tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)


def with_fused_activation_function(input_tensor, fn_name):
    if fn_name is None or fn_name == "NONE":
        return input_tensor
    if fn_name == "RELU":
        return nn_ops.relu(input_tensor)
    if fn_name == "RELU6":
        return nn_ops.relu6(input_tensor)
    if fn_name == "RELU_N1_TO_1":
        return math_ops.maximum(-1, math_ops.minimum(input_tensor, 1))
    if fn_name == "TANH":
        return math_ops.tanh(input_tensor)
    raise AssertionError("Unknown fused_activation_function {}".format(fn_name))

def _test_split(in_shape, axis, num_Splits, dtype):
    '''internal split tester taking as parameters in_shape, number of tensors to split into
       and dtype (data type)'''
    np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=in_shape, dtype=dtype)
        out = array_ops.split(in_data, num_Splits, axis=axis)
        out_names = ['out_' + str(n) + ':0' for n in range(num_Splits)]
        compare_tflite_with_tvm([np_data], ['Placeholder:0'],  [in_data], out,
                                out_names=out_names)

def test_forward_split():
    '''test split layer'''
    # rank 1
    _test_split((3,), 0, 1, 'float32')
    _test_split((3,), 0, 3, 'float32')
    _test_split((6,), 0, 3, 'float32')
    # rank 2
    _test_split((6, 2), 0, 3, 'float32')
    _test_split((2, 6), 1, 6, 'float32')
    # rank 3
    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
        _test_split((6, 2, 4), 0, 2, 'int32')

    _test_split((2, 6, 4), 1, 3, 'float32')
    _test_split((2, 4, 6), 2, 1, 'float32')
    # rank 4
    _test_split((6, 1, 3, 5), 0, 3, 'float32')
    _test_split((1, 6, 3, 5), 1, 3, 'float32')
    _test_split((1, 3, 6, 5), 2, 3, 'float32')
    _test_split((1, 3, 5, 6), 3, 3, 'float32')
    # split along negative axis
    _test_split((6, 1, 3, 5), -4, 3, 'float32')
    _test_split((1, 6, 3, 5), -3, 3, 'float32')
    _test_split((1, 3, 6, 5), -2, 3, 'float32')
    _test_split((1, 3, 5, 6), -1, 3, 'float32')

#######################################################################
# slice
# -----

def _test_slice(data, begin, size):
    """ One iteration of SLICE """
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
        out = array_ops.slice(in_data, begin, size)
        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_slice():
    """ SLICE """
    _test_slice(np.arange(4, dtype=np.float32).reshape((4, )), begin=[0], size=[2])
    _test_slice(np.arange(18, dtype=np.int32).reshape((3, 2, 3)), begin=[1, 0, 0], size=[1, 1, 3])
    # tflite 1.13 outputs nonsense values if size[i] == -1
    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
        _test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 1], size=[-1, -1])
        _test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4], size=[-1])

#######################################################################
# Topk
# ----
def _test_topk(in_shape, k=1):
    """ One iteration of TOPK """
    data = np.random.uniform(size=in_shape).astype('float32')
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
        out = nn_ops.top_k(in_data, k, name='TopK')
        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out[0]])

def test_forward_topk():
    """ TOPK """
    _test_topk((3,), 1)
    _test_topk((3,), 3)
    _test_topk((3, 5, 7), 3)
    _test_topk((3, 5, 7), 3)

#######################################################################
# transpose
# ---------


def _test_forward_transpose(ishape, axes=()):
    data = np.random.uniform(size=ishape).astype(np.float32)

    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)

        if not axes:
            out = array_ops.transpose(in_data)
        else:
            out = array_ops.transpose(in_data, axes)

        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])


def test_forward_transpose():
    _test_forward_transpose((2, 2))
    _test_forward_transpose((2, 3, 4))
    _test_forward_transpose((7, 8, 8, 10))
    _test_forward_transpose((2, 3, 4), (1, 2, 0))
    _test_forward_transpose((2, 3, 4), (0, 1, 2))
    _test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2))
    _test_forward_transpose((2, 3, 4, 5), ())

#######################################################################
# Cast
# ----

def _test_cast(data, cast_dtype):
    """ One iteration of CAST """
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
        out = math_ops.cast(in_data, cast_dtype)
        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])


def test_forward_cast():
    """ CAST """
    _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.int32)
    _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8)
    _test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64)

#######################################################################
# Tile
# ----


def _test_forward_tile(in_shape, reps, dtype):
    data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)

    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)

        out = array_ops.tile(in_data, reps)

        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])


def test_forward_tile():
    _test_forward_tile((2, ), (3, ), "int32")
    _test_forward_tile((2, 2), (2, 3), "float32")

######################################################################
# BatchToSpaceND
# --------------


def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype='int32'):
    data = np.random.uniform(0, 5, size=input_shape).astype(dtype)

    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=input_shape, dtype=dtype)

        out = array_ops.batch_to_space_nd(in_data, block_shape, crops)

        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])


def test_forward_batch_to_space_nd():
    # test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d
    _test_batch_to_space_nd(
        input_shape=[4, 1, 1, 1],
        block_shape=[2, 2],
        crops=[[0, 0], [0, 0]]
    )

    _test_batch_to_space_nd(
        input_shape=[4, 1, 1, 3],
        block_shape=[2, 2],
        crops=[[0, 0], [0, 0]]
    )

    _test_batch_to_space_nd(
        input_shape=[4, 2, 2, 1],
        block_shape=[2, 2],
        crops=[[0, 0], [0, 0]]
    )

######################################################################
# SpaceToBatchND
# --------------


def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype='int32'):
    data = np.random.uniform(0, 5, size=input_shape).astype(dtype)

    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=input_shape, dtype=dtype)

        out = array_ops.space_to_batch_nd(in_data, block_shape, paddings)

        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])


def test_forward_space_to_batch_nd():
    # test cases: https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd
    _test_space_to_batch_nd(
        input_shape=[1, 2, 2, 1],
        block_shape=[2, 2],
        paddings=[[0, 0], [0, 0]]
    )

    _test_space_to_batch_nd(
        input_shape=[1, 2, 2, 3],
        block_shape=[2, 2],
        paddings=[[0, 0], [0, 0]]
    )

    _test_space_to_batch_nd(
        input_shape=[1, 4, 4, 1],
        block_shape=[2, 2],
        paddings=[[0, 0], [0, 0]]
    )

    _test_space_to_batch_nd(
        input_shape=[2, 2, 4, 1],
        block_shape=[2, 2],
        paddings=[[0, 0], [2, 0]]
    )

#######################################################################
# Pooling
# -------
def _test_pooling_iteration(input_shape, **kwargs):
    """ One iteration of pool operation with given shapes and attributes """

    x = -np.arange(
        np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1

    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=input_shape, dtype='float32')
        out = nn_ops.pool(in_data, **kwargs)

        compare_tflite_with_tvm(x,'Placeholder:0', [in_data], [out])


def _test_pooling(input_shape, **kwargs):
    _test_pooling_iteration(input_shape, **kwargs)


def test_forward_pooling():
    """ Pooling """

    for pool_type in ['AVG', 'MAX']:
        _test_pooling(input_shape=[2, 9, 10, 2],
                      window_shape=[1, 1],
                      padding='SAME',
                      pooling_type=pool_type,
                      dilation_rate=[1, 1],
                      strides=[1, 1])

        _test_pooling(input_shape=[2, 10, 9, 2],
                      window_shape=[1, 1],
                      padding='SAME',
                      pooling_type=pool_type,
                      dilation_rate=[1, 1],
                      strides=[1, 1])

        _test_pooling(input_shape=[2, 9, 10, 2],
                      window_shape=[2, 1],
                      padding='SAME',
                      pooling_type=pool_type,
                      dilation_rate=[1, 1],
                      strides=[1, 1])

        _test_pooling(input_shape=[2, 10, 9, 2],
                      window_shape=[2, 3],
                      padding='SAME',
                      pooling_type=pool_type,
                      dilation_rate=[1, 1],
                      strides=[2, 1])


#######################################################################
# Convolution
# -----------

def _test_convolution(tensor_in_sizes, filter_in_sizes,
                      dilations, strides, padding, data_format,
                      is_depthwise=False):
    """ One iteration of convolution with given shapes and attributes """

    total_size_1 = 1
    total_size_2 = 1
    for s in tensor_in_sizes:
        total_size_1 *= s
    for s in filter_in_sizes:
        total_size_2 *= s
    # Initializes the input tensor with array containing incrementing
    # numbers from 1.
    data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
    filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]

    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
        in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
        strides = [1] + strides + [1]
        dilations = [1] + dilations + [1]

        if is_depthwise:
            out = nn_ops.depthwise_conv2d_native(in_data,
                                                 in_filter,
                                                 strides=strides,
                                                 padding=padding,
                                                 data_format=data_format)
        else:
            out = nn_ops.conv2d(in_data,
                                in_filter,
                                strides=strides,
                                padding=padding,
                                data_format=data_format)
        data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
        compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out])


def test_forward_convolution():
    _test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
    _test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
    _test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC')
    _test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC')

    # depthwise convolution
    _test_convolution([4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True)
    _test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True)
    _test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True)
    _test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True)
    _test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True)
    # dephtwise convolution with single input channel
    _test_convolution([1, 76, 64, 1], [9, 5, 1, 96], [1, 1], [1, 1], 'SAME', 'NHWC', True)


#######################################################################
# Transpose Convolution
# ---------------------

def _test_transpose_conv(tensor_in_sizes, filter_in_sizes, output_shape, strides, padding):
    """ One iteration of transpose convolution with given shapes and attributes """

    total_size_1 = 1
    total_size_2 = 1
    for s in tensor_in_sizes:
        total_size_1 *= s
    for s in filter_in_sizes:
        total_size_2 *= s
    # Initializes the input tensor with array containing incrementing
    # numbers from 1.
    data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
    filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]

    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
        in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
        strides = [1] + strides + [1]
        # in_filter layout is HWOI
        out = nn_ops.conv2d_transpose(in_data,
                                      in_filter,
                                      output_shape=output_shape,
                                      strides=strides,
                                      padding=padding)
        data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
        compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out])


def test_forward_transpose_conv():
    # kernel 3x3, padding VALID
    _test_transpose_conv([4, 32, 32, 16], [3, 3, 5, 16], [4, 34, 34, 5], [1, 1], 'VALID')
    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 65, 5], [2, 2], 'VALID')
    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 34, 5], [2, 1], 'VALID')

    # kernel 2x2, padding VALID
    _test_transpose_conv([4, 32, 32, 16], [2, 2, 5, 16], [4, 33, 33, 5], [1, 1], 'VALID')
    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], 'VALID')
    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 33, 5], [2, 1], 'VALID')

    # kernel 1x1, padding VALID
    _test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], 'VALID')
    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], 'VALID')
    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], 'VALID')

    # kernel 1x1, padding SAME
    _test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], 'SAME')
    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], 'SAME')
    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], 'SAME')


#######################################################################
# Reshape
# -------

def _test_reshape(data, out_shape):
    """ One iteration of reshape operation with given data and out shape """
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
        out = array_ops.reshape(in_data, out_shape)

        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])


def test_forward_reshape():
    _test_reshape(np.arange(6.0, dtype=np.float32), [2, 3])
    _test_reshape(np.arange(6), [-1, 2])
    _test_reshape(np.arange(6), [3, -1])
    _test_reshape(np.arange(6), [-1])


#######################################################################
# Resize
# ------

def _test_resize(tf_resize_op, data, align_corners):
    """ One iteration of Resize """

    assert len(data) == 2

    # Test with tensor and constant
    with tf.Graph().as_default():
        images_tensor = array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')
        size = ops.convert_to_tensor(data[1], dtype=data[1].dtype)
        out_tensor = tf_resize_op(images=images_tensor, size=size, align_corners=align_corners)
        compare_tflite_with_tvm([data[0]], ['in:0'], [images_tensor], [out_tensor])


def test_all_resize():
    """ Resize """
    data = [np.random.rand(1, 16, 16, 3).astype("float32"), np.array([8, 8], dtype=np.int32)]
    ### RESIZE_BILINEAR
    _test_resize(tf.image.resize_bilinear, data, align_corners=False)
    _test_resize(tf.image.resize_bilinear, data, align_corners=True)
    ### RESIZE_NEAREST_NEIGHBOR (was added in v1.13)
    # According to topi resize.h
    # Align corners not supported for nearest neighbour
    from tflite.BuiltinOperator import BuiltinOperator
    if 'RESIZE_NEAREST_NEIGHBOR' in dir(BuiltinOperator()):
        _test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False)


#######################################################################
# Concatenation
# -------------

def _test_concatenation(data, axis):
    """ One iteration of concatenation """

    assert len(data) >= 1

    with tf.Graph().as_default():
        in_data = [
            array_ops.placeholder(shape=tensor.shape, dtype=tensor.dtype, name="in_{}".format(idx))
            for idx, tensor in enumerate(data)]
        out = array_ops.concat(in_data, axis=axis)
        name = ["in_{}:0".format(idx) for idx in range(len(data))]

        compare_tflite_with_tvm(data, name, in_data, [out])


def test_forward_concatenation():

    _test_concatenation(
        [np.arange(6).reshape((1, 2, 1, 3)),
        np.arange(6).reshape((1, 2, 1, 3))], 1)

    _test_concatenation(
        [np.arange(6).reshape((3, 2)),
         np.arange(6).reshape((3, 2))], 1)

    _test_concatenation(
        [np.arange(6).reshape((2, 1, 1, 3)),
         np.arange(6).reshape((2, 1, 1, 3)),
         np.arange(6).reshape((2, 1, 1, 3))], 1)

#######################################################################
# Unary elemwise
# --------------

def _test_unary_elemwise(math_op, data):
    """ One iteration of unary elemwise """

    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name='in')
        out = math_op(in_data)
        compare_tflite_with_tvm(data, ['in:0'], [in_data], [out])

#######################################################################
# Abs
# ---

def _test_abs(data):
    """ One iteration of abs """
    return _test_unary_elemwise(math_ops.abs, data)
#######################################################################
# Ceil
# ----

def _test_ceil(data):
    """ One iteration of ceil """
    return _test_unary_elemwise(math_ops.ceil, data)
#######################################################################
# Floor
# -----

def _test_floor(data):
    """ One iteration of floor """
    return _test_unary_elemwise(math_ops.floor, data)

#######################################################################
# Round
# -----

def _test_round(data):
    """ One iteration of round """
    return _test_unary_elemwise(math_ops.round, data)

#######################################################################
# Exp
# ---

def _test_exp(data):
    """ One iteration of exp """
    return _test_unary_elemwise(math_ops.exp, data)
#######################################################################
# Log
# ---

def _test_log(data):
    """ One iteration of log """
    return _test_unary_elemwise(math_ops.log, data)
#######################################################################
# Sin
# ---

def _test_sin(data):
    """ One iteration of sin """
    return _test_unary_elemwise(math_ops.sin, data)
#######################################################################
# Cos
# ---

def _test_cos(data):
    """ One iteration of cos """
    return _test_unary_elemwise(math_ops.cos, data)
#######################################################################
# Tan
# ---

def _test_tan(data):
    """ One iteration of tan """
    return _test_unary_elemwise(math_ops.tan, data)
#######################################################################
# Sqrt
# ----

def _test_sqrt(data):
    """ One iteration of sqrt """
    return _test_unary_elemwise(math_ops.sqrt, data)
#######################################################################
# Rsqrt
# -----

def _test_rsqrt(data):
    """ One iteration of rsqrt """
    return _test_unary_elemwise(math_ops.rsqrt, data)
#######################################################################
# Neg
# ---

def _test_neg(data):
    """ One iteration of neg """
    return _test_unary_elemwise(math_ops.neg, data)
#######################################################################
# Square
# ------

def _test_square(data):
    """ One iteration of square """
    return _test_unary_elemwise(math_ops.square, data)

#######################################################################
# Elu
# ---

def _test_elu(data):
    """ One iteration of elu """
    return _test_unary_elemwise(nn_ops.elu, data)

def _test_forward_unary_elemwise(test_op):
    # functions that need positive input
    if test_op.__name__ in {'_test_log', '_test_sqrt', '_test_rsqrt'}:
        test_op(np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)))
    else:
        test_op(np.random.uniform(-10, 10, (3, 2)).astype(np.float32))

def test_all_unary_elemwise():
    _test_forward_unary_elemwise(_test_abs)
    _test_forward_unary_elemwise(_test_floor)
    _test_forward_unary_elemwise(_test_exp)
    _test_forward_unary_elemwise(_test_log)
    _test_forward_unary_elemwise(_test_sin)
    _test_forward_unary_elemwise(_test_sqrt)
    _test_forward_unary_elemwise(_test_rsqrt)
    _test_forward_unary_elemwise(_test_neg)
    _test_forward_unary_elemwise(_test_square)
    # ceil and cos come with TFLite 1.14.0.post1 fbs schema
    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
        _test_forward_unary_elemwise(_test_ceil)
        _test_forward_unary_elemwise(_test_cos)
        _test_forward_unary_elemwise(_test_round)
        _test_forward_unary_elemwise(_test_tan)
        _test_forward_unary_elemwise(_test_elu)

#######################################################################
# Element-wise
# ------------

def _test_elemwise(math_op, data, fused_activation_function=None, quantized=False, qnn_op=None):
    """ One iteration of elemwise """

    assert len(data) == 2

    # Test with two tensors
    with tf.Graph().as_default():
        in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in_0'),
                   array_ops.placeholder(shape=data[1].shape, dtype='float32', name='in_1')]

        if quantized:
            # fake_quant will keep the tensors in float32 until the conversion in the session
            inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_0"),
                        tf.quantization.fake_quant_with_min_max_args(in_data[1], min=-50, max=50, name="inq_1")]
            input_range = {'inq_0': (-100, 100), 'inq_1': (-50, 50)}
            out = math_op(inq_data[0], inq_data[1])
            out = with_fused_activation_function(out, fused_activation_function)
            # set the fp32 output range with respect to the operation
            out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
            out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out")
            compare_tflite_with_tvm(data, ['inq_0:0', 'inq_1:0'], inq_data, [out],
                                    quantized=True, input_range=input_range)
        else:
            out = math_op(in_data[0], in_data[1])
            out = with_fused_activation_function(out, fused_activation_function)
            compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])

    # Test with tensor and constant
    with tf.Graph().as_default():
        in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in_0')]

        if quantized:
            inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_0")]
            inq_const = tf.quantization.fake_quant_with_min_max_args(data[1], min=-50, max=50, name="const_tensor")
            input_range = {'inq_0': (-100, 100)}
            # the 2nd tensor is treated as constant and directly added as part of the operation
            out = math_op(inq_data, ops.convert_to_tensor(inq_const, dtype='float32', name='inq_const'))
            out = with_fused_activation_function(out, fused_activation_function)
            out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
            out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out")
            compare_tflite_with_tvm(data[0], ['inq_0:0'], inq_data, [out], quantized=True, input_range=input_range)
        else:
            out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
            out = with_fused_activation_function(out, fused_activation_function)
            compare_tflite_with_tvm(data[0], ['in_0:0'], in_data, [out])

    # Test with constant and tensor
    with tf.Graph().as_default():
        in_data = [array_ops.placeholder(shape=data[1].shape, dtype='float32', name='in_1')]

        if quantized:
            inq_const = tf.quantization.fake_quant_with_min_max_args(data[0], min=-100, max=100, name="const_tensor")
            inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-50, max=50, name="inq_1")]
            input_range = {'inq_1': (-50, 50)}
            # the 1st tensor is treated as constant and directly added as part of the operation
            out = math_op(ops.convert_to_tensor(inq_const, dtype='float32', name='inq_const'), inq_data)
            out = with_fused_activation_function(out, fused_activation_function)
            out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
            out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out")
            compare_tflite_with_tvm(data[1], ['inq_1:0'], inq_data, [out], quantized=True, input_range=input_range)
        else:
            out = math_op(ops.convert_to_tensor(data[0], dtype=data[0].dtype), in_data[0])
            out = with_fused_activation_function(out, fused_activation_function)
            compare_tflite_with_tvm(data[1], ['in_1:0'], in_data, [out])

#######################################################################
# Add
# ---

def _test_add(data, fused_activation_function=None, quantized=False, qnn_op=None):
    """ One iteration of add """
    return _test_elemwise(math_ops.add, data, fused_activation_function, quantized, qnn_op)

#######################################################################
# Subtract
# --------

def _test_sub(data, fused_activation_function=None, quantized=False, qnn_op=None):
    """ One iteration of subtract """
    return _test_elemwise(math_ops.subtract, data, fused_activation_function, quantized, qnn_op)
#######################################################################
# Mul
# ---

def _test_mul(data, fused_activation_function=None, quantized=False, qnn_op=None):
    """ One iteration of mul """
    return _test_elemwise(math_ops.multiply, data, fused_activation_function, quantized, qnn_op)

#######################################################################
# Divide
# ------

def _test_div(data, fused_activation_function=None):
    """ One iteration of divide """
    return _test_elemwise(math_ops.divide, data, fused_activation_function)
#######################################################################
# Power
# -----

def _test_pow(data):
    """ One iteration of power """
    return _test_elemwise(math_ops.pow, data)
#######################################################################
# Maximum
# -------

def _test_maximum(data):
    """ One iteration of maximum """
    return _test_elemwise(math_ops.maximum, data)
#######################################################################
# Minimum
# -------

def _test_minimum(data):
    """ One iteration of minimum """
    return _test_elemwise(math_ops.minimum, data)
#######################################################################
# Greater
# -------

def _test_greater(data):
    """ One iteration of greater """
    return _test_elemwise(math_ops.greater, data)
#######################################################################
# Greater_equal
# -------------

def _test_greater_equal(data):
    """ One iteration of greater_equal """
    return _test_elemwise(math_ops.greater_equal, data)
#######################################################################
# Less
# ----

def _test_less(data):
    """ One iteration of less """
    return _test_elemwise(math_ops.less, data)
#######################################################################
# Less_equal
# ----------

def _test_less_equal(data):
    """ One iteration of less_equal """
    return _test_elemwise(math_ops.less_equal, data)
#######################################################################
# Equal
# -----

def _test_equal(data):
    """ One iteration of equal """
    return _test_elemwise(math_ops.equal, data)
#######################################################################
# Not_equal
# ---------

def _test_not_equal(data):
    """ One iteration of not_equal"""
    return _test_elemwise(math_ops.not_equal, data)
#######################################################################
# Squared_difference
# ------------------

def _test_squared_difference(data):
    """ One iteration of squared difference """
    return _test_elemwise(math_ops.squared_difference, data)

#######################################################################
# Floor_divide
# ------------

def _test_floor_divide(data):
    """ One iteration of floor_div"""
    return _test_elemwise(math_ops.floordiv, data)

#######################################################################
# Floor_mod
# ---------

def _test_floor_mod(data):
    """ One iteration of floor_mod"""
    return _test_elemwise(math_ops.floormod, data)

def _test_forward_elemwise(testop):
    """ Elewise"""
    testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
              np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3))])
    testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
               np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3))])
    testop([np.arange(3.0, dtype=np.float32).reshape((1, 3)),
               np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3))])

def _test_forward_elemwise_quantized(testop):
    testop([np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
            np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8)], quantized=True, qnn_op=testop)

def _test_elemwise_qnn_out_range(qnn_op):
    # set the fake_quant output range with respect to the input tensors float32 range
    qnn_out_range = {
        _test_add: (-150, 150),
        _test_sub: (-150, 150),
        _test_mul: (-5e+3, 5e+3),
    }

    return qnn_out_range[qnn_op]

def test_all_elemwise():
    _test_forward_elemwise(_test_add)
    _test_forward_elemwise_quantized(_test_add)
    _test_forward_elemwise(partial(_test_add, fused_activation_function="RELU"))
    _test_forward_elemwise(partial(_test_add, fused_activation_function="RELU6"))
    _test_forward_elemwise(_test_sub)
    _test_forward_elemwise_quantized(_test_sub)
    _test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU"))
    _test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU6"))
    _test_forward_elemwise(_test_mul)
    _test_forward_elemwise_quantized(_test_mul)
    _test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU"))
    _test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU6"))
    _test_forward_elemwise(_test_div)
    _test_forward_elemwise(partial(_test_div, fused_activation_function="RELU"))
    _test_forward_elemwise(partial(_test_div, fused_activation_function="RELU6"))
    _test_forward_elemwise(_test_pow)
    _test_forward_elemwise(_test_maximum)
    _test_forward_elemwise(_test_minimum)
    _test_forward_elemwise(_test_greater)
    _test_forward_elemwise(_test_squared_difference)
    _test_forward_elemwise(_test_greater_equal)
    _test_forward_elemwise(_test_less)
    _test_forward_elemwise(_test_less_equal)
    _test_forward_elemwise(_test_equal)
    _test_forward_elemwise(_test_not_equal)
    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
        _test_forward_elemwise(_test_floor_divide)
        _test_forward_elemwise(_test_floor_mod)

#######################################################################
# Logical operators
# -----------------

def _test_logical_binary(logical_bin_op, data):

    with tf.Graph().as_default():
        in_data = [array_ops.placeholder(shape=data[0].shape, dtype='bool', name='in_0'),
                   array_ops.placeholder(shape=data[1].shape, dtype='bool', name='in_1')]
        out = logical_bin_op(in_data[0], in_data[1], name='out')
        compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])

def _test_forward_logical_and(data):
    """ One iteration of logical and """
    return _test_logical_binary(math_ops.logical_and, data)

def _test_forward_logical_or(data):
    """ One iteration of logical or """
    return _test_logical_binary(math_ops.logical_or, data)

def test_all_logical():
    data = [np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool'),
            np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool')]
    # boolean dtype is not supported by older versions than TFLite 1.15.0
    if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
        _test_forward_logical_and(data)
        _test_forward_logical_or(data)

#######################################################################
# Zeros like
# ----------

def _test_zeros_like(data):
    """ One iteration of ZEROS LIKE """
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
        out = gen_array_ops.zeros_like(in_data)
        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_zeros_like():
    """ ZEROS LIKE """
    _test_zeros_like(np.arange(6.0, dtype=np.float32).reshape((1, 6)))

#######################################################################
# Reduce
# ------

def _test_reduce(math_op, data, keep_dims=None):
    """ One iteration of reduce """

    assert len(data) == 2

    # Test with tensor and constant
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')
        out = math_op(in_data, data[1], keep_dims)
        compare_tflite_with_tvm([data[0]], ['in:0'], [in_data], [out])

def _test_reduce_quantize(math_op, data, keep_dims=None):
    """ One iteration of reduce """

    assert len(data) == 2

    # Test with tensor and constant
    with tf.Graph().as_default():
        in_data = [array_ops.placeholder(shape=data[0].shape, dtype="float32", name='in')]
        inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_0")]
        input_range = {'inq_0': (-100, 100)}
        out = math_op(inq_data, data[1], keep_dims)
        out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out")
        compare_tflite_with_tvm([data[0]], ['inq_0:0'], [inq_data[0]], [out], quantized=True, input_range=input_range)


#######################################################################
# Reduce_min
# ----------

def _test_reduce_min(data, keep_dims=None):
    """ One iteration of reduce_min """
    return _test_reduce(math_ops.reduce_min, data, keep_dims)

#######################################################################
# Reduce_max
# ----------

def _test_reduce_max(data, keep_dims=None):
    """ One iteration of reduce_max """
    return _test_reduce(math_ops.reduce_max, data, keep_dims)

#######################################################################
# Reduce_mean
# -----------

def _test_reduce_mean(data, keep_dims=None, quantized=False):
    """ One iteration of reduce_mean """
    if quantized:
        return _test_reduce_quantize(math_ops.reduce_mean, data, keep_dims)
    else:
        return _test_reduce(math_ops.reduce_mean, data, keep_dims)

#######################################################################
# Reduce_prod
# -----------

def _test_reduce_prod(data, keep_dims=None):
    """ One iteration of reduce_prod """
    return _test_reduce(math_ops.reduce_prod, data, keep_dims)

#######################################################################
# Reduce_sum
# -----------

def _test_reduce_sum(data, keep_dims=None):
    """ One iteration of reduce_sum """
    return _test_reduce(math_ops.reduce_sum, data, keep_dims)

#######################################################################
# Reduce_any
# ----------

def _test_reduce_any(data, keep_dims=None):
    """ One iteration of reduce_any """
    return _test_reduce(math_ops.reduce_any, data, keep_dims)

def _test_forward_reduce(testop, dtype="float32"):
    """ Reduce """
    if dtype == 'bool':
        data0 = [np.random.choice(a=[False, True], size=(16, 16, 16, 16)).astype(dtype),
                 None]
        data1 = [np.random.choice(a=[False, True], size=(16, 16, 16, 16)).astype(dtype),
                 np.array([1, 2], dtype=np.int32)]
    else:
        data0 = [np.random.rand(16, 16, 16, 16).astype(dtype), None]
        data1 = [np.random.rand(16, 16, 16, 16).astype(dtype), np.array([1, 2], dtype=np.int32)]
    testop(data0)
    testop(data0, keep_dims=False)
    testop(data0, keep_dims=True)
    testop(data1)
    testop(data1, keep_dims=False)
    testop(data1, keep_dims=True)

def _test_forward_reduce_quantized(testop):
    data0 = [np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8), np.array([1, 2], dtype=np.int32)]
    testop(data0, quantized=True)
    testop(data0, keep_dims=False, quantized=True)
    testop(data0, keep_dims=True, quantized=True)

def test_all_reduce():
    _test_forward_reduce(_test_reduce_min)
    _test_forward_reduce(_test_reduce_max)
    _test_forward_reduce(_test_reduce_mean)
    _test_forward_reduce_quantized(_test_reduce_mean)
    _test_forward_reduce(_test_reduce_prod)
    _test_forward_reduce(_test_reduce_sum)
    if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
        _test_forward_reduce(_test_reduce_any, dtype="bool")


#######################################################################
# Squeeze
# -------

def _test_squeeze(data, squeeze_dims=None):
    """ One iteration of squeeze """

    if squeeze_dims is None:
        squeeze_dims = []

    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)

        if squeeze_dims:
            out = array_ops.squeeze(in_data, squeeze_dims)
        else:
            out = array_ops.squeeze(in_data)

        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])


def test_forward_squeeze():
    """ Squeeze """
    _test_squeeze(np.arange(6).reshape((1, 2, 1, 3)), [0, 2])
    _test_squeeze(np.arange(6).reshape((2, 1, 3, 1)), [1, 3])


#######################################################################
# Pad
# ---

def _test_pad(data, mode="CONSTANT", quantized=False):
    """ One iteration of PAD """

    assert len(data) == 2

    # Test with tensor and constant
    with tf.Graph().as_default():
        in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in')]

        if quantized:
            # fake_quant will keep the tensors in float32 until the conversion in the session
            input_range = {'inq_0': (-100, 100)}
            inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0],
                                                                     min=-100,
                                                                     max=100,
                                                                     name="inq_0")]
            out = array_ops.pad(inq_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode)
            compare_tflite_with_tvm([data[0]], ['inq_0:0'], inq_data, [out], quantized=True,
                                    input_range=input_range)
        else:
            out = array_ops.pad(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode)
            compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])


def test_forward_pad():
    """ Pad """
    _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
               np.array([[1, 1], [2, 2], [1, 1], [2, 2]], dtype=np.int32)])
    _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)),
               np.array([[2, 2], [1, 1], [1, 1]], dtype=np.int32)])
    _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
               np.array([[1, 1], [2, 2]], dtype=np.int32)])
    _test_pad([np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3)),
               np.array([[1, 1], [2, 2]], dtype=np.int32)])
    _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
               np.array([[1, 1], [2, 2]], dtype=np.int32)], mode="REFLECT")
    _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
               np.array([[1, 1], [2, 2]], dtype=np.int32)], mode="SYMMETRIC")
    _test_pad([np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
               np.array([[1, 1], [2, 2]], dtype=np.int32)], quantized=True)


#######################################################################
# Pack
# ----

def _test_pack(data, axis):
    """ One iteration of pack """

    assert len(data) >= 1

    with tf.Graph().as_default():
        in_data = [
            array_ops.placeholder(shape=tensor.shape, dtype=tensor.dtype, name="in_{}".format(idx))
            for idx, tensor in enumerate(data)]
        out = array_ops.pack(in_data, axis=axis)
        name = ["in_{}:0".format(idx) for idx in range(len(data))]

        compare_tflite_with_tvm(data, name, in_data, [out])


def test_forward_pack():
    """ Pack """
    _test_pack(
        [np.arange(6).reshape((1, 2, 1, 3)),
        np.arange(6).reshape((1, 2, 1, 3))], 1)

    _test_pack(
        [np.arange(6).reshape((3, 2)),
         np.arange(6).reshape((3, 2))], 1)

    _test_pack(
        [np.arange(6).reshape((2, 1, 1, 3)),
         np.arange(6).reshape((2, 1, 1, 3)),
         np.arange(6).reshape((2, 1, 1, 3))], 1)


#######################################################################
# Unpack
# ------

def _test_unpack(data, axis, num_unpacks):
    """ One iteration of UNPACK """
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
        out = gen_array_ops.unpack(in_data, num=num_unpacks, axis=axis, name='unpack')
        out_names = ['out_' + str(n) + ':0' for n in range(num_unpacks)]
        compare_tflite_with_tvm([data], 'Placeholder:0',  [in_data], out, out_names=out_names)

def test_forward_unpack():
    """ UNPACK """
    _test_unpack(np.array(np.random.uniform(0, 5, (3, 1)), dtype=np.int32), axis=1, num_unpacks=1)
    _test_unpack(np.array(np.random.uniform(0, 5, (3, 4)), dtype=np.float32), axis=0, num_unpacks=3)
    # tflite 1.13 doesn't accept negative axis
    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
        _test_unpack(np.array(np.random.uniform(0, 5, (3, 6)), dtype=np.int32), axis=-2, num_unpacks=3)
        _test_unpack(np.array(np.random.uniform(0, 5, (2, 3, 4)), dtype=np.int32), axis=-3, num_unpacks=2)


#######################################################################
# Local response normalization
# ----------------------------

def _test_local_response_normalization(data, depth_radius, bias, alpha, beta):
    """ One iteration of LOCAL_RESPONSE_NORMALIZATION """
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype='float32', name='in_0')
        out = nn_ops.local_response_normalization(in_data, depth_radius=depth_radius, bias=bias, alpha=alpha, beta=beta)
        compare_tflite_with_tvm(data, 'in_0:0', [in_data], [out])

def test_forward_local_response_normalization():
    """ LOCAL_RESPONSE_NORMALIZATION """
    data = np.random.uniform(size=(1, 6, 4, 3)).astype('float32')
    # LOCAL_RESPONSE_NORMALIZATION come with TFLite >= 1.14.0 fbs schema
    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
        _test_local_response_normalization(data, depth_radius=5, bias=1, alpha=1, beta=0.5)


#######################################################################
# L2 normalization
# ----------------

def _test_l2_normalization(data, axis, fused_activation_function=None):
    """ One iteration of L2_NORMALIZATION """
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
        out = nn_impl.l2_normalize(in_data, axis)
        out = with_fused_activation_function(out, fused_activation_function)
        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_l2_normalization():
    """ L2_NORMALIZATION """
    data = np.random.uniform(size=(3, 6, 4)).astype('float32')
    _test_l2_normalization(data, axis=2)
    _test_l2_normalization(data, axis=2, fused_activation_function="RELU")

#######################################################################
# Logistic
# --------

def _test_logistic(data, quantized=False):
    """ One iteration of LOGISTIC """
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype='float32', name='in_0')

        if quantized:
            inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-5, max=5, name="inq_0")
            input_range = {'inq_0': (-5, 5)}
            out = math_ops.sigmoid(inq_data)
            out = tf.quantization.fake_quant_with_min_max_args(out, min=0, max=1, name="out")
            compare_tflite_with_tvm(data, 'inq_0:0', [inq_data], [out], quantized=True, input_range=input_range)
        else:
            out = math_ops.sigmoid(in_data)
            compare_tflite_with_tvm(data, 'in_0:0', [in_data], [out])

def test_forward_logistic():
    """ LOGISTIC """
    _test_logistic(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
    _test_logistic(np.random.uniform(0, 255, (3, 6)).astype(np.uint8), quantized=True)

#######################################################################
# Softmax
# -------

def _test_softmax(data):
    """ One iteration of softmax """
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
        out = nn_ops.softmax(in_data)
        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_softmax():
    """ Softmax """
    _test_softmax(np.arange(6.0, dtype=np.float32).reshape((1, 6)))

#######################################################################
# Tanh
# ----

def _test_tanh(data):
    """ One iteration of TANH """
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
        out = math_ops.tanh(in_data)
        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_tanh():
    """ TANH """
    _test_tanh(np.arange(6.0, dtype=np.float32).reshape((1, 6)))

#######################################################################
# ReLu
# ----

def _test_relu(data):
    """ One iteration of ReLU """
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
        out = nn_ops.relu(in_data)
        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_relu():
    """ ReLU """
    _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)))

def _test_prelu(data, alpha):
    """ One iteration of PReLU """
    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
        # This specific pattern will be replaced into PRelu by tflite
        out = nn_ops.relu(in_data) + (-alpha * nn_ops.relu(-in_data))
        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_prelu():
    """ PReLU """
    _test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"), np.full((3,), 0.2, dtype="float32"))
    _test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"), np.full((1, 1, 3), 0.2, dtype="float32"))

#######################################################################
# DepthToSpace
# ------------

def _test_depthtospace(data, block_size):
    """ One iteration of depth_to_space operation with given data and block size """

    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
        out = array_ops.depth_to_space(in_data, block_size)
        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_depthtospace():
    # DEPTH_TO_SPACE comes with TFLite >= 1.15.0 fbs schema
    if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
        _test_depthtospace(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2)
        _test_depthtospace(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4)

#######################################################################
# SpaceToDepth
# ------------

def _test_spacetodepth(data, block_size):
    """ One iteration of space_to_depth operation with given data and block size """

    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
        out = array_ops.space_to_depth(in_data, block_size)
        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_spacetodepth():
    _test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2)
    _test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4)

#######################################################################
# Fully Connected
# ---------------

def _test_fully_connected(tensor_in_sizes, filter_in_sizes, bias_in_size=None):
    """ One iteration of fully connected """

    total_size_1 = 1
    total_size_2 = 1
    for s in tensor_in_sizes:
        total_size_1 *= s
    for s in filter_in_sizes:
        total_size_2 *= s
    # Initializes the input tensor with array containing incrementing
    # numbers from 1.
    data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
    filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
    assert int(total_size_1 / tensor_in_sizes[0]) == filter_in_sizes[0], \
        "input size and filter size are mismatched"

    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
        in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')

        # reshape N H W C into N H*W*C
        in_data_reshape = array_ops.reshape(in_data, [tensor_in_sizes[0], -1])

        out = math_ops.mat_mul(in_data_reshape, in_filter)

        # if we have bias
        if bias_in_size:
            assert bias_in_size[0] == filter_in_sizes[1], "bias and filter size are mismatched"
            bias_array = [f * 1.0 for f in range(1, bias_in_size[0] + 1)]
            in_bias = constant_op.constant(bias_array, shape=bias_in_size, dtype='float32')
            out = nn_ops.bias_add(out, in_bias)

        data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
        compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out])


def test_forward_fully_connected():
    """ Fully Connected """
    _test_fully_connected([1, 1, 1, 150], [150, 100])
    _test_fully_connected([1, 1, 1, 150], [150, 100], [100])
    _test_fully_connected([5, 1, 1, 150], [150, 100])
    _test_fully_connected([5, 1, 1, 150], [150, 100], [100])


#######################################################################
# Custom Operators
# ----------------

def test_detection_postprocess():
    tf_model_file = tf_testing.get_workload_official(
        "http://download.tensorflow.org/models/object_detection/"
        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/tflite_graph.pb"
    )
    converter = tf.lite.TFLiteConverter.from_frozen_graph(
        tf_model_file,
        input_arrays=["raw_outputs/box_encodings", "raw_outputs/class_predictions"],
        output_arrays=[
            "TFLite_Detection_PostProcess",
            "TFLite_Detection_PostProcess:1",
            "TFLite_Detection_PostProcess:2",
            "TFLite_Detection_PostProcess:3"
        ],
        input_shapes={
            "raw_outputs/box_encodings": (1, 1917, 4),
            "raw_outputs/class_predictions": (1, 1917, 91),
        },
    )
    converter.allow_custom_ops = True
    converter.inference_type = tf.lite.constants.FLOAT
    tflite_model = converter.convert()
    np.random.seed(0)
    box_encodings = np.random.uniform(size=(1, 1917, 4)).astype('float32')
    class_predictions = np.random.uniform(size=(1, 1917, 91)).astype('float32')
    tflite_output = run_tflite_graph(tflite_model, [box_encodings, class_predictions])
    tvm_output = run_tvm_graph(tflite_model, [box_encodings, class_predictions],
                               ["raw_outputs/box_encodings", "raw_outputs/class_predictions"], num_output=4)
    # check valid count is the same
    assert tvm_output[3] == tflite_output[3]
    valid_count = tvm_output[3][0]
    tvm_boxes = tvm_output[0][0][:valid_count]
    tvm_classes = tvm_output[1][0][:valid_count]
    tvm_scores = tvm_output[2][0][:valid_count]
    # check the output data is correct
    tvm.testing.assert_allclose(np.squeeze(tvm_boxes), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5)
    tvm.testing.assert_allclose(np.squeeze(tvm_classes), np.squeeze(tflite_output[1]), rtol=1e-5, atol=1e-5)
    tvm.testing.assert_allclose(np.squeeze(tvm_scores), np.squeeze(tflite_output[2]), rtol=1e-5, atol=1e-5)


#######################################################################
# Mobilenet
# ---------

def test_forward_mobilenet_v1():
    """Test the Mobilenet V1 TF Lite model."""
    # MobilenetV1
    tflite_model_file = tf_testing.get_workload_official(
        "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
        "mobilenet_v1_1.0_224.tflite")
    with open(tflite_model_file, "rb") as f:
        tflite_model_buf = f.read()
    data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
    tflite_output = run_tflite_graph(tflite_model_buf, data)
    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
                                rtol=1e-5, atol=1e-5)

def test_forward_mobilenet_v2():
    """Test the Mobilenet V2 TF Lite model."""
    # MobilenetV2
    tflite_model_file = tf_testing.get_workload_official(
        "http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz",
        "mobilenet_v2_1.0_224.tflite")
    with open(tflite_model_file, "rb") as f:
        tflite_model_buf = f.read()
    data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
    tflite_output = run_tflite_graph(tflite_model_buf, data)
    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
                                rtol=1e-5, atol=1e-5)

#######################################################################
# Mobilenet V3
# ------------

def test_forward_mobilenet_v3():
    """Test the Mobilenet V3 TF Lite model."""
    # In MobilenetV3, some ops are not supported before tf 1.15 fbs schema
    if package_version.parse(tf.VERSION) < package_version.parse('1.15.0'):
        return
    tflite_model_file = tf_testing.get_workload_official(
        "https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_float.tgz",
        "v3-large_224_1.0_float/v3-large_224_1.0_float.tflite")
    with open(tflite_model_file, "rb") as f:
        tflite_model_buf = f.read()
    data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
    tflite_output = run_tflite_graph(tflite_model_buf, data)
    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
                                rtol=1e-5, atol=1e-5)

#######################################################################
# Inception
# ---------

def test_forward_inception_v3_net():
    """Test the Inception V3 TF Lite model."""
    # InceptionV3
    tflite_model_file = tf_testing.get_workload_official(
        "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz",
        "inception_v3.tflite")
    with open(tflite_model_file, "rb") as f:
        tflite_model_buf = f.read()
    data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
    tflite_output = run_tflite_graph(tflite_model_buf, data)
    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
                                rtol=1e-5, atol=1e-5)

def test_forward_inception_v4_net():
    """Test the Inception V4 TF Lite model."""
    # InceptionV4
    tflite_model_file = tf_testing.get_workload_official(
        "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz",
        "inception_v4.tflite")
    with open(tflite_model_file, "rb") as f:
        tflite_model_buf = f.read()
    data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
    tflite_output = run_tflite_graph(tflite_model_buf, data)
    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
                                rtol=1e-5, atol=1e-5)

def test_forward_qnn_inception_v1_net():
    """Test the Quantized TFLite Inception model."""
    # InceptionV1
    tflite_model_file = tf_testing.get_workload_official(
        "https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_224_quant_20181026.tgz",
        "inception_v1_224_quant.tflite")
    with open(tflite_model_file, "rb") as f:
        tflite_model_buf = f.read()

    # Test image. Checking the labels because the requantize implementation is different between
    # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
    # labels. Also, giving a real image, instead of random inputs.
    data = get_real_image(224, 224)

    tflite_output = run_tflite_graph(tflite_model_buf, data)
    tflite_predictions = np.squeeze(tflite_output)
    tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
    tvm_predictions = np.squeeze(tvm_output)
    tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
    tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)

def test_forward_qnn_mobilenet_v1_net():
    """Test the Quantized TFLite Mobilenet V1 model."""
    # MobilenetV1
    tflite_model_file = tf_testing.get_workload_official(
        "https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
        "mobilenet_v1_1.0_224_quant.tflite")
    with open(tflite_model_file, "rb") as f:
        tflite_model_buf = f.read()

    # Test image. Checking the labels because the requantize implementation is different between
    # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
    # labels. Also, giving a real image, instead of random inputs.
    data = get_real_image(224, 224)

    tflite_output = run_tflite_graph(tflite_model_buf, data)
    tflite_predictions = np.squeeze(tflite_output)
    tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
    tvm_predictions = np.squeeze(tvm_output)
    tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
    tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)

def test_forward_qnn_mobilenet_v2_net():
    """Test the Quantized TFLite Mobilenet V2 model."""
    # MobilenetV2
    tflite_model_file = tf_testing.get_workload_official(
        "https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz",
        "mobilenet_v2_1.0_224_quant.tflite")
    with open(tflite_model_file, "rb") as f:
        tflite_model_buf = f.read()

    # Test image. Checking the labels because the requantize implementation is different between
    # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
    # labels. Also, giving a real image, instead of random inputs.
    data = get_real_image(224, 224)

    tflite_output = run_tflite_graph(tflite_model_buf, data)
    tflite_predictions = np.squeeze(tflite_output)
    tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
    tvm_predictions = np.squeeze(tvm_output)
    tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
    tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)

#######################################################################
# Mobilenet V3 Quantized
# ----------------------

def test_forward_qnn_mobilenet_v3_net():
    """Test the Quantized TFLite Mobilenet V3 model."""
    # In MobilenetV3, some ops are not supported before tf 1.15 fbs schema
    if package_version.parse(tf.VERSION) < package_version.parse('1.15.0'):
        return

    tflite_model_file = tf_testing.get_workload_official(
        "https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_uint8.tgz",
        "v3-large_224_1.0_uint8/v3-large_224_1.0_uint8.tflite")
    with open(tflite_model_file, "rb") as f:
        tflite_model_buf = f.read()

    # Test image. Checking the labels because the requantize implementation is different between
    # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
    # labels. Also, giving a real image, instead of random inputs.
    data = get_real_image(224, 224)

    tflite_output = run_tflite_graph(tflite_model_buf, data)
    tflite_predictions = np.squeeze(tflite_output)
    tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
    tvm_predictions = np.squeeze(tvm_output)
    tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
    tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)

#######################################################################
# SSD Mobilenet
# -------------

def test_forward_ssd_mobilenet_v1():
    """Test the SSD Mobilenet V1 TF Lite model."""
    # SSD MobilenetV1
    tflite_model_file = tf_testing.get_workload_official(
        "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28_nopp.tgz",
        "ssd_mobilenet_v1_coco_2018_01_28_nopp.tflite")
    with open(tflite_model_file, "rb") as f:
        tflite_model_buf = f.read()
    np.random.seed(0)
    data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32')
    tflite_output = run_tflite_graph(tflite_model_buf, data)
    tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=2)
    for i in range(2):
        tvm.testing.assert_allclose(np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]),
                                    rtol=1e-5, atol=2e-5)

#######################################################################
# MediaPipe
# -------------

def test_forward_mediapipe_hand_landmark():
    """Test MediaPipe 2D hand landmark TF Lite model."""
    # MediaPipe 2D hand landmark TF
    tflite_model_file = download_testdata(
        "https://github.com/google/mediapipe/raw/master/mediapipe/models/hand_landmark.tflite",
        "hand_landmark.tflite")
    with open(tflite_model_file, "rb") as f:
        tflite_model_buf = f.read()
    data = np.random.uniform(size=(1, 256, 256, 3)).astype('float32')
    tflite_output = run_tflite_graph(tflite_model_buf, data)
    tvm_output = run_tvm_graph(tflite_model_buf, data, 'input_1', num_output=2)
    for i in range(2):
        tvm.testing.assert_allclose(np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]),
                                    rtol=1e-5, atol=1e-5)

#######################################################################
# Main
# ----
if __name__ == '__main__':
    # BatchToSpaceND
    test_forward_batch_to_space_nd()

    # SpaceToBatchND
    test_forward_space_to_batch_nd()

    # Split
    test_forward_split()

    # Transpose
    test_forward_transpose()

    # Cast
    test_forward_cast()

    # Tile
    test_forward_tile()

    # Transforms
    test_forward_concatenation()
    test_forward_pad()
    test_forward_pack()
    test_forward_unpack()
    test_forward_reshape()
    test_all_resize()
    test_forward_squeeze()
    test_forward_slice()
    test_forward_topk()
    test_forward_depthtospace()
    test_forward_spacetodepth()

    # NN
    test_forward_convolution()
    test_forward_transpose_conv()
    test_forward_logistic()
    test_forward_pooling()
    test_forward_softmax()
    test_forward_tanh()
    test_forward_relu()
    test_forward_prelu()
    test_forward_fully_connected()
    test_forward_l2_normalization()
    test_forward_local_response_normalization()

    # Elemwise
    test_all_elemwise()

    # Unary elemwise
    test_all_unary_elemwise()

    # Zeros Like
    test_forward_zeros_like()

    # Reduce
    test_all_reduce()

    # Logical
    test_all_logical()

    # Detection_PostProcess
    test_detection_postprocess()

    # End to End
    test_forward_mobilenet_v1()
    test_forward_mobilenet_v2()
    test_forward_mobilenet_v3()
    test_forward_inception_v3_net()
    test_forward_inception_v4_net()
    test_forward_ssd_mobilenet_v1()
    test_forward_mediapipe_hand_landmark()

    # End to End quantized
    test_forward_qnn_inception_v1_net()
    test_forward_qnn_mobilenet_v1_net()
    test_forward_qnn_mobilenet_v2_net()
    test_forward_qnn_mobilenet_v3_net()