# pylint: disable=import-self, invalid-name, unused-argument
"""
Tensorflow testcases
====================
This article is a test script to test tensorflow operator with NNVM.
"""
from __future__ import print_function
import numpy as np
import nnvm.compiler
import tvm
import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import graph_util
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops import init_ops
from tensorflow.core.framework import graph_pb2

import nnvm.testing.tf

#######################################################################
# Generic run functions for TVM & tensorflow
# ------------------------------------------
def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype):
    """ Generic function to compile on nnvm and execute on tvm """

    sym, params = nnvm.frontend.from_tensorflow(graph_def)
    target = 'llvm'
    if isinstance(input_data, list):
        shape_dict = {}
        dtype_dict = {}
        for i, e in enumerate(input_node):
            shape_dict[e] = input_data[i].shape
            dtype_dict[e] = input_data[i].dtype
    else:
        shape_dict = {input_node: input_data.shape}
        dtype_dict = {input_node: input_data.dtype}

    graph, lib, params = nnvm.compiler.build(sym, target, shape_dict,
                                             dtype=dtype_dict, params=params)

    ctx = tvm.cpu(0)
    from tvm.contrib import graph_runtime
    m = graph_runtime.create(graph, lib, ctx)
    # set inputs
    if isinstance(input_data, list):
        for i, e in enumerate(input_node):
            m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
    else:
        m.set_input(input_node, tvm.nd.array(input_data.astype(input_data.dtype)))

    m.set_input(**params)
    # execute
    m.run()
    # get outputs
    if isinstance(output_shape, list) and isinstance(output_dtype, list):
        tvm_output_list = []
        for i, s in enumerate(output_shape):
            tvm_output = m.get_output(i, tvm.nd.empty((s), output_dtype[i]))
            tvm_output_list.append(tvm_output.asnumpy())
        return tvm_output_list
    else:
        tvm_output = m.get_output(0, tvm.nd.empty((output_shape), output_dtype))
        return tvm_output.asnumpy()

def run_tf_graph(sess, input_data, input_node, output_node):
    """ Generic function to execute tensorflow """

    tensor = sess.graph.get_tensor_by_name(output_node)

    if isinstance(input_data, list):
        input_dict = {}
        for i, e in enumerate(input_node):
            input_dict[e] = input_data[i]
    else:
        input_dict = {input_node: input_data}

    output_data = sess.run(tensor, input_dict)
    return output_data

#######################################################################
# Pooling
# -------
def _test_pooling(input_shape, **kwargs):
    """ One iteration of pool operation with given shapes and attributes """

    x = -np.arange(
        np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1

    with tf.Graph().as_default():
        in_data = constant_op.constant(x, shape=input_shape, dtype='float32')
        # pylint: disable=unused-variable
        pool = nn_ops.pool(in_data, **kwargs)
        # pylint: enable=unused-variable

        if kwargs['pooling_type'] == 'MAX':
            out_node = 'max_pool'
            out_name = 'max_pool:0'
        else:
            out_node = 'avg_pool'
            out_name = 'avg_pool:0'

        with tf.Session() as sess:
            graph_def = tf.graph_util.convert_variables_to_constants(
                sess,
                sess.graph.as_graph_def(add_shapes=True),
                [out_node],
                )

            tf_output = run_tf_graph(sess, x, 'Const:0', out_name)
            tvm_output = run_tvm_graph(graph_def, x.astype('float32'),
                                       "Const", tf_output.shape, 'float32')
            np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3)

            sess.close()

def test_forward_pooling():
    """ Pooling """

    _test_pooling(input_shape=[2, 9, 10, 2],
                 window_shape=[1, 1],
                 padding='SAME',
                 pooling_type='MAX',
                 dilation_rate=[1, 1],
                 strides=[1, 1])
    _test_pooling(input_shape=[2, 9, 10, 2],
                 window_shape=[1, 1],
                 padding='SAME',
                 pooling_type='AVG',
                 dilation_rate=[1, 1],
                 strides=[1, 1])

    _test_pooling(input_shape=[2, 10, 9, 2],
                 window_shape=[1, 1],
                 padding='SAME',
                 pooling_type='MAX',
                 dilation_rate=[1, 1],
                 strides=[1, 1])
    _test_pooling(input_shape=[2, 10, 9, 2],
                 window_shape=[1, 1],
                 padding='SAME',
                 pooling_type='AVG',
                 dilation_rate=[1, 1],
                 strides=[1, 1])

    _test_pooling(input_shape=[2, 9, 10, 2],
                 window_shape=[2, 1],
                 padding='SAME',
                 pooling_type='MAX',
                 dilation_rate=[1, 1],
                 strides=[1, 1])
    _test_pooling(input_shape=[2, 9, 10, 2],
                 window_shape=[2, 1],
                 padding='SAME',
                 pooling_type='AVG',
                 dilation_rate=[1, 1],
                 strides=[2, 1])

    _test_pooling(input_shape=[2, 10, 9, 2],
                 window_shape=[2, 3],
                 padding='SAME',
                 pooling_type='MAX',
                 dilation_rate=[1, 1],
                 strides=[2, 1])
    _test_pooling(input_shape=[2, 10, 9, 2],
                 window_shape=[2, 3],
                 padding='SAME',
                 pooling_type='AVG',
                 dilation_rate=[1, 1],
                 strides=[1, 2])


#######################################################################
# Convolution
# -----------

def _test_convolution(tensor_in_sizes, filter_in_sizes,
                      dilations, strides, padding, data_format):
    """ One iteration of convolution with given shapes and attributes """

    total_size_1 = 1
    total_size_2 = 1
    for s in tensor_in_sizes:
        total_size_1 *= s
    for s in filter_in_sizes:
        total_size_2 *= s
    # Initializes the input tensor with array containing incrementing
    # numbers from 1.
    data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
    filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]

    with tf.Graph().as_default():
        in_data = constant_op.constant(data_array, shape=tensor_in_sizes, dtype='float32')
        in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
        strides = [1] + strides + [1]
        dilations = [1] + dilations + [1]

        # pylint: disable=unused-variable
        conv = nn_ops.conv2d(in_data,
                             in_filter,
                             strides=strides,
                             padding=padding,
                             data_format=data_format)
        # pylint: enable=unused-variable

        with tf.Session() as sess:
            graph_def = tf.graph_util.convert_variables_to_constants(
                sess,
                sess.graph.as_graph_def(add_shapes=True),
                ['Conv2D'],
                )

            tf_output = run_tf_graph(sess, np.reshape(data_array, tensor_in_sizes),
                                     'Const:0', 'Conv2D:0')
            tvm_output = run_tvm_graph(graph_def,
                                       np.reshape(data_array, tensor_in_sizes).astype('float32'),
                                       "Const", tf_output.shape, 'float32')

            np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3)

            sess.close()

def test_forward_convolution():
    _test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
    _test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
    _test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC')
    _test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC')

#######################################################################
# Reshape
# -------

def _test_reshape(data, out_shape):
    """ One iteration of reshape operation with given data and out shape """

    with tf.Graph().as_default():
        in_data = constant_op.constant(data, shape=data.shape, dtype=data.dtype)

        # pylint: disable=unused-variable
        reshape_out = array_ops.reshape(in_data, out_shape)
        # pylint: enable=unused-variable

        with tf.Session() as sess:
            graph_def = tf.graph_util.convert_variables_to_constants(
                sess,
                sess.graph.as_graph_def(add_shapes=True),
                ['Reshape'],
                )

            tf_output = run_tf_graph(sess, data,
                                     'Const:0', 'Reshape:0')
            tvm_output = run_tvm_graph(graph_def,
                                       data,
                                       "Const", tf_output.shape, data.dtype)

            np.testing.assert_allclose(tf_output, tvm_output)

            sess.close()

def test_forward_reshape():
    _test_reshape(np.arange(6.0), [2, 3])
    _test_reshape(np.arange(6), [-1, 2])
    _test_reshape(np.arange(6), [3, -1])
    _test_reshape(np.arange(6), [-1])

#######################################################################
# Squeeze
# -------

def _test_squeeze(data, squeeze_dims=None):
    """ One iteration of squeeze """

    if squeeze_dims is None:
        squeeze_dims = []

    with tf.Graph().as_default():
        in_data = constant_op.constant(data, shape=data.shape, dtype=data.dtype)

        # pylint: disable=unused-variable
        if squeeze_dims:
            squeeze_out = array_ops.squeeze(in_data, squeeze_dims)
        else:
            squeeze_out = array_ops.squeeze(in_data)
        # pylint: enable=unused-variable

        with tf.Session() as sess:
            graph_def = tf.graph_util.convert_variables_to_constants(
                sess,
                sess.graph.as_graph_def(add_shapes=True),
                ['Squeeze'],
                )

            tf_output = run_tf_graph(sess, data,
                                     'Const:0', 'Squeeze:0')
            tvm_output = run_tvm_graph(graph_def,
                                       data,
                                       "Const", tf_output.shape, data.dtype)

            np.testing.assert_allclose(tf_output, tvm_output)

            sess.close()

def test_forward_squeeze():
    """ Squeeze """

    # Nothing to squeeze.
    _test_squeeze(np.arange(2).reshape((2)))
    _test_squeeze(np.arange(6).reshape((2, 3)))

    # Squeeze the middle element away.
    _test_squeeze(np.arange(4).reshape((2, 1, 2)))

    # Squeeze on both ends.
    _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)))

    # Positive squeeze dim index.
    _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [0])
    _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [2, 4])
    _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [0, 4, 2])

    # Negative squeeze dim index.
    _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-1])
    _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5])
    _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1])

#######################################################################
# ConcatV2
# --------

def _test_concat_v2(data, dim):
    """ One iteration of ConcatV2 """

    with tf.Graph().as_default():

        # pylint: disable=unused-variable
        concat_out = gen_array_ops._concat_v2(data, dim)
        # pylint: enable=unused-variable

        with tf.Session() as sess:
            graph_def = tf.graph_util.convert_variables_to_constants(
                sess,
                sess.graph.as_graph_def(add_shapes=True),
                ['ConcatV2'],
                )

            tf_output = run_tf_graph(sess, data,
                                     ['ConcatV2/values_0:0', 'ConcatV2/values_1:0'], 'ConcatV2:0')
            tvm_output = run_tvm_graph(graph_def,
                                       data,
                                       ["ConcatV2/values_0", 'ConcatV2/values_1'],
                                       tf_output.shape, tf_output.dtype)

            np.testing.assert_allclose(tf_output, tvm_output)

            sess.close()

def _test_forward_concat_v2():
    t1 = np.array([])
    t2 = np.array([])
    test_concat_v2([t1, t2], 0)

    t1 = np.array([[1, 2, 3], [4, 5, 6]])
    t2 = np.array([[7, 8, 9], [10, 11, 12]])

    _test_concat_v2([t1, t2], 1)

#######################################################################
# Sigmoid
# -------

def _test_sigmoid(data):
    """ One iteration of sigmoid """

    with tf.Graph().as_default():
        in_data = constant_op.constant(data, shape=data.shape, dtype=data.dtype)

        # pylint: disable=unused-variable
        sigmoid_out = math_ops.sigmoid(in_data)
        # pylint: enable=unused-variable

        with tf.Session() as sess:
            graph_def = tf.graph_util.convert_variables_to_constants(
                sess,
                sess.graph.as_graph_def(add_shapes=True),
                ['Sigmoid'],
                )

            tf_output = run_tf_graph(sess, data,
                                     'Const:0', 'Sigmoid:0')
            tvm_output = run_tvm_graph(graph_def,
                                       data,
                                       "Const", tf_output.shape, data.dtype)

            np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)

            sess.close()

def test_forward_sigmoid():
    """ Sigmoid """

    _test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype('float32'))

#######################################################################
# Argmin/Argmax
# -------------

def _test_argx(func, data, **kwargs):

    with tf.Graph().as_default():
        inp = constant_op.constant(data, shape=data.shape, dtype=data.dtype, name="c0")

        # pylint: disable=unused-variable
        out = func(inp, name="argx0", **kwargs)
        # pylint: enable=unused-variable

        with tf.Session() as sess:
            graph_def = tf.graph_util.convert_variables_to_constants(
                sess=sess,
                input_graph_def=sess.graph.as_graph_def(add_shapes=True),
                output_node_names=["argx0"])

            tf_output = run_tf_graph(sess, data, input_node="c0:0", output_node="argx0:0")
            tvm_output = run_tvm_graph(graph_def, data, "c0", tf_output.shape, output_dtype='int32')

            np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)

            sess.close()

def test_argmin_argmax():
    for axis in [None,0,1,2]:
        data = np.random.uniform(size=(8,4,9)).astype('float32')
        _test_argx(tf.argmax, data=data, axis=axis)
        _test_argx(tf.argmin, data=data, axis=axis)

#######################################################################
# Variable
# --------

def _test_variable(data):
    tf.reset_default_graph()
    input_op = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
    input_tensor = array_ops.reshape(input_op, data.shape)

    size = input_tensor.shape.dims[1]
    with variable_scope.variable_scope("linear", reuse=None):
        w = variable_scope.get_variable(
            "w", shape=[size, size], dtype=input_tensor.dtype)
    # pylint: disable=unused-variable
    output_op = math_ops.matmul(input_tensor, w)
    # pylint: enable=unused-variable

    with tf.Session() as sess:
        sess.run(variables.global_variables_initializer())
        final_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            sess.graph.as_graph_def(add_shapes=True),
            ['MatMul'],
            )

        tf_output = run_tf_graph(sess, data, 'Placeholder:0', 'MatMul:0')
        tvm_output = run_tvm_graph(final_graph_def, data,
                                   "Placeholder", tf_output.shape, data.dtype)

        np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
        sess.close()

def test_forward_variable():
    """Variable type op test"""
    _test_variable(np.random.uniform(size=(32, 100)).astype('float32'))


#######################################################################
# LSTM
# ----
def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
    tf.reset_default_graph()
    input_size = num_hidden
    input_data = np.full((batch_size, input_size), 1., dtype=dtype)
    in_state_c = np.full((num_layers, batch_size, num_hidden), 0.1, dtype=dtype)
    in_state_h = np.full((num_layers, batch_size, num_hidden), 0.1, dtype=dtype)

    def _get_tensorflow_output():
        with tf.Session() as sess:
            with variable_scope.variable_scope(
                "root", initializer=init_ops.constant_initializer(0.5)):
                m0 = array_ops.zeros([batch_size, num_hidden])
                m1 = array_ops.zeros([batch_size, num_hidden])
                x=tf.placeholder(shape=(batch_size, input_size), dtype=dtype)
                g, ((out_m0, out_m1)) = \
                     tf.contrib.rnn.LSTMBlockCell(num_hidden,
                                                  forget_bias=forget_bias)(x, ((m0, m1)))
                sess.run([variables.global_variables_initializer()])
                res = sess.run([g, out_m0, out_m1], {
                    x.name: np.array([[1., 1.]]),
                    m0.name: 0.1 * np.ones([batch_size, num_hidden]),
                    m1.name: 0.1 * np.ones([batch_size, num_hidden]),
                })
            graph_def = sess.graph.as_graph_def(add_shapes=True)
            final_graph_def = graph_util.convert_variables_to_constants(
                sess,
                graph_def,
                ['root/lstm_cell/LSTMBlockCell'])
            return final_graph_def, res

    graph_def, tf_out = _get_tensorflow_output()
    tvm_output = run_tvm_graph(graph_def, [input_data, in_state_c, in_state_h],
                               ['root/Placeholder', 'root/lstm_cell/LSTMBlockCell_c',
                                'root/lstm_cell/LSTMBlockCell_h'],
                               [tf_out[0].shape, (2, batch_size, num_hidden)],
                               [tf_out[0].dtype, tf_out[1].dtype])

    if isinstance(tvm_output, list):
        out = tvm_output[0]
        out_state = tvm_output[1]
        out_state_tup = np.split(out_state, indices_or_sections=2, axis=0)
        out_state_c = np.reshape(out_state_tup[0], (batch_size, num_hidden))
        out_state_h = np.reshape(out_state_tup[1], (batch_size, num_hidden))
        tvm_out = [out, out_state_c, out_state_h]
        np.testing.assert_allclose(tf_out, tvm_out, rtol=1e-3, atol=1e-3)

def test_forward_lstm():
    '''test LSTM block cell'''
    _test_lstm_cell(1, 2, 1, 0.0, 'float32')


#######################################################################
# StridedSlice
# ------------

def _test_stridedslice(ip_shape, begin, end, stride, dtype,
                             begin_mask=0, end_mask=0, new_axis_mask=0,
                             shrink_axis_mask=0, ellipsis_mask=0):
    tf.reset_default_graph()
    in_data = tf.placeholder(dtype, ip_shape, name="in_data")
    tf.strided_slice(in_data, begin, end, stride, begin_mask=begin_mask,
                         end_mask=end_mask, new_axis_mask=new_axis_mask,
                         shrink_axis_mask=shrink_axis_mask,
                         ellipsis_mask=ellipsis_mask, name="strided_slice")
    np_data = np.random.uniform(size=ip_shape).astype(dtype)

    with tf.Session() as sess:
        final_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            sess.graph.as_graph_def(add_shapes=True),
            ['strided_slice'])
        tf_output = run_tf_graph(sess, np_data,
                                 'in_data:0', 'strided_slice:0')
        tvm_output = run_tvm_graph(final_graph_def, np_data,
                                   "in_data", tf_output.shape, np_data.dtype)
        np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
        sess.close()

def test_forward_stridedslice():
    '''test StridedSlice'''
    _test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
    _test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], 'float32', ellipsis_mask=8)
    _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [2, 1, 1], 'float32', new_axis_mask=5)
    _test_stridedslice((3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=4)
    _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=4, new_axis_mask=2)
    _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3)
    _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3)
    _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=2)
    _test_stridedslice((3,4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2)
    _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2, new_axis_mask=2)
    _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=1, new_axis_mask=2)
    _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2, new_axis_mask=1)
    _test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0], [2, 3], [1, 1], 'float32', shrink_axis_mask=5, new_axis_mask=1)
    _test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
                       'float32', shrink_axis_mask=5, new_axis_mask=1, ellipsis_mask=2, begin_mask=8, end_mask=8)
    _test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
                       'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, end_mask=5)
    _test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
                       'float32', shrink_axis_mask=16, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, end_mask=5)
    _test_stridedslice((3, 4, 5, 4, 5, 6), [1, 2, 0, -3], [4, 5, 3, 3], [2, 2, 1, 1],
                       'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5,
                       end_mask=8)


#######################################################################
# Gather
# ------

def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
    tf.reset_default_graph()
    in_data = tf.placeholder(dtype, ip_shape, name="in_data")
    indices = tf.placeholder("int32", indice_shape, name="indices")
    tf.gather(in_data, indices, axis=axis)
    np_data = np.random.uniform(size=ip_shape).astype(dtype)

    def _fill_indices(indice_value):
        indices = np.array(ip_shape, dtype=dtype)
        if isinstance(indice_value, int):
            indices = np.array([indice_value], dtype='int32')
        else:
            indices = np.asarray(indice_value, dtype='int32')
        return indices
    np_indices = _fill_indices(indice_value)

    with tf.Session() as sess:
        final_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            sess.graph.as_graph_def(add_shapes=True),
            ['GatherV2'])
        tf_output = run_tf_graph(sess, [np_data, np_indices], ['in_data:0',
                                 'indices:0'], 'GatherV2:0')
        tvm_output = run_tvm_graph(final_graph_def, [np_data, np_indices],
                                   ['in_data', 'indices'], tf_output.shape, dtype)
        np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
        sess.close()

def test_forward_gather():
    '''test gather layer'''
    _test_gather((4,), (1,), 1, 0, 'int32')
    _test_gather((4,), (1,), 1, 0, 'float32')
    _test_gather((1,4), (1,), [0], 0, 'int32')
    _test_gather((4,), (1,2,2), [[[1,0],[0,1]]], 0, 'float32')
    _test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'int32')
    _test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 1, 'int32')
    _test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'float32')
    _test_gather((3,3,3), (1,1,2), [[[1,0]]], 0, 'int32')
    _test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32')
    _test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')


#######################################################################
# Multi Input to graph
# --------------------

def test_forward_multi_input():
    with tf.Graph().as_default():
        in1 = tf.placeholder(tf.int32, shape=[3, 3], name='in1')
        in2 = tf.placeholder(tf.int32, shape=[3, 3], name='in2')
        in3 = tf.placeholder(tf.int32, shape=[3, 3], name='in3')
        in4 = tf.placeholder(tf.int32, shape=[3, 3], name='in4')

        out1 = tf.add(in1, in2, name='out1')
        out2 = tf.subtract(in3, in4, name='out2')

        out = tf.multiply(out1, out2, name='out')

        with tf.Session() as sess:
            graph_def = tf.graph_util.convert_variables_to_constants(
                sess,
                sess.graph.as_graph_def(add_shapes=True),
                ['out'],
                )

            in_data = np.arange(9, dtype='int32').reshape([3, 3])

            tf_output = run_tf_graph(sess, [in_data, in_data, in_data, in_data ],
                                     ['in1:0', 'in2:0', 'in3:0', 'in4:0'], 'out:0')
            tvm_output = run_tvm_graph(graph_def,
                                       [in_data, in_data, in_data, in_data ],
                                       ['in1', 'in2', 'in3', 'in4'],
                                       tf_output.shape, tf_output.dtype)

            np.testing.assert_allclose(tf_output, tvm_output)

            sess.close()

#######################################################################
# Resize Bilinear
# ---------------

def _test_resize_bilinear(in_shape, to_shape, align_corners):
    """ One iteration of resize bilinear """

    data = np.random.uniform(size=in_shape).astype('float32')
    shape_data = np.array(to_shape).astype('int32')

    with tf.Graph().as_default():
        in_data = constant_op.constant(data, shape=data.shape, dtype=data.dtype)
        shape_data = constant_op.constant(shape_data, shape=shape_data.shape, dtype=shape_data.dtype)

        # pylint: disable=unused-variable
        resize_out = tf.image.resize_bilinear(in_data, shape_data, align_corners=align_corners)
        # pylint: enable=unused-variable

        with tf.Session() as sess:
            graph_def = tf.graph_util.convert_variables_to_constants(
                sess,
                sess.graph.as_graph_def(add_shapes=True),
                ['ResizeBilinear'],
                )

            tf_output = run_tf_graph(sess, data,
                    'Const:0', 'ResizeBilinear:0')

            tvm_output = run_tvm_graph(graph_def,
                                       data,
                                       "Const", tf_output.shape, data.dtype)

            np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3)

            sess.close()

def test_forward_resize_bilinear():
    """ Resize Bilinear """

    _test_resize_bilinear((4, 16, 32, 32), [50, 50], False)
    _test_resize_bilinear((6, 32, 64, 64), [20, 20], True)


#######################################################################
# Inception V3
# ------------
def test_forward_inception_v3():
    '''test inception V3 model'''
    with tf.Graph().as_default():
        graph_def = nnvm.testing.tf.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb')
        # Call the utility to import the graph definition into default graph.
        graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)

        data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')

        with tf.Session() as sess:
            tf_output = run_tf_graph(sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0')
            tvm_output = run_tvm_graph(graph_def, data, 'input', tf_output.shape, 'float32')
            np.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5)

#######################################################################
# Inception V1
# ------------
def test_forward_inception_v1():
    '''test inception V1 model'''
    with tf.Graph().as_default():
        graph_def = nnvm.testing.tf.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb")
        # Call the utility to import the graph definition into default graph.
        graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)

        # Build an image from random data.
        from PIL import Image
        from tvm.contrib import util

        img_array = np.random.uniform(size=(1, 600, 600, 3)).astype("uint8")
        img = Image.frombuffer('RGB', (600, 600), img_array.tostring(), 'raw', 'RGB', 0, 1)
        temp = util.tempdir()
        img_path = temp.relpath("tf-test.jpg")
        img.save(img_path);

        import os.path
        if not tf.gfile.Exists(os.path.join(img_path)):
            tf.logging.fatal('File does not exist %s', image)
        data = tf.gfile.FastGFile(os.path.join(img_path), 'rb').read()

        temp.remove()

        # Extract tensorflow decoded image frame for tvm input
        with tf.Session() as sess:
            tvm_data = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'DecodeJpeg:0')

        with tf.Session() as sess:
            tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0')
            tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents', tf_output.shape, 'float32')
            np.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5)

#######################################################################
# Mobilenet
# ---------
def test_forward_mobilenet():
    '''test mobilenet model'''
    with tf.Graph().as_default():
        graph_def = nnvm.testing.tf.get_workload("MobilenetV1/mobilenet_v1_1.0_224_frozen-with-shapes.pb")
        # Call the utility to import the graph definition into default graph.
        graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)

        data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
        out_node = 'MobilenetV1/Predictions/Reshape_1'

        with tf.Session() as sess:
            tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
            tvm_output = run_tvm_graph(graph_def, data, 'input', tf_output.shape, 'float32')
            np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)

#######################################################################
# PTB
# ---
dir(tf.contrib)
def test_forward_ptb():
    '''test ptb model'''
    config = nnvm.testing.tf.get_config()
    num_steps = config.num_steps
    num_hidden = config.hidden_size
    num_layers = config.num_layers
    batch_size = config.batch_size
    vocab_size = config.vocab_size
    out_sample_shape = (batch_size, vocab_size)
    out_state_shape = (num_layers, 2, batch_size, num_hidden)
    #Sample input
    inpt = "we have no useful information on"
    cnt_sample = 20

    def _pretty_print(items, is_char_model, id2word):
        if not is_char_model:
            return ' '.join([id2word[x] for x in items])
        else:
            return ''.join([id2word[x] for x in items]).replace('_', ' ')

    def _get_tvm_graph_module(graph_def):
        sym, params = nnvm.frontend.from_tensorflow(graph_def)

        #Cell inputs 'c and 'h' consist of all layers values
        shape_dict = {'Model/Placeholder': (batch_size, num_steps),
                      'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':(num_layers, batch_size, num_hidden),
                      'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':(num_layers, batch_size, num_hidden)}
        dtype_dict = {'Model/Placeholder': 'int32',
                      'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':'float32',
                      'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':'float32'}
        target = 'llvm'
        graph, lib, params = nnvm.compiler.build(sym, target, shape_dict,
                                                 dtype=dtype_dict, params=params)
        from tvm.contrib import graph_runtime
        ctx = tvm.cpu(0)
        return params, graph_runtime.create(graph, lib, ctx)

    def _do_tvm_sample(model, data, in_states, params, num_samples):
        """Sampled from the model"""
        samples = []
        state = in_states
        sample = None
        def _get_sample(data, state):
            input_data = np.full((batch_size, num_steps), data, dtype="int32")
            in_state_tup = np.split(state, indices_or_sections=2, axis=1)
            in_state_c = np.reshape(in_state_tup[0], (num_layers, batch_size, num_hidden))
            in_state_h = np.reshape(in_state_tup[1], (num_layers, batch_size, num_hidden))

            model.set_input('Model/Placeholder', tvm.nd.array(input_data.astype("int32")))
            model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c',
                        tvm.nd.array(in_state_c.astype("float32")))
            model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h',
                        tvm.nd.array(in_state_h.astype("float32")))
            model.set_input(**params)
            model.run()
            tvm_output = model.get_output(0, tvm.nd.empty(out_sample_shape,
                                                      "float32")).asnumpy()
            state_output = model.get_output(1, tvm.nd.empty(out_state_shape,
                                                        "float32")).asnumpy()
            sample = nnvm.testing.tf.pick_from_weight(tvm_output[0])
            return sample, state_output

        for x in data:
            sample, state = _get_sample(x, state)

        if sample is not None:
            samples.append(sample)
        else:
            samples.append(0)

        k = 1
        while k < num_samples:
            sample, state = _get_sample(samples[-1], state)
            samples.append(sample)
            k += 1
        return samples, state

    with tf.Graph().as_default():
        word_to_id, id_to_word, graph_def = nnvm.testing.tf.get_workload_ptb()
        vocab_size = len(word_to_id)
        # Call the utility to import the graph definition into default graph.
        graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
        sess = tf.Session()

    #TVM graph module creation
    params, m = _get_tvm_graph_module(graph_def)

    # Create 10 predicted statments of 20 words
    cnt_stm = 0
    while cnt_stm < 10:
        cnt_stm += 1
        in_state = np.full((num_layers, 2, batch_size, num_hidden), 0, dtype="float32")
        seed_for_sample = inpt.split()
        tvm_samples, tvm_state = _do_tvm_sample(m, [word_to_id[word] \
                                                    for word in seed_for_sample],
                                                in_state, params, cnt_sample)
        tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word)
        tf_samples, tf_state = nnvm.testing.tf.do_tf_sample(sess,
                                [word_to_id[word] for word in seed_for_sample],
                                in_state, cnt_sample)
        tf_sample_str = _pretty_print(tf_samples, False, id_to_word)
        inpt = tvm_sample_str
        np.testing.assert_allclose(tf_samples, tvm_samples, rtol=1e-5, atol=1e-5)
        assert(tvm_sample_str == tf_sample_str)

#######################################################################
# LRN (Local Response Normalization)
# ----------------------------------

def _test_lrn(ishape, size, axis, bias, alpha, beta):
    """ testing local response normalization """
    lrn_depth_radius = size / 2

    inp_array = np.random.uniform(size=ishape).astype(np.float32)

    with tf.Graph().as_default():
        in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype, name="lrn0_data")
        nn_ops.local_response_normalization(in1,
                                            name="lrn",
                                            depth_radius=lrn_depth_radius,
                                            bias=bias,
                                            alpha=alpha,
                                            beta=beta)

        with tf.Session() as sess:
            graph_def = tf.graph_util.convert_variables_to_constants(
                sess,
                sess.graph.as_graph_def(add_shapes=True),
                ['lrn'],)

            tf_output = run_tf_graph(sess, inp_array, 'lrn0_data:0', 'lrn:0')
            tvm_output = run_tvm_graph(graph_def,
                                       inp_array,
                                       "lrn0_data", tf_output.shape, tf_output.dtype)
            np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3)
            sess.close()

def test_forward_lrn():
    _test_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5)

# Main
# ----
if __name__ == '__main__':
    test_forward_convolution()
    test_forward_pooling()
    test_forward_reshape()
    test_forward_squeeze()
    test_forward_sigmoid()
    if tf.__version__ == '1.4.1':
        _test_forward_concat_v2()
    test_forward_multi_input()
    test_forward_inception_v3()
    test_forward_inception_v1()
    test_forward_mobilenet()
    test_forward_variable()
    test_forward_resize_bilinear()
    test_forward_lstm()
    test_forward_stridedslice()
    test_forward_gather()
    test_forward_ptb()
    test_forward_lrn()