# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines
"""ONNX: Open Neural Network Exchange frontend for Relay."""
from __future__ import absolute_import as _abs

import numpy as np
import tvm
from ... import nd as _nd
from .. import analysis
from .. import expr as _expr
from .. import module as _module
from .. import op as _op
from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels
from .common import infer_type, infer_value, infer_value_simulated, get_name

__all__ = ['from_onnx']


def get_numpy(tensor_proto):
    """Grab data in TensorProto and convert to numpy array."""
    try:
        from onnx.numpy_helper import to_array
    except ImportError as e:
        raise ImportError(
            "Unable to import onnx which is required {}".format(e))
    return to_array(tensor_proto)


def dimension_picker(prefix, surfix=''):
    """Check that dimensions are supported."""
    def _impl(attr):
        kernel = attr['kernel_shape']
        if len(kernel) == 1:
            return prefix + '1d' + surfix
        if len(kernel) == 2:
            return prefix + '2d' + surfix
        msg = 'Only 1D and 2D kernels are supported for operator {}.'
        op_name = prefix + '1d/2d'
        raise tvm.error.OpAttributeInvalid(msg.format(op_name))

    return _impl


def revert_caffe2_pad(pads):
    """Caffe2 requires two times the normal padding."""
    if len(pads) == 4:
        pads = pads[:2]
    elif len(pads) == 2:
        pass
    else:
        raise tvm.error.OpAttributeInvalid(
            'Number of pads must be either 2 or 4.')
    return pads


def get_pad_pair(input1d, kernel1d, stride1d):
    """infer pad size"""
    if input1d % stride1d == 0:
        pad = max(kernel1d - stride1d, 0)
    else:
        pad = max(kernel1d - (input1d % stride1d), 0)
    pad_before = pad // 2
    pad_after = pad - pad_before
    return [pad_before, pad_after]


def onnx_default_layout(dims):
    if dims == 1:
        return 'NCW'
    if dims == 2:
        return 'NCHW'

    msg = "Only 1d and 2d layouts are currently supported"
    raise tvm.error.OpAttributeInvalid(msg.format(op_name))


def onnx_storage_order2layout(storage_order, dims=2):
    """converter of onnx storage order parameter to tvm storage order format"""
    if storage_order not in (0, 1):
        raise tvm.error.OpAttributeInvalid('Mode of storage_order must be either 0 or 1')

    if dims == 1:
        return 'NCW' if storage_order == 0 else 'NWC'
    if dims == 2:
        return 'NCHW' if storage_order == 0 else 'NHWC'

    msg = "Only 1d and 2d layouts are currently supported"
    raise tvm.error.OpAttributeInvalid(msg.format(op_name))


def dimension_constraint():
    def _dim_check(attrs):
        if len(attrs['kernel_shape']) == 2 or len(attrs['kernel_shape']) == 1:
            return True
        return False

    return _dim_check, "Only 1d and 2d kernel supported."


class OnnxOpConverter(object):
    """ A helper class for holding onnx op converters.
    """

    @classmethod
    def get_converter(cls, opset):
        """ Get converter matches given opset.

        Parameters
        ----------
        opset: int
            opset from model.

        Returns
        -------
        converter, which should be `_impl_vx`. Number x is the biggest
            number smaller than or equal to opset belongs to all support versions.
        """
        versions = [
            int(d.replace('_impl_v', '')) for d in dir(cls) if '_impl_v' in d
        ]
        versions = sorted(versions + [opset])
        version = versions[
            max([i for i, v in enumerate(versions) if v == opset]) - 1]
        if hasattr(cls, '_impl_v{}'.format(version)):
            return getattr(cls, '_impl_v{}'.format(version))
        raise NotImplementedError(
            'opset version {} of {} not implemented'.format(
                version, cls.__name__))


class Unary(OnnxOpConverter):
    """ A helper class for unary op converters.
    """
    name = ''

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        assert len(inputs) == 1, "Unary math op {} takes 1 input, {} given".format(
            cls.name, len(inputs))
        op_name = cls.name
        return get_relay_op(op_name)(*inputs)


class Elemwise(OnnxOpConverter):
    """ A helper class for elemwise op converters.
    """
    name = ''

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(
            cls.name, len(inputs))
        op_name = cls.name
        conv_ops = ["conv2d", "conv2d_transpose"]
        if attr.get('broadcast', 0) and any(x in str(inputs[0]) for x in conv_ops):
            # TODO(zhreshold): remove hard coded infershape
            axis = int(attr.get('axis', 0))
            inputs[1] = _op.expand_dims(inputs[1], axis=axis, num_newaxis=2)
        return get_relay_op(op_name)(*inputs)


class Pool(OnnxOpConverter):
    """ A helper class for pool op converters.
    """
    name = ''

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        input_shape = infer_shape(inputs[0])
        if 'auto_pad' in attr:
            attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
            if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
                pad_tuple = []
                for axis in range(len(input_shape) - 2):
                    axis_shape = input_shape[2 + axis]
                    stride = attr['strides'][axis]
                    kernel = attr['kernel_shape'][axis]
                    pad = get_pad_pair(axis_shape, kernel, stride)
                    pad_tuple.append(pad)
                pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair])
                attr['pads'] = pad_tuple
            elif attr['auto_pad'] == 'VALID':
                attr['pads'] = 0
            elif attr['auto_pad'] == 'NOTSET':
                pass
            else:
                msg = 'Value {} in attribute "auto_pad" of operator {} is invalid.'
                raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad'], cls.name))
            attr.pop("auto_pad")

        if 'storage_order' in attr:
            attr['layout'] = onnx_storage_order2layout(attr['storage_order'],
                                                       dims=(len(input_shape) - 2))
        else:
            attr['layout'] = onnx_default_layout(dims=(len(input_shape) - 2))

        return AttrCvt(
            op_name=dimension_picker(cls.name),
            transforms={
                'kernel_shape': 'pool_size',
                'pads': ('padding', 0)
            },
            ignores=['dilations'],
            custom_check=dimension_constraint())(inputs, attr, params)


class Absolute(Unary):
    """ Operator converter for Absolute.
    """
    name = 'abs'


class Add(Elemwise):
    """ Operator converter for Add.
    """
    name = 'add'


class AveragePool(Pool):
    """ Operator converter for AveragePool.
    """
    name = 'avg_pool'


class BatchNorm(OnnxOpConverter):
    """ Operator converter for BatchNorm.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        # TODO(zhreshold): 'spatial' is not properly handled here.
        out = AttrCvt(
            op_name='batch_norm',
            ignores=['spatial', 'is_test', 'consumed_inputs', 'momentum'])(inputs, attr,
                                                                           params)
        return out[0]


class InstanceNorm(OnnxOpConverter):
    """ Operator converter for BatchNorm.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return AttrCvt(op_name='instance_norm')(inputs, attr, params)


class Conv(OnnxOpConverter):
    """ Operator converter for Conv.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        # Use shape of input to determine convolution type.
        input_shape = infer_shape(inputs[0])

        if 'auto_pad' in attr:
            attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
            if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
                pad_tuple = []
                for axis in range(len(input_shape) - 2):
                    axis_shape = input_shape[2 + axis]
                    stride = attr['strides'][axis]
                    kernel = attr['kernel_shape'][axis]
                    dilation = attr['dilations'][axis]
                    dilated_kernel = (kernel - 1) * dilation + 1
                    pad = get_pad_pair(axis_shape, dilated_kernel, stride)
                    pad_tuple.append(pad)
                pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair])
                attr['pads'] = pad_tuple
            elif attr['auto_pad'] == 'VALID':
                attr['pads'] = tuple([0 for i in range(len(input_shape) - 2)])
            elif attr['auto_pad'] == 'NOTSET':
                pass
            else:
                msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'
                raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad']))
            attr.pop('auto_pad')

        out = AttrCvt(
            op_name=dimension_picker('conv'),
            transforms={
                'kernel_shape': 'kernel_size',
                'dilations': ('dilation', 1),
                'pads': ('padding', 0),
                'group': ('groups', 1)
            },
            custom_check=dimension_constraint())(inputs[:2], attr, params)

        use_bias = len(inputs) == 3
        if use_bias:
            out = _op.nn.bias_add(out, inputs[2])
        return out


class ConvTranspose(OnnxOpConverter):
    """ Operator converter for ConvTranspose.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        # get number of channels
        channels = infer_channels(inputs[1], True)
        attr['channels'] = channels
        groups = attr.pop('group')
        attr['groups'] = groups
        # infer pads for auto_pad
        if 'auto_pad' in attr:
            attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
            if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
                input_shape = infer_shape(inputs[0])
                in_h, in_w = input_shape[2], input_shape[3]
                stride_h, stride_w = attr['strides']
                kernel_h, kernel_w = attr['kernel_shape']
                dilation_h, dilation_w = attr['dilations']
                dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
                dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
                pad_v = get_pad_pair(in_h, dilated_kernel_h, stride_h)
                pad_h = get_pad_pair(in_w, dilated_kernel_w, stride_w)
                attr['pads'] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1])
            elif attr['auto_pad'] == 'VALID':
                attr['pads'] = (0, 0)
            elif attr['auto_pad'] == 'NOTSET':
                pass
            else:
                msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'
                raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad']))
            attr.pop('auto_pad')

        out = AttrCvt(
            op_name=dimension_picker('conv', '_transpose'),
            transforms={
                'kernel_shape': 'kernel_size',
                'dilations': ('dilation', (0, 0)),
                'pads': ('padding', (0, 0), revert_caffe2_pad)
            },
            disables=['output_shape'],
            custom_check=dimension_constraint())(inputs[:2], attr, params)
        use_bias = len(inputs) == 3
        if use_bias:
            out = _op.nn.bias_add(out, inputs[2])
        return out


class Div(Elemwise):
    """ Operator converter for Divide.
    """
    name = 'divide'


class Elu(OnnxOpConverter):
    """ Operator converter for Elu.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = float(attr.get('alpha', 1.0))
        return _expr.const(-alpha) * _op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) + \
                                     _op.nn.relu(inputs[0])


class Gemm(OnnxOpConverter):
    """ Operator converter for Gemm.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(
            len(inputs))
        # Y = alpha * A * B + beta * C
        alpha = float(attr.get('alpha', 1.0))
        beta = float(attr.get('beta', 1.0))
        transA = int(attr.get('transA', 0))
        transB = int(attr.get('transB', 0))
        # get number of channels
        channels = infer_channels(inputs[1], not transB)
        if transA:
            inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
        if not transB:
            inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
        inputs[0] = _op.nn.batch_flatten(inputs[0])
        out = _op.nn.dense(_expr.const(alpha) * inputs[0],
                           inputs[1], units=channels)
        return _op.nn.bias_add(out, _expr.const(beta) * inputs[2])


class MatMul(OnnxOpConverter):
    """ Operator converter for MatMul.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs))
        # Need to check input shape as batch matmul must be supported.
        a_shape = infer_shape(inputs[0])
        # When performing a batch matmul, we need to properly handle N-dim shapes.
        if len(a_shape) > 2:
            b_shape = infer_shape(inputs[1])
            # Convert a and b into 3 dimensional tensors.
            a = _op.reshape(inputs[0], [-1, a_shape[-2], a_shape[-1]])
            b = _op.reshape(inputs[1], [-1, b_shape[-2], b_shape[-1]])
            # Broadcast b to match batch size of a
            new_b_shape = list(infer_shape(b))
            new_a_shape = infer_shape(a)
            if new_a_shape[0] > new_b_shape[0]:
                new_b_shape[0] = new_a_shape[0]
                b = _op.broadcast_to(b, new_b_shape)
            # Transpose matrix dimensions of b.
            b = _op.transpose(b, [0, 2, 1])
            # Perform a batch matmul.
            output = _op.nn.batch_matmul(a, b)
            # Reshape output to original dimensions.
            return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
        # Otherwise a simple dense op will get the job done.
        input_1_t = _op.transpose(inputs[1], axes=(1, 0))
        return _op.nn.dense(inputs[0], input_1_t)


class MaxPool(Pool):
    """ Operator converter for MaxPool
    """
    name = 'max_pool'


class Mul(Elemwise):
    """ Operator converter for Multiply.
    """
    name = 'multiply'


class Pad(OnnxOpConverter):
    """ Operator converter for Pad.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        pad_width = []
        pads = attr.pop('paddings')
        dims = int(len(pads) / 2)
        for i in range(dims):
            pad_width.append((pads[i], pads[i+dims]))
        attr['pad_width'] = pad_width
        pad_mode = attr.get('mode', 'constant').decode('utf-8')
        if pad_mode in ['constant', 'edge', 'reflect']:
            attr['pad_mode'] = pad_mode
            attr.pop('mode', None)
        else:
            raise tvm.error.OpAttributeInvalid(
                'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.')

        return AttrCvt(
            _op.nn.pad,
            transforms={
                'value': 'pad_value',
            },
            )(inputs, attr, params)

    @classmethod
    def _impl_v2(cls, inputs, attr, params):
        pad_width = []
        pads = attr.pop('pads')
        dims = int(len(pads) / 2)
        for i in range(dims):
            pad_width.append((pads[i], pads[i+dims]))
        attr['pad_width'] = pad_width
        pad_mode = attr.get('mode', 'constant').decode('utf-8')
        if pad_mode in ['constant', 'edge', 'reflect']:
            attr['pad_mode'] = pad_mode
            attr.pop('mode', None)
        else:
            raise tvm.error.OpAttributeInvalid(
                'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.')

        return AttrCvt(
            'pad',
            transforms={
                'value': 'pad_value',
            },
            )(inputs, attr, params)


class ParametricSoftPlus(OnnxOpConverter):
    """ Operator converter for ParametricSoftPlus.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = _expr.const(float(attr.get('alpha', 1.0)))
        beta = _expr.const(float(attr.get('beta', 1.0)))
        return _op.log(_op.exp(beta * inputs[0]) + _expr.const(1.)) * alpha


class Prelu(OnnxOpConverter):
    """ Operator converter for Prelu.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs))
        return _op.nn.prelu(inputs[0], inputs[1])


class Reciprocal(OnnxOpConverter):
    """ Operator converter for Reciprocal.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _expr.const(1.0) / inputs[0]


class Flatten(OnnxOpConverter):
    """ Operator converter for Flatten.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        axis = attr.get('axis', 1)
        if axis == 1:
            out = _op.nn.batch_flatten(inputs[0])
        else:
            newshape = [0] * (axis + 1)
            newshape[axis] = -1
            out = _op.reshape(inputs[0], list(newshape))
        return out


class Reshape(OnnxOpConverter):
    """ Operator converter for Reshape.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _op.reshape(inputs[0], attr['shape'])

    @classmethod
    def _impl_v5(cls, inputs, attr, params):
        if get_name(inputs[1]) in params:
            # pop shape out of parameters since it wont be needed later.
            shape = tuple(params.pop(inputs[1].name_hint).asnumpy())
            out = _op.reshape(inputs[0], shape)
        else:
            data, shape = inputs
            static_shape = infer_value_simulated(shape, params)
            out = _op.reshape(data, newshape=tuple(
                static_shape.asnumpy().astype('int32')))
        return out


class DepthToSpace(OnnxOpConverter):
    """ Operator converter for DepthToSpace.
    """

    @classmethod
    def _impl_v11(cls, inputs, attr, params):

        block_size = int(attr['blocksize'])
        mode = attr.get("mode", "DCR")
        return _op.nn.depth_to_space(inputs[0], block_size, mode=mode)


class SpaceToDepth(OnnxOpConverter):
    """ Operator converter for SpaceToDepth.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):

        block_size = int(attr['blocksize'])
        return _op.nn.space_to_depth(inputs[0], block_size)


class Concat(OnnxOpConverter):
    """ Operator converter for Concat.
    """

    @classmethod
    def _impl_v1(cls, inputs, args, params):
        return AttrCvt(op_name='concatenate')((inputs,), args)

class Scale(OnnxOpConverter):
    """ Operator converter for Scale.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        scale = float(attr.get('scale', 1.0))
        return inputs[0] * _expr.const(scale)


class Selu(OnnxOpConverter):
    """ Operator converter for Selu.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = float(attr.get('alpha', 1.6732))
        gamma = float(attr.get('gamma', 1.0507))
        return _expr.const(gamma) * (_expr.const(-alpha) *
                                     _op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) +
                                     _op.nn.relu(inputs[0]))


class ScaledTanh(OnnxOpConverter):
    """ Operator converter for ScaledTanh.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = float(attr.get('alpha', 1.0))
        beta = float(attr.get('beta', 1.0))
        return _op.tanh(_expr.const(beta) * inputs[0]) * _expr.const(alpha)


class SoftPlus(OnnxOpConverter):
    """ Operator converter for SoftPlus.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _op.log(_op.exp(inputs[0]) + _expr.const(1.))


class Softsign(OnnxOpConverter):
    """ Operator converter for Softsign.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return inputs[0] / (_expr.const(1.) + Absolute.get_converter(1)(inputs, attr, params))


class Sub(Elemwise):
    """ Operator converter for Subtract.
    """
    name = 'subtract'


class Sum(OnnxOpConverter):
    """ Operator converter for Sum.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        # Onnx Sum Operator
        for in_index in range(len(inputs) - 1):
            inputs[in_index + 1] = _op.add(inputs[in_index], inputs[in_index + 1])

        return inputs[len(inputs) - 1]


class ThresholdedRelu(OnnxOpConverter):
    """ Operator converter for ThresholdedRelu.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = float(attr.get('alpha', 0.0))
        alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha))
        mask = _op.greater(inputs[0], alpha_tensor).astype("float32")
        return inputs[0] * mask


def _broadcast_constraint():

    def _broadcast_check(attrs):
        if attrs.get('axis', None):
            return False
        return True

    return _broadcast_check, "Specifying broadcast axis not allowed."


def _fully_connected(opset):

    def _impl(inputs, attr, params):
        # get number of channels
        channels = infer_channels(inputs[1], params)
        attr['units'] = channels
        return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)

    return _impl


class Upsample(OnnxOpConverter):
    """ Operator converter for Upsample (nearest mode).
    """

    @classmethod
    def _impl_v9(cls, inputs, attr, params):
        scales = attr.get('scales')
        if not scales:
            #Here we are going to higher OPSET version.
            assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs))
            scales = params[inputs[1].name_hint].asnumpy()
            inputs = inputs[:1]
        assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0
        mode = attr.get('mode')
        if mode == b'nearest':
            method = "nearest_neighbor"
        elif mode == b'linear':
            method = "bilinear"
        else:
            raise tvm.error.OpAttributeInvalid(
                'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode))
        attr = {'scale_h': scales[-2], 'scale_w': scales[-1], 'method': method,
                'layout': 'NCHW', 'align_corners': True}
        return AttrCvt('upsampling')(inputs, attr)


class Shape(OnnxOpConverter):
    """ Operator converter for Shape.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _op.shape_of(inputs[0], "int64")

class Cast(OnnxOpConverter):
    """ Operator converter for Cast.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)

    @classmethod
    def _impl_v5(cls, inputs, attr, params):
        try:
            from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
            attr['to'] = str(TENSOR_TYPE_TO_NP_TYPE[attr['to']])
        except ImportError as e:
            raise ImportError(
                "Unable to import onnx.mapping which is required {}".format(e))
        return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)


class Unsqueeze(OnnxOpConverter):
    """ Operator converter for Unsqueeze.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        for axes in attr['axes']:
            inputs[0] = _op.expand_dims(inputs[0], axis=axes, num_newaxis=1)
        return inputs[0]


class Split(OnnxOpConverter):
    """ Operator converter for Split.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        splits = attr.get('split', False)
        if splits:
            attr['indices_or_sections'] = []
            index = 0
            for i in splits[:-1]:
                index += i
                attr['indices_or_sections'].append(index)
        # When splits isnt specified divide evenly over axis.
        else:
            in_shape = infer_shape(inputs[0])
            attr['indices_or_sections'] = in_shape[attr['axis']]
        return AttrCvt(
            'split',
            ignores=['split'])(inputs, attr, params)


class Slice(OnnxOpConverter):
    """ Operator converter for Slice.
    """

    @classmethod
    def _common(cls, starts, ends, axes):
        new_axes = []
        new_starts = []
        new_ends = []
        pop_index = 0
        for i in range(max(axes) + 1):
            if i in axes:
                new_axes.append(i)
                new_starts.append(starts[pop_index])
                new_ends.append(ends[pop_index])
                pop_index += 1
            else:
                new_axes.append(i)
                new_starts.append(0)
                new_ends.append(np.iinfo(np.int32).max)
        return new_starts, new_ends, new_axes

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        if isinstance(attr['starts'], int):
            attr['starts'] = (attr['starts'],)
            attr['ends'] = (attr['ends'],)

        try:
            # Update the starts and ends according to axes if required.
            if isinstance(attr['axes'], int):
                attr['axes'] = (attr['axes'],)
            if (max(attr['axes']) + 1) != len(attr['axes']):
                new_starts, new_ends, new_axes = cls._common(
                    attr['starts'], attr['ends'], attr['axes'])
                attr['axes'] = new_axes
                attr['starts'] = new_starts
                attr['ends'] = new_ends
        except KeyError:
            pass

        return AttrCvt('strided_slice',
                       transforms={'starts': 'begin',
                                   'ends': 'end'},
                       ignores=['axes'])(inputs, attr)

    @classmethod
    def _impl_v10(cls, inputs, attr, params):
        starts = params[get_name(inputs[1])].asnumpy()
        ends = params[get_name(inputs[2])].asnumpy()

        # Update the starts and ends according to axes if required.
        if len(inputs) >= 4:
            axes = params[get_name(inputs[3])].asnumpy()

            if max(axes + 1) != len(axes):
                new_starts, new_ends, _ = cls._common(
                    starts, ends, axes)
                starts = new_starts
                ends = new_ends
        return _op.strided_slice(inputs[0], begin=starts, end=ends)


class Gather(OnnxOpConverter):
    """ Operator converter for Gather.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        axis = attr.get('axis', 0)
        return AttrCvt('take',
                       extras={'axis': axis})(inputs, {})


class Greater(OnnxOpConverter):
    """ Operator logical greater.
    """
    @classmethod
    def _impl_v7(cls, inputs, attr, params):
        return _op.greater(inputs[0], inputs[1])


class Less(OnnxOpConverter):
    """ Operator logical less than.
    """
    @classmethod
    def _impl_v7(cls, inputs, attr, params):
        return _op.less(inputs[0], inputs[1])


class LRN(OnnxOpConverter):
    """ Operator converter for Local Response Normalization.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        """LRN support only NCHW format
        https://github.com/onnx/onnx/blob/master/docs/Operators.md#LRN
        """
        axis = 1
        alpha = attr.get('alpha', 0.0001)
        beta = attr.get('beta', 0.75)
        bias = attr.get('bias', 1.0)
        nsize = attr.get('size')
        attr = {'size': nsize, 'axis': axis, 'alpha': alpha, 'beta': beta, 'bias': bias}
        return AttrCvt('lrn')(inputs, attr)

class Maximum(OnnxOpConverter):
    """ Operator converter for Maximum.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        if not isinstance(inputs, list) or len(inputs) < 2:
            raise ValueError("Expect minimum 2 inputs")
        _max = inputs[0]
        for i in range(1, len(inputs)):
            _max = AttrCvt('maximum')([_max, inputs[i]], {})
        return _max

class Minimum(OnnxOpConverter):
    """ Operator converter for Minimum.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        if not isinstance(inputs, list) or len(inputs) < 2:
            raise ValueError("Expect minimum 2 inputs")
        _min = inputs[0]
        for i in range(1, len(inputs)):
            _min = AttrCvt('minimum')([_min, inputs[i]], {})
        return _min

class Mean(OnnxOpConverter):
    """ Operator converter for Mean.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        if not isinstance(inputs, list) or len(inputs) < 2:
            raise ValueError("Expect minimum 2 inputs")
        # avoid overflow
        concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0)
        return _op.mean(concat, axis=0, keepdims=False)

class HardSigmoid(OnnxOpConverter):
    """ Operator converter for HardSigmoid.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = attr.get('alpha', 0.2)
        beta = attr.get('beta', 0.5)
        transformX = (inputs[0] * _expr.const(alpha)) + _expr.const(beta)
        attr = {'a_min': 0, 'a_max': 1}
        return AttrCvt('clip')([transformX], attr)

class Reduce(OnnxOpConverter):
    """ Operator converter for reduce ops.
    """
    name = ''
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        if 'axes' in attr:
            axis = attr.get('axes', 0)
        else:
            axis_len = len(infer_shape(inputs[0]))
            axis = list(range(axis_len))
        attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)}
        return AttrCvt(cls.name)(inputs, attr)

class ReduceMax(Reduce):
    """ Operator converter for ReduceMax.
    """
    name = 'max'

class ReduceMin(Reduce):
    """ Operator converter for ReduceMin.
    """
    name = 'min'

class ReduceSum(Reduce):
    """ Operator converter for ReduceSum.
    """
    name = 'sum'

class ReduceMean(Reduce):
    """ Operator converter for ReduceMean.
    """
    name = 'mean'

class ReduceProd(Reduce):
    """ Operator converter for ReduceProd.
    """
    name = 'prod'

class ArgMax(OnnxOpConverter):
    """ Operator converter for ArgMax.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        axis = attr.get('axis', 0)
        keepdims = attr.get('keepdims', True)
        attr = {'axis': axis, 'keepdims': keepdims}
        return AttrCvt('argmax')(inputs, attr)

class ArgMin(OnnxOpConverter):
    """ Operator converter for ArgMin.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        axis = attr.get('axis', 0)
        keepdims = attr.get('keepdims', True)
        attr = {'axis': axis, 'keepdims': keepdims}
        return AttrCvt('argmin')(inputs, attr)

class Softmax(OnnxOpConverter):
    """ Operator converter for Softmax.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        # set default value when axis is not set in the model
        if 'axis' not in attr:
            attr['axis'] = 1
        return AttrCvt('softmax', transforms={'axis': ('axis', 1)})(inputs, attr, params)


class OneHot(OnnxOpConverter):
    """ Operator converter for OneHot.
    """
    @classmethod
    def _impl_v9(cls, inputs, attr, params):
        # Extract relay one_hot inputs.
        indices, depth, values = inputs
        # Split onnx on off values into two separate expressions.
        off_value, on_value = _op.take(
            values, _op.const(0)), _op.take(values, _op.const(1))
        # Extract the datatype of the output from on_value.
        dtype = infer_type(on_value).checked_type.dtype
        # Convert depth into an integer.
        depth = int(infer_value(depth, params).asnumpy()[0])
        # set default value when axis is not set in the model
        if 'axis' not in attr:
            attr['axis'] = -1
        return _op.one_hot(indices,
                           on_value,
                           off_value,
                           depth,
                           int(attr['axis']),
                           dtype=dtype)


class ConstantOfShape(OnnxOpConverter):
    """ Operator converter for ConstantOfShape.
    """
    @classmethod
    def _impl_v9(cls, inputs, attr, params):
        if 'value' in attr:
            np_value = get_numpy(attr.pop('value'))[0]
            value = _expr.const(np_value)
            dtype = np_value.dtype.name
        else:
            value = _expr.const(0)
            dtype = 'float32'
        static_shape = infer_value_simulated(inputs[0], params)
        output = _op.full(
            value, shape=tuple(static_shape.asnumpy().astype('int32')), dtype=dtype)
        return output


class Sign(OnnxOpConverter):
    """ Operator converter for Sign.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _op.sign(inputs[0])

class Equal(Elemwise):
    """ Operator converter for Equal.
    """
    name = 'equal'


class Not(Elemwise):
    """ Operator converter for Not.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _op.logical_not(inputs[0])


class And(Elemwise):
    """ Operator converter for And.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _op.logical_and(inputs[0], inputs[1])


class Tile(Elemwise):
    """Operator converter for Tile
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        if 'repeats' not in attr:
            raise tvm.error.OpAttributeInvalid('Attribute "repeats" should be set '
                                               'for operator Tile.')
        reps = attr.pop('repeats')  # The number of times repeating the tensor data.
        return _op.tile(inputs[0], reps)

    @classmethod
    def _impl_v6(cls, inputs, attr, params):
        reps = tuple(infer_value_simulated(
            inputs[1], params).asnumpy().astype('int32'))
        return _op.tile(inputs[0], reps)

class Erf(OnnxOpConverter):
    """Operator converter for Erf
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _op.erf(inputs[0])

class Where(OnnxOpConverter):
    """Operator converter for Where
    """
    @classmethod
    def _impl_v9(cls, inputs, attr, params):
        # x and y can be broadcasted
        condition_shape = infer_shape(inputs[0])
        x_shape = infer_shape(inputs[1])
        y_shape = infer_shape(inputs[2])
        if len(condition_shape) > len(x_shape):
            inputs[1] = _op.broadcast_to(inputs[1], condition_shape)
        if len(condition_shape) > len(y_shape):
            inputs[2] = _op.broadcast_to(inputs[2], condition_shape)
        return _op.where(inputs[0], inputs[1], inputs[2])

class Or(Elemwise):
    """ Operator converter for Or.
    """
    @classmethod
    def _impl_v7(cls, inputs, attr, params):
        return _op.logical_or(inputs[0], inputs[1])


class Expand(OnnxOpConverter):
    """ Operator converter for Expand.
    """
    @classmethod
    def _impl_v8(cls, inputs, attr, params):
        in_shape = np.array(infer_shape(inputs[0])).astype('int32')
        if get_name(inputs[1]) in params:
            shape = params[inputs[1].name_hint].asnumpy().astype('int32')
        else:
            shape = infer_value_simulated(inputs[1], params).asnumpy().astype('int32')

        # Currently 'op.broadcast_to' expect the rank of the given 'shape'
        # (the 2nd input) is always higher than that of the given 'input' (the 1st input)
        # However, ONNX Expand supports multi-directional broadcasting, which allows
        # above pattern and also some extent of 'shape' can be smaller than the corresponding
        # extent of 'input'. In this case, the extent of 'shape' must be 1.
        # https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
        # In above cases, we cannot directorly apply 'op.broadcast_to' instead of 'expand'
        # so, here we solved this problem by expanding the given 'shape' itself.
        def expand_shape(in_shape, shape):
            """ A function expands the shape when the rank is lower than that of the given
            intput. Also it replaces the extent of the shape with the corresponding extent
            of the intput when it is 1.
            """

            # here we flip the shapes because this can be more simply written
            # when the innermost dimension is located at the index 0.
            in_shape = np.flip(in_shape, axis=0)
            shape = np.flip(shape, axis=0)

            if in_shape.size < shape.size:
                for i in range(shape.size):
                    if i < in_shape.size and in_shape[i] > shape[i]:
                        shape[i] = in_shape[i]
            else:
                for i in range(in_shape.size):
                    if i >= shape.size:
                        np.append(shape, in_shape[i])
                    elif shape[i] == 1:
                        shape[i] = in_shape[i]

            new_shape = np.flip(shape, axis=0)
            return new_shape

        shape = expand_shape(in_shape, shape)
        return _op.broadcast_to(inputs[0], shape=tuple(shape))


# compatible operators that do NOT require any conversion.
_identity_list = []


# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
def _get_convert_map(opset):
    return {
        # defs/experimental
        'Identity': Renamer('copy'),
        # 'Affine'
        'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
        'ScaledTanh': ScaledTanh.get_converter(opset),
        'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
        'ConstantOfShape': ConstantOfShape.get_converter(opset),
        # 'GivenTensorFill'
        'FC': AttrCvt('dense', ignores=['axis', 'axis_w']),
        'Scale': Scale.get_converter(opset),
        # 'GRUUnit'
        # 'ATen'
        # 'ImageScaler'
        # 'MeanVarianceNormalization'
        # 'Crop'
        # 'Embedding'
        'Upsample': Upsample.get_converter(opset),
        'SpatialBN': BatchNorm.get_converter(opset),

        # defs/generator
        # 'Constant' # Implemented
        # 'RandomUniform'
        # 'RandomNormal'
        # 'RandomUniformLike'
        # 'RandomNormalLike'

        # defs/logical

        # defs/math
        'Add': Add.get_converter(opset),
        'Sub': Sub.get_converter(opset),
        'Mul': Mul.get_converter(opset),
        'Div': Div.get_converter(opset),
        'Neg': Renamer('negative'),
        'Abs': Absolute.get_converter(opset),
        'Reciprocal': Reciprocal.get_converter(opset),
        'Floor': Renamer('floor'),
        'Ceil': Renamer('ceil'),
        'Sqrt': Renamer('sqrt'),
        'Relu': Renamer('relu'),
        'LeakyRelu': Renamer('leaky_relu'),
        'Selu': Selu.get_converter(opset),
        'Elu': Elu.get_converter(opset),
        'Exp': Renamer('exp'),
        'Greater': Greater.get_converter(opset),
        'Less': Less.get_converter(opset),
        'Log': Renamer('log'),
        'Tanh': Renamer('tanh'),
        'Pow': Renamer('power'),
        'PRelu': Prelu.get_converter(opset),
        'Sigmoid': Renamer('sigmoid'),
        'HardSigmoid': HardSigmoid.get_converter(opset),
        'Max': Maximum.get_converter(opset),
        'Min': Minimum.get_converter(opset),
        'Sum': Sum.get_converter(opset),
        'Mean': Mean.get_converter(opset),
        'Clip': AttrCvt('clip', transforms={'min': 'a_min', 'max': 'a_max'}),
        # softmax default axis is different in onnx
        'Softmax': Softmax.get_converter(opset),
        'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}),
        'OneHot': OneHot.get_converter(opset),
        # 'Hardmax'
        'Softsign': Softsign.get_converter(opset),
        'SoftPlus': SoftPlus.get_converter(opset),
        'Gemm': Gemm.get_converter(opset),
        'MatMul': MatMul.get_converter(opset),

        # defs/nn
        'AveragePool': AveragePool.get_converter(opset),
        'MaxPool': MaxPool.get_converter(opset),
        'Conv': Conv.get_converter(opset),
        'ConvTranspose': ConvTranspose.get_converter(opset),
        'GlobalAveragePool': Renamer('global_avg_pool2d'),
        'GlobalMaxPool': Renamer('global_max_pool2d'),
        'BatchNormalization': BatchNorm.get_converter(opset),
        'InstanceNormalization': InstanceNorm.get_converter(opset),
        # 'LpNormalization'
        'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
        'Flatten': Flatten.get_converter(opset),
        'LRN': LRN.get_converter(opset),

        # defs/reduction
        'ReduceMax': ReduceMax.get_converter(opset),
        'ReduceMin': ReduceMin.get_converter(opset),
        'ReduceSum': ReduceSum.get_converter(opset),
        'ReduceMean': ReduceMean.get_converter(opset),
        'ReduceProd': ReduceProd.get_converter(opset),
        # 'ReduceProd'
        # 'ReduceLogSumExp'
        'ArgMax': ArgMax.get_converter(opset),
        'ArgMin': ArgMin.get_converter(opset),

        # defs/tensor
        'Cast': Cast.get_converter(opset),
        'Reshape': Reshape.get_converter(opset),
        'Expand': Expand.get_converter(opset),
        'Concat': Concat.get_converter(opset),
        'Split': Split.get_converter(opset),
        'Slice': Slice.get_converter(opset),
        'Transpose': AttrCvt('transpose', {'perm': 'axes'}),
        'DepthToSpace': DepthToSpace.get_converter(opset),
        'SpaceToDepth': SpaceToDepth.get_converter(opset),
        'Gather': Gather.get_converter(opset),
        'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
        'Unsqueeze': Unsqueeze.get_converter(opset),
        'Pad': Pad.get_converter(opset),
        'Shape': Shape.get_converter(opset),
        'Sign': Sign.get_converter(opset),
        'Equal': Equal.get_converter(opset),
        'Not': Not.get_converter(opset),
        'And': And.get_converter(opset),
        'Tile': Tile.get_converter(opset),
        'Erf': Erf.get_converter(opset),
        'Where': Where.get_converter(opset),
        'Or': Or.get_converter(opset),
    }


class GraphProto(object):
    """A helper class for handling Relay expression copying from pb2.GraphProto.
    Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto

        Parameters
    ----------
    shape : dict of str to tuple, optional
        The input shape to the graph

    dtype : str or dict of str to str
        The input types to the graph
    """

    def __init__(self, shape, dtype):
        self._nodes = {}
        self._params = {}
        self._renames = {}
        self._num_input = 0
        self._num_param = 0
        self._shape = shape if shape else {}
        self._dtype = dtype

    def from_onnx(self, graph, opset):
        """Construct Relay expression from ONNX graph.

        Onnx graph is a python protobuf object.
        The companion parameters will be handled automatically.
        However, the input names from onnx graph is vague, mixing inputs and
        network weights/bias such as "1", "2"...
        For convenience, we rename the `real` input names to "input_0",
        "input_1"... And renaming parameters to "param_0", "param_1"...

        Parameters
        ----------
        graph : onnx protobuf object
            The loaded onnx graph

        opset : opset version

        Returns
        -------
        mod : tvm.relay.Module
            The returned relay module

        params : dict
            A dict of name: tvm.nd.array pairs, used as pretrained weights
        """
        # parse network inputs to relay, aka parameters
        for init_tensor in graph.initializer:
            if not init_tensor.name.strip():
                raise ValueError("Tensor's name is required.")
            self._params[init_tensor.name] = self._parse_array(init_tensor)
            self._nodes[init_tensor.name] = new_var(init_tensor.name,
                                                    shape=self._params[init_tensor.name].shape,
                                                    dtype=self._params[init_tensor.name].dtype)
        for i in graph.input:
            # from onnx v0.2, GraphProto.input has type ValueInfoProto,
            #  and the name is 'i.name'
            i_name = self._parse_value_proto(i)
            d_type = self._parse_dtype(i, 'float32')
            if i_name in self._params:
                # i is a param instead of input
                self._num_param += 1
                self._params[i_name] = self._params.pop(i_name)
                self._nodes[i_name] = new_var(i_name,
                                              shape=self._params[i_name].shape,
                                              dtype=self._params[i_name].dtype)
            else:
                self._num_input += 1
                if i_name in self._shape:
                    tshape = self._shape[i_name]
                else:
                    raise ValueError("Must provide an input shape for `{0}`.".format(i_name))
                if isinstance(self._dtype, dict):
                    dtype = self._dtype[i_name] if i_name in self._dtype else d_type
                else:
                    dtype = d_type
                self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype)
        # get list of unsupported ops
        convert_map = _get_convert_map(opset)
        unsupported_ops = set()
        for node in graph.node:
            op_name = node.op_type
            if op_name not in convert_map and \
               op_name != 'Constant' and \
               op_name not in _identity_list:
                unsupported_ops.add(op_name)
        if unsupported_ops:
            msg = 'The following operators are not supported for frontend ONNX: '
            msg += ', '.join(unsupported_ops)
            raise tvm.error.OpNotImplemented(msg)
        # construct nodes, nodes are stored as directed acyclic graph
        for node in graph.node:
            op_name = node.op_type
            attr = self._parse_attr(node.attribute)
            inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
            if op_name == "Constant":
                t_proto = self._parse_attr(node.attribute)["value"]
                self._num_param += 1
                # We should convert scalar integers to int32, to normalize.
                array = self._parse_array(t_proto)
                self._params[node.output[0]] = array
                self._nodes[node.output[0]] = new_var(
                    node.output[0],
                    shape=list(t_proto.dims),
                    dtype=array.dtype)
            else:
                i_name = self._parse_value_proto(node)
                attr['tvm_custom'] = {}
                attr['tvm_custom']['name'] = i_name

                op = self._convert_operator(op_name, inputs, attr, opset)
                node_output = self._fix_outputs(op_name, node.output)
                if not isinstance(op, _expr.TupleWrapper):
                    outputs_num = 1
                else:
                    outputs_num = len(op)
                assert len(node_output) == outputs_num, (
                    "Number of output mismatch {} vs {} in {}.".format(
                        len(node_output), outputs_num, op_name))
                if outputs_num == 1:
                    self._nodes[node_output[0]] = op
                else:
                    for k, i in zip(list(node_output), range(len(node_output))):
                        self._nodes[k] = op[i]

        # now return the outputs
        outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
        outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
        func = _expr.Function(analysis.free_vars(outputs), outputs)
        return _module.Module.from_expr(func), self._params

    def _parse_value_proto(self, value_proto):
        """Parse ValueProto or raw str."""
        try:
            name = value_proto.name
        except AttributeError:
            name = value_proto
        return name

    def _parse_dtype(self, value_proto, dtype):
        """Parse dtype."""
        try:
            from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
            return TENSOR_TYPE_TO_NP_TYPE[value_proto.type.tensor_type.elem_type].name
        except AttributeError:
            return dtype

    def _parse_array(self, tensor_proto):
        np_array = get_numpy(tensor_proto).reshape(tuple(tensor_proto.dims))
        return _nd.array(np_array)

    def _parse_attr(self, attr_proto):
        """Convert a list of AttributeProto to a dict, with names as keys."""
        attrs = {}
        for a in attr_proto:
            for f in ['f', 'i', 's']:
                if a.HasField(f):
                    attrs[a.name] = getattr(a, f)
            for f in ['floats', 'ints', 'strings']:
                if list(getattr(a, f)):
                    assert a.name not in attrs, "Only one type of attr is allowed"
                    attrs[a.name] = tuple(getattr(a, f))
            for f in ['t']:
                if a.HasField(f):
                    attrs[a.name] = getattr(a, f)
            for f in ['tensors']:
                if list(getattr(a, f)):
                    assert a.name not in attrs, "Only one type of attr is allowed"
                    attrs[a.name] = tuple(getattr(a, f))
            for f in ['g']:
                if a.HasField(f):
                    raise NotImplementedError(
                        "Filed {} is not supported in relay.".format(f))
            for f in ['graphs']:
                if list(getattr(a, f)):
                    raise NotImplementedError(
                        "Filed {} is not supported in relay.".format(f))
            if a.name not in attrs:
                raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
        return attrs

    def _convert_operator(self,
                          op_name,
                          inputs,
                          attrs,
                          opset):
        """Convert ONNX operator into a Relay operator.
        The converter must specify conversions explicitly for incompatible name, and
        apply handlers to operator attributes.

        Parameters
        ----------
        op_name : str
            Operator name, such as Convolution, FullyConnected
        inputs : list of tvm.relay.expr.Function
            List of inputs.
        attrs : dict
            Dict of operator attributes
        opset : int
            Opset version

        Returns
        -------
        sym : tvm.relay.expr.Function
            Converted relay function
        """
        convert_map = _get_convert_map(opset)
        if op_name in _identity_list:
            sym = get_relay_op(op_name)(*inputs, **attrs)
        elif op_name in convert_map:
            sym = convert_map[op_name](inputs, attrs, self._params)
        else:
            raise NotImplementedError(
                "Operator {} not implemented.".format(op_name))
        return sym

    def _fix_outputs(self, op_name, outputs):
        """A hack to handle dropout or similar operator that have more than one out
        in ONNX.
        """
        if op_name == 'Dropout':
            if len(outputs) == 1:
                return outputs
            # TODO(zhreshold): support dropout mask?
            outputs = outputs[:-1]
        return outputs

def from_onnx(model,
              shape=None,
              dtype="float32",
              opset=None):
    """Convert a ONNX model into an equivalent Relay Function.

    ONNX graphs are represented as Python Protobuf objects.
    The companion parameters will be handled automatically.
    However, the input names from onnx graph is vague, mixing inputs and
    network weights/bias such as "1", "2"...
    For convenience, we rename the `real` input names to "input_0",
    "input_1"... And renaming parameters to "param_0", "param_1"...

    Parameters
    ----------
    model : protobuf object
        ONNX ModelProto after ONNX v1.1.0

    shape : dict of str to tuple, optional
        The input shape to the graph

    dtype : str or dict of str to str
        The input types to the graph

    opset : int, optional
        Override to autodetected opset.
        This can be helpful for some testing.

    Returns
    -------
    mod : tvm.relay.Module
        The relay module for compilation

    params : dict of str to tvm.NDArray
        The parameter dict to be used by relay
    """
    try:
        import onnx
        if hasattr(onnx.checker, 'check_model'):
            # try use onnx's own model checker before converting any model
            try:
                onnx.checker.check_model(model)
            except onnx.onnx_cpp2py_export.checker.ValidationError as e:
                import warnings
                # the checker is a bit violent about errors, so simply print warnings here
                warnings.warn(str(e))
    except ImportError:
        pass
    g = GraphProto(shape, dtype)
    graph = model.graph
    if opset is None:
        try:
            opset = model.opset_import[0].version if model.opset_import else 1
        except AttributeError:
            opset = 1
    mod, params = g.from_onnx(graph, opset)
    return mod, params