nn.py 75.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
hlu1 committed
17
#pylint: disable=invalid-name, too-many-lines
18 19
"""Neural network operations."""
from __future__ import absolute_import as _abs
20
from ...expr import TupleWrapper
21
from . import _make
22
from .util import get_pad_tuple2d, get_pad_tuple3d
23 24


25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
def conv1d(data,
           weight,
           strides=1,
           padding=0,
           dilation=1,
           groups=1,
           channels=None,
           kernel_size=None,
           data_layout="NCW",
           kernel_layout="OIW",
           out_layout="",
           out_dtype=""):
    r"""1D convolution.

    This operator takes the weight as the convolution kernel
    and convolves it with data to produce an output.


    In the default case, where the data_layout is `NCW`
    and kernel_layout is `OIW`, conv1d takes in
    a data Tensor with shape `(batch_size, in_channels, width)`,
    and a weight Tensor with shape `(channels, in_channels, kernel_size)`
    to produce an output Tensor with the following rule:

    .. math::

        \mbox{out}[b, c, w] = \sum_{dw, k}
           \mbox{data}[b, k, \mbox{strides}[0] * w + dw] *
           \mbox{weight}[c, k, dw]

    Padding and dilation are applied to data and weight respectively before the computation.
    This operator accepts data layout specification.
    Semantically, the operator will convert the layout to the canonical layout
    (`NCW` for data and `OIW` for weight), perform the computation,
    then convert to the out_layout.


    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    weight : tvm.relay.Expr
        The weight expressions.

    strides : Optional[int, Tuple[int]]
        The strides of convolution.

    padding : Optional[int, Tuple[int]]
        The padding of convolution on both sides of the input before convolution.

    dilation : Optional[int, Tuple[int]]
        Specifies the dilation rate to be used for dilated convolution.

    groups : Optional[int]
        Currently unused for 1D convolution.

    channels : Optional[int]
        Number of output channels of this convolution.

    kernel_size : Optional[int, Tuple[int]]
        The spatial dimension of the convolution kernel.

    data_layout : Optional[str]
        Layout of the input.

    kernel_layout : Optional[str]
        Layout of the weight.

    out_layout : Optional[str]
        Layout of the output, by default, out_layout is the same as data_layout

    out_dtype : Optional[str]
        Specifies the output data type for mixed precision conv2d.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, )
    if isinstance(strides, int):
        strides = (strides, )
    if isinstance(dilation, int):
        dilation = (dilation, )
    if isinstance(padding, int):
        padding = (padding, padding)
    return _make.conv1d(data, weight, strides, padding, dilation,
                        groups, channels, kernel_size, data_layout,
                        kernel_layout, out_layout, out_dtype)


118 119 120 121 122 123 124 125 126
def conv2d(data,
           weight,
           strides=(1, 1),
           padding=(0, 0),
           dilation=(1, 1),
           groups=1,
           channels=None,
           kernel_size=None,
           data_layout="NCHW",
127
           kernel_layout="OIHW",
128 129 130 131 132 133 134 135 136
           out_layout="",
           out_dtype=""):
    r"""2D convolution.

    This operator takes the weight as the convolution kernel
    and convolves it with data to produce an output.


    In the default case, where the data_layout is `NCHW`
137
    and kernel_layout is `OIHW`, conv2d takes in
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
    a data Tensor with shape `(batch_size, in_channels, height, width)`,
    and a weight Tensor with shape `(channels, in_channels, kernel_size[0], kernel_size[1])`
    to produce an output Tensor with the following rule:

    .. math::

        \mbox{out}[b, c, y, x] = \sum_{dy, dx, k}
           \mbox{data}[b, k, \mbox{strides}[0] * y  + dy, \mbox{strides}[1] * x + dx] *
           \mbox{weight}[c, k, dy, dx]

    Padding and dilation are applied to data and weight respectively before the computation.
    This operator accepts data layout specification.
    Semantically, the operator will convert the layout to the canonical layout
    (`NCHW` for data and `OIHW` for weight), perform the computation,
    then convert to the out_layout.


    Parameters
    ----------
157
    data : tvm.relay.Expr
158 159
        The input data to the operator.

160
    weight : tvm.relay.Expr
161 162
        The weight expressions.

163
    strides : Optional[int, Tuple[int]]
164
        The strides of convolution.
165

166
    padding : Optional[int, Tuple[int]]
167 168
        The padding of convolution on both sides of inputs before convolution.

169
    dilation : Optional[int, Tuple[int]]
170 171
        Specifies the dilation rate to be used for dilated convolution.

172
    groups : Optional[int]
173 174
        Number of groups for grouped convolution.

175
    channels : Optional[int]
176 177
        Number of output channels of this convolution.

178
    kernel_size : Optional[int, Tuple[int]]
179 180
        The spatial of the convolution kernel.

181
    data_layout : Optional[str]
182 183
        Layout of the input.

184
    kernel_layout : Optional[str]
185 186
        Layout of the weight.

187
    out_layout : Optional[str]
188 189
        Layout of the output, by default, out_layout is the same as data_layout

190
    out_dtype : Optional[str]
191 192 193 194
        Specifies the output data type for mixed precision conv2d.

    Returns
    -------
195
    result : tvm.relay.Expr
196 197
        The computed result.
    """
198 199 200 201 202 203
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size)
    if isinstance(strides, int):
        strides = (strides, strides)
    if isinstance(dilation, int):
        dilation = (dilation, dilation)
204 205 206
    # TODO enforce 4-way padding in topi/nn/conv2d after #4644 merged
    # convert 2-way padding to 4-way padding
    padding = get_pad_tuple2d(padding)
207 208
    return _make.conv2d(data, weight, strides, padding, dilation,
                        groups, channels, kernel_size, data_layout,
209
                        kernel_layout, out_layout, out_dtype)
210 211


212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
def conv3d(data,
           weight,
           strides=(1, 1, 1),
           padding=(0, 0, 0),
           dilation=(1, 1, 1),
           groups=1,
           channels=None,
           kernel_size=None,
           data_layout="NCDHW",
           kernel_layout="OIDHW",
           out_layout="",
           out_dtype=""):
    r"""3D convolution.

    This operator takes the weight as the convolution kernel
    and convolves it with data to produce an output.


    In the default case, where the data_layout is `NCDHW`
    and kernel_layout is `OIDHW`, conv3d takes in
    a data Tensor with shape `(batch_size, in_channels, depth, height, width)`,
    and a weight Tensor with shape `(channels, in_channels, kernel_size[0], kernel_size[1],
    kernel_size[2])` to produce an output Tensor with the following rule:

    .. math::

        \mbox{out}[b, c, z, y, x] = \sum_{dz, dy, dx, k}
           \mbox{data}[b, k, \mbox{strides}[0] * z  + dz, \mbox{strides}[1] * y  + dy,
           \mbox{strides}[2] * x + dx] * \mbox{weight}[c, k, dz, dy, dx]

    Padding and dilation are applied to data and weight respectively before the computation.
    This operator accepts data layout specification.
    Semantically, the operator will convert the layout to the canonical layout
    (`NCDHW` for data and `OIDHW` for weight), perform the computation,
    then convert to the out_layout.


    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    weight : tvm.relay.Expr
        The weight expressions.

    strides : Optional[Tuple[int]]
        The strides of convolution.

260
    padding : Optional[int, Tuple[int]]
261 262
        The padding of convolution on both sides of inputs before convolution.

263
    dilation : Optional[int, Tuple[int]]
264 265 266 267 268 269 270 271
        Specifies the dilation rate to be used for dilated convolution.

    groups : Optional[int]
        Number of groups for grouped convolution.

    channels : Optional[int]
        Number of output channels of this convolution.

272
    kernel_size : Optional[int, Tuple[int]]
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
        The spatial of the convolution kernel.

    data_layout : Optional[str]
        Layout of the input.

    kernel_layout : Optional[str]
        Layout of the weight.

    out_layout : Optional[str]
        Layout of the output, by default, out_layout is the same as data_layout

    out_dtype : Optional[str]
        Specifies the output data type for mixed precision conv2d.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
292 293 294 295 296 297
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(strides, int):
        strides = (strides, strides, strides)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation)
298
    padding = get_pad_tuple3d(padding)
299 300 301 302 303
    return _make.conv3d(data, weight, strides, padding, dilation,
                        groups, channels, kernel_size, data_layout,
                        kernel_layout, out_layout, out_dtype)


304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
def contrib_conv3d_winograd_without_weight_transform(data,
                                                     weight,
                                                     tile_size,
                                                     strides=(1, 1, 1),
                                                     padding=(0, 0, 0),
                                                     dilation=(1, 1, 1),
                                                     groups=1,
                                                     channels=None,
                                                     kernel_size=None,
                                                     data_layout="NCDHW",
                                                     kernel_layout="OIDHW",
                                                     out_layout="",
                                                     out_dtype=""):
    r"""3D convolution with winograd algorithm.

    The basic parameters are the same as the ones in vanilla conv3d.
    It assumes the weight is pre-transformed by nn.contrib_conv3d_winograd_weight_transform

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    weight : tvm.relay.Expr
        The weight expressions.

    tile_size : int
        The Tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)

    strides : tuple of int, optional
        The strides of convolution.

    padding : tuple of int, optional
        The padding of convolution on both sides of inputs before convolution.

    dilation : tuple of int, optional
        Specifies the dilation rate to be used for dilated convolution.

    groups : int, optional
        Number of groups for grouped convolution.

    channels : int, optional
        Number of output channels of this convolution.

    kernel_size : tuple of int, optional
        The spatial of the convolution kernel.

    data_layout : str, optional
        Layout of the input.

    kernel_layout : str, optional
        Layout of the weight.

    out_layout : str, optional
        Layout of the output, by default, out_layout is the same as data_layout

    out_dtype : str, optional
        Specifies the output data type for mixed precision conv2d.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    # convert 3-way padding to 6-way padding
    padding = get_pad_tuple3d(padding)
    return _make.contrib_conv3d_winograd_without_weight_transform(
        data, weight, tile_size, strides, padding, dilation,
        groups, channels, kernel_size, data_layout,
        kernel_layout, out_layout, out_dtype)


376 377 378 379 380 381 382 383 384
def conv2d_transpose(data,
                     weight,
                     strides=(1, 1),
                     padding=(0, 0),
                     dilation=(1, 1),
                     groups=1,
                     channels=None,
                     kernel_size=None,
                     data_layout="NCHW",
385
                     kernel_layout="OIHW",
386
                     out_layout="",
387 388
                     output_padding=(0, 0),
                     out_dtype=""):
389
    """Two dimensional transposed convolution operator.
390 391 392

    Parameters
    ----------
393
    data : tvm.relay.Expr
394 395
        The input data to the operator.

396
    weight : tvm.relay.Expr
397 398 399
        The weight expressions.

    strides : Tuple[int], optional
400
        The strides of convolution.
401 402 403 404 405 406 407

    padding : Tuple[int], optional
        The padding of convolution on both sides of inputs.

    dilation : Tuple[int], optional
        Specifies the dilation rate to be used for dilated convolution.

408 409 410 411 412 413
    channels : int, optional
        Number of output channels of this convolution.

    kernel_size : tuple of int, optional
        The spatial of the convolution kernel.

414 415 416 417 418 419
    groups : int, optional
        Number of groups for grouped convolution.

    data_layout : str, optional
        Layout of the input.

420
    kernel_layout : str, optional
421 422
        Layout of the weight.

423 424 425
    out_layout : Optional[str]
        Layout of the output, by default, out_layout is the same as data_layout

426 427 428 429 430 431 432 433
    output_padding : Tuple[int], optional
        Additional zero-padding to be added to one side of the output.

    out_dtype : str, optional
        Specifies the output data type for mixed precision conv2d.

    Returns
    -------
434
    result : tvm.relay.Expr
435 436
        The computed result.
    """
437 438
    # convert 2-way padding to 4-way padding
    padding = get_pad_tuple2d(padding)
439 440
    return _make.conv2d_transpose(data, weight, strides, padding, dilation,
                                  groups, channels, kernel_size, data_layout,
441
                                  kernel_layout, out_layout, output_padding, out_dtype)
442 443


444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509
def conv1d_transpose(data,
                     weight,
                     strides=(1,),
                     padding=(0,),
                     dilation=(1,),
                     groups=1,
                     channels=None,
                     kernel_size=None,
                     data_layout="NCW",
                     kernel_layout="OIW",
                     out_layout="",
                     output_padding=(0,),
                     out_dtype=""):
    """One dimensional transposed convolution operator.

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    weight : tvm.relay.Expr
        The weight expressions.

    strides : Tuple[int], optional
        The strides of convolution.

    padding : Tuple[int], optional
        The padding of convolution on both sides of inputs.

    dilation : Tuple[int], optional
        Specifies the dilation rate to be used for dilated convolution.

    channels : int, optional
        Number of output channels of this convolution.

    kernel_size : tuple of int, optional
        The spatial of the convolution kernel.

    groups : int, optional
        Number of groups for grouped convolution.

    data_layout : str, optional
        Layout of the input.

    kernel_layout : str, optional
        Layout of the weight.

    out_layout : Optional[str]
        Layout of the output, by default, out_layout is the same as data_layout

    output_padding : Tuple[int], optional
        Additional zero-padding to be added to one side of the output.

    out_dtype : str, optional
        Specifies the output data type for mixed precision conv2d.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    return _make.conv1d_transpose(data, weight, strides, padding, dilation,
                                  groups, channels, kernel_size, data_layout,
                                  kernel_layout, out_layout, output_padding, out_dtype)


510
def softmax(data, axis=-1):
511 512 513 514 515 516 517 518 519
    r"""Computes softmax.

    .. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}

    .. note::
        This operator can be optimized away for inference.

    Parameters
    ----------
520
    data: tvm.relay.Expr
521 522
        The input data to the operator.

523
    axis: int, optional
524 525
        The axis to sum over when computing softmax

526 527
    Returns
    -------
528
    result : tvm.relay.Expr
529 530
        The computed result.
    """
531
    return _make.softmax(data, axis)
532 533


534
def log_softmax(data, axis=-1):
535 536 537 538 539 540 541 542 543 544 545
    r"""Computes log softmax.

    .. math::

        \text{log_softmax}(x)_i = \log \frac{exp(x_i)}{\sum_j exp(x_j)}

    .. note::
        This operator can be optimized away for inference.

    Parameters
    ----------
546
    data: tvm.relay.Expr
547 548
        The input data to the operator.

549 550
    axis: int, optional
        The axis to sum over when computing log softmax
551

552 553
    Returns
    -------
554
    result : tvm.relay.Expr
555 556
        The computed result.
    """
557 558 559
    return _make.log_softmax(data, axis)


560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583
def max_pool1d(data,
               pool_size=(1,),
               strides=(1,),
               padding=(0,),
               layout="NCW",
               ceil_mode=False):
    r"""1D maximum pooling operator.

    This operator takes data as input and does 1D max value calculation
    with in pool_size sized window by striding defined by stride.

    In the default case, where the data_layout is `NCW`
    a data Tensor with shape `(batch_size, channels, width)`,
    to produce an output Tensor.

    The ceil_mode is used to take ceil or floor while computing out shape.
    count_include_pad indicates including or excluding padded input values in computation.
    This operator accepts data layout specification.

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

584 585 586
    pool_size : int or tuple of int, optional
        The size of window for pooling.

587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611
    strides : int or tuple of int, optional
        The strides of pooling.

    padding : int or tuple of int, optional
        The padding for pooling.

    layout : str, optional
        Layout of the input.

    ceil_mode : bool, optional
        To enable or disable ceil while pooling.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    if isinstance(strides, int):
        strides = (strides,)
    if isinstance(padding, int):
        padding = (padding,)
    return _make.max_pool1d(data, pool_size, strides, padding,
                            layout, ceil_mode)


612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640
def max_pool2d(data,
               pool_size=(1, 1),
               strides=(1, 1),
               padding=(0, 0),
               layout="NCHW",
               ceil_mode=False):
    r"""2D maximum pooling operator.

    This operator takes data as input and does 2D max value calculation
    with in pool_size sized window by striding defined by stride


    In the default case, where the data_layout is `NCHW`
    a data Tensor with shape `(batch_size, in_channels, height, width)`,
    to produce an output Tensor with the following rule:

    with data of shape (b, c, h, w) and pool_size (kh, kw)

    .. math::

        \mbox{out}(b, c, y, x)  = \max_{m=0, \ldots, kh-1} \max_{n=0, \ldots, kw-1}
             \mbox{data}(b, c, \mbox{stride}[0] * y + m, \mbox{stride}[1] * x + n)

    Padding is applied to data before the computation.
    ceil_mode is used to take ceil or floor while computing out shape.
    This operator accepts data layout specification.

    Parameters
    ----------
641
    data : tvm.relay.Expr
642 643
        The input data to the operator.

644 645 646
    pool_size : int or tuple of int, optional
        The size of window for pooling.

647 648 649 650 651 652 653 654 655 656 657 658 659 660
    strides : tuple of int, optional
        The strides of pooling.

    padding : tuple of int, optional
        The padding for pooling.

    layout : str, optional
        Layout of the input.

    ceil_mode : bool, optional
        To enable or disable ceil while pooling.

    Returns
    -------
661
    result : tvm.relay.Expr
662 663 664 665 666
        The computed result.
    """
    return _make.max_pool2d(data, pool_size, strides, padding,
                            layout, ceil_mode)

667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691
def max_pool3d(data,
               pool_size=(1, 1, 1),
               strides=(1, 1, 1),
               padding=(0, 0, 0),
               layout="NCDHW",
               ceil_mode=False):
    r"""3D maximum pooling operator.

    This operator takes data as input and does 3D max value calculation
    with in pool_size sized window by striding defined by stride.


    In the default case, where the data_layout is `NCDHW`
    a data Tensor with shape `(batch_size, channels, depth, height, width)`,
    to produce an output Tensor.

    The ceil_mode is used to take ceil or floor while computing out shape.
    count_include_pad indicates including or excluding padded input values in computation.
    This operator accepts data layout specification.

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

692 693 694
    pool_size : int or tuple of int, optional
        The size of window for pooling.

695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714
    strides : tuple of int, optional
        The strides of pooling.

    padding : tuple of int, optional
        The padding for pooling.

    layout : str, optional
        Layout of the input.

    ceil_mode : bool, optional
        To enable or disable ceil while pooling.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    return _make.max_pool3d(data, pool_size, strides, padding,
                            layout, ceil_mode)

715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740

def avg_pool1d(data,
               pool_size=(1,),
               strides=(1,),
               padding=(0,),
               layout="NCW",
               ceil_mode=False,
               count_include_pad=False):
    r"""1D average pooling operator.

    This operator takes data as input and does 1D average value calculation
    with in pool_size sized window by striding defined by stride

    In the default case, where the data_layout is `NCW`
    a data Tensor with shape `(batch_size, channels, width)`,
    to produce an output Tensor.

    The ceil_mode is used to take ceil or floor while computing out shape.
    count_include_pad indicates including or excluding padded input values in computation.
    This operator accepts data layout specification.

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

741 742 743
    pool_size : int or tuple of int, optional
        The size of window for pooling.

744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771
    strides : int or tuple of int, optional
        The strides of pooling.

    padding : int or tuple of int, optional
        The padding for pooling.

    layout : str, optional
        Layout of the input.

    ceil_mode : bool, optional
        To enable or disable ceil while pooling.

    count_include_pad : bool, optional
        To include padding to compute the average.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    if isinstance(strides, int):
        strides = (strides,)
    if isinstance(padding, int):
        padding = (padding,)
    return _make.avg_pool1d(data, pool_size, strides, padding,
                            layout, ceil_mode, count_include_pad)


772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802
def avg_pool2d(data,
               pool_size=(1, 1),
               strides=(1, 1),
               padding=(0, 0),
               layout="NCHW",
               ceil_mode=False,
               count_include_pad=False):
    r"""2D average pooling operator.

    This operator takes data as input and does 2D average value calculation
    with in pool_size sized window by striding defined by stride


    In the default case, where the data_layout is `NCHW`
    a data Tensor with shape `(batch_size, in_channels, height, width)`,
    to produce an output Tensor with the following rule:

    with data of shape (b, c, h, w), pool_size (kh, kw)

    .. math::

        \mbox{out}(b, c, y, x)  = \frac{1}{kh * kw} \sum_{m=0}^{kh-1} \sum_{n=0}^{kw-1}
             \mbox{data}(b, c, \mbox{stride}[0] * y + m, \mbox{stride}[1] * x + n)

    Padding is applied to data before the computation.
    ceil_mode is used to take ceil or floor while computing out shape.
    count_include_pad indicates including or excluding padded input values in computation.
    This operator accepts data layout specification.

    Parameters
    ----------
803
    data : tvm.relay.Expr
804 805
        The input data to the operator.

806 807 808
    pool_size : int or tuple of int, optional
        The size of window for pooling.

809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825
    strides : tuple of int, optional
        The strides of pooling.

    padding : tuple of int, optional
        The padding for pooling.

    layout : str, optional
        Layout of the input.

    ceil_mode : bool, optional
        To enable or disable ceil while pooling.

    count_include_pad : bool, optional
        To include padding to compute the average.

    Returns
    -------
826
    result : tvm.relay.Expr
827 828 829 830 831
        The computed result.
    """
    return _make.avg_pool2d(data, pool_size, strides, padding,
                            layout, ceil_mode, count_include_pad)

832 833 834 835 836 837 838 839 840 841 842 843 844 845
def avg_pool3d(data,
               pool_size=(1, 1, 1),
               strides=(1, 1, 1),
               padding=(0, 0, 0),
               layout="NCDHW",
               ceil_mode=False,
               count_include_pad=False):
    r"""3D average pooling operator.

    This operator takes data as input and does 3D average value calculation
    with in pool_size sized window by striding defined by stride


    In the default case, where the data_layout is `NCDHW`
846
    a data Tensor with shape `(batch_size, channels, depth, height, width)`,
847 848 849 850 851 852 853 854 855 856 857
    to produce an output Tensor.

    The ceil_mode is used to take ceil or floor while computing out shape.
    count_include_pad indicates including or excluding padded input values in computation.
    This operator accepts data layout specification.

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

858 859 860
    pool_size : int or tuple of int, optional
        The size of window for pooling.

861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883
    strides : tuple of int, optional
        The strides of pooling.

    padding : tuple of int, optional
        The padding for pooling.

    layout : str, optional
        Layout of the input.

    ceil_mode : bool, optional
        To enable or disable ceil while pooling.

    count_include_pad : bool, optional
        To include padding to compute the average.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    return _make.avg_pool3d(data, pool_size, strides, padding,
                            layout, ceil_mode, count_include_pad)

884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902
def max_pool2d_grad(out_grad,
                    data,
                    pool_size=(1, 1),
                    strides=(1, 1),
                    padding=(0, 0),
                    layout="NCHW",
                    ceil_mode=False):
    r"""Gradient of 2D maximum pooling operator.

    This operator takes out_grad and data as input and calculates gradient of max_pool2d.

    Parameters
    ----------
    out_grad : tvm.relay.Expr
        The output gradient

    data : tvm.relay.Expr
        The input data to the operator.

903 904 905
    pool_size : int or tuple of int, optional
        The size of window for pooling.

906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945
    strides : tuple of int, optional
        The strides of pooling.

    padding : tuple of int, optional
        The padding for pooling.

    layout : str, optional
        Layout of the input.

    ceil_mode : bool, optional
        To enable or disable ceil while pooling.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    return _make.max_pool2d_grad(out_grad, data, pool_size, strides, padding,
                                 layout, ceil_mode)

def avg_pool2d_grad(out_grad,
                    data,
                    pool_size=(1, 1),
                    strides=(1, 1),
                    padding=(0, 0),
                    layout="NCHW",
                    ceil_mode=False,
                    count_include_pad=False):
    r"""Gradient of 2D average pooling operator.

    This operator takes out_grad and data as input and calculates gradient of avg_pool2d.

    Parameters
    ----------
    out_grad : tvm.relay.Expr
        The output gradient

    data : tvm.relay.Expr
        The input data to the operator.

946 947 948
    pool_size : int or tuple of int, optional
        The size of window for pooling.

949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971
    strides : tuple of int, optional
        The strides of pooling.

    padding : tuple of int, optional
        The padding for pooling.

    layout : str, optional
        Layout of the input.

    ceil_mode : bool, optional
        To enable or disable ceil while pooling.

    count_include_pad : bool, optional
        To include padding to compute the average.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    return _make.avg_pool2d_grad(out_grad, data, pool_size, strides, padding,
                                 layout, ceil_mode, count_include_pad)

972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992
def global_max_pool2d(data,
                      layout="NCHW"):
    r"""2D global maximum pooling operator.

    This operator takes data as input and does 2D max value calculation
    across each window represented by WxH.


    In the default case, where the data_layout is `NCHW`
    a data Tensor with shape `(batch_size, in_channels, height, width)`,
    to produce an output Tensor with the following rule:

    with data of shape (b, c, h, w)

    .. math::

        \mbox{out}(b, c, 1, 1)  = \max_{m=0, \ldots, h} \max_{n=0, \ldots, w}
             \mbox{data}(b, c, m, n)

    Parameters
    ----------
993
    data : tvm.relay.Expr
994 995 996 997 998 999 1000
        The input data to the operator.

    layout : str, optional
        Layout of the input.

    Returns
    -------
1001
    result : tvm.relay.Expr
1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026
        The computed result.
    """
    return _make.global_max_pool2d(data, layout)

def global_avg_pool2d(data,
                      layout="NCHW"):
    r"""2D global average pooling operator.

    This operator takes data as input and does 2D average value calculation
    across each window represented by WxH.


    In the default case, where the data_layout is `NCHW`
    a data Tensor with shape `(batch_size, in_channels, height, width)`,
    to produce an output Tensor with the following rule:

    with data of shape (b, c, h, w)

    .. math::

        \mbox{out}(b, c, 1, 1)  = \frac{1}{h * w} \sum_{m=0}^{h-1} \sum_{n=0}^{w-1}
             \mbox{data}(b, c, m, n)

    Parameters
    ----------
1027
    data : tvm.relay.Expr
1028 1029 1030 1031 1032 1033 1034
        The input data to the operator.

    layout : str, optional
        Layout of the input.

    Returns
    -------
1035
    result : tvm.relay.Expr
1036 1037 1038 1039 1040 1041
        The computed result.
    """
    return _make.global_avg_pool2d(data, layout)


def upsampling(data,
1042 1043
               scale_h=1,
               scale_w=1,
1044
               layout="NCHW",
1045 1046
               method="nearest_neighbor",
               align_corners=False):
1047 1048 1049 1050 1051
    """Upsampling.

    This operator takes data as input and does 2D scaling to the given scale factor.
    In the default case, where the data_layout is `NCHW`
    with data of shape (n, c, h, w)
1052
    out will have a shape (n, c, h*scale_h, w*scale_w)
1053

1054
    method indicates the algorithm to be used while calculating the out value
1055
    and method can be one of ("bilinear", "nearest_neighbor", "bicubic")
1056 1057 1058

    Parameters
    ----------
1059
    data : tvm.relay.Expr
1060 1061
        The input data to the operator.

1062 1063 1064 1065 1066
    scale_h : tvm.relay.Expr
        The scale factor for height upsampling.

    scale_w : tvm.relay.Expr
        The scale factor for width upsampling.
1067 1068 1069 1070 1071

    layout : str, optional
        Layout of the input.

    method : str, optional
1072 1073 1074 1075
        Scale method to used [nearest_neighbor, bilinear, bicubic].

    align_corners : bool, optional
        Whether to keep corners in proper place.
1076 1077 1078

    Returns
    -------
1079
    result : tvm.relay.Expr
1080 1081
        The computed result.
    """
1082
    return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners)
1083

1084

1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136
def upsampling3d(data,
                 scale_d=1,
                 scale_h=1,
                 scale_w=1,
                 layout="NCDHW",
                 method="nearest_neighbor",
                 coordinate_transformation_mode="half_pixel"):
    """3D Upsampling.

    This operator takes data as input and does 3D scaling to the given scale factor.
    In the default case, where the data_layout is `NCDHW`
    with data of shape (n, c, d, h, w)
    out will have a shape (n, c, d*scale_d, h*scale_h, w*scale_w)

    method indicates the algorithm to be used while calculating the out value
    and method can be one of ("trilinear", "nearest_neighbor")

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    scale_d : tvm.relay.Expr
        The scale factor for depth upsampling.

    scale_h : tvm.relay.Expr
        The scale factor for height upsampling.

    scale_w : tvm.relay.Expr
        The scale factor for width upsampling.

    layout : str, optional
        Layout of the input.

    method : str, optional
        Scale method to used [nearest_neighbor, trilinear].

    coordinate_transformation_mode: string, optional
        Describes how to transform the coordinate in the resized tensor
        to the coordinate in the original tensor.
        Refer to the ONNX Resize operator specification for details.
        Available options are "half_pixel", "align_corners" and "asymmetric".

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    return _make.upsampling3d(data, scale_d, scale_h, scale_w, layout, method,
                              coordinate_transformation_mode)


1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148
def batch_flatten(data):
    """BatchFlatten.

    This operator flattens all the dimensions except for the batch dimension.
    which results a 2D output.

    For data with shape ``(d1, d2, ..., dk)``
    batch_flatten(data) returns reshaped output of shape ``(d1, d2*...*dk)``.


    Parameters
    ----------
1149
    data : tvm.relay.Expr
1150 1151 1152 1153
        The input data to the operator.

    Returns
    -------
1154
    result : tvm.relay.Expr
1155 1156 1157
        The Flattened result.
    """
    return _make.batch_flatten(data)
雾雨魔理沙 committed
1158

1159

1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185
def bias_add(data, bias, axis=1):
    """add_bias operator.

    Add 1D bias to the axis of data.
    This function is a special case of add which allows
    inference of shape of the bias from data.

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    bias : tvm.relay.Expr
        The bias to be added.

    axis : int, optional
        The axis to add the bias.

    Returns
    -------
    result : tvm.relay.Expr
        The final result.
    """
    return _make.bias_add(data, bias, axis)


1186
def dense(data, weight, units=None, out_dtype=""):
1187 1188 1189 1190 1191 1192 1193 1194 1195
    """Dense operator.
    Applies a linear transformation

    .. math::

    `Y = X * W`

    Parameters
    ----------
1196
    data : tvm.relay.Expr
1197 1198
        The input data to the operator.

1199
    weight : tvm.relay.Expr
1200 1201 1202 1203 1204
        The weight expressions.

    units : int, optional
        Number of hidden units of the dense transformation.

1205 1206 1207
    out_dtype : str, optional
        Specifies the output data type for mixed precision dense.

1208 1209
    Returns
    -------
1210
    result : tvm.relay.Expr
1211 1212
        The computed result.
    """
1213
    return _make.dense(data, weight, units, out_dtype)
1214 1215


1216
def fifo_buffer(data, buffer, axis):
1217
    """FIFO buffer to enable computation reuse in CNNs with sliding indow input
1218 1219

    Compute equivalent of
1220 1221 1222 1223 1224 1225 1226

    .. code-block:: python

        concat(buffer, data, axis=axis)
        .slice_axis(axis=axis,
                    begin=data.shape[axis],
                    end=data.shape[axis]+buffer.shape[axis])
1227 1228

    Useful for
1229

1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249
    * Encoding explicit re-use of computation in convolution ops operated on a sliding window input
    * Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet.

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data
    buffer : tvm.relay.Expr
        Previous value of the FIFO buffer
    axis : int
        Specify which axis should be used for buffering

    Returns
    -------
    result : tvm.relay.Expr
        Updated value for the buffer
    """
    return _make.fifo_buffer(data, buffer, axis)


雾雨魔理沙 committed
1250 1251 1252 1253 1254 1255 1256 1257
def relu(data):
    """Rectified linear unit.

    .. math::
       out = max(x, 0)

    Parameters
    ----------
1258
    data : tvm.relay.Expr
雾雨魔理沙 committed
1259 1260 1261 1262
        The input data

    Returns
    -------
1263
    result : tvm.relay.Expr
雾雨魔理沙 committed
1264 1265 1266
        The computed result.
    """
    return _make.relu(data)
1267 1268


1269 1270 1271 1272 1273 1274 1275 1276 1277 1278
def leaky_relu(data, alpha):
    """This operator takes data as input and does Leaky version
    of a Rectified Linear Unit.

    .. math::

        `y = x > 0 ? x : alpha * x`

    Parameters
    ----------
1279
    data : tvm.relay.Expr
1280 1281 1282 1283 1284 1285 1286
        The input data to the operator.

    alpha : float
        Slope coefficient for the negative half axis.

    Returns
    -------
1287
    result : tvm.relay.Expr
1288 1289 1290 1291 1292
        The computed result.
    """
    return _make.leaky_relu(data, alpha)


Siju committed
1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319
def prelu(data, alpha, axis=1):
    """This operator takes data as input and does Leaky version
    of a Rectified Linear Unit.

    .. math::

        `y = x > 0 ? x : alpha * x`

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    alpha : tvm.relay.Expr
        Slope coefficient for the negative half axis.

    axis : int, optional
        Specify which shape axis the channel is specified.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    return _make.prelu(data, alpha, axis)


1320 1321
def pad(data,
        pad_width,
1322 1323
        pad_value=0.0,
        pad_mode='constant'):
1324 1325 1326 1327 1328 1329 1330
    r"""Padding

    This operator takes in a tensor and pads each axis by the specified
    widths using the specified value.

    Parameters
    ----------
1331
    data: tvm.relay.Expr
1332 1333 1334 1335 1336 1337
        The input data to the operator
    pad_width: tuple of <tuple of <int>>, required
        Number of values padded to the edges of each axis, in the format
        of ((before_1, after_1), ..., (before_N, after_N))
    pad_value: float, optional, default=0.0
        The value used for padding
1338 1339 1340 1341
    pad_mode: 'constant', 'edge', 'reflect'
        'constant' pads with constant_value pad_value
        'edge' pads using the edge values of the input array
        'reflect' pads by reflecting values with respect to the edge
1342 1343
    Returns
    -------
1344
    result : tvm.relay.Expr
1345 1346
        The computed result.
    """
1347
    return _make.pad(data, pad_width, pad_value, pad_mode)
1348 1349


1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375
def mirror_pad(data,
               pad_width,
               mode="SYMMETRIC"):
    r"""MirrorPadding

    This operator takes in a tensor and pads each axis by the specified
    widths using mirroring of the border pixels.

    Parameters
    ----------
    data: tvm.relay.Expr
        The input data to the operator
    pad_width: tuple of <tuple of <int>>, required
        Number of values padded to the edges of each axis, in the format
        of ((before_1, after_1), ..., (before_N, after_N))
    mode: string, optional, default='SYMMETRIC'
        What type of mirroring to use, must be SYMMETRIC or REFLECT.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    return _make.mirror_pad(data, pad_width, mode)


1376 1377 1378 1379 1380 1381 1382 1383 1384 1385
def lrn(data, size=5, axis=1, bias=2, alpha=.00001, beta=0.75):
    """This operator takes data as input and does local response normalization.

    Normalize the input in a local region across or within feature maps.
    Each input value is divided by (data / (bias + (alpha * sum_data ^2 /size))^beta)
    where n is the size of each local region, and the sum is taken over the region
    centered at that value (zero padding is added where necessary).

    .. math::
        (data / (bias + (alpha * sum_data ^2 /size))^beta)
1386

1387 1388
    Parameters
    ----------
1389
    data : tvm.relay.Expr
1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408
        The input data to the operator.

    size : int, optional
        The size of the local region to be considered for normalization.

    axis : int, optional
        Input data layout channel axis. Default value is 1 for NCHW format

    bias : float, optional
        The offset parameter to avoid dividing by 0.

    alpha : float, optional
        The scaling parameter.

    beta : float, optional
        The exponent parameter.

    Returns
    -------
1409
    result : tvm.relay.Expr
1410 1411 1412 1413
        The computed result.
    """
    return _make.lrn(data, size, axis, alpha, beta, bias)

1414

1415 1416 1417 1418 1419 1420 1421 1422
def l2_normalize(data, eps, axis=None):
    """Perform L2 normalization on the input data

    .. math::
        y(i, j) = x(i, j) / sqrt(max(sum(x^2), eps))

    Parameters
    ----------
1423
    data : tvm.relay.Expr
1424 1425 1426 1427 1428 1429 1430 1431 1432 1433
        The input data to the operator.

    eps : float
        epsilon value

    axis : list of int, optional
        axis over the normalization applied

    Returns
    -------
1434
    result : tvm.relay.Expr
1435 1436 1437
        The computed result.
    """
    return _make.l2_normalize(data, eps, axis)
1438

1439

1440 1441 1442 1443 1444 1445 1446 1447 1448
def dropout(data, rate=0.5):
    """Applies the dropout operation to the input array.

    During training, each element of the input is set to zero with
    probability ``p``. The whole array is rescaled by ``1/(1-p)``
    to keep the expected sum of the input unchanged.

    Parameters
    ----------
1449
    data : tvm.relay.Expr
1450 1451 1452 1453 1454 1455 1456
        The input data to the operator.

    rate : float, optional (default=0.5)
        The probability for an element to be reset to 0.

    Returns
    -------
1457 1458
    result : tvm.relay.Expr
        The result of dropout
1459
    """
1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483
    return TupleWrapper(dropout_raw(data, rate), 2)[0]


def dropout_raw(data, rate=0.5):
    """Applies the dropout operation to the input array.

    During training, each element of the input is set to zero with
    probability ``p``. The whole array is rescaled by ``1/(1-p)``
    to keep the expected sum of the input unchanged.

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    rate : float, optional (default=0.5)
        The probability for an element to be reset to 0.

    Returns
    -------
    result : tvm.relay.Expr
        The result of dropout
    """
    return _make.dropout(data, rate)
1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494


def batch_norm(data,
               gamma,
               beta,
               moving_mean,
               moving_var,
               axis=1,
               epsilon=1e-5,
               center=True,
               scale=True):
1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519
    r"""
    Batch normalization layer (Ioffe and Szegedy, 2014).
    Normalizes the input at each batch, i.e. applies a transformation
    that maintains the mean activation close to 0 and the activation
    standard deviation close to 1.

    .. math::

        data\_mean[i] = mean(data[:,i,:,...]) \\
        data\_var[i] = var(data[:,i,:,...])

    Then compute the normalized output, which has the same shape as input, as following:

    .. math::

        out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}}
            * gamma[i] + beta[i]

    Both *mean* and *var* returns a scalar by treating the input as a vector.

    Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
    have shape *(k,)*.

    Besides the inputs and the outputs, this operator accepts two auxiliary
    states, ``moving_mean`` and ``moving_var``, which are *k*-length
1520
    vectors. They are global statistics for the whole dataset, which are updated by
1521

1522 1523 1524 1525
    .. code:: python

        moving_mean = moving_mean * momentum + data_mean * (1 - momentum)
        moving_var = moving_var * momentum + data_var * (1 - momentum)
1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536

    The parameter ``axis`` specifies which axis of the input shape denotes
    the 'channel' (separately normalized groups).  The default is 1.
    Specifying -1 sets the channel axis to be the last item in the input shape.

    .. note::

        This operator can be optimized away for inference.

    Parameters
    ----------
1537
    data : tvm.relay.Expr
1538
        Input to which batch_norm will be applied.
1539 1540

    gamma : tvm.relay.Expr
1541
        The gamma scale factor.
1542 1543

    beta : tvm.relay.Expr
1544
        The beta offset factor.
1545 1546

    moving_mean : tvm.relay.Expr
1547
        Running mean of input,
1548 1549

    moving_var : tvm.relay.Expr
1550
        Running variance of input.
1551

1552 1553
    axis : int, optional, default=1
        Specify along which shape axis the channel is specified.
1554

1555
    epsilon : double, optional, default=1e-5
1556
        Small float added to variance to avoid dividing by zero.
1557

1558 1559 1560
    center : boolean, optional, default=True
        If True, add offset of beta to normalized tensor, If False,
        beta is ignored.
1561

1562 1563 1564
    scale : boolean, optional, default=True
        If true, multiply by gamma. If False, gamma is not used.
        When the next layer is piecewise linear (also e.g. nn.relu),
1565
        this can be disabled since the scaling will be done by the next layer.
1566 1567 1568

    Returns
    -------
1569 1570 1571
    result : relay.Tuple([tvm.relay.Expr, tvm.relay.Expr, tvm.relay.Expr])
        Tuple of normed data (same shape as input),
        new running mean (k-length vector),
1572 1573
        and new running variance (k-length vector)
    """
1574 1575 1576 1577 1578 1579 1580 1581 1582
    result = _make.batch_norm(data,
                              gamma,
                              beta,
                              moving_mean,
                              moving_var,
                              axis,
                              epsilon,
                              center,
                              scale)
1583
    return TupleWrapper(result, 3)
1584 1585


1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652
def instance_norm(data,
                  gamma,
                  beta,
                  axis=1,
                  epsilon=1e-5,
                  center=True,
                  scale=True):
    r"""
    Instance Normalization (Ulyanov and et al., 2016)
    Applies instance normalization to the n-dimensional input array.

    .. math::

        out = \frac{data - mean(data)}{\sqrt{var(data)+\epsilon}}
            * gamma + beta

    The instance normalization is similar to batch normalization, but unlike
    batch normalization, the mean and var are calculated per-dimension
    separately for each object(instance) in a mini-batch, not over a batch.
    And the same normalization is applied both at test and train time.

    Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
    have shape *(k,)*.

    The parameter ``axis`` specifies which axis of the input shape denotes
    the 'channel'.  The default is 1. Specifying -1 sets the channel axis
    to be the last item in the input shape.

    .. note::

        This operator can be optimized away for inference.

    Parameters
    ----------
    data : tvm.relay.Expr
        Input to which instance_norm will be applied.

    gamma : tvm.relay.Expr
        The gamma scale factor.

    beta : tvm.relay.Expr
        The beta offset factor.

    axis : int, optional, default=1
        Specify along which shape axis the channel is specified.

    epsilon : double, optional, default=1e-5
        Small float added to variance to avoid dividing by zero.

    center : boolean, optional, default=True
        If True, add offset of beta to normalized tensor, If False,
        beta is ignored.

    scale : boolean, optional, default=True
        If True, multiply by gamma. If False, gamma is not used.

    Returns
    -------
    result : tvm.relay.Expr
        The normalized data.

    .. _`Instance Normalization: The Missing Ingredient for Fast Stylization`:
        https://arxiv.org/abs/1607.08022
    """
    return _make.instance_norm(data, gamma, beta, axis, epsilon, center, scale)


1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681
def layer_norm(data,
               gamma,
               beta,
               axis=-1,
               epsilon=1e-5,
               center=True,
               scale=True):
    r"""
    Layer normalization (Lei Ba and et al., 2016).
    Applies layer normalization to the n-dimensional input array.
    This operator takes an n-dimensional input array and normalizes
    the input using the given axis:

    .. math::

        out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}}
            * gamma + beta

    Unlike batch normalization, the mean and var are computed along the channel dimension.

    Assume the input has size k on axis 1, then both gamma and beta have shape (k,).

    .. note::

        This operator can be optimized away for inference.

    Parameters
    ----------
    data : tvm.relay.Expr
1682
        Input to which layer_norm will be applied.
1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710

    gamma : tvm.relay.Expr
        The gamma scale factor.

    beta : tvm.relay.Expr
        The beta offset factor.

    axis : int, optional, default=-1
        The axis that should be normalized, typically the axis of the channels.

    epsilon : double, optional, default=1e-5
        Small float added to variance to avoid dividing by zero.

    center : boolean, optional, default=True
        If True, add offset of beta to normalized tensor, If False,
        beta is ignored.

    scale : boolean, optional, default=True
        If True, multiply by gamma. If False, gamma is not used.

    Returns
    -------
    result : tvm.relay.Expr
        The normalized data.
    """
    return _make.layer_norm(data, gamma, beta, axis, epsilon, center, scale)


1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734
def batch_matmul(x, y):
    r"""
    Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data
    in batch.

    .. math::

        \mbox{batch_matmul}(x, y)[i, :, :] = \mbox{matmul}(x[i, :, :], y[i, :, :]^T)

    Parameters
    ----------
    x : tvm.relay.Expr
        The first input.

    y : tvm.relay.Expr
        The second input.

    Returns
    -------
    result: tvm.relay.Expr
        The computed result.
    """
    return _make.batch_matmul(x, y)

1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767
def sparse_dense(data, weight):
    r"""
    Computes the matrix multiplication of `data` and `weight`, where `data` is
    a dense matrix and `weight` is a sparse (either BSR or CSR) namedtuple with
    fields `data`, `indices`, and `indptr`.

    .. math::

        \mbox{sparse_dense}(data, weight)[m, n] = \mbox{matmul}(x, \mbox{as_dense}(weight)^T)[m, n]

    where `as_dense` returns dense equivalent of the given sparse matrix.

    See
    https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html
    and
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.bsr_matrix.html
    for more detail on the sparse matrix representation.

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data for the matrix multiplication

    weight : namedtuple.
        The sparse weight matrix for the matrix multiplication.

    Returns
    -------
    result: tvm.relay.Expr
        The computed result.
    """
    return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)

1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794
def sparse_transpose(x):
    r"""
    Computes the fast matrix transpose of x,
    where x is a sparse tensor in CSR format (represented as a namedtuple
    with fields `data`, `indices`, and `indptr`).

    ** Currently only support Square Matrices **

    .. math::

        \mbox{sparse_transpose}(x)[n, n] = (x^T)[n, n]

    Please refer to https://github.com/scipy/scipy/blob/v1.3.0/scipy/sparse/csr.py
    for the algorithm implemented in this operator.

    Parameters
    ----------
    x : namedtuple.
        The sparse weight matrix for the fast matrix transpose.

    Returns
    -------
    result : relay.Tuple([tvm.relay.Expr, tvm.relay.Expr, tvm.relay.Expr])
        Tuple of output sparse tensor (same shape and format as input),
        i.e. if CSR then output is in ([data, indices, indptr]) form
    """
    return TupleWrapper(_make.sparse_transpose(x.data, x.indices, x.indptr), 3)
1795

1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825
def contrib_conv2d_winograd_without_weight_transform(data,
                                                     weight,
                                                     tile_size,
                                                     strides=(1, 1),
                                                     padding=(0, 0),
                                                     dilation=(1, 1),
                                                     groups=1,
                                                     channels=None,
                                                     kernel_size=None,
                                                     data_layout="NCHW",
                                                     kernel_layout="OIHW",
                                                     out_layout="",
                                                     out_dtype=""):
    r"""2D convolution with winograd algorithm.

    The basic parameters are the same as the ones in vanilla conv2d.
    It assumes the weight is pre-transformed by nn.contrib_conv2d_winograd_weight_transform

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    weight : tvm.relay.Expr
        The weight expressions.

    tile_size : int
        The Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)

    strides : tuple of int, optional
1826
        The strides of convolution.
1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859

    padding : tuple of int, optional
        The padding of convolution on both sides of inputs before convolution.

    dilation : tuple of int, optional
        Specifies the dilation rate to be used for dilated convolution.

    groups : int, optional
        Number of groups for grouped convolution.

    channels : int, optional
        Number of output channels of this convolution.

    kernel_size : tuple of int, optional
        The spatial of the convolution kernel.

    data_layout : str, optional
        Layout of the input.

    kernel_layout : str, optional
        Layout of the weight.

    out_layout : str, optional
        Layout of the output, by default, out_layout is the same as data_layout

    out_dtype : str, optional
        Specifies the output data type for mixed precision conv2d.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
1860 1861
    # convert 2-way padding to 4-way padding
    padding = get_pad_tuple2d(padding)
1862 1863 1864 1865 1866 1867
    return _make.contrib_conv2d_winograd_without_weight_transform(
        data, weight, tile_size, strides, padding, dilation,
        groups, channels, kernel_size, data_layout,
        kernel_layout, out_layout, out_dtype)


1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894
def contrib_conv2d_nchwc(data,
                         kernel,
                         strides=(1, 1),
                         padding=(0, 0),
                         dilation=(1, 1),
                         groups=1,
                         channels=None,
                         kernel_size=None,
                         data_layout="NCHW8c",
                         kernel_layout="OIHW",
                         out_layout="",
                         out_dtype=""):
    r"""Variant of 2D convolution.

    This operator takes the weight as the convolution kernel
    and convolves it with data to produce an output, following a specialized
    NCHWc data layout.

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    kernel : tvm.relay.Expr
        The kernel expressions.

    strides : tuple of int, optional
1895
        The strides of convolution.
1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928

    padding : tuple of int, optional
        The padding of convolution on both sides of inputs before convolution.

    dilation : tuple of int, optional
        Specifies the dilation rate to be used for dilated convolution.

    groups : int, optional
        Number of groups for grouped convolution.

    channels : int, optional
        Number of output channels of this convolution.

    kernel_size : tuple of int, optional
        The spatial of the convolution kernel.

    data_layout : str, optional
        Layout of the input.

    kernel_layout : str, optional
        Layout of the weight.

    out_layout : str, optional
        Layout of the output, by default, out_layout is the same as data_layout

    out_dtype : str, optional
        Specifies the output data type for mixed precision conv2d.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
1929 1930
    # convert 2-way padding to 4-way padding
    padding = get_pad_tuple2d(padding)
1931 1932 1933 1934
    return _make.contrib_conv2d_NCHWc(data, kernel, strides, padding, dilation,
                                      groups, channels, kernel_size, data_layout,
                                      kernel_layout, out_layout, out_dtype)

1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961
def contrib_depthwise_conv2d_nchwc(data,
                                   kernel,
                                   strides=(1, 1),
                                   padding=(0, 0),
                                   dilation=(1, 1),
                                   groups=1,
                                   channels=None,
                                   kernel_size=None,
                                   data_layout="NCHW8c",
                                   kernel_layout="OIHW",
                                   out_layout="",
                                   out_dtype=""):
    r"""Variant of 2D depthwise convolution.

    This operator takes the weight as the depthwise convolution kernel
    and depthwise convolves it with data to produce an output, following a specialized
    NCHWc data layout.

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    kernel : tvm.relay.Expr
        The kernel expressions.

    strides : tuple of int, optional
1962
        The strides of convolution.
1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995

    padding : tuple of int, optional
        The padding of convolution on both sides of inputs before convolution.

    dilation : tuple of int, optional
        Specifies the dilation rate to be used for dilated convolution.

    groups : int, optional
        Number of groups for grouped convolution.

    channels : int, optional
        Number of output channels of this convolution.

    kernel_size : tuple of int, optional
        The spatial of the convolution kernel.

    data_layout : str, optional
        Layout of the input.

    kernel_layout : str, optional
        Layout of the weight.

    out_layout : str, optional
        Layout of the output, by default, out_layout is the same as data_layout

    out_dtype : str, optional
        Specifies the output data type for mixed precision conv2d.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
1996 1997
    # convert 2-way padding to 4-way padding
    padding = get_pad_tuple2d(padding)
1998 1999 2000
    return _make.contrib_depthwise_conv2d_NCHWc(data, kernel, strides, padding, dilation,
                                                groups, channels, kernel_size, data_layout,
                                                kernel_layout, out_layout, out_dtype)
2001

2002

2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023
def contrib_conv2d_winograd_weight_transform(weight,
                                             tile_size):
    r"""Weight Transformation part for 2D convolution with winograd algorithm.

    We separate this as a single op to enable pre-compute for inference.
    Use this together with nn.contrib_conv2d_winograd_without_weight_transform

    Parameters
    ----------
    weight : tvm.relay.Expr
        The weight expressions.

    tile_size : int
        The Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size)
hlu1 committed
2024 2025


2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048
def contrib_conv3d_winograd_weight_transform(weight,
                                             tile_size):
    r"""Weight Transformation part for 3D convolution with winograd algorithm.

    We separate this as a single op to enable pre-compute for inference.
    Use this together with nn.contrib_conv3d_winograd_without_weight_transform

    Parameters
    ----------
    weight : tvm.relay.Expr
        The weight expressions.

    tile_size : int
        The Tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    return _make.contrib_conv3d_winograd_weight_transform(weight, tile_size)


hlu1 committed
2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071
def contrib_conv2d_winograd_nnpack_weight_transform(weight,
                                                    convolution_algorithm,
                                                    out_dtype=""):
    r"""Weight Transformation part for 2D convolution with winograd algorithm.

    We separate this as a single op to enable pre-compute for inference.
    Use this together with nn.contrib_conv2d_winograd_without_weight_transform

    Parameters
    ----------
    weight : tvm.relay.Expr
        The weight expressions.

    convolution_algorithm : int
        The Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    return _make.contrib_conv2d_winograd_nnpack_weight_transform(
        weight, convolution_algorithm, out_dtype)
2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103


def deformable_conv2d(data,
                      offset,
                      weight,
                      strides=(1, 1),
                      padding=(0, 0),
                      dilation=(1, 1),
                      deformable_groups=1,
                      groups=1,
                      channels=None,
                      kernel_size=None,
                      data_layout='NCHW',
                      kernel_layout='OIHW',
                      out_layout='',
                      out_dtype=''):
    r""" Deformable 2d convolution.

    The deformable convolution operation is described in https://arxiv.org/abs/1703.06211

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    offset : tvm.relay.Expr
        The offset expressions.

    weight : tvm.relay.Expr
        The weight expressions.

    strides : tuple of int, optional
2104
        The strides of convolution.
2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141

    padding : tuple of int, optional
        The padding of convolution on both sides of inputs before convolution.

    dilation : tuple of int, optional
        Specifies the dilation rate to be used for dilated convolution.

    deformable_groups : int, optional
        Number of deformable groups.

    groups : int, optional
        Number of groups for grouped convolution.

    channels : int, optional
        Number of output channels of this convolution.

    kernel_size : tuple of int, optional
        The spatial of the convolution kernel.

    data_layout : str, optional
        Layout of the input.

    kernel_layout : str, optional
        Layout of the weight.

    out_layout : str, optional
        Layout of the output, by default, out_layout is the same as data_layout

    out_dtype : str, optional
        Specifies the output data type for mixed precision conv2d.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.

    """
2142 2143
    # convert 2-way padding to 4-way padding
    padding = get_pad_tuple2d(padding)
2144 2145 2146
    return _make.deformable_conv2d(data, offset, weight, strides, padding, dilation,
                                   deformable_groups, groups, channels, kernel_size, data_layout,
                                   kernel_layout, out_layout, out_dtype)
2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252


def bitpack(data,
            bits=1,
            pack_axis=1,
            bit_axis=2,
            pack_type="uint32",
            name="BitPack"):
    r"""Tensor packing for bitserial operations.
    The values along the input tensor's pack_axis are quantized
    and packed together into the specified pack_type in a new
    bit axis.

    For example, consider bitpacking with data to be a tensor with shape [1, 64, 128, 128],
    pack_axis=1, bit_axis=4, pack_type=uint8, and bits=2. The output in this case will
    be of shape [1, 8, 128, 128, 2]. The dimension of axis 1 has been reduced by a factor
    of 8 since each value is packed into an 8-bit uint8. Axis 4 is now two bitplanes
    representing the quantized value of the incoming data. The output tensor is now
    ready to be used in a bitserial operation.

    Parameters
    ----------
    data : tvm.relay.expr
        The incoming tensor to be packed.

    bits : int
        Number of bits that should be packed.

    pack_axis : int
        Axis that should be decomposed and packed.

    bit_axis : int
        New axis containing bitplane.

    pack_type : str
        Datatype to pack bits into.

    name : str, optional
        Name of the operation.

    Returns
    -------
    result : tvm.relay.Expr
        The packed tensor.
    """
    return _make.bitpack(data, bits, pack_axis, bit_axis, pack_type, name)


def bitserial_conv2d(data,
                     weight,
                     strides=(1, 1),
                     padding=(0, 0),
                     channels=None,
                     kernel_size=(3, 3),
                     activation_bits=1,
                     weight_bits=1,
                     data_layout='NCHW',
                     kernel_layout='OIHW',
                     pack_dtype='uint32',
                     out_dtype='int16',
                     unipolar=True):
    r"""2D convolution using bitserial computation.

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    weight : tvm.relay.Expr
        The weight expressions.

    strides : tuple of int, optional
        The strides of convolution.

    padding : tuple of int, optional
        The padding of convolution on both sides of inputs before convolution.

    channels : int, optional
        Number of output channels of this convolution.

    kernel_size : tuple of int, optional
        The spatial of the convolution kernel.

    activation_bits : int
        Number of bits to pack for activations.

    weight_bits : int
        Number of bits to pack for weights.

    data_layout : str, optional
        Layout of the input.

    kernel_layout : str, optional
        Layout of the kernel

    pack_dtype: str, optional
        Datatype to pack bits into.

    out_dtype : str, optional
        Specifies the output data type for mixed precision conv2d.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
2253 2254
    # convert 2-way padding to 4-way padding
    padding = get_pad_tuple2d(padding)
2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309
    return _make.bitserial_conv2d(data, weight, strides, padding, channels,
                                  kernel_size, activation_bits, weight_bits,
                                  data_layout, kernel_layout, pack_dtype,
                                  out_dtype, unipolar)


def bitserial_dense(data,
                    weight,
                    units=None,
                    data_bits=1,
                    weight_bits=1,
                    pack_dtype='uint32',
                    out_dtype='int16',
                    unipolar=True):
    """Bitserial Dense operator.
    Applies matrix multiplication of two quantized matrices
    using a fast bitserial algorithm.

    .. math::

    `Y = X * W`

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    weight : tvm.relay.Expr
        The weight expressions.

    units : int, optional
        Number of hidden units of the dense transformation.

    data_bits : int
        Number of bits incoming tensor should be packed with.

    weight_bits : int
        Number of bits weight tensor should be packed with.

    pack_dtype : str, optional
        Datatype to pack individual bits into before computation.

    out_dtype : str, optional
        Specifies the output data type for mixed precision dense.

    unipolar : bool, optional
        Whether to use unipolar or bipolar quantization for inputs.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    return _make.bitserial_dense(data, weight, units, data_bits, weight_bits,
                                 pack_dtype, out_dtype, unipolar)
2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328


def cross_entropy(predictions, targets):
    """CrossEntropy without logits.

    Parameters
    ----------
    predictions : tvm.relay.Expr
      The predictions.

    targets : tvm.relay.Expr
      The targets.

    Returns
    -------
    result : tvm.relay.Expr
      The computed result.
    """
    return _make.cross_entropy(predictions, targets)
2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347


def cross_entropy_with_logits(predictions, targets):
    """CrossEntropy with logits.

    Parameters
    ----------
    predictions : tvm.relay.Expr
      The predictions.

    targets : tvm.relay.Expr
      The targets.

    Returns
    -------
    result : tvm.relay.Expr
      The computed result.
    """
    return _make.cross_entropy_with_logits(predictions, targets)
2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397


def depth_to_space(data, block_size, layout='NCHW', mode='DCR'):
    """Convert channels into spatial blocks.

    Parameters
    ----------
    data : tvm.relay.Expr
        Input data with channels divisible by block_size**2

    block_size : int
        Size of blocks to convert channels into.

    layout : string
        One of NCHW or NHWC, indicates channel axis.

    mode : string
        One of DCR or CDR, indicates which order channels
        are accessed in.

    Returns
    -------
    result : tvm.relay.Expr
        Tensor with shape [in_batch, in_channel / block_size * block_size,
                           in_height * block_size, in_width * block_size]
    """
    return _make.depth_to_space(data, block_size, layout, mode)


def space_to_depth(data, block_size, layout='NCHW'):
    """Convert spatial blocks into channels.

    Parameters
    ----------
    data : tvm.relay.Expr
        Input data with spatial dimensions divisible by block_size

    block_size : int
        Size of blocks to decompose into channels.

    layout : string
        One of NCHW or NHWC, indicates channel axis.

    Returns
    -------
    result : tvm.relay.Expr
        Tensor with shape [in_batch, in_channel * block_size * block_size,
                           in_height / block_size, in_width / block_size]
    """
    return _make.space_to_depth(data, block_size, layout)
2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491


def adaptive_max_pool2d(data,
                        output_size=None,
                        layout="NCHW"):
    r"""2D adaptive max pooling operator. This operator is experimental.

    This operator takes data as input and does 2D max value calculation
    across each window represented by WxH.


    In the default case, where the data_layout is `NCHW`
    a data Tensor with shape `(batch_size, in_channels, height, width)`,
    to produce an output Tensor with shape
    (batch_size, in_channels, output_height, output_width).

    The pooling kernel and stride sizes are automatically chosen for
    desired output sizes.

    For output_size:
        If this argument is not provided, input height and width will be used
        as output height and width.

        If a single integer is provided for output_size, the output size is
        (N x C x output_size x output_size) for any input (NCHW).

        If a tuple of integers (height, width) are provided for output_size,
        the output size is (N x C x height x width) for any input (NCHW).

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    output_size : tuple of int. optional
        Output height and width.

    layout : str, optional
        Layout of the input.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    output_size = [] or output_size
    return _make.adaptive_max_pool2d(data, output_size, layout)


def adaptive_avg_pool2d(data,
                        output_size=None,
                        layout="NCHW"):
    r"""2D adaptive average pooling operator. This operator is experimental.

    This operator takes data as input and does 2D average value calculation
    across each window represented by WxH.


    In the default case, where the data_layout is `NCHW`
    a data Tensor with shape `(batch_size, in_channels, height, width)`,
    to produce an output Tensor with shape
    (batch_size, in_channels, output_height, output_width).

    The pooling kernel and stride sizes are automatically chosen for
    desired output sizes.

    For output_size:
        If this argument is not provided, input height and width will be used
        as output height and width.

        If a single integer is provided for output_size, the output size is
        (N x C x output_size x output_size) for any input (NCHW).

        If a tuple of integers (height, width) are provided for output_size,
        the output size is (N x C x height x width) for any input (NCHW).

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    output_size : tuple of int. optional
        Output height and width.

    layout : str, optional
        Layout of the input.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    output_size = [] or output_size
    return _make.adaptive_avg_pool2d(data, output_size, layout)
2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561 2562 2563 2564 2565 2566 2567 2568 2569 2570 2571 2572 2573 2574 2575 2576 2577 2578 2579 2580 2581 2582 2583


def adaptive_max_pool3d(data,
                        output_size=None,
                        layout="NCDHW"):
    r"""3D adaptive max pooling operator. This operator is experimental.

    This operator takes data as input and does 3D max value calculation
    across each window represented by DxWxH.

    In the default case, where the data_layout is `NCDHW`
    a data Tensor with shape `(batch_size, in_channels, depth, height, width)`,
    to produce an output Tensor with shape
    (batch_size, in_channels, output_depth, output_height, output_width).

    The pooling kernel and stride sizes are automatically chosen for
    desired output sizes.

    For output_size:
        If this argument is not provided, input depth, height and width will be used
        as output depth, height and width.

        If a single integer is provided for output_size, the output size is
        (N x C x output_size x output_size x output_size) for any input (NCDHW).

        If a tuple of integers (depth, height, width) are provided for output_size,
        the output size is (N x C x depth x height x width) for any input (NCDHW).

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    output_size : tuple of int. optional
        Output height and width.

    layout : str, optional
        Layout of the input.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    output_size = [] or output_size
    return _make.adaptive_max_pool3d(data, output_size, layout)


def adaptive_avg_pool3d(data,
                        output_size=None,
                        layout="NCDHW"):
    r"""3D adaptive avg pooling operator. This operator is experimental.

    This operator takes data as input and does 3D avg value calculation
    across each window represented by DxWxH.

    In the default case, where the data_layout is `NCDHW`
    a data Tensor with shape `(batch_size, in_channels, depth, height, width)`,
    to produce an output Tensor with shape
    (batch_size, in_channels, output_depth, output_height, output_width).

    The pooling kernel and stride sizes are automatically chosen for
    desired output sizes.

    For output_size:
        If this argument is not provided, input depth, height and width will be used
        as output depth, height and width.

        If a single integer is provided for output_size, the output size is
        (N x C x output_size x output_size x output_size) for any input (NCDHW).

        If a tuple of integers (depth, height, width) are provided for output_size,
        the output size is (N x C x depth x height x width) for any input (NCDHW).

    Parameters
    ----------
    data : tvm.relay.Expr
        The input data to the operator.

    output_size : tuple of int. optional
        Output height and width.

    layout : str, optional
        Layout of the input.

    Returns
    -------
    result : tvm.relay.Expr
        The computed result.
    """
    output_size = [] or output_size
    return _make.adaptive_avg_pool3d(data, output_size, layout)