pytorch.py 51.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks
# pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
"""PT: PyTorch frontend."""
21
import itertools
22
import logging
23
import sys
24

25 26 27 28 29 30 31 32
import numpy as np

import tvm
from tvm.ir import module as _module

from .. import analysis as _analysis
from .. import expr as _expr
from .. import op as _op
33
from ..loops import while_loop
34 35
from .common import get_relay_op
from .common import infer_shape as _infer_shape
36
from .common import infer_value as _infer_value
37

38 39
from . import qnn_torch

40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
__all__ = ["from_pytorch"]

# operator implementation
def _elemwise(name):
    def _impl(inputs, input_types):
        # TODO: Figure out a better way to get typing to work for tensor + scalar
        type0 = input_types[0]
        if isinstance(inputs[1], _expr.Expr):
            type0 = input_types[1]

        type1 = input_types[1]
        if isinstance(inputs[0], _expr.Expr):
            type1 = input_types[0]

        data0 = _convert_elemwise_input(inputs[0], type0)
        data1 = _convert_elemwise_input(inputs[1], type1)

        return get_relay_op(name)(data0, data1)
    return _impl

60 61 62 63 64 65 66 67 68 69 70
def _squeeze():
    def _impl(inputs, input_types):
        data = inputs[0]
        if len(inputs) == 1:
            axis = None
        else:
            axis = [int(inputs[1])]

        return _op.transform.squeeze(data, axis)
    return _impl

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
def _unsqueeze():
    def _impl(inputs, input_types):
        data = inputs[0]
        axis = inputs[1]

        return _op.transform.expand_dims(data, int(axis), 1)
    return _impl

def _concatenate():
    def _impl(inputs, input_types):
        data = inputs[0]
        axis = inputs[1]

        if isinstance(data, _expr.Expr):
            data = [data]

        return _op.tensor.concatenate(data, int(axis))
    return _impl

def _slice():
    def _impl(inputs, input_types):
        data = inputs[0]
        strides = []

        if isinstance(data, _expr.Expr):
            inferred_shape = _infer_shape(data)
            end = []
            for infer in inferred_shape:
                end.append(int(infer))
            if isinstance(data, _expr.Var):
                end = inferred_shape
                end = list(end)
        else:
            end = data.shape

        begin = [0]*len(end)
        dim = int(inputs[1])
        begin[dim] = int(inputs[2])

        if isinstance(inputs[3], str) and inputs[3].isdigit():
            end[dim] = min(end[dim], int(inputs[3]))
        else:
            end[dim] = inputs[3]

        strides.append(int(inputs[4]))
        return _op.transform.strided_slice(data, begin, end, strides)
    return _impl

119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
def _split():
    def _impl(inputs, input_types):
        data = inputs[0]
        split_size = int(inputs[1])
        dim = int(inputs[2])

        split_index = split_size
        indices = []
        while split_index < _infer_shape(data)[dim]:
            indices.append(split_index)
            split_index += split_size

        return _op.split(data, indices, dim)
    return _impl

def _split_with_sizes():
    def _impl(inputs, inputs_types):
        data = inputs[0]
        dim = int(inputs[2])

        split_index = 0
        indices = []
        sections = _infer_shape(inputs[1])
        for i in range(len(sections) - 1):
            split_index += sections[i]
            indices.append(split_index)

        return _op.split(data, indices, dim)
    return _impl

149 150 151 152
def _select():
    def _impl(inputs, input_types):
        data = inputs[0]
        dim = int(inputs[1])
153 154
        index = _wrap_const(inputs[2])
        return _op.transform.take(data, index, axis=dim)
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
    return _impl

def _ones():
    def _impl(inputs, input_types):
        data = inputs[0]

        import torch
        if isinstance(data, _expr.Expr):
            shape = _infer_shape(data)
        elif isinstance(data, list):
            shape = data
        elif isinstance(data, (torch.Tensor, np.ndarray)):
            shape = data.shape
        else:
            assert "data type {} could not be parsed in ones op" % (type(data))

171 172 173 174
        dtype_map = {6: "float32", 3: "int32"}
        dtype_id = inputs[1]
        assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id
        return _op.full(_expr.const(1), shape, dtype=dtype_map[dtype_id])
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
    return _impl

def _zeros():
    def _impl(inputs, input_types):
        data = inputs[0]

        import torch
        if isinstance(data, _expr.Expr):
            shape = _infer_shape(data)
        elif isinstance(data, list):
            shape = data
        elif isinstance(data, (torch.Tensor, np.ndarray)):
            shape = data.shape
        else:
            assert "data type {} could not be parsed in zeros op" % (type(data))

191 192 193 194
        dtype_map = {6: "float32", 3: "int32"}
        dtype_id = inputs[1]
        assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id
        return _op.full(_expr.const(0), shape, dtype=dtype_map[dtype_id])
195 196 197 198 199
    return _impl

def _relu():
    def _impl(inputs, input_types):
        data = inputs[0]
200 201 202 203
        if input_types[0] == "quint8":
            assert len(inputs) == 3, "Input quant param not found in op inputs"
            input_zero_point = _expr.const(inputs[2], dtype="int32")
            return qnn_torch.quantized_relu(data, input_zero_point)
204 205 206
        return _op.nn.relu(data)
    return _impl

207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
def _prelu():
    def _impl(inputs, input_types):
        data = inputs[0]
        alpha = inputs[1]
        return _op.nn.prelu(data, alpha)
    return _impl

def _leaky_relu():
    def _impl(inputs, input_types):
        data = inputs[0]
        alpha = int(inputs[1])
        return _op.nn.leaky_relu(data, alpha)
    return _impl

def _elu():
    def _impl(inputs, input_types):
        data = inputs[0]
        alpha = _expr.const(int(inputs[1]), dtype='float32')
        return alpha * _op.nn.relu(alpha - _op.exp(data)) + _op.nn.relu(data)
    return _impl

def _log_sigmoid():
    def _impl(inputs, input_types):
        data = inputs[0]
        return _op.log(_op.tensor.sigmoid(data))
    return _impl

234
def _adaptive_avg_pool_2d():
235 236 237 238
    def _impl(inputs, input_types):
        data = inputs[0]
        output_size = _infer_shape(inputs[1])

239 240 241 242
        def func(x):
            return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)

        if input_types[0] == "quint8":
243
            return qnn_torch.apply_with_upcast(data, func)
244 245 246

        return func(data)

247 248
    return _impl

249
def _adaptive_max_pool_2d():
250 251 252 253
    def _impl(inputs, input_types):
        data = inputs[0]
        output_size = _infer_shape(inputs[1])

254
        # returns dummy indices too
255
        return _op.nn.adaptive_max_pool2d(
256
            data,
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
            output_size=output_size), None
    return _impl

def _adaptive_max_pool_3d():
    def _impl(inputs, input_types):
        data = inputs[0]
        output_size = _infer_shape(inputs[1])
        # returns dummy indices too
        return _op.nn.adaptive_max_pool3d(data, output_size=output_size), None

    return _impl

def _adaptive_avg_pool_3d():
    def _impl(inputs, input_types):
        data = inputs[0]
        output_size = _infer_shape(inputs[1])
        return _op.nn.adaptive_avg_pool3d(data, output_size=output_size)

275 276 277 278 279 280 281 282 283
    return _impl

def _maxpool_2d():
    def _impl(inputs, input_types):
        data = inputs[0]

        pool_size = _infer_shape(inputs[1])
        strides = _infer_shape(inputs[2])
        padding = _infer_shape(inputs[3])
284
        dilation = _infer_shape(inputs[4])
285 286
        ceil_mode = int(inputs[5])

287 288 289 290
        if dilation != (1, 1):
            msg = "MaxPool2d with dilation %s is not implemented" % (str(dilation), )
            raise NotImplementedError(msg)

291 292 293
        return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode)
    return _impl

294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
def _maxpool_1d():
    def _impl(inputs, input_types):
        data = inputs[0]

        pool_size = _infer_shape(inputs[1])
        strides = _infer_shape(inputs[2])
        padding = _infer_shape(inputs[3])
        dilation = _infer_shape(inputs[4])
        ceil_mode = int(inputs[5])

        if dilation != (1,):
            msg = "MaxPool1d with dilation %s is not implemented" % (str(dilation), )
            raise NotImplementedError(msg)

        return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode)
    return _impl

311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
def _maxpool_3d():
    def _impl(inputs, input_types):
        data = inputs[0]

        pool_size = _infer_shape(inputs[1])
        strides = _infer_shape(inputs[2])
        padding = _infer_shape(inputs[3])
        dilation = _infer_shape(inputs[4])
        ceil_mode = int(inputs[5])
        if dilation != (1, 1, 1):
            msg = "MaxPool3d with dilation %s is not implemented" % (str(dilation), )
            raise NotImplementedError(msg)

        return _op.nn.max_pool3d(data,
                                 pool_size=pool_size,
                                 strides=strides,
                                 padding=padding,
                                 ceil_mode=ceil_mode)
    return _impl

331 332 333 334 335 336 337 338 339 340 341
def _hardtanh():
    def _impl(inputs, input_types):
        a = inputs[0]
        tanh_min = float(inputs[1])
        tanh_max = float(inputs[2])
        return _op.tensor.clip(a, tanh_min, tanh_max)
    return _impl

def _convolution():
    def _impl(inputs, input_types):
        # Use transpose or normal
342
        use_transpose = True if inputs[6] == 1 else False
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358

        data = inputs[0]
        weight = inputs[1]
        bias = inputs[2]
        strides = inputs[3]
        padding = inputs[4]
        dilation = inputs[5]

        if isinstance(weight, _expr.Expr):
            inferred_shape = _infer_shape(weight)
            weight_shape = []
            for infer in inferred_shape:
                weight_shape.append(infer)
        else:
            assert "data type {} could not be parsed in conv op" % (type(weight))

359 360 361 362
        # Transposed convolutions have IOHW layout.
        if use_transpose:
            weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0]

363 364 365
        channels = weight_shape[0]
        groups = int(inputs[8])

366 367 368 369 370 371
        # Check if this is depth wise convolution
        # We need to reshape weight so that Relay could recognize this is depth wise
        # weight_shape[1] is always in_channels // groups
        # For depthwise, in_channels == groups, so weight_shape[1] == 1
        # If groups > 1 but weight_shape[1] != 1, this is group convolution
        if groups > 1 and weight_shape[1] == 1:
372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
            channel_multiplier = channels // groups
            new_weight_shape = (groups, channel_multiplier, weight_shape[2], weight_shape[3])
            weight = _op.transform.reshape(weight, new_weight_shape)

        kernel_size = weight_shape[2:]
        use_bias = isinstance(bias, _expr.Expr)

        if isinstance(strides, _expr.Expr):
            strides = _infer_shape(strides)

        if isinstance(padding, _expr.Expr):
            padding = _infer_shape(padding)

        if isinstance(dilation, _expr.Expr):
            dilation = _infer_shape(dilation)

388 389 390
        data_layout = "NCHW"
        kernel_layout = "OIHW"
        conv_op = _op.nn.conv2d
391

392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
        if use_transpose:
            assert len(kernel_size) == 2, "ConvTranspose 3D not supported"
            conv_op = _op.nn.conv2d_transpose
        if len(kernel_size) == 3:
            conv_op = _op.nn.conv3d
            data_layout = "NCDHW"
            kernel_layout = "OIDHW"

        conv_out = conv_op(data,
                           weight,
                           strides=strides,
                           padding=padding,
                           dilation=dilation,
                           groups=groups,
                           channels=channels,
                           kernel_size=kernel_size,
                           data_layout=data_layout,
                           kernel_layout=kernel_layout,
                           out_layout="",
                           out_dtype="")
412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475
        if use_bias:
            return _op.nn.bias_add(conv_out, bias)
        else:
            return conv_out
    return _impl

def _softmax():
    def _impl(inputs, input_types):
        data = inputs[0]
        axis = inputs[1]
        if isinstance(axis, str):
            axis = int(axis)

        return _op.nn.softmax(data, axis=axis)
    return _impl

def _threshold():
    def _impl(inputs, input_types):
        data = inputs[0]
        return _op.nn.relu(data)
    return _impl

def _contiguous():
    def _impl(inputs, input_types):
        data = inputs[0]
        return _op.tensor.copy(data)
    return _impl

def _batch_norm():
    def _impl(inputs, input_types):
        data = inputs[0]
        data_type = input_types[0]

        channels = _infer_shape(data)

        if isinstance(inputs[1], _expr.Expr) and isinstance(inputs[2], _expr.Expr):
            scale = center = True
            weight = inputs[1]
            beta = inputs[2]
            gamma = weight
        else:
            scale = center = False

        if not scale:
            gamma = _create_typed_const(np.ones([int(channels[1])]), data_type)

        if not center:
            beta = _create_typed_const(np.zeros([int(channels[1])]), data_type)

        moving_mean = inputs[3]
        moving_var = inputs[4]
        epsilon = float(inputs[7])

        return _op.nn.batch_norm(data,
                                 gamma,
                                 beta,
                                 moving_mean,
                                 moving_var,
                                 axis=1,
                                 epsilon=epsilon,
                                 center=center,
                                 scale=scale)[0]
    return _impl

476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
def _instance_norm():
    def _impl(inputs, input_types):
        data = inputs[0]
        data_type = input_types[0]
        channels = _infer_shape(data)

        if isinstance(inputs[1], _expr.Expr) and isinstance(inputs[2], _expr.Expr):
            scale = center = True
            weight = inputs[1]
            beta = inputs[2]
            gamma = weight
        else:
            scale = center = False

        if not scale:
            gamma = _create_typed_const(np.ones([int(channels[1])]), data_type)

        if not center:
            beta = _create_typed_const(np.zeros([int(channels[1])]), data_type)

        epsilon = float(inputs[7])
        return _op.nn.instance_norm(data,
                                    gamma,
                                    beta,
                                    axis=1,
                                    epsilon=epsilon,
                                    center=center,
                                    scale=scale)
    return _impl

506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582
def _transpose():
    def _impl(inputs, input_types):
        data = inputs[0]

        import torch
        if isinstance(data, _expr.Expr):
            ndims = len(_infer_shape(data))
        elif isinstance(data, list):
            ndims = data
        elif isinstance(data, (torch.Tensor, np.ndarray)):
            ndims = data.shape
        else:
            assert "data type {} could not be parsed in transpose op" % (type(data))

        if isinstance(data, tvm.runtime.NDArray):
            ndims = len(data.shape)
        axes = list(range(ndims))

        num_inputs = len(inputs)

        if num_inputs == 1:
            if ndims >= 2:
                axes[-1] = ndims - 2
                axes[-2] = ndims - 1
            if not isinstance(data, _expr.Expr):
                data = _expr.const(data)

        elif num_inputs == 3:
            parse = lambda i: ndims * (i < 0) + i
            src, dst = [parse(int(inputs[i])) for i in [1, 2]]
            axes[src] = dst
            axes[dst] = src
        else:
            axes = inputs[1]
        return _op.transform.transpose(data, axes)
    return _impl

def _flatten():
    def _impl(inputs, input_types):
        data = inputs[0]
        return _op.nn.batch_flatten(data)
    return _impl

def _dense():
    def _impl(inputs, input_types):
        use_bias = isinstance(inputs[0], _expr.Expr)

        data = inputs[1]
        data_type = input_types[1]
        weight = inputs[2]

        beta = inputs[3]
        alpha = inputs[4]

        if not isinstance(alpha, _expr.Expr):
            alpha = _create_typed_const(alpha, data_type)
            data *= alpha

        if not isinstance(beta, _expr.Expr):
            beta = _create_typed_const(beta, data_type)
            weight *= beta

        weight_out = _op.transform.transpose(weight, axes=[1, 0])

        units = _infer_shape(weight_out)[0]
        dense_out = _op.nn.dense(data, weight_out, units=units)

        if use_bias:
            bias = inputs[0]
            return _op.nn.bias_add(dense_out, bias)
        else:
            return dense_out
    return _impl

def _size():
    def _impl(inputs, input_types):
        shape = _infer_shape(inputs[0])
583 584 585 586
        if len(inputs) > 1:
            axis = int(inputs[1])
            return shape[axis]
        return shape
587 588 589 590 591 592 593
    return _impl

def _numtotensor():
    def _impl(inputs, input_types):
        val = inputs[0]
        dtype = type(val)

594
        if isinstance(val, tvm.tir.IntImm):
595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640
            val = val.__int__()
            dtype = int

        arr = val * np.ones([]).astype(dtype)
        return arr
    return _impl

def _view():
    def _impl(inputs, input_types):
        data = inputs[0]

        if len(inputs) == 3:
            new_shape = [inputs[1], _infer_shape(inputs[2])[0]]
        else:
            if isinstance(inputs[1], list):
                new_shape = inputs[1]
            else:
                new_shape = _infer_shape(inputs[1])

        return _op.transform.reshape(data, new_shape)
    return _impl

def _clone():
    def _impl(inputs, input_types):
        data = inputs[0]
        return _op.tensor.copy(data)
    return _impl

def _log_softmax():
    def _impl(inputs, input_types):
        data = inputs[0]
        axis = int(inputs[1])
        return _op.nn.log_softmax(data, axis)
    return _impl

def _sigmoid():
    def _impl(inputs, input_types):
        data = inputs[0]
        return _op.tensor.sigmoid(data)
    return _impl

def _avg_pool2d():
    def _impl(inputs, input_types):
        data = inputs[0]

        pool_size = _infer_shape(inputs[1])
641 642 643 644
        if inputs[2]:
            strides = _infer_shape(inputs[2])
        else:
            strides = pool_size
645 646 647 648 649
        padding = _infer_shape(inputs[3])

        ceil_mode = int(inputs[4])
        count_include_pad = int(inputs[5])

650 651 652 653 654 655 656 657 658 659 660 661 662
        def func(x):
            return _op.nn.avg_pool2d(x,
                                     pool_size=pool_size,
                                     strides=strides,
                                     padding=padding,
                                     ceil_mode=ceil_mode,
                                     count_include_pad=count_include_pad)

        if input_types[0] == "quint8":
            return qnn_torch.apply_with_upcast(data, func)

        return func(data)

663 664
    return _impl

665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
def _avg_pool3d():
    def _impl(inputs, input_types):
        data = inputs[0]

        pool_size = _infer_shape(inputs[1])
        if inputs[2]:
            strides = _infer_shape(inputs[2])
        else:
            strides = pool_size
        padding = _infer_shape(inputs[3])

        ceil_mode = int(inputs[4])
        count_include_pad = int(inputs[5])

        return _op.nn.avg_pool3d(data,
                                 pool_size=pool_size,
                                 strides=strides,
                                 padding=padding,
                                 ceil_mode=ceil_mode,
                                 count_include_pad=count_include_pad)
    return _impl
686

687 688 689 690 691 692 693 694 695
def _dropout():
    def _impl(inputs, input_types):
        data = inputs[0]
        rate = float(inputs[1])

        return _op.nn.dropout(data, rate)
    return _impl

def _reduce(name):
696
    def _impl(inputs, input_types):
697 698 699 700 701 702 703 704
        data = inputs[0]
        return get_relay_op(name)(data)
    return _impl

def _mean():
    def _impl(inputs, input_types):
        data = inputs[0]

705 706 707 708 709 710 711 712 713 714 715 716
        if inputs[1]:
            axis = _infer_shape(inputs[1])
        else:
            axis = None
        if len(inputs) > 2 and inputs[2]:
            keepdims = int(inputs[2])
        else:
            keepdims = False
        if len(inputs) > 3 and inputs[3]:
            exclude = int(inputs[3])
        else:
            exclude = False
717

718 719 720 721 722 723 724 725 726 727 728 729
        def func(x):
            return _op.mean(x, axis, keepdims, exclude)

        if input_types[0] == "quint8":
            assert len(inputs) == 6, "Input quant param not found in op inputs"
            input_scale = _expr.const(inputs[4])
            input_zero_point = _expr.const(inputs[5])
            return qnn_torch.quantized_mean(data, input_scale,
                                            input_zero_point, func)

        return func(data)

730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839
    return _impl

def _chunk():
    def _impl(inputs, input_types):
        data = inputs[0]

        num_chunks = int(inputs[1])
        axis = int(inputs[2])

        if isinstance(data, _expr.Expr):
            inferred_shape = _infer_shape(data)

        shape = []
        for infer in inferred_shape:
            shape.append(infer)

        dim = int(shape[axis])

        if dim % num_chunks:
            unif_size = int(dim / (num_chunks - 1))
        else:
            unif_size = int(dim / num_chunks)

        chunks = []
        for i in range(0, dim, unif_size):
            begin = [0] * len(shape)
            end = shape[:]
            begin[axis] = i
            end[axis] = i + unif_size
            stride = [1] * len(shape)

            chunk_out = _op.transform.strided_slice(data, begin, end, stride)
            chunks.append(chunk_out)


        if dim % num_chunks:
            begin = [0] * len(shape)
            end = shape[:]
            begin[axis] = unif_size * (num_chunks - 1)
            end[axis] = dim
            stride = [1] * len(shape)

            chunk_out = _op.transform.strided_slice(data, begin, end, stride)
            chunks.append(chunk_out)

        return chunks
    return _impl

def _matmul():
    def _impl(inputs, input_types):
        data0 = inputs[0]
        data1 = inputs[1]
        data1_t = _op.transpose(data1, axes=(1, 0))

        return _op.nn.dense(data0, data1_t)
    return _impl

def _expand():
    def _impl(inputs, input_types):
        data_in = inputs[0]
        if isinstance(data_in, _expr.Expr):
            shape = _infer_shape(data_in)

        ndims = len(shape)
        sizes = _infer_shape(inputs[1])
        out = inputs[0]

        for i in range(ndims):
            if sizes[i] in {-1, shape[i]}:
                continue
            data = list()
            for temp in range(sizes[i]):
                data.append(out)
            call = _op.tensor.concatenate(data, i)

        return call
    return _impl

def _int():
    def _impl(inputs, input_types):
        if isinstance(inputs[0], _expr.Expr):
            return inputs[0]
        return int(inputs[0])
    return _impl

def _identity():
    def _impl(inputs, input_types):
        return inputs[0]
    return _impl

def _none():
    def _impl(inputs, input_types):
        return None
    return _impl

def _pad():
    def _impl(inputs, input_types):
        data = inputs[0]
        padding = inputs[1]
        pad_width = list(zip(padding, padding))
        pad_value = inputs[2]
        return _op.nn.pad(data, pad_width, pad_value)
    return _impl

def _sqrt():
    def _impl(inputs, input_types):
        data = inputs[0]
        return _op.tensor.sqrt(data)
    return _impl

840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890
def _floor():
    def _impl(inputs, input_types):
        data = inputs[0]
        return _op.floor(data)
    return _impl

def _to():
    def _impl(inputs, input_types):
        data = inputs[0]
        if inputs[3] in ["cpu", "cuda"]:
            return data
        # special handling for aten::to(data, 6, _, _, _) case
        # 6 means dtype = float
        # this happens when converting upsampling with scale factor
        cast_func = {
            6: float,
            3: int,
        }
        cast_func_expr = {
            6: lambda x: _op.cast(x, "float32"),
            3: lambda x: _op.cast(x, "int32"),
        }
        if inputs[1] in cast_func and not isinstance(data, _expr.Expr):
            return cast_func[inputs[1]](data)
        elif inputs[1] in cast_func and isinstance(data, _expr.Expr):
            return cast_func_expr[inputs[1]](data)
        return data

    return _impl

def _upsample(method):
    def _impl(inputs, input_types):
        if isinstance(inputs[1], _expr.Var):
            out_size = _infer_shape(inputs[1])
        elif isinstance(inputs[1], list):
            infer_res = [_infer_value(size, {}) for size in inputs[1]]
            out_size = [np.asscalar(res.asnumpy().astype(np.int))
                        for res in infer_res]

        data = inputs[0]

        if len(inputs) > 2:
            align_corners = inputs[2]
        else:
            align_corners = False

        if align_corners:
            coord_trans = "align_corners"
        else:
            coord_trans = "half_pixel"

891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910
        def func(x):
            return _op.image.resize(x, out_size, "NCHW", method, coord_trans)

        if input_types[0] == "quint8":
            import torch
            from packaging import version

            # Torch version > 1.4 changed upsampling API
            if version.parse(torch.__version__) > version.parse("1.4.0"):
                num_inputs = 7
            else:
                num_inputs = 5

            assert len(inputs) == num_inputs, "Input quant param not found in op inputs"

            input_scale = _expr.const(inputs[-2])
            input_zero_point = _expr.const(inputs[-1])
            return qnn_torch.quantized_upsample(data, input_scale,
                                                input_zero_point, func)
        return func(data)
911 912 913

    return _impl

914 915 916 917 918 919 920 921 922
def _expand_as():
    def _impl(inputs, input_types):
        # TODO: maybe fix this
        # This assumes expand_as can be removed because TVM has broadcast op
        msg = "aten::expand_as(...) found, assume it is part of broadcast op"
        logging.warning(msg)
        return inputs[0]
    return _impl

923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945
def _neg():
    def _impl(inputs, input_types):
        data = inputs[0]
        return _op.tensor.negative(data)
    return _impl

def _tanh():
    def _impl(inputs, input_types):
        data = inputs[0]
        return _op.tensor.tanh(data)
    return _impl

def _Bool():
    def _impl(inputs, input_types):
        assert len(inputs) == 1
        return inputs[0]
    return _impl

def _Float():
    def _impl(inputs, input_types):
        assert len(inputs) == 1
        return _op.cast(inputs[0], "float32")
    return _impl
946

947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997
# Helper functions for operator implementation

def _convert_data_type(input_type):
    if input_type in ["double", "torch.float64"]:
        return "float64"
    elif input_type in ["float", "torch.float32"]:
        return "float32"
    elif input_type in ["half", "torch.float16"]:
        return "float16"
    elif input_type in ["long", "torch.int64"]:
        return "int64"
    elif input_type in ["int", "torch.int32"]:
        return "int32"
    elif input_type in ["short", "torch.int16"]:
        return "int16"
    elif input_type in ["char", "torch.int8"]:
        return "int8"
    elif input_type in ["byte", "torch.uint8"]:
        return "uint8"
    else:
        raise NotImplementedError("input_type {} is not handled yet" % (input_type))
    return "float32"

def _create_typed_const(data, data_type):
    dtype = _convert_data_type(data_type)

    if dtype == "float64":
        typed_data = _expr.const(np.float64(data), dtype=dtype)
    elif dtype == "float32":
        typed_data = _expr.const(np.float32(data), dtype=dtype)
    elif dtype == "float16":
        typed_data = _expr.const(np.float16(data), dtype=dtype)
    elif dtype == "int64":
        typed_data = _expr.const(np.int64(data), dtype=dtype)
    elif dtype == "int32":
        typed_data = _expr.const(np.int32(data), dtype=dtype)
    elif dtype == "int16":
        typed_data = _expr.const(np.int16(data), dtype=dtype)
    elif dtype == "int8":
        typed_data = _expr.const(np.int8(data), dtype=dtype)
    elif dtype == "uint8":
        typed_data = _expr.const(np.uint8(data), dtype=dtype)
    else:
        raise NotImplementedError("input_type {} is not handled yet" % (data_type))
    return typed_data

def _convert_elemwise_input(data, input_type):
    import torch
    if isinstance(data, torch.Tensor):
        return _expr.const(data.item(), dtype=_convert_data_type(input_type))
    elif not isinstance(data, _expr.Expr):
998
        return _expr.const(data, dtype=_convert_data_type(input_type))
999 1000 1001
    else:
        return data

1002 1003 1004 1005 1006
def _wrap_const(c):
    if not isinstance(c, _expr.Expr) and not isinstance(c, list):
        return _expr.const(c)
    return c

1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023
# Operator mappings

_convert_map = {
    "aten::device"                          : _none(),
    "aten::add"                             : _elemwise("add"),
    "aten::add_"                            : _elemwise("add"),
    "aten::sub"                             : _elemwise("subtract"),
    "aten::sub_"                            : _elemwise("subtract"),
    "aten::max"                             : _elemwise("maximum"),
    "aten::min"                             : _elemwise("minimum"),
    "aten::mul"                             : _elemwise("multiply"),
    "aten::mul_"                            : _elemwise("multiply"),
    "aten::pow"                             : _elemwise("power"),
    "aten::div"                             : _elemwise("divide"),
    "aten::div_"                            : _elemwise("divide"),
    "aten::ones"                            : _ones(),
    "aten::zeros"                           : _zeros(),
1024
    "aten::to"                              : _to(),
1025
    "aten::squeeze"                         : _squeeze(),
1026 1027 1028
    "aten::unsqueeze"                       : _unsqueeze(),
    "aten::cat"                             : _concatenate(),
    "aten::slice"                           : _slice(),
1029 1030
    "aten::split"                           : _split(),
    "aten::split_with_sizes"                : _split_with_sizes(),
1031 1032 1033
    "aten::select"                          : _select(),
    "aten::relu"                            : _relu(),
    "aten::relu_"                           : _relu(),
1034 1035 1036 1037
    "aten::prelu"                           : _prelu(),
    "aten::leaky_relu"                      : _leaky_relu(),
    "aten::elu"                             : _elu(),
    "aten::log_sigmoid"                     : _log_sigmoid(),
1038 1039
    "aten::adaptive_avg_pool2d"             : _adaptive_avg_pool_2d(),
    "aten::adaptive_max_pool2d"             : _adaptive_max_pool_2d(),
1040 1041
    "aten::max_pool2d"                      : _maxpool_2d(),
    "aten::max_pool2d_with_indices"         : _maxpool_2d(),
1042
    "aten::max_pool1d"                      : _maxpool_1d(),
1043
    "aten::max_pool3d"                      : _maxpool_3d(),
1044 1045 1046 1047 1048 1049 1050 1051
    "aten::hardtanh"                        : _hardtanh(),
    "aten::hardtanh_"                       : _hardtanh(),
    "aten::_convolution"                    : _convolution(),
    "aten::softmax"                         : _softmax(),
    "aten::threshold"                       : _threshold(),
    "aten::threshold_"                      : _threshold(),
    "aten::contiguous"                      : _contiguous(),
    "aten::batch_norm"                      : _batch_norm(),
1052
    "aten::instance_norm"                   : _instance_norm(),
1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063
    "aten::transpose"                       : _transpose(),
    "aten::transpose_"                      : _transpose(),
    "aten::t"                               : _transpose(),
    "aten::flatten"                         : _flatten(),
    "aten::addmm"                           : _dense(),
    "aten::size"                            : _size(),
    "aten::view"                            : _view(),
    "aten::clone"                           : _clone(),
    "aten::log_softmax"                     : _log_softmax(),
    "aten::sigmoid"                         : _sigmoid(),
    "aten::avg_pool2d"                      : _avg_pool2d(),
1064
    "aten::avg_pool3d"                      : _avg_pool3d(),
1065 1066
    "aten::dropout"                         : _dropout(),
    "aten::dropout_"                        : _dropout(),
1067 1068
    "aten::feature_dropout"                 : _dropout(),
    "aten::alpha_dropout"                   : _dropout(),
1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079
    "aten::mean"                            : _mean(),
    "aten::chunk"                           : _chunk(),
    "aten::matmul"                          : _matmul(),
    "aten::expand"                          : _expand(),
    "aten::Int"                             : _int(),
    "prim::NumToTensor"                     : _numtotensor(),
    "prim::ListUnpack"                      : _identity(),
    "aten::constant_pad_nd"                 : _pad(),
    "aten::permute"                         : _transpose(),
    "aten::sum"                             : _reduce("sum"),
    "aten::prod"                            : _reduce("prod"),
1080 1081 1082 1083 1084
    "aten::sqrt"                            : _sqrt(),
    'aten::floor'                           : _floor(),
    "aten::detach"                          : _identity(),
    "aten::upsample_bilinear2d"             : _upsample("bilinear"),
    "aten::upsample_nearest2d"              : _upsample("nearest_neighbor"),
1085 1086 1087 1088 1089 1090 1091 1092 1093 1094
    "aten::expand_as"                       : _expand_as(),
    "aten::lt"                              : _elemwise("less"),
    "aten::gt"                              : _elemwise("greater"),
    "aten::le"                              : _elemwise("less_equal"),
    "aten::ge"                              : _elemwise("greater_equal"),
    "aten::ne"                              : _elemwise("not_equal"),
    "aten::Bool"                            : _Bool(),
    "aten::Float"                           : _Float(),
    "aten::neg"                             : _neg(),
    "aten::tanh"                            : _tanh(),
1095 1096
    "aten::adaptive_avg_pool3d"             : _adaptive_avg_pool_3d(),
    "aten::adaptive_max_pool3d"             : _adaptive_max_pool_3d()
1097 1098 1099
}


1100 1101 1102
def _run_jit_passes(graph):
    """ The inline pass is necessary to unwrap prim::CallMethod """
    import torch
1103
    torch._C._jit_pass_inline(graph)
1104 1105


1106 1107
def _is_int_seq(seq):
    return len(seq) > 0 and all([isinstance(i, int) for i in seq])
1108

1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128

def _get_tensor_and_var(torch_tensor, name):
    tensor = tvm.nd.array(torch_tensor.cpu().numpy())
    var = _expr.var(name, shape=tensor.shape)
    return tensor, var


def _get_output_name(node):
    assert node.outputsSize() == 1
    return node.output().debugName()


def _get_output_names(node):
    return [output.debugName() for output in node.outputs()]


def _get_input_names(node_or_graph):
    return [inp.debugName() for inp in node_or_graph.inputs()]


1129 1130
def _get_op_inputs(op_node, outputs):
    return [outputs[name] for name in _get_input_names(op_node)]
1131 1132 1133 1134 1135 1136


def _report_missing_conversion(op_names):
    """ Check if all ops in an input graph are supported by TVM """
    known_ops = ["prim::Constant", "prim::GetAttr",
                 "prim::ListConstruct", "prim::ListUnpack",
1137 1138
                 "prim::TupleConstruct", "prim::TupleUnpack",
                 "prim::If", "prim::Loop"]
1139
    known_ops += list(_convert_map.keys())
1140
    known_ops += list(qnn_torch.convert_map.keys())
1141 1142 1143 1144 1145 1146 1147 1148

    missing = [op_name for op_name in op_names
               if op_name not in known_ops]

    if missing:
        msg = "The following operators are not implemented: {}".format(missing)
        raise NotImplementedError(msg)

1149

1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174
def _check_inputs(graph, input_shapes):
    """
    Check the graph inputs match the expected number of inputs
    and are in the correct format
    """
    ir_inputs = _get_graph_input_names(graph)

    if not isinstance(input_shapes, list):
        msg = "Graph inputs input_shapes should be list"
        raise RuntimeError(msg)
    missing_inputs = len(ir_inputs) - len(input_shapes)
    if missing_inputs > 0:
        msg = "Missing {} graph input(s) in input_shapes".format(missing_inputs)
        raise RuntimeError(msg)

    for num, inp in enumerate(input_shapes):
        if num < len(ir_inputs):
            if not isinstance(inp, tuple):
                msg = "Graph input {} is not a tuple".format(num)
                raise RuntimeError(msg)
            if (len(inp) != 2 or not isinstance(inp[0], str)):
                msg = "Graph input {} is not valid, expected ('name', shape)".format(inp)
                raise RuntimeError(msg)
        else:
            msg = "Unused graph input {} in input_shapes".format(inp)
1175 1176
            logging.warning(msg)

1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196

def _getattr_attr_name(node):
    attribute_names = node.attributeNames()
    assert len(attribute_names) == 1
    attr_name = node.s(attribute_names[0])
    return attr_name


def _getattr_full_name(getattrs):
    return ".".join([_getattr_attr_name(node) for node in getattrs])


def _get_input_types(op_node):
    """ Returns a torch type for each input nodes """
    input_list_types = []
    for input_node in op_node.inputs():
        in_ty = input_node.type()
        input_node_kind = in_ty.kind()
        if input_node_kind == 'TensorType':
            if in_ty.scalarType() is None:
1197 1198 1199 1200
                # Tensor's type can be unknown if we use torch.jit.script(...)
                # Defaults to float for now
                logging.warning("Untyped Tensor found, assume it is float")
                input_list_types.append("float")
1201
            else:
1202
                input_list_types.append(in_ty.scalarType().lower())
1203

1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251
        elif input_node_kind == 'ListType':
            input_list_types.append(str(in_ty.getElementType()).lower())
        elif input_node_kind in ['IntType', 'FloatType', 'BoolType',
                                 'StringType', 'OptionalType']:
            input_list_types.append(str(in_ty).lower())
        else:
            input_list_types.append('UnsupportedType')

    if op_node.kind() in ['aten::ones', 'aten::zeros']:
        node_type = op_node.output().type()
        scalar_type = node_type.scalarType()
        if scalar_type:
            input_list_types[0] = scalar_type.lower()

    return input_list_types


def _get_constant(node):
    """ Retrieve a constant associated with this prim::Constant node """
    attribute_names = node.attributeNames()
    num_attributes = len(attribute_names)

    if num_attributes == 1:
        attr_name = attribute_names[0]
        ty = node.output().type().kind()

        if ty in ["IntType", "BoolType"]:
            return node.i(attr_name)
        elif ty in ["FloatType", "LongType"]:
            return node.f(attr_name)
        elif ty in ["TensorType", "CompleteTensorType"]:
            tensor = node.t(attr_name)
            if len(tensor.shape) == 0:  # tensor(0.1)
                return float(tensor)
            return tensor
        elif ty == "DeviceObjType":
            return node.s(attr_name)
        elif ty == "FunctionType":
            return None
        else:
            raise NotImplementedError("Unsupported type: %s" % ty)
    else:
        assert num_attributes == 0
        return None


def _get_operator_nodes(nodes):
    """ Returns torch IR nodes that need conversion to Relay """
1252
    ops = []
1253 1254 1255 1256 1257 1258 1259 1260
    # Traverse nodes and add to graph
    for node in nodes:
        if node.outputsSize() > 1:
            node_name = "_".join(_get_output_names(node))
        else:
            node_name = _get_output_name(node)

        if node.kind() != "prim::GetAttr":
1261
            ops.append((node_name, node))
1262 1263 1264 1265

    return ops


1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278
def _get_relay_input_vars(graph, input_shapes):
    """
    Return Relay vars from input shapes and create entries based on
    expected graph inputs - to allow translation
    """
    input_vars = {}
    ir_inputs = _get_graph_input_names(graph)
    for ir_input, (name, shape) in zip(ir_inputs, input_shapes):
        inp = _expr.var(name, shape=shape)
        # Translate from graph input to user input name
        input_vars[ir_input] = inp

    return input_vars
1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320


def get_use_chains(root_node, terminate=lambda _: False):
    """
    Track a chain of users of this node forward, returning a list of chains
    See get_attr_chains below for its usage
    """
    def concat_lists(lists):
        return itertools.chain.from_iterable(lists)

    def inner(current, accum):
        users = []
        for output in current.outputs():
            users += [use.user for use in output.uses()]

        if not users or terminate(users):
            return [accum]

        return concat_lists([inner(nxt, accum + [nxt]) for nxt in users])

    return inner(root_node, [root_node])


def get_attr_chains(root_getattr_node):
    """ Returns chains of attribute access starting from root_getattr_node

    For example, given attribute "block", as in "self.block" when "self" points
    to the top level torch.nn.Module, it returns lists of attribute "chains",
    e.g. ['block', '2'], ['block', '1'], ['block', '0', '_packed_params']

    These sets of attributes form full attribute accessors. For example,
    "self.block.1", "self.block.2" will return the second and third submodule,
    and "self.block.0._packed_params" will return the parameters of the first
    submodule.
    """
    def terminate(users):
        next_attrs = [user for user in users if user.kind() == "prim::GetAttr"]
        return len(next_attrs) == 0

    return get_use_chains(root_getattr_node, terminate)


1321
def convert_params(graph, state_dict):
1322 1323 1324 1325 1326 1327 1328
    """
    Return Relay vars and TVM NDArrays for input parameters
    A chain of prim::GetAttr nodes is processed one at a time
    """
    getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True)
    params = {}
    param_tensors = {}
1329
    packed_param_map = {}
1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341
    seen = set()

    for node in getattr_nodes:
        if _get_output_name(node) in seen:
            continue

        for getattrs in get_attr_chains(node):
            seen.update(map(_get_output_name, getattrs))

            full_attr = _getattr_full_name(getattrs)
            full_attr_node_name = _get_output_name(getattrs[-1])

1342 1343 1344 1345 1346
            if full_attr.endswith("_packed_params"):  # for quantized models
                err_msg = "parameter %s not found in state dict" % full_attr
                assert full_attr in state_dict, err_msg
                packed_param_map[full_attr_node_name] = full_attr
            elif full_attr in state_dict:
1347 1348 1349 1350 1351 1352
                torch_tensor = state_dict[full_attr]
                tensor, var = _get_tensor_and_var(torch_tensor,
                                                  full_attr_node_name)
                param_tensors[full_attr_node_name] = tensor
                params[full_attr_node_name] = var

1353
    return params, param_tensors, packed_param_map
1354 1355


1356
def convert_block(block, outputs):
1357 1358 1359
    """ Translate Torch "Block", used for prim::If and prim::Loop """
    ops = _get_operator_nodes(block.nodes())
    ret_names = _get_input_names(block.returnNode())
1360
    return convert_operators(ops, outputs, ret_names)
1361 1362


1363
def convert_if(if_node, outputs):
1364
    """ Translate Torch prim::If to Relay If """
1365
    cond = outputs[if_node.inputsAt(0).debugName()]
1366
    blocks = list(if_node.blocks())
1367 1368
    true_branch = convert_block(blocks[0], outputs)
    false_branch = convert_block(blocks[1], outputs)
1369 1370 1371 1372
    assert len(true_branch) == 1 and len(false_branch) == 1
    return _expr.If(cond, true_branch[0], false_branch[0])


1373
def convert_loop(loop_node, outputs):
1374 1375 1376 1377 1378 1379 1380
    """ Translate Torch prim::Loop to Relay while_loop """
    def get_input(index):
        ivalue = loop_node.inputsAt(index)
        inode = ivalue.node()
        if inode.kind() == "prim::Constant":
            return _expr.const(_get_constant(inode))
        var_name = ivalue.debugName()
1381 1382
        assert var_name in outputs
        return _wrap_const(outputs[var_name])
1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413

    # Refer to the spec for prim::Loop below
    # https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
    # The first input: %max_trip_count
    # The second input: %initial_condition
    # The rest of input: loop variables
    max_loop_count = get_input(0)
    init_cond = get_input(1)
    num_loop_var = len(list(loop_node.inputs())) - 2
    init_vals = [get_input(i + 2) for i in range(num_loop_var)]

    # while loop has always max_loop_count being int64 max
    # max_loop_count.data (tvm.runtime.NDArray) is -1, so _get_constant again
    is_while_loop = (isinstance(max_loop_count, _expr.Constant) and
                     _get_constant(loop_node.inputsAt(0).node()) == sys.maxsize)

    body_block = list(loop_node.blocks())[0]
    block_input_names = _get_input_names(body_block)

    def cond(*current_vals):
        i = current_vals[0]

        if is_while_loop:
            return _op.equal(i, _expr.const(True, 'bool'))

        return _op.less(i, max_loop_count)

    def body(*current_vals):
        # Update loop variables using the prev iteration outputs
        assert len(current_vals) == len(block_input_names)
        for (i, iname) in enumerate(block_input_names):
1414
            outputs[iname] = current_vals[i]
1415

1416
        block_outputs = convert_block(body_block, outputs)
1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445

        if not is_while_loop:
            # iter var increment implicit in torch, so do it manually
            # for while loop, block_outputs[0] is already a boolean,
            # the result of termination check
            incr = _expr.const(1, dtype="int32")
            block_outputs[0] = current_vals[0] + incr

        return block_outputs

    def get_var(name, val):
        if isinstance(val, _expr.Constant):
            return _expr.var(name, shape=val.data.shape, dtype=val.data.dtype)
        return _expr.var(name)

    if is_while_loop:
        loop_iter_dtype = "bool"
        # while loop with non input dependent condition such as while i < 10:
        # init_cond is int, need to cast to bool to type check
        if isinstance(init_cond, _expr.Constant):
            init_cond = _op.cast(init_cond, "bool")
        init_loop_iter_val = init_cond
    else:
        loop_iter_dtype = "int32"
        # always count from 0
        init_loop_iter_val = _expr.const(0, dtype="int32")

    name_val_pairs = list(zip(block_input_names,
                              [init_loop_iter_val] + init_vals))
1446
    outputs.update(name_val_pairs)
1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457

    loop_iter_var = _expr.var(block_input_names[0], shape=(),
                              dtype=loop_iter_dtype)
    loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]]
    loop = while_loop(cond, [loop_iter_var] + loop_vars, body)
    loop_val = loop(init_loop_iter_val, *init_vals)

    # The first element is a loop counter or boolean condition, ignore it
    return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)]


1458
def convert_operators(operators, outputs, ret_names):
1459
    """ Convert each Torch IR operators to Relay equivalent """
1460
    for node_name, op_node in operators:
1461
        operator = op_node.kind()
1462
        inputs = _get_op_inputs(op_node, outputs)
1463 1464

        if operator == "prim::Constant":
1465
            outputs[node_name] = _get_constant(op_node)
1466
        elif operator == 'prim::ListConstruct' and _is_int_seq(inputs):
1467
            outputs[node_name] = _expr.var(node_name, shape=inputs)
1468
        elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']:
1469
            outputs[node_name] = inputs
1470 1471 1472
        elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']:
            assert len(inputs) == 1
            unpacked_names = _get_output_names(op_node)
1473
            outputs.update(zip(unpacked_names, inputs[0]))
1474
        elif operator == "prim::If":
1475 1476
            if_out = convert_if(op_node, outputs)
            outputs[node_name] = if_out
1477
        elif operator == "prim::Loop":
1478
            loop_out = convert_loop(op_node, outputs)
1479 1480
            unpacked_names = _get_output_names(op_node)
            assert len(loop_out) == len(unpacked_names)
1481
            outputs.update(zip(unpacked_names, loop_out))
1482 1483
        else:
            relay_op = _convert_map[operator]
1484 1485 1486 1487 1488 1489
            relay_out = relay_op(inputs, _get_input_types(op_node))

            if isinstance(relay_out, tuple):
                # This is for torch operators that return multiple outputs
                # See _adaptive_max_2d above for example
                out_names = _get_output_names(op_node)
1490
                outputs.update(zip(out_names, relay_out))
1491
            else:
1492
                outputs[node_name] = relay_out
1493

1494
    return [_wrap_const(outputs[ret_name])
1495
            for ret_name in ret_names]
1496 1497 1498 1499


def get_all_op_names(graph):
    """ Return all operator names in the input graph """
1500 1501 1502 1503 1504 1505 1506 1507
    nodes = list(graph.nodes())
    prim_with_blocks = ["prim::If", "prim::Loop"]
    for prim in prim_with_blocks:
        prim_nodes = graph.findAllNodes(prim, recurse=True)
        for prim_node in prim_nodes:
            for block in prim_node.blocks():
                nodes += block.nodes()
    return set(node.kind() for node in nodes)
1508 1509


1510 1511 1512 1513 1514
def _get_graph_input_names(graph):
    """ Get the graph input names (use after graph copy and run jit passes) """
    # Variable names could change the first time a copy is made and after
    # _run_jit_passes is called, expected that those functions already invoked
    ir_inputs = _get_input_names(graph)
1515 1516
    return ir_inputs[1:]  # remove self at the 0th arg

1517

1518
def from_pytorch(script_module, input_shapes, custom_convert_map=None):
1519 1520 1521 1522 1523 1524 1525 1526 1527
    """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
    The companion parameters will be handled automatically.

    Parameters
    ----------
    script_module : TopLevelTracedModule object
        TorchScripted PyTorch graph
        Note: We currently only support traces (ie: torch.jit.trace(model, input))

1528 1529 1530 1531
    input_shapes : List of tuples of input name and input dimensions
        Graph level input shape list
        The same input names need to be used for deployment, so choose easy to
        remember names (such as: input0, input1)
1532

1533 1534 1535
    custom_convert_map: Dictionary of str to Relay op
        A custom op conversion map in the same format as _convert_map above

1536 1537 1538 1539 1540
    Returns
    -------
    mod : tvm.relay.Module
        The module that optimizations will be performed on.

1541 1542
    params : dict of str to tvm.runtime.NDArray
        Dict of converted parameters stored in tvm.runtime.ndarray format
1543
    """
1544 1545
    graph = script_module.graph.copy()
    _run_jit_passes(graph)
1546 1547 1548 1549

    if custom_convert_map:
        _convert_map.update(custom_convert_map)

1550 1551
    op_names = get_all_op_names(graph)
    _report_missing_conversion(op_names)
1552
    _check_inputs(graph, input_shapes)
1553 1554

    params = script_module.state_dict()
1555
    outputs = _get_relay_input_vars(graph, input_shapes)
1556
    param_vars, tensors, packed_param_map = convert_params(graph, params)
1557
    tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
1558

1559
    outputs.update(param_vars)
1560
    ret_name = _get_input_names(graph.return_node())
1561

1562 1563 1564 1565
    # For quantized models
    if "aten::quantize_per_tensor" in op_names:
        weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
        qnn_torch.add_input_quant_params_to_op_inputs(graph)
1566
        qnn_torch.add_quant_params_to_outputs(outputs,
1567 1568 1569 1570 1571
                                              packed_param_map,
                                              weight_quant_params)
        qnn_torch.add_quant_params(tvm_params, weight_quant_params)
        _convert_map.update(qnn_torch.convert_map)

1572 1573
    ret = convert_operators(_get_operator_nodes(graph.nodes()),
                            outputs, ret_name)
1574 1575 1576 1577

    if isinstance(ret[0], list):
        ret[0] = _expr.Tuple(ret[0])

1578
    func = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
1579 1580

    return _module.IRModule.from_expr(func), tvm_params