onnx.py 63.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines
18
# pylint: disable=import-outside-toplevel
19 20
"""ONNX: Open Neural Network Exchange frontend for Relay."""
import numpy as np
Zhi committed
21
import tvm
22 23
from tvm.ir import IRModule

24
from ... import nd as _nd
Zhi committed
25
from .. import analysis
26
from .. import expr as _expr
Zhi committed
27
from .. import function as _function
28 29
from .. import op as _op
from .common import AttrCvt, Renamer
30 31
from .common import get_relay_op, new_var, infer_shape, infer_channels
from .common import infer_type, infer_value, infer_value_simulated, get_name
32 33 34

__all__ = ['from_onnx']

35

36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
class onnx_input():
    """ Dual purpose list or dictionary access object."""

    def __init__(self):
        self.input_keys = []
        self.input_dict = {}

    def __getitem__(self, item):
        if isinstance(item, int):
            return self.input_dict[self.input_keys[item]]
        if isinstance(item, str):
            if item not in self.input_keys:
                return None
            return self.input_dict[item]
        if isinstance(item, slice):
            keys = self.input_keys[item]
            return [self.input_dict[key] for key in keys]

        raise ValueError("Only integer, string, and slice accesses allowed.")

    def __setitem__(self, item, value):
        if isinstance(item, int):
            self.input_dict[self.input_keys[item]] = value
        elif isinstance(item, str):
            if item not in self.input_dict:
                self.input_keys.append(item)
            self.input_dict[item] = value
        else:
            raise ValueError("Only integer and string indexed writes allowed.")

    def keys(self):
        return self.input_keys

    def __len__(self):
        return len(self.input_keys)

    def __iter__(self):
        self.n = 0
        return self

    def __next__(self):
        if self.n < len(self.input_keys):
            output = self.input_dict[self.input_keys[self.n]]
            self.n += 1
            return output

        raise StopIteration


85 86 87 88 89 90 91 92 93 94
def get_numpy(tensor_proto):
    """Grab data in TensorProto and convert to numpy array."""
    try:
        from onnx.numpy_helper import to_array
    except ImportError as e:
        raise ImportError(
            "Unable to import onnx which is required {}".format(e))
    return to_array(tensor_proto)


95
def dimension_picker(prefix, suffix=''):
96
    """Check that dimensions are supported."""
97 98
    def _impl(attr):
        kernel = attr['kernel_shape']
99
        if len(kernel) == 1:
100
            return prefix + '1d' + suffix
101
        if len(kernel) == 2:
102 103 104 105 106
            return prefix + '2d' + suffix
        if len(kernel) == 3:
            return prefix + '3d' + suffix
        msg = 'Only 1D, 2D, and 3D kernels are supported for operator {}.'
        op_name = prefix + '1d/2d/3d'
107
        raise tvm.error.OpAttributeInvalid(msg.format(op_name))
108 109 110

    return _impl

111

112 113 114 115 116 117 118
def revert_caffe2_pad(pads):
    """Caffe2 requires two times the normal padding."""
    if len(pads) == 4:
        pads = pads[:2]
    elif len(pads) == 2:
        pass
    else:
119 120
        raise tvm.error.OpAttributeInvalid(
            'Number of pads must be either 2 or 4.')
121 122
    return pads

123

124 125 126 127 128 129 130 131 132 133 134
def get_pad_pair(input1d, kernel1d, stride1d):
    """infer pad size"""
    if input1d % stride1d == 0:
        pad = max(kernel1d - stride1d, 0)
    else:
        pad = max(kernel1d - (input1d % stride1d), 0)
    pad_before = pad // 2
    pad_after = pad - pad_before
    return [pad_before, pad_after]


135 136 137 138 139 140 141 142 143 144
def onnx_default_layout(dims):
    if dims == 1:
        return 'NCW'
    if dims == 2:
        return 'NCHW'

    msg = "Only 1d and 2d layouts are currently supported"
    raise tvm.error.OpAttributeInvalid(msg.format(op_name))


145
def onnx_storage_order2layout(storage_order, dims=2):
146 147 148 149
    """converter of onnx storage order parameter to tvm storage order format"""
    if storage_order not in (0, 1):
        raise tvm.error.OpAttributeInvalid('Mode of storage_order must be either 0 or 1')

150 151
    if dims == 1:
        return 'NCW' if storage_order == 0 else 'NWC'
152
    if dims == 2:
153
        return 'NCHW' if storage_order == 0 else 'NHWC'
154 155 156

    msg = "Only 1d and 2d layouts are currently supported"
    raise tvm.error.OpAttributeInvalid(msg.format(op_name))
157 158


159 160
def dimension_constraint():
    def _dim_check(attrs):
161
        if len(attrs['kernel_shape']) in [1, 2, 3]:
162 163 164
            return True
        return False

165
    return _dim_check, "Only 1d, 2d and 3d kernel supported."
166

167

168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
class OnnxOpConverter(object):
    """ A helper class for holding onnx op converters.
    """

    @classmethod
    def get_converter(cls, opset):
        """ Get converter matches given opset.

        Parameters
        ----------
        opset: int
            opset from model.

        Returns
        -------
        converter, which should be `_impl_vx`. Number x is the biggest
            number smaller than or equal to opset belongs to all support versions.
        """
        versions = [
            int(d.replace('_impl_v', '')) for d in dir(cls) if '_impl_v' in d
        ]
        versions = sorted(versions + [opset])
        version = versions[
            max([i for i, v in enumerate(versions) if v == opset]) - 1]
        if hasattr(cls, '_impl_v{}'.format(version)):
            return getattr(cls, '_impl_v{}'.format(version))
        raise NotImplementedError(
            'opset version {} of {} not implemented'.format(
                version, cls.__name__))


199 200 201 202 203 204 205 206 207 208 209 210 211
class Unary(OnnxOpConverter):
    """ A helper class for unary op converters.
    """
    name = ''

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        assert len(inputs) == 1, "Unary math op {} takes 1 input, {} given".format(
            cls.name, len(inputs))
        op_name = cls.name
        return get_relay_op(op_name)(*inputs)


212 213 214 215 216 217 218
class Elemwise(OnnxOpConverter):
    """ A helper class for elemwise op converters.
    """
    name = ''

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
219 220
        assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(
            cls.name, len(inputs))
221 222 223 224 225 226 227 228
        op_name = cls.name
        conv_ops = ["conv2d", "conv2d_transpose"]
        if attr.get('broadcast', 0) and any(x in str(inputs[0]) for x in conv_ops):
            # TODO(zhreshold): remove hard coded infershape
            axis = int(attr.get('axis', 0))
            inputs[1] = _op.expand_dims(inputs[1], axis=axis, num_newaxis=2)
        return get_relay_op(op_name)(*inputs)

229

230 231 232 233 234 235 236
class Pool(OnnxOpConverter):
    """ A helper class for pool op converters.
    """
    name = ''

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
        input_shape = infer_shape(inputs[0])
        if 'auto_pad' in attr:
            attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
            if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
                pad_tuple = []
                for axis in range(len(input_shape) - 2):
                    axis_shape = input_shape[2 + axis]
                    stride = attr['strides'][axis]
                    kernel = attr['kernel_shape'][axis]
                    pad = get_pad_pair(axis_shape, kernel, stride)
                    pad_tuple.append(pad)
                pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair])
                attr['pads'] = pad_tuple
            elif attr['auto_pad'] == 'VALID':
                attr['pads'] = 0
            elif attr['auto_pad'] == 'NOTSET':
                pass
            else:
                msg = 'Value {} in attribute "auto_pad" of operator {} is invalid.'
                raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad'], cls.name))
            attr.pop("auto_pad")

        if 'storage_order' in attr:
            attr['layout'] = onnx_storage_order2layout(attr['storage_order'],
                                                       dims=(len(input_shape) - 2))
        else:
            attr['layout'] = onnx_default_layout(dims=(len(input_shape) - 2))

265 266 267 268
        return AttrCvt(
            op_name=dimension_picker(cls.name),
            transforms={
                'kernel_shape': 'pool_size',
269
                'pads': ('padding', 0)
270
            },
271
            ignores=['dilations'],
272 273 274
            custom_check=dimension_constraint())(inputs, attr, params)


275
class Absolute(Unary):
276 277
    """ Operator converter for Absolute.
    """
278
    name = 'abs'
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306


class Add(Elemwise):
    """ Operator converter for Add.
    """
    name = 'add'


class AveragePool(Pool):
    """ Operator converter for AveragePool.
    """
    name = 'avg_pool'


class BatchNorm(OnnxOpConverter):
    """ Operator converter for BatchNorm.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        # TODO(zhreshold): 'spatial' is not properly handled here.
        out = AttrCvt(
            op_name='batch_norm',
            ignores=['spatial', 'is_test', 'consumed_inputs', 'momentum'])(inputs, attr,
                                                                           params)
        return out[0]


307 308 309 310 311 312 313 314 315
class InstanceNorm(OnnxOpConverter):
    """ Operator converter for BatchNorm.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return AttrCvt(op_name='instance_norm')(inputs, attr, params)


316 317 318 319 320 321
class Conv(OnnxOpConverter):
    """ Operator converter for Conv.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
322 323 324
        # Use shape of input to determine convolution type.
        input_shape = infer_shape(inputs[0])

325 326 327
        if 'auto_pad' in attr:
            attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
            if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
328 329 330 331 332 333 334 335 336 337 338
                pad_tuple = []
                for axis in range(len(input_shape) - 2):
                    axis_shape = input_shape[2 + axis]
                    stride = attr['strides'][axis]
                    kernel = attr['kernel_shape'][axis]
                    dilation = attr['dilations'][axis]
                    dilated_kernel = (kernel - 1) * dilation + 1
                    pad = get_pad_pair(axis_shape, dilated_kernel, stride)
                    pad_tuple.append(pad)
                pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair])
                attr['pads'] = pad_tuple
339
            elif attr['auto_pad'] == 'VALID':
340
                attr['pads'] = tuple([0 for i in range(len(input_shape) - 2)])
341 342 343 344 345 346
            elif attr['auto_pad'] == 'NOTSET':
                pass
            else:
                msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'
                raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad']))
            attr.pop('auto_pad')
kice committed
347 348 349 350 351 352 353 354
        elif len(attr['kernel_shape']) == 2:
            sym_pad = True
            padding = attr['pads']
            for i in range(0, len(padding), 2):
                sym_pad = sym_pad and padding[i] == padding[i + 1]

            if sym_pad:
                attr['pads'] = padding[0::2]
355 356 357 358 359

        out = AttrCvt(
            op_name=dimension_picker('conv'),
            transforms={
                'kernel_shape': 'kernel_size',
360 361 362 363
                'dilations': ('dilation', 1),
                'pads': ('padding', 0),
                'group': ('groups', 1)
            },
364
            custom_check=dimension_constraint())(inputs[:2], attr, params)
365

366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
        use_bias = len(inputs) == 3
        if use_bias:
            out = _op.nn.bias_add(out, inputs[2])
        return out


class ConvTranspose(OnnxOpConverter):
    """ Operator converter for ConvTranspose.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        # get number of channels
        channels = infer_channels(inputs[1], True)
        attr['channels'] = channels
        groups = attr.pop('group')
        attr['groups'] = groups
382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
        # infer pads for auto_pad
        if 'auto_pad' in attr:
            attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
            if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
                input_shape = infer_shape(inputs[0])
                in_h, in_w = input_shape[2], input_shape[3]
                stride_h, stride_w = attr['strides']
                kernel_h, kernel_w = attr['kernel_shape']
                dilation_h, dilation_w = attr['dilations']
                dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
                dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
                pad_v = get_pad_pair(in_h, dilated_kernel_h, stride_h)
                pad_h = get_pad_pair(in_w, dilated_kernel_w, stride_w)
                attr['pads'] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1])
            elif attr['auto_pad'] == 'VALID':
                attr['pads'] = (0, 0)
            elif attr['auto_pad'] == 'NOTSET':
                pass
            else:
                msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'
                raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad']))
            attr.pop('auto_pad')

405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420
        out = AttrCvt(
            op_name=dimension_picker('conv', '_transpose'),
            transforms={
                'kernel_shape': 'kernel_size',
                'dilations': ('dilation', (0, 0)),
                'pads': ('padding', (0, 0), revert_caffe2_pad)
            },
            disables=['output_shape'],
            custom_check=dimension_constraint())(inputs[:2], attr, params)
        use_bias = len(inputs) == 3
        if use_bias:
            out = _op.nn.bias_add(out, inputs[2])
        return out


class Div(Elemwise):
421 422
    """ Operator converter for Divide.
    """
423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
    name = 'divide'


class Elu(OnnxOpConverter):
    """ Operator converter for Elu.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = float(attr.get('alpha', 1.0))
        return _expr.const(-alpha) * _op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) + \
                                     _op.nn.relu(inputs[0])


class Gemm(OnnxOpConverter):
    """ Operator converter for Gemm.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(
            len(inputs))
        # Y = alpha * A * B + beta * C
        alpha = float(attr.get('alpha', 1.0))
        beta = float(attr.get('beta', 1.0))
        transA = int(attr.get('transA', 0))
        transB = int(attr.get('transB', 0))
        # get number of channels
        channels = infer_channels(inputs[1], not transB)
        if transA:
            inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
        if not transB:
            inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
        inputs[0] = _op.nn.batch_flatten(inputs[0])
        out = _op.nn.dense(_expr.const(alpha) * inputs[0],
                           inputs[1], units=channels)
        return _op.nn.bias_add(out, _expr.const(beta) * inputs[2])

461

462 463 464 465 466 467 468
class MatMul(OnnxOpConverter):
    """ Operator converter for MatMul.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs))
469 470 471 472 473 474 475 476
        # Need to check input shape as batch matmul must be supported.
        a_shape = infer_shape(inputs[0])
        # When performing a batch matmul, we need to properly handle N-dim shapes.
        if len(a_shape) > 2:
            b_shape = infer_shape(inputs[1])
            # Convert a and b into 3 dimensional tensors.
            a = _op.reshape(inputs[0], [-1, a_shape[-2], a_shape[-1]])
            b = _op.reshape(inputs[1], [-1, b_shape[-2], b_shape[-1]])
477 478 479 480 481 482
            # Broadcast b to match batch size of a
            new_b_shape = list(infer_shape(b))
            new_a_shape = infer_shape(a)
            if new_a_shape[0] > new_b_shape[0]:
                new_b_shape[0] = new_a_shape[0]
                b = _op.broadcast_to(b, new_b_shape)
483 484 485 486 487 488 489
            # Transpose matrix dimensions of b.
            b = _op.transpose(b, [0, 2, 1])
            # Perform a batch matmul.
            output = _op.nn.batch_matmul(a, b)
            # Reshape output to original dimensions.
            return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
        # Otherwise a simple dense op will get the job done.
490 491 492
        input_1_t = _op.transpose(inputs[1], axes=(1, 0))
        return _op.nn.dense(inputs[0], input_1_t)

493

494
class MaxPool(Pool):
495 496
    """ Operator converter for MaxPool
    """
497 498 499 500
    name = 'max_pool'


class Mul(Elemwise):
501 502
    """ Operator converter for Multiply.
    """
503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
    name = 'multiply'


class Pad(OnnxOpConverter):
    """ Operator converter for Pad.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        pad_width = []
        pads = attr.pop('paddings')
        dims = int(len(pads) / 2)
        for i in range(dims):
            pad_width.append((pads[i], pads[i+dims]))
        attr['pad_width'] = pad_width
kice committed
518
        pad_mode = attr.get('mode', b'constant').decode('utf-8')
519 520 521 522 523 524
        if pad_mode in ['constant', 'edge', 'reflect']:
            attr['pad_mode'] = pad_mode
            attr.pop('mode', None)
        else:
            raise tvm.error.OpAttributeInvalid(
                'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.')
525 526 527 528 529 530

        return AttrCvt(
            _op.nn.pad,
            transforms={
                'value': 'pad_value',
            },
531
            )(inputs, attr, params)
532 533 534 535 536 537 538 539 540

    @classmethod
    def _impl_v2(cls, inputs, attr, params):
        pad_width = []
        pads = attr.pop('pads')
        dims = int(len(pads) / 2)
        for i in range(dims):
            pad_width.append((pads[i], pads[i+dims]))
        attr['pad_width'] = pad_width
kice committed
541
        pad_mode = attr.get('mode', b'constant').decode('utf-8')
542 543 544 545 546 547
        if pad_mode in ['constant', 'edge', 'reflect']:
            attr['pad_mode'] = pad_mode
            attr.pop('mode', None)
        else:
            raise tvm.error.OpAttributeInvalid(
                'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.')
548 549 550 551 552 553

        return AttrCvt(
            'pad',
            transforms={
                'value': 'pad_value',
            },
554
            )(inputs, attr, params)
555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585


class ParametricSoftPlus(OnnxOpConverter):
    """ Operator converter for ParametricSoftPlus.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = _expr.const(float(attr.get('alpha', 1.0)))
        beta = _expr.const(float(attr.get('beta', 1.0)))
        return _op.log(_op.exp(beta * inputs[0]) + _expr.const(1.)) * alpha


class Prelu(OnnxOpConverter):
    """ Operator converter for Prelu.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs))
        return _op.nn.prelu(inputs[0], inputs[1])


class Reciprocal(OnnxOpConverter):
    """ Operator converter for Reciprocal.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _expr.const(1.0) / inputs[0]

586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602

class Flatten(OnnxOpConverter):
    """ Operator converter for Flatten.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        axis = attr.get('axis', 1)
        if axis == 1:
            out = _op.nn.batch_flatten(inputs[0])
        else:
            newshape = [0] * (axis + 1)
            newshape[axis] = -1
            out = _op.reshape(inputs[0], list(newshape))
        return out


603 604 605 606 607 608
class Reshape(OnnxOpConverter):
    """ Operator converter for Reshape.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
609
        return _op.reshape(inputs[0], attr['shape'])
610

611 612
    @classmethod
    def _impl_v5(cls, inputs, attr, params):
613
        if get_name(inputs[1]) in params:
614 615
            # pop shape out of parameters since it wont be needed later.
            shape = tuple(params.pop(inputs[1].name_hint).asnumpy())
616 617
            out = _op.reshape(inputs[0], shape)
        else:
Zhi committed
618
            data, shape = inputs
619 620 621
            static_shape = infer_value_simulated(shape, params)
            out = _op.reshape(data, newshape=tuple(
                static_shape.asnumpy().astype('int32')))
622 623
        return out

624 625 626 627 628 629 630 631 632

class DepthToSpace(OnnxOpConverter):
    """ Operator converter for DepthToSpace.
    """

    @classmethod
    def _impl_v11(cls, inputs, attr, params):

        block_size = int(attr['blocksize'])
kice committed
633
        mode = attr.get('mode', b'DCR').decode('utf-8')
634
        return _op.nn.depth_to_space(inputs[0], block_size, mode=mode)
635 636 637 638 639 640 641 642 643 644


class SpaceToDepth(OnnxOpConverter):
    """ Operator converter for SpaceToDepth.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):

        block_size = int(attr['blocksize'])
645
        return _op.nn.space_to_depth(inputs[0], block_size)
646 647


648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708
class Concat(OnnxOpConverter):
    """ Operator converter for Concat.
    """

    @classmethod
    def _impl_v1(cls, inputs, args, params):
        return AttrCvt(op_name='concatenate')((inputs,), args)

class Scale(OnnxOpConverter):
    """ Operator converter for Scale.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        scale = float(attr.get('scale', 1.0))
        return inputs[0] * _expr.const(scale)


class Selu(OnnxOpConverter):
    """ Operator converter for Selu.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = float(attr.get('alpha', 1.6732))
        gamma = float(attr.get('gamma', 1.0507))
        return _expr.const(gamma) * (_expr.const(-alpha) *
                                     _op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) +
                                     _op.nn.relu(inputs[0]))


class ScaledTanh(OnnxOpConverter):
    """ Operator converter for ScaledTanh.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = float(attr.get('alpha', 1.0))
        beta = float(attr.get('beta', 1.0))
        return _op.tanh(_expr.const(beta) * inputs[0]) * _expr.const(alpha)


class SoftPlus(OnnxOpConverter):
    """ Operator converter for SoftPlus.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _op.log(_op.exp(inputs[0]) + _expr.const(1.))


class Softsign(OnnxOpConverter):
    """ Operator converter for Softsign.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return inputs[0] / (_expr.const(1.) + Absolute.get_converter(1)(inputs, attr, params))


class Sub(Elemwise):
709 710
    """ Operator converter for Subtract.
    """
711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726
    name = 'subtract'


class Sum(OnnxOpConverter):
    """ Operator converter for Sum.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        # Onnx Sum Operator
        for in_index in range(len(inputs) - 1):
            inputs[in_index + 1] = _op.add(inputs[in_index], inputs[in_index + 1])

        return inputs[len(inputs) - 1]


727 728 729 730 731 732 733 734 735 736 737
class Affine(OnnxOpConverter):
    """ Operator converter for Affine transformation.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = _expr.const(attr.get('alpha', 1.0))
        beta = _expr.const(attr.get('beta', 0.0))
        return (alpha * inputs[0]) + beta


738 739 740 741 742 743
class ThresholdedRelu(OnnxOpConverter):
    """ Operator converter for ThresholdedRelu.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
744
        alpha = float(attr.get('alpha', 1.0))
745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775
        alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha))
        mask = _op.greater(inputs[0], alpha_tensor).astype("float32")
        return inputs[0] * mask


def _broadcast_constraint():

    def _broadcast_check(attrs):
        if attrs.get('axis', None):
            return False
        return True

    return _broadcast_check, "Specifying broadcast axis not allowed."


def _fully_connected(opset):

    def _impl(inputs, attr, params):
        # get number of channels
        channels = infer_channels(inputs[1], params)
        attr['units'] = channels
        return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)

    return _impl


class Upsample(OnnxOpConverter):
    """ Operator converter for Upsample (nearest mode).
    """

    @classmethod
776
    def _impl_v9(cls, inputs, attr, params):
777
        scales = attr.get('scales')
778 779 780 781 782
        if not scales:
            #Here we are going to higher OPSET version.
            assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs))
            scales = params[inputs[1].name_hint].asnumpy()
            inputs = inputs[:1]
783
        assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0
784 785
        mode = attr.get('mode')
        if mode == b'nearest':
786
            method = "nearest_neighbor"
787
        elif mode == b'linear':
788
            method = "bilinear"
789
        else:
790 791
            raise tvm.error.OpAttributeInvalid(
                'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode))
792 793
        attr = {'scale_h': scales[-2], 'scale_w': scales[-1], 'method': method,
                'layout': 'NCHW', 'align_corners': True}
794 795 796 797 798 799 800 801 802
        return AttrCvt('upsampling')(inputs, attr)


class Shape(OnnxOpConverter):
    """ Operator converter for Shape.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
masahi committed
803
        return _op.shape_of(inputs[0], "int64")
804 805 806 807 808 809 810 811 812 813 814 815 816

class Cast(OnnxOpConverter):
    """ Operator converter for Cast.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)

    @classmethod
    def _impl_v5(cls, inputs, attr, params):
        try:
            from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
817
            attr['to'] = str(TENSOR_TYPE_TO_NP_TYPE[attr['to']])
818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840
        except ImportError as e:
            raise ImportError(
                "Unable to import onnx.mapping which is required {}".format(e))
        return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)


class Unsqueeze(OnnxOpConverter):
    """ Operator converter for Unsqueeze.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        for axes in attr['axes']:
            inputs[0] = _op.expand_dims(inputs[0], axis=axes, num_newaxis=1)
        return inputs[0]


class Split(OnnxOpConverter):
    """ Operator converter for Split.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
841 842 843 844 845 846 847 848 849 850 851
        splits = attr.get('split', False)
        if splits:
            attr['indices_or_sections'] = []
            index = 0
            for i in splits[:-1]:
                index += i
                attr['indices_or_sections'].append(index)
        # When splits isnt specified divide evenly over axis.
        else:
            in_shape = infer_shape(inputs[0])
            attr['indices_or_sections'] = in_shape[attr['axis']]
852 853 854 855 856 857 858 859
        return AttrCvt(
            'split',
            ignores=['split'])(inputs, attr, params)


class Slice(OnnxOpConverter):
    """ Operator converter for Slice.
    """
860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878

    @classmethod
    def _common(cls, starts, ends, axes):
        new_axes = []
        new_starts = []
        new_ends = []
        pop_index = 0
        for i in range(max(axes) + 1):
            if i in axes:
                new_axes.append(i)
                new_starts.append(starts[pop_index])
                new_ends.append(ends[pop_index])
                pop_index += 1
            else:
                new_axes.append(i)
                new_starts.append(0)
                new_ends.append(np.iinfo(np.int32).max)
        return new_starts, new_ends, new_axes

879 880 881 882 883 884 885 886 887 888 889
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        if isinstance(attr['starts'], int):
            attr['starts'] = (attr['starts'],)
            attr['ends'] = (attr['ends'],)

        try:
            # Update the starts and ends according to axes if required.
            if isinstance(attr['axes'], int):
                attr['axes'] = (attr['axes'],)
            if (max(attr['axes']) + 1) != len(attr['axes']):
890 891
                new_starts, new_ends, new_axes = cls._common(
                    attr['starts'], attr['ends'], attr['axes'])
892 893 894 895 896 897 898 899 900 901 902
                attr['axes'] = new_axes
                attr['starts'] = new_starts
                attr['ends'] = new_ends
        except KeyError:
            pass

        return AttrCvt('strided_slice',
                       transforms={'starts': 'begin',
                                   'ends': 'end'},
                       ignores=['axes'])(inputs, attr)

903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919
    @classmethod
    def _impl_v10(cls, inputs, attr, params):
        starts = params[get_name(inputs[1])].asnumpy()
        ends = params[get_name(inputs[2])].asnumpy()

        # Update the starts and ends according to axes if required.
        if len(inputs) >= 4:
            axes = params[get_name(inputs[3])].asnumpy()

            if max(axes + 1) != len(axes):
                new_starts, new_ends, _ = cls._common(
                    starts, ends, axes)
                starts = new_starts
                ends = new_ends
        return _op.strided_slice(inputs[0], begin=starts, end=ends)


920 921 922 923 924 925 926
class Gather(OnnxOpConverter):
    """ Operator converter for Gather.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        axis = attr.get('axis', 0)
        return AttrCvt('take',
927
                       extras={'axis': axis})(inputs, {})
928

929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945

class Greater(OnnxOpConverter):
    """ Operator logical greater.
    """
    @classmethod
    def _impl_v7(cls, inputs, attr, params):
        return _op.greater(inputs[0], inputs[1])


class Less(OnnxOpConverter):
    """ Operator logical less than.
    """
    @classmethod
    def _impl_v7(cls, inputs, attr, params):
        return _op.less(inputs[0], inputs[1])


946 947 948 949 950 951 952 953 954 955 956 957 958
class LRN(OnnxOpConverter):
    """ Operator converter for Local Response Normalization.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        """LRN support only NCHW format
        https://github.com/onnx/onnx/blob/master/docs/Operators.md#LRN
        """
        axis = 1
        alpha = attr.get('alpha', 0.0001)
        beta = attr.get('beta', 0.75)
        bias = attr.get('bias', 1.0)
        nsize = attr.get('size')
959
        attr = {'size': nsize, 'axis': axis, 'alpha': alpha, 'beta': beta, 'bias': bias}
960 961 962 963 964 965 966
        return AttrCvt('lrn')(inputs, attr)

class Maximum(OnnxOpConverter):
    """ Operator converter for Maximum.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
967
        if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
968 969 970 971 972 973 974 975 976 977 978
            raise ValueError("Expect minimum 2 inputs")
        _max = inputs[0]
        for i in range(1, len(inputs)):
            _max = AttrCvt('maximum')([_max, inputs[i]], {})
        return _max

class Minimum(OnnxOpConverter):
    """ Operator converter for Minimum.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
979
        if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
980 981 982 983 984 985 986 987 988 989 990
            raise ValueError("Expect minimum 2 inputs")
        _min = inputs[0]
        for i in range(1, len(inputs)):
            _min = AttrCvt('minimum')([_min, inputs[i]], {})
        return _min

class Mean(OnnxOpConverter):
    """ Operator converter for Mean.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
991
        if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004
            raise ValueError("Expect minimum 2 inputs")
        # avoid overflow
        concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0)
        return _op.mean(concat, axis=0, keepdims=False)

class HardSigmoid(OnnxOpConverter):
    """ Operator converter for HardSigmoid.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = attr.get('alpha', 0.2)
        beta = attr.get('beta', 0.5)
        transformX = (inputs[0] * _expr.const(alpha)) + _expr.const(beta)
1005
        attr = {'a_min': 0, 'a_max': 1}
1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018
        return AttrCvt('clip')([transformX], attr)

class Reduce(OnnxOpConverter):
    """ Operator converter for reduce ops.
    """
    name = ''
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        if 'axes' in attr:
            axis = attr.get('axes', 0)
        else:
            axis_len = len(infer_shape(inputs[0]))
            axis = list(range(axis_len))
1019
        attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)}
1020 1021 1022
        return AttrCvt(cls.name)(inputs, attr)

class ReduceMax(Reduce):
1023
    """ Operator converter for ReduceMax.
1024 1025 1026 1027
    """
    name = 'max'

class ReduceMin(Reduce):
1028
    """ Operator converter for ReduceMin.
1029 1030 1031 1032
    """
    name = 'min'

class ReduceSum(Reduce):
1033
    """ Operator converter for ReduceSum.
1034 1035 1036 1037
    """
    name = 'sum'

class ReduceMean(Reduce):
1038
    """ Operator converter for ReduceMean.
1039 1040 1041
    """
    name = 'mean'

1042
class ReduceProd(Reduce):
1043
    """ Operator converter for ReduceProd.
1044 1045 1046
    """
    name = 'prod'

1047 1048 1049 1050 1051 1052 1053
class ArgMax(OnnxOpConverter):
    """ Operator converter for ArgMax.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        axis = attr.get('axis', 0)
        keepdims = attr.get('keepdims', True)
1054
        attr = {'axis': axis, 'keepdims': keepdims}
1055 1056 1057 1058 1059 1060 1061 1062 1063
        return AttrCvt('argmax')(inputs, attr)

class ArgMin(OnnxOpConverter):
    """ Operator converter for ArgMin.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        axis = attr.get('axis', 0)
        keepdims = attr.get('keepdims', True)
1064
        attr = {'axis': axis, 'keepdims': keepdims}
1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076
        return AttrCvt('argmin')(inputs, attr)

class Softmax(OnnxOpConverter):
    """ Operator converter for Softmax.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        # set default value when axis is not set in the model
        if 'axis' not in attr:
            attr['axis'] = 1
        return AttrCvt('softmax', transforms={'axis': ('axis', 1)})(inputs, attr, params)

1077 1078 1079

class OneHot(OnnxOpConverter):
    """ Operator converter for OneHot.
1080 1081
    """
    @classmethod
1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111
    def _impl_v9(cls, inputs, attr, params):
        # Extract relay one_hot inputs.
        indices, depth, values = inputs
        # Split onnx on off values into two separate expressions.
        off_value, on_value = _op.take(
            values, _op.const(0)), _op.take(values, _op.const(1))
        # Extract the datatype of the output from on_value.
        dtype = infer_type(on_value).checked_type.dtype
        # Convert depth into an integer.
        depth = int(infer_value(depth, params).asnumpy()[0])
        # set default value when axis is not set in the model
        if 'axis' not in attr:
            attr['axis'] = -1
        return _op.one_hot(indices,
                           on_value,
                           off_value,
                           depth,
                           int(attr['axis']),
                           dtype=dtype)


class ConstantOfShape(OnnxOpConverter):
    """ Operator converter for ConstantOfShape.
    """
    @classmethod
    def _impl_v9(cls, inputs, attr, params):
        if 'value' in attr:
            np_value = get_numpy(attr.pop('value'))[0]
            value = _expr.const(np_value)
            dtype = np_value.dtype.name
1112
        else:
1113 1114 1115 1116 1117 1118
            value = _expr.const(0)
            dtype = 'float32'
        static_shape = infer_value_simulated(inputs[0], params)
        output = _op.full(
            value, shape=tuple(static_shape.asnumpy().astype('int32')), dtype=dtype)
        return output
1119 1120


1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132
class Sign(OnnxOpConverter):
    """ Operator converter for Sign.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _op.sign(inputs[0])

class Equal(Elemwise):
    """ Operator converter for Equal.
    """
    name = 'equal'

1133 1134 1135 1136 1137 1138 1139 1140 1141

class Not(Elemwise):
    """ Operator converter for Not.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _op.logical_not(inputs[0])


1142 1143 1144 1145 1146 1147 1148 1149
class And(Elemwise):
    """ Operator converter for And.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _op.logical_and(inputs[0], inputs[1])


1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160
class Tile(Elemwise):
    """Operator converter for Tile
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        if 'repeats' not in attr:
            raise tvm.error.OpAttributeInvalid('Attribute "repeats" should be set '
                                               'for operator Tile.')
        reps = attr.pop('repeats')  # The number of times repeating the tensor data.
        return _op.tile(inputs[0], reps)

1161 1162 1163 1164 1165 1166
    @classmethod
    def _impl_v6(cls, inputs, attr, params):
        reps = tuple(infer_value_simulated(
            inputs[1], params).asnumpy().astype('int32'))
        return _op.tile(inputs[0], reps)

1167 1168 1169 1170 1171 1172 1173
class Erf(OnnxOpConverter):
    """Operator converter for Erf
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _op.erf(inputs[0])

1174 1175 1176 1177 1178
class Where(OnnxOpConverter):
    """Operator converter for Where
    """
    @classmethod
    def _impl_v9(cls, inputs, attr, params):
1179 1180 1181
        condition_shape = infer_shape(inputs[0])
        x_shape = infer_shape(inputs[1])
        y_shape = infer_shape(inputs[2])
1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205

        # condition, x, and y can all be broadcasted.
        # broadcast each of them to the longest shape.
        # if two shapes have the same number of dimensions,
        # try to choose the one that doesn't have "1" as
        # a dimension.
        shapes = [condition_shape, x_shape, y_shape]
        shape_lens = [len(shape) for shape in shapes]
        max_size = max(shape_lens)
        max_size_idxs = [i for i, x in enumerate(shape_lens) if x == max_size]
        broadcast_idx = max_size_idxs[0]
        if len(max_size_idxs) > 1:
            for idx in max_size_idxs:
                if 1 not in shapes[idx]:
                    broadcast_idx = idx

        broadcast_shape = shapes[broadcast_idx]

        if condition_shape != broadcast_shape:
            inputs[0] = _op.broadcast_to(inputs[0], broadcast_shape)
        if x_shape != broadcast_shape:
            inputs[1] = _op.broadcast_to(inputs[1], broadcast_shape)
        if y_shape != broadcast_shape:
            inputs[2] = _op.broadcast_to(inputs[2], broadcast_shape)
1206 1207
        return _op.where(inputs[0], inputs[1], inputs[2])

1208 1209 1210 1211 1212 1213
class Or(Elemwise):
    """ Operator converter for Or.
    """
    @classmethod
    def _impl_v7(cls, inputs, attr, params):
        return _op.logical_or(inputs[0], inputs[1])
1214

1215

1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261
class Expand(OnnxOpConverter):
    """ Operator converter for Expand.
    """
    @classmethod
    def _impl_v8(cls, inputs, attr, params):
        in_shape = np.array(infer_shape(inputs[0])).astype('int32')
        if get_name(inputs[1]) in params:
            shape = params[inputs[1].name_hint].asnumpy().astype('int32')
        else:
            shape = infer_value_simulated(inputs[1], params).asnumpy().astype('int32')

        # Currently 'op.broadcast_to' expect the rank of the given 'shape'
        # (the 2nd input) is always higher than that of the given 'input' (the 1st input)
        # However, ONNX Expand supports multi-directional broadcasting, which allows
        # above pattern and also some extent of 'shape' can be smaller than the corresponding
        # extent of 'input'. In this case, the extent of 'shape' must be 1.
        # https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
        # In above cases, we cannot directorly apply 'op.broadcast_to' instead of 'expand'
        # so, here we solved this problem by expanding the given 'shape' itself.
        def expand_shape(in_shape, shape):
            """ A function expands the shape when the rank is lower than that of the given
            intput. Also it replaces the extent of the shape with the corresponding extent
            of the intput when it is 1.
            """

            # here we flip the shapes because this can be more simply written
            # when the innermost dimension is located at the index 0.
            in_shape = np.flip(in_shape, axis=0)
            shape = np.flip(shape, axis=0)

            if in_shape.size < shape.size:
                for i in range(shape.size):
                    if i < in_shape.size and in_shape[i] > shape[i]:
                        shape[i] = in_shape[i]
            else:
                for i in range(in_shape.size):
                    if i >= shape.size:
                        np.append(shape, in_shape[i])
                    elif shape[i] == 1:
                        shape[i] = in_shape[i]

            new_shape = np.flip(shape, axis=0)
            return new_shape

        shape = expand_shape(in_shape, shape)
        return _op.broadcast_to(inputs[0], shape=tuple(shape))
1262

1263

1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408
class LSTM(OnnxOpConverter):
    """ Operator converter for LSTM.
    """

    @classmethod
    def _activation_helper(cls, activation, alpha, beta):
        convert_map = _get_convert_map(1)
        attrs = {}
        if alpha is not None:
            attrs['alpha'] = alpha
        if beta is not None:
            attrs['beta'] = beta
        return lambda x: convert_map[activation.decode("utf-8")]([x], attrs, {})

    @classmethod
    def _activation_needs_alpha(cls, activation):
        needs_alpha = [
            "Affine",
            "LeakyRelu",
            "ThresholdedRelu",
            "ScaledTanh",
            "HardSigmoid",
            "Elu",
        ]
        return activation.decode("utf-8") in needs_alpha

    @classmethod
    def _activation_needs_beta(cls, activation):
        needs_beta = [
            "Affine",
            "ScaledTanh",
            "HardSigmoid",
        ]
        return activation.decode("utf-8") in needs_beta

    @classmethod
    def _impl_v7(cls, inputs, attr, params):
        # Unpack inputs, note that if optional and not provided then value will be None.
        X = inputs[0]
        W = inputs[1]
        R = inputs[2]
        B = inputs['B']
        # Sequence length currently unused as it can be inferred from shapes.
        #sequence_lens = inputs['sequence_lens']
        h_0 = inputs['initial_h']
        c_0 = inputs['initial_c']
        P = inputs['P']

        num_directions = infer_shape(W)[0]
        W_dtype = infer_type(W).type_annotation.dtype

        if num_directions != 1:
            raise NotImplementedError("Bidirectional LSTMs not yet supported.")
        # Remove num_directions axis from weights.
        W = _op.squeeze(W, axis=[0])
        R = _op.squeeze(R, axis=[0])
        if B is not None:
            B = _op.squeeze(B, axis=[0])

        X_shape = infer_shape(X)
        hidden_size = infer_shape(R)[-1]
        batch_size = X_shape[1]

        # Initialize state if not provided.
        # Otherwise remove bidirectional axis.
        if h_0 is None:
            h_0 = _op.zeros((batch_size, hidden_size), W_dtype)
        else:
            h_0 = _op.squeeze(h_0, axis=[0])
        if c_0 is None:
            c_0 = _op.zeros((batch_size, hidden_size), W_dtype)
        else:
            c_0 = _op.squeeze(c_0, axis=[0])

        if P is not None:
            P = _op.squeeze(P, axis=[0])
            p_i, p_o, p_f = _op.split(P, 3)
        H_t = h_0
        C_t = c_0
        h_list = []

        if 'activations' in attr:
            activations = attr['activations']
            if len(activations) != 3:
                raise NotImplementedError("LSTM assumes 3 activation functions are provided")
            alpha_loc = 0
            alphas = attr.get('activation_alpha', [])
            if isinstance(alphas, float):
                alphas = [alphas]
            beta_loc = 0
            betas = attr.get('activation_beta', [])
            if isinstance(betas, float):
                betas = [betas]
            acts = []
            for i in range(3):
                alpha = None
                beta = None
                activation = activations[i]
                if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc:
                    alpha = alphas[alpha_loc]
                    alpha_loc += 1
                if cls._activation_needs_beta(activation) and len(betas) > beta_loc:
                    beta = betas[beta_loc]
                    beta_loc += 1
                acts.append(cls._activation_helper(activation, alpha, beta))
            f_act, g_act, h_act = acts
        else:
            f_act = _op.sigmoid
            g_act = _op.tanh
            h_act = _op.tanh

        X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0)
        for step in X_steps:
            step = _op.squeeze(step, axis=[0])
            gates = _op.nn.dense(step, W) + _op.nn.dense(H_t, R)
            if B is not None:
                WB, RB = _op.split(B, 2)
                gates += WB + RB
            i, o, f, c = _op.split(gates, 4, axis=-1)
            if P is not None:
                i = f_act(i + p_i * C_t)
                f = f_act(f + p_f * C_t)

            else:
                i = f_act(i)
                f = f_act(f)
            c = g_act(c)
            C = f * C_t + i * c
            if P is not None:
                o = f_act(o + p_o * C)
            else:
                o = f_act(o)
            H = o * h_act(C)
            H_t = H
            C_t = C
            h_list.append(_op.expand_dims(H, axis=0))
        # Concatenate outputs and add back in direction axis.
        concatenated = _op.concatenate(h_list, 0)
        output = _op.expand_dims(concatenated, axis=1)
        H_t = _op.expand_dims(H_t, axis=0)
        C_t = _op.expand_dims(C_t, axis=0)

        return _expr.TupleWrapper(_expr.Tuple((output, H_t, C_t)), 3)


1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446
class Resize(OnnxOpConverter):
    """Operator converter for Resize
    """
    @classmethod
    def _impl_v11(cls, inputs, attr, params):
        mode = attr.get('mode')
        if mode == b'nearest':
            method = "nearest_neighbor"
        elif mode == b'linear':
            method = "bilinear"
        else:
            raise tvm.error.OpAttributeInvalid(
                'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode))

        in_size = np.array(infer_shape(inputs[0]))
        scale = infer_value_simulated(inputs[2], params).asnumpy()
        if len(inputs) == 4:
            assert len(scale) == 0, "One of scale or size should be passed, not both."
            size = infer_value_simulated(inputs[3], params).asnumpy().astype(np.int32)
        else:
            assert len(scale) != 0, "One of scale or size should be passed."
            size = (in_size * scale).astype(np.int32)

        coord_trans = attr.get('coordinate_transformation_mode')
        if coord_trans in [b'pytorch_half_pixel', b'half_pixel']:
            coord_trans = "half_pixel"
        elif coord_trans == b'align_corners':
            coord_trans = "align_corners"
        elif coord_trans == b'asymmetric' or method == "nearest_neighbor":
            coord_trans = "asymmetric"
        else:
            raise tvm.error.OpAttributeInvalid(
                'Unsupported coordinate_transformation_mode: {}'.format(coord_trans))
        layout = "NCHW"  # ONNX assumes NCHW layout
        out_size = (size[2], size[3])
        return _op.image.resize(inputs[0], out_size, layout, method, coord_trans)


1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458
class NonZero(OnnxOpConverter):
    """Operator converter for NonZero
    """
    @classmethod
    def _impl_v9(cls, inputs, attr, params):
        if len(inputs) > 1:
            raise ValueError("Expect 1 input only")

        output = AttrCvt(op_name='argwhere')(inputs, attr, params)
        return _op.transpose(output, axes=(1, 0))


1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471
# compatible operators that do NOT require any conversion.
_identity_list = []


# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
def _get_convert_map(opset):
    return {
        # defs/experimental
        'Identity': Renamer('copy'),
1472
        'Affine': Affine.get_converter(opset),
1473 1474 1475
        'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
        'ScaledTanh': ScaledTanh.get_converter(opset),
        'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
1476
        'ConstantOfShape': ConstantOfShape.get_converter(opset),
1477 1478 1479 1480 1481 1482 1483 1484 1485
        # 'GivenTensorFill'
        'FC': AttrCvt('dense', ignores=['axis', 'axis_w']),
        'Scale': Scale.get_converter(opset),
        # 'GRUUnit'
        # 'ATen'
        # 'ImageScaler'
        # 'MeanVarianceNormalization'
        # 'Crop'
        # 'Embedding'
1486
        'Upsample': Upsample.get_converter(opset),
1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513
        'SpatialBN': BatchNorm.get_converter(opset),

        # defs/generator
        # 'Constant' # Implemented
        # 'RandomUniform'
        # 'RandomNormal'
        # 'RandomUniformLike'
        # 'RandomNormalLike'

        # defs/logical

        # defs/math
        'Add': Add.get_converter(opset),
        'Sub': Sub.get_converter(opset),
        'Mul': Mul.get_converter(opset),
        'Div': Div.get_converter(opset),
        'Neg': Renamer('negative'),
        'Abs': Absolute.get_converter(opset),
        'Reciprocal': Reciprocal.get_converter(opset),
        'Floor': Renamer('floor'),
        'Ceil': Renamer('ceil'),
        'Sqrt': Renamer('sqrt'),
        'Relu': Renamer('relu'),
        'LeakyRelu': Renamer('leaky_relu'),
        'Selu': Selu.get_converter(opset),
        'Elu': Elu.get_converter(opset),
        'Exp': Renamer('exp'),
1514 1515
        'Greater': Greater.get_converter(opset),
        'Less': Less.get_converter(opset),
1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529
        'Log': Renamer('log'),
        'Tanh': Renamer('tanh'),
        'Pow': Renamer('power'),
        'PRelu': Prelu.get_converter(opset),
        'Sigmoid': Renamer('sigmoid'),
        'HardSigmoid': HardSigmoid.get_converter(opset),
        'Max': Maximum.get_converter(opset),
        'Min': Minimum.get_converter(opset),
        'Sum': Sum.get_converter(opset),
        'Mean': Mean.get_converter(opset),
        'Clip': AttrCvt('clip', transforms={'min': 'a_min', 'max': 'a_max'}),
        # softmax default axis is different in onnx
        'Softmax': Softmax.get_converter(opset),
        'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}),
1530
        'OneHot': OneHot.get_converter(opset),
1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544
        # 'Hardmax'
        'Softsign': Softsign.get_converter(opset),
        'SoftPlus': SoftPlus.get_converter(opset),
        'Gemm': Gemm.get_converter(opset),
        'MatMul': MatMul.get_converter(opset),

        # defs/nn
        'AveragePool': AveragePool.get_converter(opset),
        'MaxPool': MaxPool.get_converter(opset),
        'Conv': Conv.get_converter(opset),
        'ConvTranspose': ConvTranspose.get_converter(opset),
        'GlobalAveragePool': Renamer('global_avg_pool2d'),
        'GlobalMaxPool': Renamer('global_max_pool2d'),
        'BatchNormalization': BatchNorm.get_converter(opset),
1545
        'InstanceNormalization': InstanceNorm.get_converter(opset),
1546 1547
        # 'LpNormalization'
        'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
1548
        'Flatten': Flatten.get_converter(opset),
1549
        'LRN': LRN.get_converter(opset),
1550 1551
        # Recurrent Layers
        'LSTM': LSTM.get_converter(opset),
1552 1553 1554 1555 1556 1557

        # defs/reduction
        'ReduceMax': ReduceMax.get_converter(opset),
        'ReduceMin': ReduceMin.get_converter(opset),
        'ReduceSum': ReduceSum.get_converter(opset),
        'ReduceMean': ReduceMean.get_converter(opset),
1558
        'ReduceProd': ReduceProd.get_converter(opset),
1559 1560 1561 1562 1563 1564 1565 1566
        # 'ReduceProd'
        # 'ReduceLogSumExp'
        'ArgMax': ArgMax.get_converter(opset),
        'ArgMin': ArgMin.get_converter(opset),

        # defs/tensor
        'Cast': Cast.get_converter(opset),
        'Reshape': Reshape.get_converter(opset),
1567
        'Expand': Expand.get_converter(opset),
1568 1569 1570 1571
        'Concat': Concat.get_converter(opset),
        'Split': Split.get_converter(opset),
        'Slice': Slice.get_converter(opset),
        'Transpose': AttrCvt('transpose', {'perm': 'axes'}),
1572 1573
        'DepthToSpace': DepthToSpace.get_converter(opset),
        'SpaceToDepth': SpaceToDepth.get_converter(opset),
1574 1575 1576 1577
        'Gather': Gather.get_converter(opset),
        'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
        'Unsqueeze': Unsqueeze.get_converter(opset),
        'Pad': Pad.get_converter(opset),
1578
        'Shape': Shape.get_converter(opset),
1579
        'Sign': Sign.get_converter(opset),
1580
        'Equal': Equal.get_converter(opset),
1581
        'Not': Not.get_converter(opset),
1582
        'And': And.get_converter(opset),
1583
        'Tile': Tile.get_converter(opset),
1584
        'Erf': Erf.get_converter(opset),
1585
        'Where': Where.get_converter(opset),
1586
        'Or': Or.get_converter(opset),
1587
        'Resize': Resize.get_converter(opset),
1588
        'NonZero': NonZero.get_converter(opset),
1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610
    }


class GraphProto(object):
    """A helper class for handling Relay expression copying from pb2.GraphProto.
    Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto

        Parameters
    ----------
    shape : dict of str to tuple, optional
        The input shape to the graph

    dtype : str or dict of str to str
        The input types to the graph
    """

    def __init__(self, shape, dtype):
        self._nodes = {}
        self._params = {}
        self._renames = {}
        self._num_input = 0
        self._num_param = 0
1611
        self._shape = shape if shape else {}
1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627
        self._dtype = dtype

    def from_onnx(self, graph, opset):
        """Construct Relay expression from ONNX graph.

        Onnx graph is a python protobuf object.
        The companion parameters will be handled automatically.
        However, the input names from onnx graph is vague, mixing inputs and
        network weights/bias such as "1", "2"...
        For convenience, we rename the `real` input names to "input_0",
        "input_1"... And renaming parameters to "param_0", "param_1"...

        Parameters
        ----------
        graph : onnx protobuf object
            The loaded onnx graph
1628

1629 1630 1631 1632
        opset : opset version

        Returns
        -------
1633
        mod : tvm.IRModule
1634 1635
            The returned relay module

1636 1637 1638 1639 1640 1641 1642 1643
        params : dict
            A dict of name: tvm.nd.array pairs, used as pretrained weights
        """
        # parse network inputs to relay, aka parameters
        for init_tensor in graph.initializer:
            if not init_tensor.name.strip():
                raise ValueError("Tensor's name is required.")
            self._params[init_tensor.name] = self._parse_array(init_tensor)
1644 1645 1646
            self._nodes[init_tensor.name] = new_var(init_tensor.name,
                                                    shape=self._params[init_tensor.name].shape,
                                                    dtype=self._params[init_tensor.name].dtype)
1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660
        for i in graph.input:
            # from onnx v0.2, GraphProto.input has type ValueInfoProto,
            #  and the name is 'i.name'
            i_name = self._parse_value_proto(i)
            d_type = self._parse_dtype(i, 'float32')
            if i_name in self._params:
                # i is a param instead of input
                self._num_param += 1
                self._params[i_name] = self._params.pop(i_name)
                self._nodes[i_name] = new_var(i_name,
                                              shape=self._params[i_name].shape,
                                              dtype=self._params[i_name].dtype)
            else:
                self._num_input += 1
1661 1662 1663 1664
                if i_name in self._shape:
                    tshape = self._shape[i_name]
                else:
                    raise ValueError("Must provide an input shape for `{0}`.".format(i_name))
1665 1666 1667 1668
                if isinstance(self._dtype, dict):
                    dtype = self._dtype[i_name] if i_name in self._dtype else d_type
                else:
                    dtype = d_type
1669
                self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype)
1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682
        # get list of unsupported ops
        convert_map = _get_convert_map(opset)
        unsupported_ops = set()
        for node in graph.node:
            op_name = node.op_type
            if op_name not in convert_map and \
               op_name != 'Constant' and \
               op_name not in _identity_list:
                unsupported_ops.add(op_name)
        if unsupported_ops:
            msg = 'The following operators are not supported for frontend ONNX: '
            msg += ', '.join(unsupported_ops)
            raise tvm.error.OpNotImplemented(msg)
1683 1684 1685 1686
        # construct nodes, nodes are stored as directed acyclic graph
        for node in graph.node:
            op_name = node.op_type
            attr = self._parse_attr(node.attribute)
1687 1688 1689 1690 1691
            # Create and populate onnx input object.
            inputs = onnx_input()
            for i in node.input:
                if i != '':
                    inputs[i] = self._nodes[self._renames.get(i, i)]
1692 1693 1694
            if op_name == "Constant":
                t_proto = self._parse_attr(node.attribute)["value"]
                self._num_param += 1
Zhi committed
1695 1696 1697 1698 1699 1700 1701
                # We should convert scalar integers to int32, to normalize.
                array = self._parse_array(t_proto)
                self._params[node.output[0]] = array
                self._nodes[node.output[0]] = new_var(
                    node.output[0],
                    shape=list(t_proto.dims),
                    dtype=array.dtype)
1702
            else:
1703 1704 1705 1706
                i_name = self._parse_value_proto(node)
                attr['tvm_custom'] = {}
                attr['tvm_custom']['name'] = i_name

1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724
                op = self._convert_operator(op_name, inputs, attr, opset)
                node_output = self._fix_outputs(op_name, node.output)
                if not isinstance(op, _expr.TupleWrapper):
                    outputs_num = 1
                else:
                    outputs_num = len(op)
                assert len(node_output) == outputs_num, (
                    "Number of output mismatch {} vs {} in {}.".format(
                        len(node_output), outputs_num, op_name))
                if outputs_num == 1:
                    self._nodes[node_output[0]] = op
                else:
                    for k, i in zip(list(node_output), range(len(node_output))):
                        self._nodes[k] = op[i]

        # now return the outputs
        outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
        outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
Zhi committed
1725
        func = _function.Function(analysis.free_vars(outputs), outputs)
1726
        return IRModule.from_expr(func), self._params
1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744

    def _parse_value_proto(self, value_proto):
        """Parse ValueProto or raw str."""
        try:
            name = value_proto.name
        except AttributeError:
            name = value_proto
        return name

    def _parse_dtype(self, value_proto, dtype):
        """Parse dtype."""
        try:
            from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
            return TENSOR_TYPE_TO_NP_TYPE[value_proto.type.tensor_type.elem_type].name
        except AttributeError:
            return dtype

    def _parse_array(self, tensor_proto):
1745
        np_array = get_numpy(tensor_proto).reshape(tuple(tensor_proto.dims))
1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783
        return _nd.array(np_array)

    def _parse_attr(self, attr_proto):
        """Convert a list of AttributeProto to a dict, with names as keys."""
        attrs = {}
        for a in attr_proto:
            for f in ['f', 'i', 's']:
                if a.HasField(f):
                    attrs[a.name] = getattr(a, f)
            for f in ['floats', 'ints', 'strings']:
                if list(getattr(a, f)):
                    assert a.name not in attrs, "Only one type of attr is allowed"
                    attrs[a.name] = tuple(getattr(a, f))
            for f in ['t']:
                if a.HasField(f):
                    attrs[a.name] = getattr(a, f)
            for f in ['tensors']:
                if list(getattr(a, f)):
                    assert a.name not in attrs, "Only one type of attr is allowed"
                    attrs[a.name] = tuple(getattr(a, f))
            for f in ['g']:
                if a.HasField(f):
                    raise NotImplementedError(
                        "Filed {} is not supported in relay.".format(f))
            for f in ['graphs']:
                if list(getattr(a, f)):
                    raise NotImplementedError(
                        "Filed {} is not supported in relay.".format(f))
            if a.name not in attrs:
                raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
        return attrs

    def _convert_operator(self,
                          op_name,
                          inputs,
                          attrs,
                          opset):
        """Convert ONNX operator into a Relay operator.
1784
        The converter must specify conversions explicitly for incompatible name, and
1785 1786 1787 1788 1789 1790
        apply handlers to operator attributes.

        Parameters
        ----------
        op_name : str
            Operator name, such as Convolution, FullyConnected
Zhi committed
1791
        inputs : list of tvm.relay.function.Function
1792 1793 1794 1795 1796 1797 1798 1799
            List of inputs.
        attrs : dict
            Dict of operator attributes
        opset : int
            Opset version

        Returns
        -------
Zhi committed
1800
        sym : tvm.relay.function.Function
1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825
            Converted relay function
        """
        convert_map = _get_convert_map(opset)
        if op_name in _identity_list:
            sym = get_relay_op(op_name)(*inputs, **attrs)
        elif op_name in convert_map:
            sym = convert_map[op_name](inputs, attrs, self._params)
        else:
            raise NotImplementedError(
                "Operator {} not implemented.".format(op_name))
        return sym

    def _fix_outputs(self, op_name, outputs):
        """A hack to handle dropout or similar operator that have more than one out
        in ONNX.
        """
        if op_name == 'Dropout':
            if len(outputs) == 1:
                return outputs
            # TODO(zhreshold): support dropout mask?
            outputs = outputs[:-1]
        return outputs

def from_onnx(model,
              shape=None,
1826 1827
              dtype="float32",
              opset=None):
1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847
    """Convert a ONNX model into an equivalent Relay Function.

    ONNX graphs are represented as Python Protobuf objects.
    The companion parameters will be handled automatically.
    However, the input names from onnx graph is vague, mixing inputs and
    network weights/bias such as "1", "2"...
    For convenience, we rename the `real` input names to "input_0",
    "input_1"... And renaming parameters to "param_0", "param_1"...

    Parameters
    ----------
    model : protobuf object
        ONNX ModelProto after ONNX v1.1.0

    shape : dict of str to tuple, optional
        The input shape to the graph

    dtype : str or dict of str to str
        The input types to the graph

1848 1849 1850 1851
    opset : int, optional
        Override to autodetected opset.
        This can be helpful for some testing.

1852 1853
    Returns
    -------
1854
    mod : tvm.IRModule
1855
        The relay module for compilation
1856

1857
    params : dict of str to tvm.nd.NDArray
1858 1859
        The parameter dict to be used by relay
    """
1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871
    try:
        import onnx
        if hasattr(onnx.checker, 'check_model'):
            # try use onnx's own model checker before converting any model
            try:
                onnx.checker.check_model(model)
            except onnx.onnx_cpp2py_export.checker.ValidationError as e:
                import warnings
                # the checker is a bit violent about errors, so simply print warnings here
                warnings.warn(str(e))
    except ImportError:
        pass
1872 1873
    g = GraphProto(shape, dtype)
    graph = model.graph
1874 1875 1876 1877 1878
    if opset is None:
        try:
            opset = model.opset_import[0].version if model.opset_import else 1
        except AttributeError:
            opset = 1
1879 1880
    mod, params = g.from_onnx(graph, opset)
    return mod, params