_nn.py 40.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
# pylint: disable=no-else-return, invalid-name, unused-argument, too-many-arguments
18
"""Backend compiler related feature registration"""
19 20
from __future__ import absolute_import

21
import topi
22
from topi.util import get_const_tuple
23 24
from .. import op as reg
from ..op import OpPattern, schedule_injective
25 26 27
from .._tensor import elemwise_shape_func
from ....api import convert
from ....hybrid import script
28

29 30 31 32
# relu
reg.register_schedule("nn.relu", schedule_injective)
reg.register_pattern("nn.relu", OpPattern.ELEMWISE)

33
# softmax
34 35 36 37 38 39
@reg.register_schedule("nn.softmax")
def schedule_softmax(_, outputs, target):
    """Schedule definition of softmax"""
    with target:
        return topi.generic.schedule_softmax(outputs)

40

41
reg.register_pattern("nn.softmax", OpPattern.OPAQUE)
42

43
schedule_broadcast = schedule_injective
44

45

46 47 48 49 50 51
@reg.register_schedule("nn.log_softmax")
def schedule_log_softmax(_, outputs, target):
    """Schedule definition of log_softmax"""
    with target:
        return topi.generic.schedule_softmax(outputs)

52

53
reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
54 55


56 57 58 59
# dense
@reg.register_compute("nn.dense")
def compute_dense(attrs, inputs, out_type, target):
    """Compute definition of dense"""
60 61
    out_dtype = attrs.out_dtype
    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
62
    return [topi.nn.dense(inputs[0], inputs[1], None, out_dtype)]
63

64

65 66 67 68 69
@reg.register_schedule("nn.dense")
def schedule_dense(attrs, outputs, target):
    """Schedule definition of dense"""
    with target:
        return topi.generic.schedule_dense(outputs)
70

71

72 73 74
reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


75 76 77 78 79 80 81 82 83 84 85 86 87 88
@reg.register_compute('nn.fifo_buffer')
def compute_fifo_buffer(attrs, inputs, out_type, target):
    return [topi.nn.fifo_buffer(inputs[0], inputs[1], axis=attrs.get_int('axis'))]


@reg.register_schedule('nn.fifo_buffer')
def schedule_fifo_buffer(attrs, outputs, target):
    with target:
        return topi.generic.schedule_injective(outputs)


reg.register_pattern("nn.fifo_buffer", OpPattern.OPAQUE)


89 90 91 92
# batch_matmul
@reg.register_compute("nn.batch_matmul")
def compute_batch_matmul(attrs, inputs, out_type, target):
    """Compute definition of batch_matmul"""
93 94
    with target:
        return [topi.nn.batch_matmul(inputs[0], inputs[1])]
95

96

97 98 99 100 101 102
@reg.register_schedule("nn.batch_matmul")
def schedule_batch_matmul(attrs, outputs, target):
    """Schedule definition of batch_matmul"""
    with target:
        return topi.generic.schedule_batch_matmul(outputs)

103

104 105
reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)

106 107 108 109 110 111 112 113 114 115 116 117 118
# sparse_dense
@reg.register_compute("nn.sparse_dense")
def compute_sparse_dense(attrs, inputs, out_type, target):
    """Compute definition of sparse_dense"""
    return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])]

@reg.register_schedule("nn.sparse_dense")
def schedule_sparse_dense(attrs, outputs, target):
    """Schedule definition of batch_matmul"""
    with target:
        return topi.generic.schedule_sparse_dense(outputs)

reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
119

120 121 122 123 124 125 126 127 128 129 130 131 132 133
# sparse_transpose
@reg.register_compute("nn.sparse_transpose")
def compute_sparse_transpose(attrs, inputs, out_type, target):
    """Compute definition of sparse_transpose"""
    return topi.nn.sparse_transpose(inputs[0], inputs[1], inputs[2])

@reg.register_schedule("nn.sparse_transpose")
def schedule_sparse_transpose(attrs, outputs, target):
    """Schedule definition of batch_matmul"""
    with target:
        return topi.generic.schedule_sparse_transpose(outputs)

reg.register_pattern("nn.sparse_transpose", reg.OpPattern.OUT_ELEMWISE_FUSABLE)

134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169

# Conv1D
@reg.register_compute("nn.conv1d")
def compute_conv1d(attrs, inputs, out_type, target):
    """Compute definition of conv1d"""
    strides = get_const_tuple(attrs.strides)
    padding = get_const_tuple(attrs.padding)
    dilation = get_const_tuple(attrs.dilation)
    layout = attrs.data_layout
    out_dtype = attrs.out_dtype
    out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
                 else out_dtype)

    assert layout in ["NCW", "NWC"]
    if dilation[0] < 1:
        raise ValueError("dilation should be a positive value")

    return [topi.nn.conv1d(inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype)]


@reg.register_schedule("nn.conv1d")
def schedule_conv1d(attrs, outs, target):
    """Schedule definition of conv1d"""
    layout = attrs.data_layout

    with target:
        if layout == "NCW":
            return topi.generic.schedule_conv1d_ncw(outs)
        elif layout == "NCW":
            return topi.generic.schedule_conv1d_nwc(outs)
    raise ValueError("No compatible schedule")


reg.register_pattern("nn.conv1d", OpPattern.OUT_ELEMWISE_FUSABLE)


170
# conv2d
171 172 173 174 175 176 177 178 179 180
def _find_conv2d_op(op):
    """Find the op with conv2d in its tag by traversing."""
    if 'conv2d' in op.tag:
        return op
    for tensor in op.input_tensors:
        op_ = _find_conv2d_op(tensor.op)
        if op_ is not None:
            return op_
    return None

181 182 183 184 185 186 187 188
@reg.register_compute("nn.conv2d")
def compute_conv2d(attrs, inputs, out_type, target):
    """Compute definition of conv2d"""
    padding = get_const_tuple(attrs.padding)
    strides = get_const_tuple(attrs.strides)
    dilation = get_const_tuple(attrs.dilation)
    groups = attrs.groups
    layout = attrs.data_layout
189
    kernel_layout = attrs.kernel_layout
190
    out_dtype = attrs.out_dtype
191
    out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
192 193
                 else out_dtype)

194
    assert layout in ["NCHW", "NHWC", "NCHW4c", "HWCN"]
195 196 197 198
    (dilation_h, dilation_w) = dilation
    if dilation_h < 1 or dilation_w < 1:
        raise ValueError("dilation should be positive value")

199 200
    def _get_out_depth():
        weight_shape = get_const_tuple(inputs[1].shape)
201
        # NHWC layout
202
        if kernel_layout.startswith("HW"):
203
            return weight_shape[2] * weight_shape[3]
204 205 206 207 208 209 210 211
        # NCHW layout.
        # in ARM CPU contrib_spatial_pack schedule, we will prepack weight layout
        if len(weight_shape) == 4:
            return weight_shape[0] * weight_shape[1]
        else:
            assert len(weight_shape) == 5
            C, M, _, _, VC = weight_shape
            return C * VC * M
212

213 214 215
    if groups == 1:
        out = topi.nn.conv2d(
            inputs[0], inputs[1], strides, padding,
216
            dilation, layout, out_dtype)
217
    elif layout == "NCHW" and _get_out_depth() == groups:
218
        out = topi.nn.depthwise_conv2d_nchw(
219
            inputs[0], inputs[1], strides, padding, dilation, out_dtype)
220
    elif layout == "NHWC" and kernel_layout == "HWOI" and _get_out_depth() == groups:
221
        out = topi.nn.depthwise_conv2d_nhwc(
222
            inputs[0], inputs[1], strides, padding, dilation, out_dtype)
223 224
    elif layout in ['NCHW', 'NCHW4c']:
        out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups,
225
                                        out_dtype)
226 227 228 229 230 231 232 233 234 235
    else:
        raise ValueError("not support arbitrary group number for now")
    return [out]


@reg.register_schedule("nn.conv2d")
def schedule_conv2d(attrs, outs, target):
    """Schedule definition of conv2d"""
    groups = attrs.groups
    layout = attrs.data_layout
236
    kernel_layout = attrs.kernel_layout
237

238 239 240
    with target:
        if groups == 1 and layout == "NCHW":
            return topi.generic.schedule_conv2d_nchw(outs)
241
        elif groups == 1 and layout == "NCHW4c":
242
            return topi.generic.schedule_conv2d_nchw(outs)
243
        elif groups == 1 and layout == "NHWC":
244
            return topi.generic.schedule_conv2d_nhwc(outs)
245 246 247
        elif groups == 1 and layout == "HWCN":
            return topi.generic.schedule_conv2d_hwcn(outs)
        elif groups != 1:
248 249 250 251 252 253 254 255 256 257 258 259 260 261
            # collect in_channels to distinguish depthwise and group conv2d
            op = _find_conv2d_op(outs[0].op)
            assert op is not None

            is_depthwise = 'depthwise' in op.tag
            if is_depthwise:
                if layout == "NCHW":
                    # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d.
                    return topi.generic.schedule_depthwise_conv2d_nchw(outs)
                if layout == "NHWC" and kernel_layout == "HWOI":
                    return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
            else:
                if layout in ["NCHW", "NCHW4c"]:
                    return topi.generic.schedule_group_conv2d_nchw(outs)
262 263
    raise ValueError("No compatible schedule")

264 265 266 267

@reg.register_alter_op_layout("nn.conv2d")
def alter_op_layout_conv2d(attrs, inputs, tinfos):
    """Alternate the layout of conv2d"""
268 269
    from ... import op
    return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
270

271
@reg.register_legalize("nn.conv2d")
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
def legalize_conv2d(attrs, inputs, types):
    """Legalize conv2d op.

    Parameters
    ----------
    attrs : tvm.attrs.Attrs
        Attributes of current convolution
    inputs : list of tvm.relay.Expr
        The args of the Relay expr to be legalized
    types : list of types
        List of input and output types

    Returns
    -------
    result : tvm.relay.Expr
        The legalized expr
    """
    return topi.nn.conv2d_legalize(attrs, inputs, types)
290

291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331

@reg.register_convert_op_layout("nn.conv2d")
def convert_conv2d(attrs, inputs, tinfos, desired_layout):
    """Convert Layout pass registration for conv2d op.

    Parameters
    ----------
    attrs : tvm.attrs.Attrs
        Attributes of current convolution
    inputs : list of tvm.relay.Expr
        The args of the Relay expr to be legalized
    tinfos : list of types
        List of input and output types
    desired_layout : str
        The desired layout

    Returns
    -------
    result : tvm.relay.Expr
        The transformed expr
    """

    from tvm import relay
    data_layout = attrs['data_layout']
    kernel_layout = attrs['kernel_layout']
    data, weight = inputs
    assert desired_layout == 'NCHW', \
            "Currently only transformation to NCHW layout is supported."
    if desired_layout == 'NCHW':
        new_attrs = dict(attrs)
        new_attrs['data_layout'] = desired_layout
        new_attrs['kernel_layout'] = 'OIHW'

        if data_layout == 'NHWC' and kernel_layout == 'HWIO':
            # Convert (NHWC, HWIO) to (NCHW, OIHW)
            return relay.nn.conv2d(data, weight, **new_attrs)
        if data_layout == 'NHWC' and kernel_layout == 'HWOI':
            # Convert (NHWC, HWOI) to (NCHW, OIHW). Depthwise conv2d.
            return relay.nn.conv2d(data, weight, **new_attrs)
    return None

332 333 334 335 336 337 338 339 340 341 342 343 344
reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)


# conv2d_transpose
@reg.register_compute("nn.conv2d_transpose")
def compute_conv2d_transpose(attrs, inputs, out_dtype, target):
    """Compute definition of conv2d_transpose"""
    padding = get_const_tuple(attrs.padding)
    strides = get_const_tuple(attrs.strides)
    dilation = get_const_tuple(attrs.dilation)
    groups = attrs.groups
    layout = attrs.data_layout
    out_dtype = attrs.out_dtype
345
    out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
346 347 348 349
                 else out_dtype)
    assert layout == "NCHW", "only support nchw for now"
    assert dilation == (1, 1), "not support dilate now"
    assert groups == 1, "only support groups == 1 for now"
350
    out = topi.nn.conv2d_transpose_nchw(
351 352 353 354
        inputs[0], inputs[1], strides, padding, out_dtype)
    output_padding = get_const_tuple(attrs.output_padding)
    out = topi.nn.pad(out,
                      [0, 0, 0, 0], [0, 0, output_padding[0], output_padding[1]])
355 356
    return [out]

357

358 359 360 361 362 363 364 365 366 367 368 369
@reg.register_compute("nn.conv3d")
def compute_conv3d(attrs, inputs, out_type, target):
    """Compute definition of conv3d"""
    padding = get_const_tuple(attrs.padding)
    strides = get_const_tuple(attrs.strides)
    dilation = get_const_tuple(attrs.dilation)
    groups = attrs.groups
    layout = attrs.data_layout
    out_dtype = attrs.out_dtype
    out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
                 else out_dtype)

370
    assert layout in ["NCDHW", "NDHWC"]
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
    (dilation_d, dilation_h, dilation_w) = dilation
    if dilation_d < 1 or dilation_h < 1 or dilation_w < 1:
        raise ValueError("dilation should be positive value")

    if groups == 1:
        out = topi.nn.conv3d(
            inputs[0], inputs[1], strides, padding,
            dilation, layout, out_dtype)
    else:
        raise ValueError("not support arbitrary group number for now")
    return [out]


@reg.register_schedule("nn.conv3d")
def schedule_conv3d(attrs, outs, target):
    """Schedule definition of conv3d"""
    groups = attrs.groups
    layout = attrs.data_layout

    with target:
        if groups == 1 and layout == "NCDHW":
            return topi.generic.schedule_conv3d_ncdhw(outs)
393 394
        elif groups == 1 and layout == "NDHWC":
            return topi.generic.schedule_conv3d_ndhwc(outs)
395 396 397 398 399 400 401

    raise ValueError("No compatible schedule")


reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE)


402 403 404 405 406 407
@reg.register_schedule("nn.conv2d_transpose")
def schedule_conv2d_transpose(attrs, outs, target):
    """Schedule definition of conv2d_transpose"""
    with target:
        return topi.generic.schedule_conv2d_transpose_nchw(outs)

408

409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428
@reg.register_legalize("nn.conv2d_transpose")
def legalize_conv2d_transpose(attrs, inputs, types):
    """Legalize conv2d_transpose op.

    Parameters
    ----------
    attrs : tvm.attrs.Attrs
        Attributes of current Transposed convolution
    inputs : list of tvm.relay.Expr
        The args of the Relay expr to be legalized
    types : list of types
        List of input and output types

    Returns
    -------
    result : tvm.relay.Expr
        The legalized expr
    """
    return topi.nn.conv2d_transpose_legalize(attrs, inputs, types)

429 430
reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)

431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
# conv1d_transpose
@reg.register_compute("nn.conv1d_transpose")
def compute_conv1d_transpose(attrs, inputs, out_dtype, target):
    """Compute definition of conv1d_transpose"""
    padding = get_const_tuple(attrs.padding)
    strides = get_const_tuple(attrs.strides)
    dilation = get_const_tuple(attrs.dilation)
    groups = attrs.groups
    layout = attrs.data_layout
    out_dtype = attrs.out_dtype
    out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
                 else out_dtype)
    assert layout == "NCW", "conv1d_transpose ncw only supported"
    assert dilation == (1,), "conv1d_transpose dilation is not supported"
    assert groups == 1, "conv1d_transpose groups == 1 only supported"
    out = topi.nn.conv1d_transpose_ncw(
447 448 449 450
        inputs[0], inputs[1], strides, padding, out_dtype)
    output_padding = get_const_tuple(attrs.output_padding)
    out = topi.nn.pad(out,
                      [0, 0, 0], [0, 0, output_padding[0]])
451 452 453 454 455 456 457 458 459 460 461
    return [out]


@reg.register_schedule("nn.conv1d_transpose")
def schedule_conv1d_transpose(attrs, outs, target):
    """Schedule definition of conv1d_transpose"""
    with target:
        return topi.generic.schedule_conv1d_transpose_ncw(outs)

reg.register_pattern("nn.conv1d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)

462 463 464 465 466
# bias_add
reg.register_schedule("nn.bias_add", schedule_injective)
reg.register_pattern("nn.bias_add", OpPattern.BROADCAST)


467 468 469 470 471 472 473 474 475 476 477 478
# max_pool1d
@reg.register_schedule("nn.max_pool1d")
def schedule_max_pool1d(attrs, outs, target):
    """Schedule definition of max_pool1d"""
    layout = attrs.layout
    with target:
        return topi.generic.schedule_pool(outs, layout)


reg.register_pattern("nn.max_pool1d", OpPattern.OUT_ELEMWISE_FUSABLE)


479 480 481 482 483 484 485 486
# max_pool2d
@reg.register_schedule("nn.max_pool2d")
def schedule_max_pool2d(attrs, outs, target):
    """Schedule definition of max_pool2d"""
    layout = attrs.layout
    with target:
        return topi.generic.schedule_pool(outs, layout)

487

488 489 490
reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


491 492 493 494 495 496 497 498 499 500 501 502
# max_pool3d
@reg.register_schedule("nn.max_pool3d")
def schedule_max_pool3d(attrs, outs, target):
    """Schedule definition of max_pool3d"""
    layout = attrs.layout
    with target:
        return topi.generic.schedule_pool(outs, layout)


reg.register_pattern("nn.max_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)


503 504 505 506 507 508 509 510 511 512 513 514
# avg_pool1d
@reg.register_schedule("nn.avg_pool1d")
def schedule_avg_pool1d(attrs, outs, target):
    """Schedule definition of avg_pool1d"""
    layout = attrs.layout
    with target:
        return topi.generic.schedule_pool(outs, layout)


reg.register_pattern("nn.avg_pool1d", OpPattern.OUT_ELEMWISE_FUSABLE)


515 516 517 518 519 520 521 522 523
# avg_pool2d
@reg.register_schedule("nn.avg_pool2d")
def schedule_avg_pool2d(attrs, outs, target):
    """Schedule definition of avg_pool2d"""
    layout = attrs.layout
    with target:
        return topi.generic.schedule_pool(outs, layout)

reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
524 525


526 527 528 529 530 531 532 533 534 535 536 537
# avg_pool3d
@reg.register_schedule("nn.avg_pool3d")
def schedule_avg_pool3d(attrs, outs, target):
    """Schedule definition of avg_pool3d"""
    layout = attrs.layout
    with target:
        return topi.generic.schedule_pool(outs, layout)


reg.register_pattern("nn.avg_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)


538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557
# max_pool2d_grad
@reg.register_schedule("nn.max_pool2d_grad")
def schedule_max_pool2d_grad(attrs, outs, target):
    """Schedule definition of max_pool2d_grad"""
    with target:
        return topi.generic.schedule_pool_grad(outs)


reg.register_pattern("nn.max_pool2d_grad", OpPattern.OUT_ELEMWISE_FUSABLE)


# avg_pool2d_grad
@reg.register_schedule("nn.avg_pool2d_grad")
def schedule_avg_pool2d_grad(attrs, outs, target):
    """Schedule definition of avg_pool2d_grad"""
    with target:
        return topi.generic.schedule_pool_grad(outs)


reg.register_pattern("nn.avg_pool2d_grad", OpPattern.OUT_ELEMWISE_FUSABLE)
558 559 560 561 562 563 564


# global_max_pool2d
@reg.register_schedule("nn.global_max_pool2d")
def schedule_global_max_pool2d(_, outs, target):
    """Schedule definition of global_max_pool2d"""
    with target:
565
        return topi.generic.schedule_adaptive_pool(outs)
566

567

568 569 570 571 572 573 574 575
reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


# global_avg_pool2d
@reg.register_schedule("nn.global_avg_pool2d")
def schedule_global_avg_pool2d(_, outs, target):
    """Schedule definition of global_avg_pool2d"""
    with target:
576
        return topi.generic.schedule_adaptive_pool(outs)
577

578

579
reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
580

581

582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602
# leaky_relu
reg.register_schedule("nn.leaky_relu", schedule_broadcast)
reg.register_pattern("nn.leaky_relu", OpPattern.ELEMWISE)

# prelu
reg.register_schedule("nn.prelu", schedule_broadcast)
reg.register_pattern("nn.prelu", OpPattern.BROADCAST)

# flatten
reg.register_schedule("nn.batch_flatten", schedule_broadcast)
reg.register_pattern("nn.batch_flatten", OpPattern.INJECTIVE)


# lrn
@reg.register_compute("nn.lrn")
def compute_lrn(attrs, inputs, out_dtype, target):
    """Compute definition of lrn"""
    assert len(inputs) == 1
    return [topi.nn.lrn(inputs[0], attrs.size, attrs.axis,
                        attrs.alpha, attrs.beta, attrs.bias)]

603

604 605 606 607 608 609
@reg.register_schedule("nn.lrn")
def schedule_lrn(attrs, outs, target):
    """Schedule definition of lrn"""
    with target:
        return topi.generic.schedule_lrn(outs)

610

611 612 613
reg.register_pattern("nn.lrn", OpPattern.OPAQUE)


614
# upsampling
615
reg.register_schedule("nn.upsampling", reg.schedule_injective)
616 617


618 619 620 621
def schedule_upsampling(_, outs, target):
    """Schedule definition of upsampling"""
    with target:
        return topi.generic.schedule_injective(outs)
622

623 624
@reg.register_compute("nn.upsampling")
def compute_upsampling(attrs, inputs, out_dtype, target):
625 626
    scale_h = attrs.scale_h
    scale_w = attrs.scale_w
627 628 629
    layout = attrs.layout
    method = attrs.method
    align_corners = attrs.align_corners
630
    return [topi.nn.upsampling(inputs[0], scale_h, scale_w, layout, method, align_corners)]
631

632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650
# upsampling3d
reg.register_schedule("nn.upsampling3d", reg.schedule_injective)

def schedule_upsampling3d(_, outs, target):
    """Schedule definition of upsampling3d"""
    with target:
        return topi.generic.schedule_injective(outs)

@reg.register_compute("nn.upsampling3d")
def compute_upsampling3d(attrs, inputs, out_dtype, target):
    scale_d = attrs.scale_d
    scale_h = attrs.scale_h
    scale_w = attrs.scale_w
    layout = attrs.layout
    method = attrs.method
    coordinate_transformation_mode = attrs.coordinate_transformation_mode
    return [topi.nn.upsampling3d(inputs[0], scale_d, scale_h, scale_w, layout, method,\
        coordinate_transformation_mode)]

651 652
# pad
reg.register_schedule("nn.pad", schedule_broadcast)
653

654 655 656 657 658 659 660 661 662 663
# mirror_pad
reg.register_schedule("nn.mirror_pad", schedule_broadcast)

@reg.register_compute("nn.mirror_pad")
def compute_mirror_pad(attrs, inputs, out_dtype, target):
    pad_before, pad_after = list(zip(*attrs.pad_width))
    mode = attrs.mode
    out = topi.nn.mirror_pad(inputs[0], pad_before=pad_before, pad_after=pad_after, mode=mode)
    return [out]

664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
# winograd related operators
@reg.register_compute("nn.contrib_conv2d_winograd_without_weight_transform")
def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, out_dtype, target):
    """Compute definition of conv2d_winograd_without_weight_transform"""
    # pylint: disable=assignment-from-no-return
    padding = attrs.get_int_tuple("padding")
    strides = attrs.get_int_tuple("strides")
    dilation = attrs.get_int_tuple("dilation")
    groups = attrs.get_int("groups")
    data_layout = attrs.get_str("data_layout")
    out_dtype = attrs.get_str("out_dtype")
    tile_size = attrs.get_int("tile_size")
    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
    assert dilation == (1, 1), "Do not support dilate now"
    assert groups == 1, "Do not supoort arbitrary group number"

    out = topi.nn.conv2d_winograd_without_weight_transform(
        inputs[0], inputs[1], strides, padding, dilation, data_layout,
        out_dtype, tile_size)

    return [out]

686

687 688 689 690 691 692
@reg.register_schedule("nn.contrib_conv2d_winograd_without_weight_transform")
def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target):
    """Schedule definition of conv2d_winograd_without_weight_transform"""
    with target:
        return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs)

693

694 695 696 697 698 699 700
reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform",
                     OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_compute("nn.contrib_conv2d_winograd_weight_transform")
def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype, target):
    """Compute definition of contrib_conv2d_winograd_weight_transform"""
701 702
    out = topi.nn.conv2d_winograd_weight_transform(
        inputs[0], attrs.get_int('tile_size'))
703 704
    return [out]

705

706 707 708 709 710 711
@reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform")
def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target):
    """Schedule definition of contrib_conv2d_winograd_weight_transform"""
    with target:
        return topi.generic.schedule_conv2d_winograd_weight_transform(outs)

712

713 714
reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform",
                     OpPattern.OUT_ELEMWISE_FUSABLE)
715

hlu1 committed
716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739

# winograd nnpack related operators
@reg.register_compute("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
def compute_contrib_conv2d_winograd_nnpack_without_weight_transform(
        attrs, inputs, out_dtype, target):
    """Compute definition of conv2d_winograd_nnpack_without_weight_transform"""
    # pylint: disable=assignment-from-no-return
    padding = attrs.get_int_tuple("padding")
    strides = attrs.get_int_tuple("strides")
    dilation = attrs.get_int_tuple("dilation")
    groups = attrs.get_int("groups")
    data_layout = attrs.get_str("data_layout")
    out_dtype = attrs.get_str("out_dtype")
    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
    assert dilation == (1, 1), "Do not support dilate now"
    assert groups == 1, "Do not supoort arbitrary group number"

    # No bias
    out = topi.nn.conv2d_winograd_nnpack_without_weight_transform(
        inputs[0], inputs[1], None, strides, padding, dilation, data_layout,
        out_dtype)

    return [out]

740

hlu1 committed
741 742 743 744 745 746
@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target):
    """Schedule definition of conv2d_winograd_nnpack_without_weight_transform"""
    with target:
        return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs)

747

hlu1 committed
748 749 750 751 752 753 754 755 756 757 758 759
reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_without_weight_transform",
                     OpPattern.OPAQUE)


@reg.register_compute("nn.contrib_conv2d_winograd_nnpack_weight_transform")
def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_dtype, target):
    """Compute definition of contrib_conv2d_winograd_nnpack_weight_transform"""
    convolution_algorithm = attrs.get_int('convolution_algorithm')
    out = topi.nn.conv2d_winograd_nnpack_weight_transform(
        inputs[0], convolution_algorithm, out_dtype)
    return [out]

760

hlu1 committed
761 762 763 764 765 766
@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
    """Schedule definition of contrib_conv2d_winograd_nnpack_weight_transform"""
    with target:
        return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)

767

hlu1 committed
768 769 770 771
reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform",
                     OpPattern.OPAQUE)


772 773 774 775 776 777 778
@reg.register_compute("nn.contrib_conv2d_NCHWc")
def compute_contrib_conv2d_NCHWc(attrs, inputs, out_dtype, target):
    """Compute definition of conv2d NCHWc"""
    # pylint: disable=assignment-from-no-return
    padding = attrs.get_int_tuple("padding")
    strides = attrs.get_int_tuple("strides")
    dilation = attrs.get_int_tuple("dilation")
eqy committed
779
    data_layout = attrs.get_str("data_layout")
780 781 782 783 784
    out_layout = attrs.get_str("out_layout")
    out_dtype = attrs.get_str("out_dtype")
    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype

    out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation,
eqy committed
785
                               data_layout, out_layout, out_dtype)
786 787
    return [out]

788

789 790 791 792 793 794
@reg.register_schedule("nn.contrib_conv2d_NCHWc")
def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
    """Schedule definition of contrib_conv2d_NCHWc"""
    with target:
        return topi.generic.schedule_conv2d_NCHWc(outs)

795

796 797
reg.register_pattern("nn.contrib_conv2d_NCHWc",
                     OpPattern.OUT_ELEMWISE_FUSABLE)
798

799

800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827
@reg.register_compute("nn.contrib_conv2d_NCHWc_int8")
def compute_contrib_conv2d_NCHWc_int8(attrs, inputs, out_dtype, target):
    """Compute definition of conv2d NCHWc"""
    # pylint: disable=assignment-from-no-return
    padding = attrs.get_int_tuple("padding")
    strides = attrs.get_int_tuple("strides")
    dilation = attrs.get_int_tuple("dilation")
    data_layout = attrs.get_str("data_layout")
    out_layout = attrs.get_str("out_layout")
    out_dtype = attrs.get_str("out_dtype")
    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype

    out = topi.nn.conv2d_NCHWc_int8(inputs[0], inputs[1], strides, padding, dilation,
                                    data_layout, out_layout, out_dtype)
    return [out]


@reg.register_schedule("nn.contrib_conv2d_NCHWc_int8")
def schedule_contrib_conv2d_NCHWc_int8(attrs, outs, target):
    """Schedule definition of contrib_conv2d_NCHWc_int8"""
    with target:
        return topi.generic.schedule_conv2d_NCHWc_int8(outs)


reg.register_pattern("nn.contrib_conv2d_NCHWc_int8",
                     OpPattern.OUT_ELEMWISE_FUSABLE)


828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843
@reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc")
def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target):
    """Compute definition of depthwise conv2d NCHWc"""
    # pylint: disable=assignment-from-no-return
    padding = attrs.get_int_tuple("padding")
    strides = attrs.get_int_tuple("strides")
    dilation = attrs.get_int_tuple("dilation")
    data_layout = attrs.get_str("data_layout")
    out_layout = attrs.get_str("out_layout")
    out_dtype = attrs.get_str("out_dtype")
    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype

    out = topi.nn.depthwise_conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation,
                                         data_layout, out_layout, out_dtype)
    return [out]

844

845 846 847 848 849 850
@reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc")
def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target):
    """Schedule definition of contrib_conv2d_NCHWc"""
    with target:
        return topi.generic.schedule_depthwise_conv2d_NCHWc(outs)

851

852 853
reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc",
                     OpPattern.OUT_ELEMWISE_FUSABLE)
854

855

856 857 858 859 860 861 862 863 864 865 866 867 868 869 870
@reg.register_compute("nn.deformable_conv2d")
def compute_deformable_conv2d(attrs, inputs, out_dtype, target):
    """Compute definition of deformable_conv2d"""
    padding = get_const_tuple(attrs.padding)
    strides = get_const_tuple(attrs.strides)
    dilation = get_const_tuple(attrs.dilation)
    deformable_groups = attrs.deformable_groups
    groups = attrs.groups
    out_dtype = attrs.out_dtype
    out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
    with target:
        out = topi.nn.deformable_conv2d_nchw(inputs[0], inputs[1], inputs[2], strides, padding,
                                             dilation, deformable_groups, groups, out_dtype)
    return [out]

871

872 873 874 875 876 877
@reg.register_schedule("nn.deformable_conv2d")
def schedule_deformable_conv2d(attrs, outs, target):
    """Schedule definition of deformable_conv2d"""
    with target:
        return topi.generic.schedule_deformable_conv2d_nchw(outs)

878

879
reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996


@reg.register_compute("nn.bitpack")
def compute_bitpack(attrs, inputs, out_dtype, target):
    """Compute definition for bitpack"""
    bits = attrs.bits
    pack_axis = attrs.pack_axis
    bit_axis = attrs.bit_axis
    pack_type = attrs.pack_type
    name = attrs.name
    with target:
        out = topi.nn.bitpack(inputs[0], bits, pack_axis, bit_axis, pack_type,
                              name)
    return [out]

@reg.register_schedule("nn.bitpack")
def schedule_bitpack(attrs, outs, target):
    with target:
        return topi.generic.schedule_bitpack(outs)

reg.register_pattern("nn.bitpack", OpPattern.INJECTIVE)


@reg.register_compute("nn.bitserial_conv2d")
def compute_bitserial_conv2d(attrs, inputs, out_dtype, target):
    """Compute definition for bitserial conv2d."""
    padding = get_const_tuple(attrs.padding)
    strides = get_const_tuple(attrs.strides)
    activation_bits = attrs.activation_bits
    weight_bits = attrs.weight_bits
    layout = attrs.data_layout
    pack_dtype = attrs.pack_dtype
    out_dtype = attrs.out_dtype
    unipolar = attrs.unipolar
    if layout == 'NCHW':
        with target:
            out = topi.nn.bitserial_conv2d_nchw(
                inputs[0], inputs[1], strides, padding, activation_bits,
                weight_bits, pack_dtype, out_dtype, unipolar)
    elif layout == 'NHWC':
        with target:
            out = topi.nn.bitserial_conv2d_nhwc(
                inputs[0], inputs[1], strides, padding, activation_bits,
                weight_bits, pack_dtype, out_dtype, unipolar)
    else:
        raise ValueError("Data layout not supported.")

    return [out]


@reg.register_schedule("nn.bitserial_conv2d")
def schedule_bitserial_conv2d(attrs, outs, target):
    """Schedule definition for bitserial conv2d."""
    layout = attrs.data_layout
    if layout == 'NCHW':
        with target:
            return topi.generic.schedule_bitserial_conv2d_nchw(outs)
    elif layout == 'NHWC':
        with target:
            return topi.generic.schedule_bitserial_conv2d_nhwc(outs)
    else:
        raise ValueError("Data layout not supported.")

@reg.register_legalize("nn.bitserial_conv2d")
def legalize_bitserial_conv2d(attrs, inputs, types):
    """Legalize bitserial_conv2d op.

    Parameters
    ----------
    attrs : tvm.attrs.Attrs
        Attributes of current convolution
    inputs : list of tvm.relay.Expr
        The args of the Relay expr to be legalized
    types : list of types
        List of input and output types

    Returns
    -------
    result : tvm.relay.Expr
        The legalized expr
    """
    return topi.nn.bitserial_conv2d_legalize(attrs, inputs, types)


reg.register_pattern("nn.bitserial_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)


# bitserial_dense
@reg.register_compute("nn.bitserial_dense")
def compute_bitserial_dense(attrs, inputs, out_type, target):
    """Compute definition of bitserial_dense"""
    data_bits = attrs.data_bits
    weight_bits = attrs.weight_bits
    pack_dtype = attrs.pack_dtype
    out_dtype = attrs.out_dtype
    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
    unipolar = attrs.unipolar
    return [
        topi.nn.bitserial_dense(
            inputs[0],
            inputs[1],
            data_bits,
            weight_bits,
            pack_dtype,
            out_dtype,
            unipolar)
    ]


@reg.register_schedule("nn.bitserial_dense")
def schedule_bitserial_dense(attrs, outputs, target):
    """Schedule definition of bitserial_dense"""
    with target:
        return topi.generic.schedule_bitserial_dense(outputs)


reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
997 998 999 1000 1001 1002 1003 1004


reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE)

@reg.register_compute("nn.cross_entropy")
def compute_cross_entropy(attrs, inputs, out_dtype, target):
    x, y = inputs
    return [-topi.sum(topi.log(x) * y) / x.shape[0]]
1005 1006 1007 1008 1009 1010 1011 1012


reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE)

@reg.register_compute("nn.cross_entropy_with_logits")
def compute_cross_entropy_with_logits(attrs, inputs, out_dtype, target):
    x, y = inputs
    return [-topi.sum(x * y) / x.shape[0]]
1013

1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035

@reg.register_compute("nn.depth_to_space")
def compute_depth_to_space(attrs, inputs, out_dtype, target):
    block_size = attrs.block_size
    layout = attrs.layout
    mode = attrs.mode
    return [topi.nn.depth_to_space(inputs[0], block_size, layout=layout, mode=mode)]

reg.register_schedule("nn.depth_to_space", schedule_injective)
reg.register_pattern("nn.depth_to_space", OpPattern.INJECTIVE)


@reg.register_compute("nn.space_to_depth")
def compute_space_to_depth(attrs, inputs, out_dtype, target):
    block_size = attrs.block_size
    layout = attrs.layout
    return [topi.nn.space_to_depth(inputs[0], block_size, layout=layout)]

reg.register_schedule("nn.space_to_depth", schedule_injective)
reg.register_pattern("nn.space_to_depth", OpPattern.INJECTIVE)


1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197
# shape func
@script
def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn):
    out = output_tensor((dshape.shape[0],), "int64")
    ic_chunk = dshape[1]
    height = dshape[2]
    width = dshape[3]
    ic_bn = dshape[4]
    kheight = kshape[2]
    kwidth = kshape[3]
    dilated_kh = (kheight - 1) * dilation[0] + 1
    dilated_kw = (kwidth - 1) * dilation[1] + 1
    kflatten = int64(1)
    for i in const_range(kshape.shape[0]):
        kflatten *= kshape[i]

    oc = kflatten // (kheight * kwidth * ic_chunk * ic_bn)
    oc_chunk = oc // oc_bn

    out_height = (height + 2 * padding[0] - dilated_kh) // strides[0] + 1
    out_width = (width + 2 * padding[1] - dilated_kw) // strides[1] + 1

    out[0] = dshape[0]
    out[1] = oc_chunk
    out[2] = out_height
    out[3] = out_width
    out[4] = int64(oc_bn)
    return out

@reg.register_shape_func("nn.contrib_conv2d_NCHWc", False)
def conv2d_NCHWc_shape_func(attrs, inputs, _):
    """
    Shape function for contrib_conv2d_NCHWc op.
    """
    strides = get_const_tuple(attrs.strides)
    padding = get_const_tuple(attrs.padding)
    dilation = get_const_tuple(attrs.dilation)
    out_layout = attrs.out_layout
    oc_bn = int(out_layout[4:-1])

    return [_conv2d_NCHWc_shape_func(inputs[0], inputs[1],
                                     convert(strides), convert(padding),
                                     convert(dilation), convert(oc_bn))]

@script
def _pool2d_shape_func(data_shape, pool_size, strides,
                       padding, height_axis, width_axis):
    out = output_tensor((data_shape.shape[0],), "int64")
    for i in const_range(data_shape.shape[0]):
        if i == height_axis:
            out[i] = (data_shape[i] + padding[0] + padding[2] - pool_size[0]) // strides[0] + 1
        elif i == width_axis:
            out[i] = (data_shape[i] + padding[1] + padding[3] - pool_size[1]) // strides[1] + 1
        else:
            out[i] = data_shape[i]

    return out

def pool2d_shape_func(attrs, inputs, _):
    """
    Shape function for pool2d op.
    """
    pool_size = get_const_tuple(attrs.pool_size)
    strides = get_const_tuple(attrs.strides)
    padding = get_const_tuple(attrs.padding)
    layout = attrs.layout
    height_axis = layout.index("H")
    width_axis = layout.index("W")
    if len(padding) == 1:
        padding = [padding[0]] * 4
    elif len(padding) == 2:
        padding = [padding[0], padding[1], padding[0], padding[1]]

    return [_pool2d_shape_func(inputs[0], convert(pool_size),
                               convert(strides), convert(padding),
                               convert(height_axis), convert(width_axis))]

reg.register_shape_func("nn.max_pool2d", False, pool2d_shape_func)
reg.register_shape_func("nn.avg_pool2d", False, pool2d_shape_func)

@script
def _global_pool2d_shape_func(data_shape, height_axis, width_axis):
    out = output_tensor((data_shape.shape[0],), "int64")
    for i in const_range(out.shape[0]):
        if i == height_axis or i == width_axis:
            out[i] = int64(1)
        else:
            out[i] = data_shape[i]

    return out

def global_pool2d_shape_func(attrs, inputs, _):
    """
    Shape function for global pool2d op.
    """
    layout = attrs.layout
    height_axis = width_axis = 1
    for i, letter in enumerate(layout):
        if letter == "H":
            height_axis = i
        if letter == "W":
            width_axis = i
    return [_global_pool2d_shape_func(inputs[0], convert(height_axis), convert(width_axis))]

reg.register_shape_func("nn.global_max_pool2d", False, global_pool2d_shape_func)
reg.register_shape_func("nn.global_avg_pool2d", False, global_pool2d_shape_func)

@script
def _batch_flatten_shape_func(data_shape):
    out = output_tensor((2,), "int64")
    out[0] = data_shape[0]
    out[1] = int64(1)
    for i in const_range(data_shape.shape[0] - 1):
        out[1] *= data_shape[i + 1]

    return out

@reg.register_shape_func("nn.batch_flatten", False)
def batch_flatten_shape_func(attrs, inputs, _):
    """
    Shape function for batch_flatten op.
    """
    return [_batch_flatten_shape_func(inputs[0])]

@script
def _dense_shape_func(data_shape, weight_shape):
    out = output_tensor((data_shape.shape[0],), "int64")
    for i in const_range(out.shape[0] - 1):
        out[i] = data_shape[i]
    out[out.shape[0] - 1] = weight_shape[0]

    return out

@reg.register_shape_func("nn.dense", False)
def dense_shape_func(attrs, inputs, _):
    """
    Shape function for dense op.
    """
    ret = [_dense_shape_func(inputs[0], inputs[1])]
    return ret

@script
def _pad_shape_func(data_shape, pad_width):
    out = output_tensor((data_shape.shape[0],), "int64")
    for i in const_range(out.shape[0]):
        out[i] = data_shape[i] + pad_width[i][0] + pad_width[i][1]

    return out

@reg.register_shape_func("nn.pad", False)
def pad_shape_func(attrs, inputs, _):
    """
    Shape function for pad op.
    """
    pad_width = []
    for pair in attrs.pad_width:
        pad_width.append(get_const_tuple(pair))
    return [_pad_shape_func(inputs[0], convert(pad_width))]

reg.register_shape_func("nn.bias_add", False, elemwise_shape_func)
reg.register_shape_func("nn.softmax", False, elemwise_shape_func)
reg.register_shape_func("nn.relu", False, elemwise_shape_func)