onnx.py 53.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21
# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines
"""ONNX: Open Neural Network Exchange frontend for Relay."""
from __future__ import absolute_import as _abs

import numpy as np
Zhi committed
22
import tvm
23
from ... import nd as _nd
Zhi committed
24
from .. import analysis
25
from .. import expr as _expr
26
from .. import module as _module
27 28
from .. import op as _op
from .common import AttrCvt, Renamer
29 30
from .common import get_relay_op, new_var, infer_shape, infer_channels
from .common import infer_type, infer_value, infer_value_simulated, get_name
31 32 33

__all__ = ['from_onnx']

34 35 36 37 38 39 40 41 42 43 44

def get_numpy(tensor_proto):
    """Grab data in TensorProto and convert to numpy array."""
    try:
        from onnx.numpy_helper import to_array
    except ImportError as e:
        raise ImportError(
            "Unable to import onnx which is required {}".format(e))
    return to_array(tensor_proto)


45
def dimension_picker(prefix, surfix=''):
46
    """Check that dimensions are supported."""
47 48
    def _impl(attr):
        kernel = attr['kernel_shape']
49 50
        if len(kernel) == 1:
            return prefix + '1d' + surfix
51 52
        if len(kernel) == 2:
            return prefix + '2d' + surfix
53 54
        msg = 'Only 1D and 2D kernels are supported for operator {}.'
        op_name = prefix + '1d/2d'
55
        raise tvm.error.OpAttributeInvalid(msg.format(op_name))
56 57 58

    return _impl

59

60 61 62 63 64 65 66
def revert_caffe2_pad(pads):
    """Caffe2 requires two times the normal padding."""
    if len(pads) == 4:
        pads = pads[:2]
    elif len(pads) == 2:
        pass
    else:
67 68
        raise tvm.error.OpAttributeInvalid(
            'Number of pads must be either 2 or 4.')
69 70
    return pads

71

72 73 74 75 76 77 78 79 80 81 82
def get_pad_pair(input1d, kernel1d, stride1d):
    """infer pad size"""
    if input1d % stride1d == 0:
        pad = max(kernel1d - stride1d, 0)
    else:
        pad = max(kernel1d - (input1d % stride1d), 0)
    pad_before = pad // 2
    pad_after = pad - pad_before
    return [pad_before, pad_after]


83 84 85 86 87 88 89 90 91 92
def onnx_default_layout(dims):
    if dims == 1:
        return 'NCW'
    if dims == 2:
        return 'NCHW'

    msg = "Only 1d and 2d layouts are currently supported"
    raise tvm.error.OpAttributeInvalid(msg.format(op_name))


93
def onnx_storage_order2layout(storage_order, dims=2):
94 95 96 97
    """converter of onnx storage order parameter to tvm storage order format"""
    if storage_order not in (0, 1):
        raise tvm.error.OpAttributeInvalid('Mode of storage_order must be either 0 or 1')

98 99
    if dims == 1:
        return 'NCW' if storage_order == 0 else 'NWC'
100
    if dims == 2:
101
        return 'NCHW' if storage_order == 0 else 'NHWC'
102 103 104

    msg = "Only 1d and 2d layouts are currently supported"
    raise tvm.error.OpAttributeInvalid(msg.format(op_name))
105 106


107 108
def dimension_constraint():
    def _dim_check(attrs):
109
        if len(attrs['kernel_shape']) == 2 or len(attrs['kernel_shape']) == 1:
110 111 112
            return True
        return False

113
    return _dim_check, "Only 1d and 2d kernel supported."
114

115

116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
class OnnxOpConverter(object):
    """ A helper class for holding onnx op converters.
    """

    @classmethod
    def get_converter(cls, opset):
        """ Get converter matches given opset.

        Parameters
        ----------
        opset: int
            opset from model.

        Returns
        -------
        converter, which should be `_impl_vx`. Number x is the biggest
            number smaller than or equal to opset belongs to all support versions.
        """
        versions = [
            int(d.replace('_impl_v', '')) for d in dir(cls) if '_impl_v' in d
        ]
        versions = sorted(versions + [opset])
        version = versions[
            max([i for i, v in enumerate(versions) if v == opset]) - 1]
        if hasattr(cls, '_impl_v{}'.format(version)):
            return getattr(cls, '_impl_v{}'.format(version))
        raise NotImplementedError(
            'opset version {} of {} not implemented'.format(
                version, cls.__name__))


147 148 149 150 151 152 153 154 155 156 157 158 159
class Unary(OnnxOpConverter):
    """ A helper class for unary op converters.
    """
    name = ''

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        assert len(inputs) == 1, "Unary math op {} takes 1 input, {} given".format(
            cls.name, len(inputs))
        op_name = cls.name
        return get_relay_op(op_name)(*inputs)


160 161 162 163 164 165 166
class Elemwise(OnnxOpConverter):
    """ A helper class for elemwise op converters.
    """
    name = ''

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
167 168
        assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(
            cls.name, len(inputs))
169 170 171 172 173 174 175 176
        op_name = cls.name
        conv_ops = ["conv2d", "conv2d_transpose"]
        if attr.get('broadcast', 0) and any(x in str(inputs[0]) for x in conv_ops):
            # TODO(zhreshold): remove hard coded infershape
            axis = int(attr.get('axis', 0))
            inputs[1] = _op.expand_dims(inputs[1], axis=axis, num_newaxis=2)
        return get_relay_op(op_name)(*inputs)

177

178 179 180 181 182 183 184
class Pool(OnnxOpConverter):
    """ A helper class for pool op converters.
    """
    name = ''

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
        input_shape = infer_shape(inputs[0])
        if 'auto_pad' in attr:
            attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
            if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
                pad_tuple = []
                for axis in range(len(input_shape) - 2):
                    axis_shape = input_shape[2 + axis]
                    stride = attr['strides'][axis]
                    kernel = attr['kernel_shape'][axis]
                    pad = get_pad_pair(axis_shape, kernel, stride)
                    pad_tuple.append(pad)
                pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair])
                attr['pads'] = pad_tuple
            elif attr['auto_pad'] == 'VALID':
                attr['pads'] = 0
            elif attr['auto_pad'] == 'NOTSET':
                pass
            else:
                msg = 'Value {} in attribute "auto_pad" of operator {} is invalid.'
                raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad'], cls.name))
            attr.pop("auto_pad")

        if 'storage_order' in attr:
            attr['layout'] = onnx_storage_order2layout(attr['storage_order'],
                                                       dims=(len(input_shape) - 2))
        else:
            attr['layout'] = onnx_default_layout(dims=(len(input_shape) - 2))

213 214 215 216
        return AttrCvt(
            op_name=dimension_picker(cls.name),
            transforms={
                'kernel_shape': 'pool_size',
217
                'pads': ('padding', 0)
218
            },
219
            ignores=['dilations'],
220 221 222
            custom_check=dimension_constraint())(inputs, attr, params)


223
class Absolute(Unary):
224 225
    """ Operator converter for Absolute.
    """
226
    name = 'abs'
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254


class Add(Elemwise):
    """ Operator converter for Add.
    """
    name = 'add'


class AveragePool(Pool):
    """ Operator converter for AveragePool.
    """
    name = 'avg_pool'


class BatchNorm(OnnxOpConverter):
    """ Operator converter for BatchNorm.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        # TODO(zhreshold): 'spatial' is not properly handled here.
        out = AttrCvt(
            op_name='batch_norm',
            ignores=['spatial', 'is_test', 'consumed_inputs', 'momentum'])(inputs, attr,
                                                                           params)
        return out[0]


255 256 257 258 259 260 261 262 263
class InstanceNorm(OnnxOpConverter):
    """ Operator converter for BatchNorm.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return AttrCvt(op_name='instance_norm')(inputs, attr, params)


264 265 266 267 268 269
class Conv(OnnxOpConverter):
    """ Operator converter for Conv.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
270 271 272
        # Use shape of input to determine convolution type.
        input_shape = infer_shape(inputs[0])

273 274 275
        if 'auto_pad' in attr:
            attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
            if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
276 277 278 279 280 281 282 283 284 285 286
                pad_tuple = []
                for axis in range(len(input_shape) - 2):
                    axis_shape = input_shape[2 + axis]
                    stride = attr['strides'][axis]
                    kernel = attr['kernel_shape'][axis]
                    dilation = attr['dilations'][axis]
                    dilated_kernel = (kernel - 1) * dilation + 1
                    pad = get_pad_pair(axis_shape, dilated_kernel, stride)
                    pad_tuple.append(pad)
                pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair])
                attr['pads'] = pad_tuple
287
            elif attr['auto_pad'] == 'VALID':
288
                attr['pads'] = tuple([0 for i in range(len(input_shape) - 2)])
289 290 291 292 293 294 295 296 297 298 299
            elif attr['auto_pad'] == 'NOTSET':
                pass
            else:
                msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'
                raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad']))
            attr.pop('auto_pad')

        out = AttrCvt(
            op_name=dimension_picker('conv'),
            transforms={
                'kernel_shape': 'kernel_size',
300 301 302 303
                'dilations': ('dilation', 1),
                'pads': ('padding', 0),
                'group': ('groups', 1)
            },
304
            custom_check=dimension_constraint())(inputs[:2], attr, params)
305

306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
        use_bias = len(inputs) == 3
        if use_bias:
            out = _op.nn.bias_add(out, inputs[2])
        return out


class ConvTranspose(OnnxOpConverter):
    """ Operator converter for ConvTranspose.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        # get number of channels
        channels = infer_channels(inputs[1], True)
        attr['channels'] = channels
        groups = attr.pop('group')
        attr['groups'] = groups
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
        # infer pads for auto_pad
        if 'auto_pad' in attr:
            attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
            if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
                input_shape = infer_shape(inputs[0])
                in_h, in_w = input_shape[2], input_shape[3]
                stride_h, stride_w = attr['strides']
                kernel_h, kernel_w = attr['kernel_shape']
                dilation_h, dilation_w = attr['dilations']
                dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
                dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
                pad_v = get_pad_pair(in_h, dilated_kernel_h, stride_h)
                pad_h = get_pad_pair(in_w, dilated_kernel_w, stride_w)
                attr['pads'] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1])
            elif attr['auto_pad'] == 'VALID':
                attr['pads'] = (0, 0)
            elif attr['auto_pad'] == 'NOTSET':
                pass
            else:
                msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'
                raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad']))
            attr.pop('auto_pad')

345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
        out = AttrCvt(
            op_name=dimension_picker('conv', '_transpose'),
            transforms={
                'kernel_shape': 'kernel_size',
                'dilations': ('dilation', (0, 0)),
                'pads': ('padding', (0, 0), revert_caffe2_pad)
            },
            disables=['output_shape'],
            custom_check=dimension_constraint())(inputs[:2], attr, params)
        use_bias = len(inputs) == 3
        if use_bias:
            out = _op.nn.bias_add(out, inputs[2])
        return out


class Div(Elemwise):
361 362
    """ Operator converter for Divide.
    """
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400
    name = 'divide'


class Elu(OnnxOpConverter):
    """ Operator converter for Elu.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = float(attr.get('alpha', 1.0))
        return _expr.const(-alpha) * _op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) + \
                                     _op.nn.relu(inputs[0])


class Gemm(OnnxOpConverter):
    """ Operator converter for Gemm.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(
            len(inputs))
        # Y = alpha * A * B + beta * C
        alpha = float(attr.get('alpha', 1.0))
        beta = float(attr.get('beta', 1.0))
        transA = int(attr.get('transA', 0))
        transB = int(attr.get('transB', 0))
        # get number of channels
        channels = infer_channels(inputs[1], not transB)
        if transA:
            inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
        if not transB:
            inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
        inputs[0] = _op.nn.batch_flatten(inputs[0])
        out = _op.nn.dense(_expr.const(alpha) * inputs[0],
                           inputs[1], units=channels)
        return _op.nn.bias_add(out, _expr.const(beta) * inputs[2])

401

402 403 404 405 406 407 408
class MatMul(OnnxOpConverter):
    """ Operator converter for MatMul.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs))
409 410 411 412 413 414 415 416
        # Need to check input shape as batch matmul must be supported.
        a_shape = infer_shape(inputs[0])
        # When performing a batch matmul, we need to properly handle N-dim shapes.
        if len(a_shape) > 2:
            b_shape = infer_shape(inputs[1])
            # Convert a and b into 3 dimensional tensors.
            a = _op.reshape(inputs[0], [-1, a_shape[-2], a_shape[-1]])
            b = _op.reshape(inputs[1], [-1, b_shape[-2], b_shape[-1]])
417 418 419 420 421 422
            # Broadcast b to match batch size of a
            new_b_shape = list(infer_shape(b))
            new_a_shape = infer_shape(a)
            if new_a_shape[0] > new_b_shape[0]:
                new_b_shape[0] = new_a_shape[0]
                b = _op.broadcast_to(b, new_b_shape)
423 424 425 426 427 428 429
            # Transpose matrix dimensions of b.
            b = _op.transpose(b, [0, 2, 1])
            # Perform a batch matmul.
            output = _op.nn.batch_matmul(a, b)
            # Reshape output to original dimensions.
            return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
        # Otherwise a simple dense op will get the job done.
430 431 432
        input_1_t = _op.transpose(inputs[1], axes=(1, 0))
        return _op.nn.dense(inputs[0], input_1_t)

433

434
class MaxPool(Pool):
435 436
    """ Operator converter for MaxPool
    """
437 438 439 440
    name = 'max_pool'


class Mul(Elemwise):
441 442
    """ Operator converter for Multiply.
    """
443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
    name = 'multiply'


class Pad(OnnxOpConverter):
    """ Operator converter for Pad.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        pad_width = []
        pads = attr.pop('paddings')
        dims = int(len(pads) / 2)
        for i in range(dims):
            pad_width.append((pads[i], pads[i+dims]))
        attr['pad_width'] = pad_width
458 459 460 461 462 463 464
        pad_mode = attr.get('mode', 'constant').decode('utf-8')
        if pad_mode in ['constant', 'edge', 'reflect']:
            attr['pad_mode'] = pad_mode
            attr.pop('mode', None)
        else:
            raise tvm.error.OpAttributeInvalid(
                'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.')
465 466 467 468 469 470

        return AttrCvt(
            _op.nn.pad,
            transforms={
                'value': 'pad_value',
            },
471
            )(inputs, attr, params)
472 473 474 475 476 477 478 479 480

    @classmethod
    def _impl_v2(cls, inputs, attr, params):
        pad_width = []
        pads = attr.pop('pads')
        dims = int(len(pads) / 2)
        for i in range(dims):
            pad_width.append((pads[i], pads[i+dims]))
        attr['pad_width'] = pad_width
481 482 483 484 485 486 487
        pad_mode = attr.get('mode', 'constant').decode('utf-8')
        if pad_mode in ['constant', 'edge', 'reflect']:
            attr['pad_mode'] = pad_mode
            attr.pop('mode', None)
        else:
            raise tvm.error.OpAttributeInvalid(
                'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.')
488 489 490 491 492 493

        return AttrCvt(
            'pad',
            transforms={
                'value': 'pad_value',
            },
494
            )(inputs, attr, params)
495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525


class ParametricSoftPlus(OnnxOpConverter):
    """ Operator converter for ParametricSoftPlus.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = _expr.const(float(attr.get('alpha', 1.0)))
        beta = _expr.const(float(attr.get('beta', 1.0)))
        return _op.log(_op.exp(beta * inputs[0]) + _expr.const(1.)) * alpha


class Prelu(OnnxOpConverter):
    """ Operator converter for Prelu.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs))
        return _op.nn.prelu(inputs[0], inputs[1])


class Reciprocal(OnnxOpConverter):
    """ Operator converter for Reciprocal.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _expr.const(1.0) / inputs[0]

526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542

class Flatten(OnnxOpConverter):
    """ Operator converter for Flatten.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        axis = attr.get('axis', 1)
        if axis == 1:
            out = _op.nn.batch_flatten(inputs[0])
        else:
            newshape = [0] * (axis + 1)
            newshape[axis] = -1
            out = _op.reshape(inputs[0], list(newshape))
        return out


543 544 545 546 547 548
class Reshape(OnnxOpConverter):
    """ Operator converter for Reshape.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
549
        return _op.reshape(inputs[0], attr['shape'])
550

551 552
    @classmethod
    def _impl_v5(cls, inputs, attr, params):
553
        if get_name(inputs[1]) in params:
554 555
            # pop shape out of parameters since it wont be needed later.
            shape = tuple(params.pop(inputs[1].name_hint).asnumpy())
556 557
            out = _op.reshape(inputs[0], shape)
        else:
Zhi committed
558
            data, shape = inputs
559 560 561
            static_shape = infer_value_simulated(shape, params)
            out = _op.reshape(data, newshape=tuple(
                static_shape.asnumpy().astype('int32')))
562 563
        return out

564 565 566 567 568 569 570 571 572 573

class DepthToSpace(OnnxOpConverter):
    """ Operator converter for DepthToSpace.
    """

    @classmethod
    def _impl_v11(cls, inputs, attr, params):

        block_size = int(attr['blocksize'])
        mode = attr.get("mode", "DCR")
574
        return _op.nn.depth_to_space(inputs[0], block_size, mode=mode)
575 576 577 578 579 580 581 582 583 584


class SpaceToDepth(OnnxOpConverter):
    """ Operator converter for SpaceToDepth.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):

        block_size = int(attr['blocksize'])
585
        return _op.nn.space_to_depth(inputs[0], block_size)
586 587


588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648
class Concat(OnnxOpConverter):
    """ Operator converter for Concat.
    """

    @classmethod
    def _impl_v1(cls, inputs, args, params):
        return AttrCvt(op_name='concatenate')((inputs,), args)

class Scale(OnnxOpConverter):
    """ Operator converter for Scale.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        scale = float(attr.get('scale', 1.0))
        return inputs[0] * _expr.const(scale)


class Selu(OnnxOpConverter):
    """ Operator converter for Selu.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = float(attr.get('alpha', 1.6732))
        gamma = float(attr.get('gamma', 1.0507))
        return _expr.const(gamma) * (_expr.const(-alpha) *
                                     _op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) +
                                     _op.nn.relu(inputs[0]))


class ScaledTanh(OnnxOpConverter):
    """ Operator converter for ScaledTanh.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = float(attr.get('alpha', 1.0))
        beta = float(attr.get('beta', 1.0))
        return _op.tanh(_expr.const(beta) * inputs[0]) * _expr.const(alpha)


class SoftPlus(OnnxOpConverter):
    """ Operator converter for SoftPlus.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _op.log(_op.exp(inputs[0]) + _expr.const(1.))


class Softsign(OnnxOpConverter):
    """ Operator converter for Softsign.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return inputs[0] / (_expr.const(1.) + Absolute.get_converter(1)(inputs, attr, params))


class Sub(Elemwise):
649 650
    """ Operator converter for Subtract.
    """
651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704
    name = 'subtract'


class Sum(OnnxOpConverter):
    """ Operator converter for Sum.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        # Onnx Sum Operator
        for in_index in range(len(inputs) - 1):
            inputs[in_index + 1] = _op.add(inputs[in_index], inputs[in_index + 1])

        return inputs[len(inputs) - 1]


class ThresholdedRelu(OnnxOpConverter):
    """ Operator converter for ThresholdedRelu.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = float(attr.get('alpha', 0.0))
        alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha))
        mask = _op.greater(inputs[0], alpha_tensor).astype("float32")
        return inputs[0] * mask


def _broadcast_constraint():

    def _broadcast_check(attrs):
        if attrs.get('axis', None):
            return False
        return True

    return _broadcast_check, "Specifying broadcast axis not allowed."


def _fully_connected(opset):

    def _impl(inputs, attr, params):
        # get number of channels
        channels = infer_channels(inputs[1], params)
        attr['units'] = channels
        return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)

    return _impl


class Upsample(OnnxOpConverter):
    """ Operator converter for Upsample (nearest mode).
    """

    @classmethod
705
    def _impl_v9(cls, inputs, attr, params):
706
        scales = attr.get('scales')
707 708 709 710 711
        if not scales:
            #Here we are going to higher OPSET version.
            assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs))
            scales = params[inputs[1].name_hint].asnumpy()
            inputs = inputs[:1]
712
        assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0
713 714
        mode = attr.get('mode')
        if mode == b'nearest':
715
            method = "nearest_neighbor"
716
        elif mode == b'linear':
717
            method = "bilinear"
718
        else:
719 720
            raise tvm.error.OpAttributeInvalid(
                'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode))
721 722
        attr = {'scale_h': scales[-2], 'scale_w': scales[-1], 'method': method,
                'layout': 'NCHW', 'align_corners': True}
723 724 725 726 727 728 729 730 731
        return AttrCvt('upsampling')(inputs, attr)


class Shape(OnnxOpConverter):
    """ Operator converter for Shape.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
masahi committed
732
        return _op.shape_of(inputs[0], "int64")
733 734 735 736 737 738 739 740 741 742 743 744 745

class Cast(OnnxOpConverter):
    """ Operator converter for Cast.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)

    @classmethod
    def _impl_v5(cls, inputs, attr, params):
        try:
            from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
746
            attr['to'] = str(TENSOR_TYPE_TO_NP_TYPE[attr['to']])
747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769
        except ImportError as e:
            raise ImportError(
                "Unable to import onnx.mapping which is required {}".format(e))
        return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)


class Unsqueeze(OnnxOpConverter):
    """ Operator converter for Unsqueeze.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        for axes in attr['axes']:
            inputs[0] = _op.expand_dims(inputs[0], axis=axes, num_newaxis=1)
        return inputs[0]


class Split(OnnxOpConverter):
    """ Operator converter for Split.
    """

    @classmethod
    def _impl_v1(cls, inputs, attr, params):
770 771 772 773 774 775 776 777 778 779 780
        splits = attr.get('split', False)
        if splits:
            attr['indices_or_sections'] = []
            index = 0
            for i in splits[:-1]:
                index += i
                attr['indices_or_sections'].append(index)
        # When splits isnt specified divide evenly over axis.
        else:
            in_shape = infer_shape(inputs[0])
            attr['indices_or_sections'] = in_shape[attr['axis']]
781 782 783 784 785 786 787 788
        return AttrCvt(
            'split',
            ignores=['split'])(inputs, attr, params)


class Slice(OnnxOpConverter):
    """ Operator converter for Slice.
    """
789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807

    @classmethod
    def _common(cls, starts, ends, axes):
        new_axes = []
        new_starts = []
        new_ends = []
        pop_index = 0
        for i in range(max(axes) + 1):
            if i in axes:
                new_axes.append(i)
                new_starts.append(starts[pop_index])
                new_ends.append(ends[pop_index])
                pop_index += 1
            else:
                new_axes.append(i)
                new_starts.append(0)
                new_ends.append(np.iinfo(np.int32).max)
        return new_starts, new_ends, new_axes

808 809 810 811 812 813 814 815 816 817 818
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        if isinstance(attr['starts'], int):
            attr['starts'] = (attr['starts'],)
            attr['ends'] = (attr['ends'],)

        try:
            # Update the starts and ends according to axes if required.
            if isinstance(attr['axes'], int):
                attr['axes'] = (attr['axes'],)
            if (max(attr['axes']) + 1) != len(attr['axes']):
819 820
                new_starts, new_ends, new_axes = cls._common(
                    attr['starts'], attr['ends'], attr['axes'])
821 822 823 824 825 826 827 828 829 830 831
                attr['axes'] = new_axes
                attr['starts'] = new_starts
                attr['ends'] = new_ends
        except KeyError:
            pass

        return AttrCvt('strided_slice',
                       transforms={'starts': 'begin',
                                   'ends': 'end'},
                       ignores=['axes'])(inputs, attr)

832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848
    @classmethod
    def _impl_v10(cls, inputs, attr, params):
        starts = params[get_name(inputs[1])].asnumpy()
        ends = params[get_name(inputs[2])].asnumpy()

        # Update the starts and ends according to axes if required.
        if len(inputs) >= 4:
            axes = params[get_name(inputs[3])].asnumpy()

            if max(axes + 1) != len(axes):
                new_starts, new_ends, _ = cls._common(
                    starts, ends, axes)
                starts = new_starts
                ends = new_ends
        return _op.strided_slice(inputs[0], begin=starts, end=ends)


849 850 851 852 853 854 855
class Gather(OnnxOpConverter):
    """ Operator converter for Gather.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        axis = attr.get('axis', 0)
        return AttrCvt('take',
856
                       extras={'axis': axis})(inputs, {})
857

858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874

class Greater(OnnxOpConverter):
    """ Operator logical greater.
    """
    @classmethod
    def _impl_v7(cls, inputs, attr, params):
        return _op.greater(inputs[0], inputs[1])


class Less(OnnxOpConverter):
    """ Operator logical less than.
    """
    @classmethod
    def _impl_v7(cls, inputs, attr, params):
        return _op.less(inputs[0], inputs[1])


875 876 877 878 879 880 881 882 883 884 885 886 887
class LRN(OnnxOpConverter):
    """ Operator converter for Local Response Normalization.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        """LRN support only NCHW format
        https://github.com/onnx/onnx/blob/master/docs/Operators.md#LRN
        """
        axis = 1
        alpha = attr.get('alpha', 0.0001)
        beta = attr.get('beta', 0.75)
        bias = attr.get('bias', 1.0)
        nsize = attr.get('size')
888
        attr = {'size': nsize, 'axis': axis, 'alpha': alpha, 'beta': beta, 'bias': bias}
889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933
        return AttrCvt('lrn')(inputs, attr)

class Maximum(OnnxOpConverter):
    """ Operator converter for Maximum.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        if not isinstance(inputs, list) or len(inputs) < 2:
            raise ValueError("Expect minimum 2 inputs")
        _max = inputs[0]
        for i in range(1, len(inputs)):
            _max = AttrCvt('maximum')([_max, inputs[i]], {})
        return _max

class Minimum(OnnxOpConverter):
    """ Operator converter for Minimum.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        if not isinstance(inputs, list) or len(inputs) < 2:
            raise ValueError("Expect minimum 2 inputs")
        _min = inputs[0]
        for i in range(1, len(inputs)):
            _min = AttrCvt('minimum')([_min, inputs[i]], {})
        return _min

class Mean(OnnxOpConverter):
    """ Operator converter for Mean.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        if not isinstance(inputs, list) or len(inputs) < 2:
            raise ValueError("Expect minimum 2 inputs")
        # avoid overflow
        concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0)
        return _op.mean(concat, axis=0, keepdims=False)

class HardSigmoid(OnnxOpConverter):
    """ Operator converter for HardSigmoid.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        alpha = attr.get('alpha', 0.2)
        beta = attr.get('beta', 0.5)
        transformX = (inputs[0] * _expr.const(alpha)) + _expr.const(beta)
934
        attr = {'a_min': 0, 'a_max': 1}
935 936 937 938 939 940 941 942 943 944 945 946 947
        return AttrCvt('clip')([transformX], attr)

class Reduce(OnnxOpConverter):
    """ Operator converter for reduce ops.
    """
    name = ''
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        if 'axes' in attr:
            axis = attr.get('axes', 0)
        else:
            axis_len = len(infer_shape(inputs[0]))
            axis = list(range(axis_len))
948
        attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)}
949 950 951
        return AttrCvt(cls.name)(inputs, attr)

class ReduceMax(Reduce):
952
    """ Operator converter for ReduceMax.
953 954 955 956
    """
    name = 'max'

class ReduceMin(Reduce):
957
    """ Operator converter for ReduceMin.
958 959 960 961
    """
    name = 'min'

class ReduceSum(Reduce):
962
    """ Operator converter for ReduceSum.
963 964 965 966
    """
    name = 'sum'

class ReduceMean(Reduce):
967
    """ Operator converter for ReduceMean.
968 969 970
    """
    name = 'mean'

971
class ReduceProd(Reduce):
972
    """ Operator converter for ReduceProd.
973 974 975
    """
    name = 'prod'

976 977 978 979 980 981 982
class ArgMax(OnnxOpConverter):
    """ Operator converter for ArgMax.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        axis = attr.get('axis', 0)
        keepdims = attr.get('keepdims', True)
983
        attr = {'axis': axis, 'keepdims': keepdims}
984 985 986 987 988 989 990 991 992
        return AttrCvt('argmax')(inputs, attr)

class ArgMin(OnnxOpConverter):
    """ Operator converter for ArgMin.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        axis = attr.get('axis', 0)
        keepdims = attr.get('keepdims', True)
993
        attr = {'axis': axis, 'keepdims': keepdims}
994 995 996 997 998 999 1000 1001 1002 1003 1004 1005
        return AttrCvt('argmin')(inputs, attr)

class Softmax(OnnxOpConverter):
    """ Operator converter for Softmax.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        # set default value when axis is not set in the model
        if 'axis' not in attr:
            attr['axis'] = 1
        return AttrCvt('softmax', transforms={'axis': ('axis', 1)})(inputs, attr, params)

1006 1007 1008

class OneHot(OnnxOpConverter):
    """ Operator converter for OneHot.
1009 1010
    """
    @classmethod
1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040
    def _impl_v9(cls, inputs, attr, params):
        # Extract relay one_hot inputs.
        indices, depth, values = inputs
        # Split onnx on off values into two separate expressions.
        off_value, on_value = _op.take(
            values, _op.const(0)), _op.take(values, _op.const(1))
        # Extract the datatype of the output from on_value.
        dtype = infer_type(on_value).checked_type.dtype
        # Convert depth into an integer.
        depth = int(infer_value(depth, params).asnumpy()[0])
        # set default value when axis is not set in the model
        if 'axis' not in attr:
            attr['axis'] = -1
        return _op.one_hot(indices,
                           on_value,
                           off_value,
                           depth,
                           int(attr['axis']),
                           dtype=dtype)


class ConstantOfShape(OnnxOpConverter):
    """ Operator converter for ConstantOfShape.
    """
    @classmethod
    def _impl_v9(cls, inputs, attr, params):
        if 'value' in attr:
            np_value = get_numpy(attr.pop('value'))[0]
            value = _expr.const(np_value)
            dtype = np_value.dtype.name
1041
        else:
1042 1043 1044 1045 1046 1047
            value = _expr.const(0)
            dtype = 'float32'
        static_shape = infer_value_simulated(inputs[0], params)
        output = _op.full(
            value, shape=tuple(static_shape.asnumpy().astype('int32')), dtype=dtype)
        return output
1048 1049


1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061
class Sign(OnnxOpConverter):
    """ Operator converter for Sign.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _op.sign(inputs[0])

class Equal(Elemwise):
    """ Operator converter for Equal.
    """
    name = 'equal'

1062 1063 1064 1065 1066 1067 1068 1069 1070

class Not(Elemwise):
    """ Operator converter for Not.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _op.logical_not(inputs[0])


1071 1072 1073 1074 1075 1076 1077 1078
class And(Elemwise):
    """ Operator converter for And.
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _op.logical_and(inputs[0], inputs[1])


1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089
class Tile(Elemwise):
    """Operator converter for Tile
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        if 'repeats' not in attr:
            raise tvm.error.OpAttributeInvalid('Attribute "repeats" should be set '
                                               'for operator Tile.')
        reps = attr.pop('repeats')  # The number of times repeating the tensor data.
        return _op.tile(inputs[0], reps)

1090 1091 1092 1093 1094 1095
    @classmethod
    def _impl_v6(cls, inputs, attr, params):
        reps = tuple(infer_value_simulated(
            inputs[1], params).asnumpy().astype('int32'))
        return _op.tile(inputs[0], reps)

1096 1097 1098 1099 1100 1101 1102
class Erf(OnnxOpConverter):
    """Operator converter for Erf
    """
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        return _op.erf(inputs[0])

1103 1104 1105 1106 1107
class Where(OnnxOpConverter):
    """Operator converter for Where
    """
    @classmethod
    def _impl_v9(cls, inputs, attr, params):
1108 1109 1110 1111 1112 1113 1114 1115
        # x and y can be broadcasted
        condition_shape = infer_shape(inputs[0])
        x_shape = infer_shape(inputs[1])
        y_shape = infer_shape(inputs[2])
        if len(condition_shape) > len(x_shape):
            inputs[1] = _op.broadcast_to(inputs[1], condition_shape)
        if len(condition_shape) > len(y_shape):
            inputs[2] = _op.broadcast_to(inputs[2], condition_shape)
1116 1117
        return _op.where(inputs[0], inputs[1], inputs[2])

1118 1119 1120 1121 1122 1123
class Or(Elemwise):
    """ Operator converter for Or.
    """
    @classmethod
    def _impl_v7(cls, inputs, attr, params):
        return _op.logical_or(inputs[0], inputs[1])
1124

1125

1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171
class Expand(OnnxOpConverter):
    """ Operator converter for Expand.
    """
    @classmethod
    def _impl_v8(cls, inputs, attr, params):
        in_shape = np.array(infer_shape(inputs[0])).astype('int32')
        if get_name(inputs[1]) in params:
            shape = params[inputs[1].name_hint].asnumpy().astype('int32')
        else:
            shape = infer_value_simulated(inputs[1], params).asnumpy().astype('int32')

        # Currently 'op.broadcast_to' expect the rank of the given 'shape'
        # (the 2nd input) is always higher than that of the given 'input' (the 1st input)
        # However, ONNX Expand supports multi-directional broadcasting, which allows
        # above pattern and also some extent of 'shape' can be smaller than the corresponding
        # extent of 'input'. In this case, the extent of 'shape' must be 1.
        # https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
        # In above cases, we cannot directorly apply 'op.broadcast_to' instead of 'expand'
        # so, here we solved this problem by expanding the given 'shape' itself.
        def expand_shape(in_shape, shape):
            """ A function expands the shape when the rank is lower than that of the given
            intput. Also it replaces the extent of the shape with the corresponding extent
            of the intput when it is 1.
            """

            # here we flip the shapes because this can be more simply written
            # when the innermost dimension is located at the index 0.
            in_shape = np.flip(in_shape, axis=0)
            shape = np.flip(shape, axis=0)

            if in_shape.size < shape.size:
                for i in range(shape.size):
                    if i < in_shape.size and in_shape[i] > shape[i]:
                        shape[i] = in_shape[i]
            else:
                for i in range(in_shape.size):
                    if i >= shape.size:
                        np.append(shape, in_shape[i])
                    elif shape[i] == 1:
                        shape[i] = in_shape[i]

            new_shape = np.flip(shape, axis=0)
            return new_shape

        shape = expand_shape(in_shape, shape)
        return _op.broadcast_to(inputs[0], shape=tuple(shape))
1172

1173

1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190
# compatible operators that do NOT require any conversion.
_identity_list = []


# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
def _get_convert_map(opset):
    return {
        # defs/experimental
        'Identity': Renamer('copy'),
        # 'Affine'
        'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
        'ScaledTanh': ScaledTanh.get_converter(opset),
        'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
1191
        'ConstantOfShape': ConstantOfShape.get_converter(opset),
1192 1193 1194 1195 1196 1197 1198 1199 1200
        # 'GivenTensorFill'
        'FC': AttrCvt('dense', ignores=['axis', 'axis_w']),
        'Scale': Scale.get_converter(opset),
        # 'GRUUnit'
        # 'ATen'
        # 'ImageScaler'
        # 'MeanVarianceNormalization'
        # 'Crop'
        # 'Embedding'
1201
        'Upsample': Upsample.get_converter(opset),
1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228
        'SpatialBN': BatchNorm.get_converter(opset),

        # defs/generator
        # 'Constant' # Implemented
        # 'RandomUniform'
        # 'RandomNormal'
        # 'RandomUniformLike'
        # 'RandomNormalLike'

        # defs/logical

        # defs/math
        'Add': Add.get_converter(opset),
        'Sub': Sub.get_converter(opset),
        'Mul': Mul.get_converter(opset),
        'Div': Div.get_converter(opset),
        'Neg': Renamer('negative'),
        'Abs': Absolute.get_converter(opset),
        'Reciprocal': Reciprocal.get_converter(opset),
        'Floor': Renamer('floor'),
        'Ceil': Renamer('ceil'),
        'Sqrt': Renamer('sqrt'),
        'Relu': Renamer('relu'),
        'LeakyRelu': Renamer('leaky_relu'),
        'Selu': Selu.get_converter(opset),
        'Elu': Elu.get_converter(opset),
        'Exp': Renamer('exp'),
1229 1230
        'Greater': Greater.get_converter(opset),
        'Less': Less.get_converter(opset),
1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244
        'Log': Renamer('log'),
        'Tanh': Renamer('tanh'),
        'Pow': Renamer('power'),
        'PRelu': Prelu.get_converter(opset),
        'Sigmoid': Renamer('sigmoid'),
        'HardSigmoid': HardSigmoid.get_converter(opset),
        'Max': Maximum.get_converter(opset),
        'Min': Minimum.get_converter(opset),
        'Sum': Sum.get_converter(opset),
        'Mean': Mean.get_converter(opset),
        'Clip': AttrCvt('clip', transforms={'min': 'a_min', 'max': 'a_max'}),
        # softmax default axis is different in onnx
        'Softmax': Softmax.get_converter(opset),
        'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}),
1245
        'OneHot': OneHot.get_converter(opset),
1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259
        # 'Hardmax'
        'Softsign': Softsign.get_converter(opset),
        'SoftPlus': SoftPlus.get_converter(opset),
        'Gemm': Gemm.get_converter(opset),
        'MatMul': MatMul.get_converter(opset),

        # defs/nn
        'AveragePool': AveragePool.get_converter(opset),
        'MaxPool': MaxPool.get_converter(opset),
        'Conv': Conv.get_converter(opset),
        'ConvTranspose': ConvTranspose.get_converter(opset),
        'GlobalAveragePool': Renamer('global_avg_pool2d'),
        'GlobalMaxPool': Renamer('global_max_pool2d'),
        'BatchNormalization': BatchNorm.get_converter(opset),
1260
        'InstanceNormalization': InstanceNorm.get_converter(opset),
1261 1262
        # 'LpNormalization'
        'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
1263
        'Flatten': Flatten.get_converter(opset),
1264 1265 1266 1267 1268 1269 1270
        'LRN': LRN.get_converter(opset),

        # defs/reduction
        'ReduceMax': ReduceMax.get_converter(opset),
        'ReduceMin': ReduceMin.get_converter(opset),
        'ReduceSum': ReduceSum.get_converter(opset),
        'ReduceMean': ReduceMean.get_converter(opset),
1271
        'ReduceProd': ReduceProd.get_converter(opset),
1272 1273 1274 1275 1276 1277 1278 1279
        # 'ReduceProd'
        # 'ReduceLogSumExp'
        'ArgMax': ArgMax.get_converter(opset),
        'ArgMin': ArgMin.get_converter(opset),

        # defs/tensor
        'Cast': Cast.get_converter(opset),
        'Reshape': Reshape.get_converter(opset),
1280
        'Expand': Expand.get_converter(opset),
1281 1282 1283 1284
        'Concat': Concat.get_converter(opset),
        'Split': Split.get_converter(opset),
        'Slice': Slice.get_converter(opset),
        'Transpose': AttrCvt('transpose', {'perm': 'axes'}),
1285 1286
        'DepthToSpace': DepthToSpace.get_converter(opset),
        'SpaceToDepth': SpaceToDepth.get_converter(opset),
1287 1288 1289 1290
        'Gather': Gather.get_converter(opset),
        'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
        'Unsqueeze': Unsqueeze.get_converter(opset),
        'Pad': Pad.get_converter(opset),
1291
        'Shape': Shape.get_converter(opset),
1292
        'Sign': Sign.get_converter(opset),
1293
        'Equal': Equal.get_converter(opset),
1294
        'Not': Not.get_converter(opset),
1295
        'And': And.get_converter(opset),
1296
        'Tile': Tile.get_converter(opset),
1297
        'Erf': Erf.get_converter(opset),
1298
        'Where': Where.get_converter(opset),
1299
        'Or': Or.get_converter(opset),
1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321
    }


class GraphProto(object):
    """A helper class for handling Relay expression copying from pb2.GraphProto.
    Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto

        Parameters
    ----------
    shape : dict of str to tuple, optional
        The input shape to the graph

    dtype : str or dict of str to str
        The input types to the graph
    """

    def __init__(self, shape, dtype):
        self._nodes = {}
        self._params = {}
        self._renames = {}
        self._num_input = 0
        self._num_param = 0
1322
        self._shape = shape if shape else {}
1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338
        self._dtype = dtype

    def from_onnx(self, graph, opset):
        """Construct Relay expression from ONNX graph.

        Onnx graph is a python protobuf object.
        The companion parameters will be handled automatically.
        However, the input names from onnx graph is vague, mixing inputs and
        network weights/bias such as "1", "2"...
        For convenience, we rename the `real` input names to "input_0",
        "input_1"... And renaming parameters to "param_0", "param_1"...

        Parameters
        ----------
        graph : onnx protobuf object
            The loaded onnx graph
1339

1340 1341 1342 1343
        opset : opset version

        Returns
        -------
1344 1345 1346
        mod : tvm.relay.Module
            The returned relay module

1347 1348 1349 1350 1351 1352 1353 1354
        params : dict
            A dict of name: tvm.nd.array pairs, used as pretrained weights
        """
        # parse network inputs to relay, aka parameters
        for init_tensor in graph.initializer:
            if not init_tensor.name.strip():
                raise ValueError("Tensor's name is required.")
            self._params[init_tensor.name] = self._parse_array(init_tensor)
1355 1356 1357
            self._nodes[init_tensor.name] = new_var(init_tensor.name,
                                                    shape=self._params[init_tensor.name].shape,
                                                    dtype=self._params[init_tensor.name].dtype)
1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371
        for i in graph.input:
            # from onnx v0.2, GraphProto.input has type ValueInfoProto,
            #  and the name is 'i.name'
            i_name = self._parse_value_proto(i)
            d_type = self._parse_dtype(i, 'float32')
            if i_name in self._params:
                # i is a param instead of input
                self._num_param += 1
                self._params[i_name] = self._params.pop(i_name)
                self._nodes[i_name] = new_var(i_name,
                                              shape=self._params[i_name].shape,
                                              dtype=self._params[i_name].dtype)
            else:
                self._num_input += 1
1372 1373 1374 1375
                if i_name in self._shape:
                    tshape = self._shape[i_name]
                else:
                    raise ValueError("Must provide an input shape for `{0}`.".format(i_name))
1376 1377 1378 1379
                if isinstance(self._dtype, dict):
                    dtype = self._dtype[i_name] if i_name in self._dtype else d_type
                else:
                    dtype = d_type
1380
                self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype)
1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393
        # get list of unsupported ops
        convert_map = _get_convert_map(opset)
        unsupported_ops = set()
        for node in graph.node:
            op_name = node.op_type
            if op_name not in convert_map and \
               op_name != 'Constant' and \
               op_name not in _identity_list:
                unsupported_ops.add(op_name)
        if unsupported_ops:
            msg = 'The following operators are not supported for frontend ONNX: '
            msg += ', '.join(unsupported_ops)
            raise tvm.error.OpNotImplemented(msg)
1394 1395 1396 1397 1398 1399 1400 1401
        # construct nodes, nodes are stored as directed acyclic graph
        for node in graph.node:
            op_name = node.op_type
            attr = self._parse_attr(node.attribute)
            inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
            if op_name == "Constant":
                t_proto = self._parse_attr(node.attribute)["value"]
                self._num_param += 1
Zhi committed
1402 1403 1404 1405 1406 1407 1408
                # We should convert scalar integers to int32, to normalize.
                array = self._parse_array(t_proto)
                self._params[node.output[0]] = array
                self._nodes[node.output[0]] = new_var(
                    node.output[0],
                    shape=list(t_proto.dims),
                    dtype=array.dtype)
1409
            else:
1410 1411 1412 1413
                i_name = self._parse_value_proto(node)
                attr['tvm_custom'] = {}
                attr['tvm_custom']['name'] = i_name

1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431
                op = self._convert_operator(op_name, inputs, attr, opset)
                node_output = self._fix_outputs(op_name, node.output)
                if not isinstance(op, _expr.TupleWrapper):
                    outputs_num = 1
                else:
                    outputs_num = len(op)
                assert len(node_output) == outputs_num, (
                    "Number of output mismatch {} vs {} in {}.".format(
                        len(node_output), outputs_num, op_name))
                if outputs_num == 1:
                    self._nodes[node_output[0]] = op
                else:
                    for k, i in zip(list(node_output), range(len(node_output))):
                        self._nodes[k] = op[i]

        # now return the outputs
        outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
        outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
Zhi committed
1432
        func = _expr.Function(analysis.free_vars(outputs), outputs)
1433
        return _module.Module.from_expr(func), self._params
1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451

    def _parse_value_proto(self, value_proto):
        """Parse ValueProto or raw str."""
        try:
            name = value_proto.name
        except AttributeError:
            name = value_proto
        return name

    def _parse_dtype(self, value_proto, dtype):
        """Parse dtype."""
        try:
            from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
            return TENSOR_TYPE_TO_NP_TYPE[value_proto.type.tensor_type.elem_type].name
        except AttributeError:
            return dtype

    def _parse_array(self, tensor_proto):
1452
        np_array = get_numpy(tensor_proto).reshape(tuple(tensor_proto.dims))
1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490
        return _nd.array(np_array)

    def _parse_attr(self, attr_proto):
        """Convert a list of AttributeProto to a dict, with names as keys."""
        attrs = {}
        for a in attr_proto:
            for f in ['f', 'i', 's']:
                if a.HasField(f):
                    attrs[a.name] = getattr(a, f)
            for f in ['floats', 'ints', 'strings']:
                if list(getattr(a, f)):
                    assert a.name not in attrs, "Only one type of attr is allowed"
                    attrs[a.name] = tuple(getattr(a, f))
            for f in ['t']:
                if a.HasField(f):
                    attrs[a.name] = getattr(a, f)
            for f in ['tensors']:
                if list(getattr(a, f)):
                    assert a.name not in attrs, "Only one type of attr is allowed"
                    attrs[a.name] = tuple(getattr(a, f))
            for f in ['g']:
                if a.HasField(f):
                    raise NotImplementedError(
                        "Filed {} is not supported in relay.".format(f))
            for f in ['graphs']:
                if list(getattr(a, f)):
                    raise NotImplementedError(
                        "Filed {} is not supported in relay.".format(f))
            if a.name not in attrs:
                raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
        return attrs

    def _convert_operator(self,
                          op_name,
                          inputs,
                          attrs,
                          opset):
        """Convert ONNX operator into a Relay operator.
1491
        The converter must specify conversions explicitly for incompatible name, and
1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532
        apply handlers to operator attributes.

        Parameters
        ----------
        op_name : str
            Operator name, such as Convolution, FullyConnected
        inputs : list of tvm.relay.expr.Function
            List of inputs.
        attrs : dict
            Dict of operator attributes
        opset : int
            Opset version

        Returns
        -------
        sym : tvm.relay.expr.Function
            Converted relay function
        """
        convert_map = _get_convert_map(opset)
        if op_name in _identity_list:
            sym = get_relay_op(op_name)(*inputs, **attrs)
        elif op_name in convert_map:
            sym = convert_map[op_name](inputs, attrs, self._params)
        else:
            raise NotImplementedError(
                "Operator {} not implemented.".format(op_name))
        return sym

    def _fix_outputs(self, op_name, outputs):
        """A hack to handle dropout or similar operator that have more than one out
        in ONNX.
        """
        if op_name == 'Dropout':
            if len(outputs) == 1:
                return outputs
            # TODO(zhreshold): support dropout mask?
            outputs = outputs[:-1]
        return outputs

def from_onnx(model,
              shape=None,
1533 1534
              dtype="float32",
              opset=None):
1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554
    """Convert a ONNX model into an equivalent Relay Function.

    ONNX graphs are represented as Python Protobuf objects.
    The companion parameters will be handled automatically.
    However, the input names from onnx graph is vague, mixing inputs and
    network weights/bias such as "1", "2"...
    For convenience, we rename the `real` input names to "input_0",
    "input_1"... And renaming parameters to "param_0", "param_1"...

    Parameters
    ----------
    model : protobuf object
        ONNX ModelProto after ONNX v1.1.0

    shape : dict of str to tuple, optional
        The input shape to the graph

    dtype : str or dict of str to str
        The input types to the graph

1555 1556 1557 1558
    opset : int, optional
        Override to autodetected opset.
        This can be helpful for some testing.

1559 1560
    Returns
    -------
1561 1562
    mod : tvm.relay.Module
        The relay module for compilation
1563 1564 1565 1566

    params : dict of str to tvm.NDArray
        The parameter dict to be used by relay
    """
1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578
    try:
        import onnx
        if hasattr(onnx.checker, 'check_model'):
            # try use onnx's own model checker before converting any model
            try:
                onnx.checker.check_model(model)
            except onnx.onnx_cpp2py_export.checker.ValidationError as e:
                import warnings
                # the checker is a bit violent about errors, so simply print warnings here
                warnings.warn(str(e))
    except ImportError:
        pass
1579 1580
    g = GraphProto(shape, dtype)
    graph = model.graph
1581 1582 1583 1584 1585
    if opset is None:
        try:
            opset = model.opset_import[0].version if model.opset_import else 1
        except AttributeError:
            opset = 1
1586 1587
    mod, params = g.from_onnx(graph, opset)
    return mod, params