nn.cc 35.6 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 * \file nn.cc
 * \brief Property def of nn operators.
 */

25
#include <tvm/data_layout.h>
26 27
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
28
#include <tvm/relay/attrs/image.h>
29
#include <topi/nn.h>
30
#include <topi/nn/bias_add.h>
31
#include <topi/nn/softmax.h>
32
#include <topi/nn/flatten.h>
33
#include <vector>
34
#include <string>
35
#include "../type_relations.h"
36
#include "../../pass/infer_layout_util.h"
雾雨魔理沙 committed
37
#include "../op_common.h"
38
#include "nn.h"
39 40 41 42

namespace tvm {
namespace relay {

43
// relay.nn.bias_add
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
TVM_REGISTER_NODE_TYPE(BiasAddAttrs);

bool BiasAddRel(const Array<Type>& types,
                int num_inputs,
                const Attrs& attrs,
                const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 3);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) return false;

  const BiasAddAttrs* param = attrs.as<BiasAddAttrs>();
  CHECK(param != nullptr);
  int axis = param->axis;
  if (axis < 0) {
    axis = data->shape.size() + axis;
  }
  CHECK_LE(axis, static_cast<int>(data->shape.size()))
      << "axis " << param->axis << " is out of range";

  // assign output type
  reporter->Assign(types[1], TensorTypeNode::make(
      {data->shape[axis]}, data->dtype));
  reporter->Assign(types[2], types[0]);
  return true;
}


// Positional relay function to create dense operator used by frontend FFI.
Expr MakeBiasAdd(Expr data,
                 Expr bias,
                 int axis) {
  auto attrs = make_node<BiasAddAttrs>();
  attrs->axis = axis;
  static const Op& op = Op::Get("nn.bias_add");
  return CallNode::make(op, {data, bias}, Attrs(attrs), {});
}


TVM_REGISTER_API("relay.op.nn._make.bias_add")
83
.set_body_typed(MakeBiasAdd);
84 85 86 87 88 89


RELAY_REGISTER_OP("nn.bias_add")
.describe(R"code(Add bias to an axis of the input.

)code" TVM_ADD_FILELINE)
90
.set_attrs_type<BiasAddAttrs>()
91 92 93 94
.set_num_inputs(2)
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("bias", "1D Tensor", "Bias.")
.set_support_level(1)
95 96 97 98 99 100
.add_type_rel("BiasAdd", BiasAddRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<Tensor>& inputs,
                                        const Type& out_type, const Target& target) {
    const auto* param = attrs.as<BiasAddAttrs>();
    return tvm::Array<tvm::Tensor>{topi::nn::bias_add(inputs[0], inputs[1], param->axis)};
});
101 102


103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
// relay.nn.fifo_buffer
TVM_REGISTER_NODE_TYPE(FIFOBufferAttrs);

Expr MakeFIFOBuffer(Expr input, Expr buffer, int axis) {
  auto attrs = make_node<FIFOBufferAttrs>();
  attrs->axis = axis;
  static const Op& op = Op::Get("nn.fifo_buffer");
  return CallNode::make(op, {input, buffer}, Attrs(attrs), {});
}

bool FIFOBufferRel(const Array<Type>& types,
                   int num_inputs,
                   const Attrs& attrs,
                   const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 3);
  const auto* input = types[0].as<TensorTypeNode>();
  const auto* buffer = types[1].as<TensorTypeNode>();
  const FIFOBufferAttrs* param = attrs.as<FIFOBufferAttrs>();
  if (input == nullptr || buffer == nullptr) {
    return false;
  }
  CHECK(param != nullptr);
  CHECK_EQ(input->shape.size(), buffer->shape.size());

  const size_t buffer_axis
    = static_cast<size_t>(param->axis < 0 ? static_cast<int>(buffer->shape.size()) + param->axis
                                          : param->axis);

  reporter->Assert(buffer_axis < buffer->shape.size());
  for (size_t i = 0; i < buffer->shape.size(); ++i) {
    if (i != buffer_axis) {
      reporter->AssertEQ(input->shape[i], buffer->shape[i]);
    }
  }
  reporter->Assert(input->shape[buffer_axis] < buffer->shape[buffer_axis]);

  Array<tvm::Expr> oshape = buffer->shape;

  reporter->Assign(types[2], TensorTypeNode::make(oshape, buffer->dtype));
  return true;
}

TVM_REGISTER_API("relay.op.nn._make.fifo_buffer")
.set_body_typed(MakeFIFOBuffer);

RELAY_REGISTER_OP("nn.fifo_buffer")
.describe(R"code(FIFO buffer
Compute equivalent of

```
concat(buffer, data, axis=axis) \
.slice_axis(axis=axis, begin=data.shape[axis], end=data.shape[axis]+buffer.shape[axis])
```

Useful for
* Encoding explicit re-use of computation in convolution ops operated on a sliding window input
* Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet.
)code" TVM_ADD_FILELINE)
161
.set_attrs_type<FIFOBufferAttrs>()
162 163 164 165 166 167 168 169
.set_num_inputs(2)
.add_argument("data", "Tensor", "Latest input")
.add_argument("buffer", "Tensor",
              "Buffer storing latest [length_buffer] inputs")
.set_support_level(3)
.add_type_rel("FIFOBuffer", FIFOBufferRel);


170
// relay.nn.dense
171 172 173 174 175
TVM_REGISTER_NODE_TYPE(DenseAttrs);

// Positional relay function to create dense operator used by frontend FFI.
Expr MakeDense(Expr data,
               Expr weight,
176 177
               IndexExpr units,
               DataType out_dtype) {
178 179
  auto attrs = make_node<DenseAttrs>();
  attrs->units = units;
180
  attrs->out_dtype = out_dtype;
181 182 183 184 185 186
  static const Op& op = Op::Get("nn.dense");
  return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}


TVM_REGISTER_API("relay.op.nn._make.dense")
187
.set_body_typed(MakeDense);
188 189 190 191 192 193 194 195 196 197


RELAY_REGISTER_OP("nn.dense")
.describe(R"code(Applies a linear transformation: :math:`Y = XW^T`.

- **data**: `(x1, x2, ..., xn, input_dim)`
- **weight**: `(units, input_dim)`
- **out**: `(x1, x2, ..., xn, units)`.

)code" TVM_ADD_FILELINE)
198
.set_attrs_type<DenseAttrs>()
199 200 201
.set_num_inputs(2)
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("weight", "2D Tensor", "Weight matrix.")
202
.set_support_level(1)
203
.add_type_rel("Dense", DenseRel<DenseAttrs>);
204

205 206
// relay.leaky_relu
TVM_REGISTER_NODE_TYPE(LeakyReluAttrs);
207 208 209 210 211 212 213 214 215 216 217 218

// Positional relay function to create leaky relu operator used by frontend FFI.
Expr MakeLeakyRelu(Expr data,
                   double alpha) {
  auto attrs = make_node<LeakyReluAttrs>();
  attrs->alpha = alpha;
  static const Op& op = Op::Get("nn.leaky_relu");
  return CallNode::make(op, {data}, Attrs(attrs), {});
}


TVM_REGISTER_API("relay.op.nn._make.leaky_relu")
219
.set_body_typed(MakeLeakyRelu);
220 221 222 223 224 225 226 227


RELAY_REGISTER_OP("nn.leaky_relu")
.describe(R"code(Leaky version of a Rectified Linear Unit.

`y = x > 0 ? x : alpha * x`

)code" TVM_ADD_FILELINE)
228
.set_attrs_type<LeakyReluAttrs>()
229 230 231
.set_num_inputs(1)
.add_argument("data", "Tensor", "Input data.")
.set_support_level(3)
232
.add_type_rel("Identity", IdentityRel)
233
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
234 235 236 237 238 239 240 241
.set_attr<FTVMCompute>(
  "FTVMCompute", [](const Attrs& attrs,
                    const Array<Tensor>& inputs,
                    const Type& out_type,
                    const Target& target) {
    const auto* param = attrs.as<LeakyReluAttrs>();
    return Array<Tensor>{ topi::leaky_relu(inputs[0], param->alpha) };
});
242 243


244
// relay.prelu
Siju committed
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
TVM_REGISTER_NODE_TYPE(PReluAttrs);

bool PReluRel(const Array<Type>& types,
              int num_inputs,
              const Attrs& attrs,
              const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 3);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) return false;

  const PReluAttrs* param = attrs.as<PReluAttrs>();
  CHECK(param != nullptr);

  CHECK(param->axis < static_cast<int>(data->shape.size()))
    << "Wrong axis ("  << param->axis << ")value.";

  // assign alpha type
  Array<IndexExpr> alpha_shape({data->shape[param->axis]});
  reporter->Assign(types[1], TensorTypeNode::make(alpha_shape, data->dtype));

  // assign output type
  reporter->Assign(types[2], TensorTypeNode::make(data->shape, data->dtype));
  return true;
}

270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
template<typename T>
Array<Array<Layout> > PReluInferCorrectLayout(
    const Attrs& attrs,
    const Array<Layout>& new_in_layouts,
    const Array<Layout>& old_in_layouts,
    const Array<Array<IndexExpr>> &old_in_shapes) {

  CHECK_EQ(old_in_layouts.size(), 2U);
  CHECK_EQ(old_in_shapes.size(), 2U);
  Layout data_layout = old_in_layouts[0];
  if (new_in_layouts.defined()) {
    CHECK_EQ(new_in_layouts.size(), 2U);
  }
  return Array<Array<Layout> >{{data_layout, Layout("C")},
                               {data_layout}};
}

Siju committed
287 288 289 290 291 292 293 294 295 296 297 298
// Positional relay function to create prelu operator used by frontend FFI.
Expr MakePRelu(Expr data,
               Expr alpha,
               int axis) {
  auto attrs = make_node<PReluAttrs>();
  attrs->axis = axis;
  static const Op& op = Op::Get("nn.prelu");
  return CallNode::make(op, {data, alpha}, Attrs(attrs), {});
}


TVM_REGISTER_API("relay.op.nn._make.prelu")
299
.set_body_typed(MakePRelu);
Siju committed
300 301 302 303 304 305 306 307


RELAY_REGISTER_OP("nn.prelu")
.describe(R"code(Parametric version of a Rectified Linear Unit.
It accepts two arguments: an input ``x`` and a channelwise slope ``alpha``
and computes the output as :math:`PReLU(x) y = x > 0 ? x : alpha * x`,
where :math:`*` is an channelwise multiplication for each sample in the batch.
)code" TVM_ADD_FILELINE)
308
.set_attrs_type<PReluAttrs>()
Siju committed
309 310 311 312
.set_num_inputs(2)
.add_argument("data", "Tensor", "Input data.")
.add_argument("alpha", "Tensor", "Input channelwise alpha.")
.set_support_level(3)
313
.add_type_rel("PRelu", PReluRel)
314
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PReluInferCorrectLayout<PReluAttrs>)
315 316 317 318 319 320 321 322
.set_attr<FTVMCompute>(
  "FTVMCompute", [](const Attrs& attrs,
                    const Array<Tensor>& inputs,
                    const Type& out_type,
                    const Target& target) {
    const auto* param = attrs.as<PReluAttrs>();
    return Array<Tensor>{ topi::prelu(inputs[0], inputs[1], param->axis)};
});
Siju committed
323 324


325 326 327
// relay.softmax
TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);

328
TVM_REGISTER_API("relay.op.nn._make.softmax")
329 330 331 332 333
.set_body_typed<Call(Expr, int)>([](Expr data, int axis) {
  auto attrs = make_node<SoftmaxAttrs>();
  attrs->axis = axis;
  static const Op& op = Op::Get("nn.softmax");
  return CallNode::make(op, {data}, Attrs(attrs), {});
334 335
});

336

337 338 339 340 341 342 343 344 345 346
RELAY_REGISTER_OP("nn.softmax")
    .describe(R"code(Softmax layer.

.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}

.. note::
    This operator can be optimized away for inference.

- **data**: The input data
)code" TVM_ADD_FILELINE)
347
.set_attrs_type<SoftmaxAttrs>()
348 349 350
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
351 352 353 354 355 356 357 358 359
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
                                         const Array<Tensor>& inputs,
                                         const Type& out_type,
                                         const Target& target) {
  const auto* param = attrs.as<SoftmaxAttrs>();
  CHECK(param != nullptr);
  return Array<Tensor>{ topi::nn::softmax(inputs[0], param->axis) };
});
360

361

362
// relay.nn.log_softmax
363
TVM_REGISTER_API("relay.op.nn._make.log_softmax")
364 365 366 367 368
.set_body_typed<Call(Expr, int)>([](Expr data, int axis) {
  auto attrs = make_node<SoftmaxAttrs>();
  attrs->axis = axis;
  static const Op& op = Op::Get("nn.log_softmax");
  return CallNode::make(op, {data}, Attrs(attrs), {});
369 370 371 372 373 374 375 376 377 378 379 380
});

RELAY_REGISTER_OP("nn.log_softmax")
    .describe(R"code(Computes log softmax.

.. math:: \text{log_softmax}(x)_i = \log \frac{exp(x_i)}{\sum_j exp(x_j)}

.. note::
    This operator can be optimized away for inference.

- **data**: The input data
)code" TVM_ADD_FILELINE)
381
.set_attrs_type<SoftmaxAttrs>()
382 383 384
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
385 386 387 388 389 390 391 392 393 394 395 396
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
                                         const Array<Tensor>& inputs,
                                         const Type& out_type,
                                         const Target& target) {
  const auto* param = attrs.as<SoftmaxAttrs>();
  CHECK(param != nullptr);
  CHECK(param->axis == -1 || param->axis == static_cast<int32_t>(inputs[0].ndim()) - 1)
      << "log_softmax currently only works on last dimension";
  return Array<Tensor>{ topi::nn::log_softmax(inputs[0]) };
});

397

398
// relay.nn.batch_flatten
399
bool BatchFlattenRel(const Array<Type>& types,
400 401 402
                     int num_inputs,
                     const Attrs& attrs,
                     const TypeReporter& reporter) {
403 404 405 406 407
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) return false;
  if (data->shape.size() == 0) return false;

408
  auto target_dim = make_const(DataType::Int(32), 1);
409 410

  for (uint32_t i = 1; i < data->shape.size(); ++i) {
411 412 413 414 415 416
    if (!data->shape[i].as<ir::Any>()) {
      target_dim = target_dim * data->shape[i];
    } else {
      target_dim = data->shape[i];
      break;
    }
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432
  }

  std::vector<IndexExpr> oshape({data->shape[0], target_dim});

  // assign output type
  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
  return true;
}

Expr MakeBatchFlatten(Expr data) {
  static const Op& op = Op::Get("nn.batch_flatten");
  return CallNode::make(op, {data}, Attrs(), {});
}


TVM_REGISTER_API("relay.op.nn._make.batch_flatten")
433
.set_body_typed(MakeBatchFlatten);
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460


RELAY_REGISTER_OP("nn.batch_flatten")
.describe(R"code(Flattens the input into a 2-D array.

For an input array with shape ``(d1, d2, ..., dk)``, `batch_flatten` operation reshapes
the input array into an output array of shape ``(d1, d2*...*dk)``.

Example::

    x = [[
        [1,2,3],
        [4,5,6],
        [7,8,9]
    ],
    [   [1,2,3],
        [4,5,6],
        [7,8,9]
    ]],

    batch_flatten(x) = [[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
       [ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.]]

)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
461 462 463 464 465 466 467 468
.add_type_rel("BatchFlatten", BatchFlattenRel)
.set_attr<FTVMCompute>(
  "FTVMCompute", [](const Attrs& attrs,
                    const Array<Tensor>& inputs,
                    const Type& out_type,
                    const Target& target) {
    return Array<Tensor>{ topi::nn::flatten(inputs[0]) };
});
469

470 471 472

// relu
TVM_REGISTER_API("relay.op.nn._make.relu")
473
.set_body_typed<Call(Expr)>([](Expr data) {
474 475 476 477 478
    static const Op& op = Op::Get("nn.relu");
    return CallNode::make(op, {data}, Attrs(), {});
  });

RELAY_REGISTER_OP("nn.relu")
雾雨魔理沙 committed
479 480 481 482 483 484
.describe(R"code(Returns the relu input array, computed element-wise.

.. math::
   max(x, 0)

)code" TVM_ADD_FILELINE)
485 486
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
雾雨魔理沙 committed
487
.set_support_level(1)
488
.add_type_rel("Identity", IdentityRel)
489
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
490 491 492 493 494 495
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
                                         const Array<Tensor>& inputs,
                                         const Type& out_type,
                                         const Target& target) {
  return Array<Tensor>{ topi::relu(inputs[0], 0.0f) };
});
雾雨魔理沙 committed
496

497 498

// Positional relay function to create LRN operator used by frontend FFI.
499 500
TVM_REGISTER_NODE_TYPE(LRNAttrs);

501
Expr MakeLRN(Expr data,
502
             int size,
503
             int axis,
504 505 506 507 508 509 510 511 512 513 514 515 516 517
             double alpha,
             double beta,
             double bias) {
  auto attrs = make_node<LRNAttrs>();
  attrs->size = size;
  attrs->axis = axis;
  attrs->alpha = alpha;
  attrs->beta = beta;
  attrs->bias = bias;
  static const Op& op = Op::Get("nn.lrn");
  return CallNode::make(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.lrn")
518
.set_body_typed(MakeLRN);
519 520

RELAY_REGISTER_OP("nn.lrn")
521
.describe(R"code(LRN layer.
522 523 524 525 526 527 528 529 530 531 532 533

Normalize the input in a local region across or within feature maps.
Each input value is divided by (1 + (\alpha/n) \sum_i x_i^2)^\beta,
where n is the size of each local region, and the sum is taken over the region
centered at that value (zero padding is added where necessary).

.. math::

    data / (bias + (alpha * sum_data ^2 /size))^beta

- **data**: The input tensor.
)code" TVM_ADD_FILELINE)
534
.set_attrs_type<LRNAttrs>()
535 536 537
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
538
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
539 540 541 542
.add_type_rel("Identity", IdentityRel);


// Positional relay function to create L2Normalize operator used by frontend FFI.
543 544
TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs);

545 546
Expr MakeL2Normalize(Expr data,
                     double eps,
547
                     Array<Integer> axis) {
548 549 550 551 552 553 554 555
  auto attrs = make_node<L2NormalizeAttrs>();
  attrs->eps = eps;
  attrs->axis = std::move(axis);
  static const Op& op = Op::Get("nn.l2_normalize");
  return CallNode::make(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.l2_normalize")
556
.set_body_typed(MakeL2Normalize);
557 558

RELAY_REGISTER_OP("nn.l2_normalize")
559
.describe(R"code(L2 Normalization layer.
560 561 562 563 564 565 566 567

Normalizes along dimension axis using an L2 norm

.. math::
    output = x / sqrt(max(sum(x^2), epsilon))

- **data**: The input tensor.
)code" TVM_ADD_FILELINE)
568
.set_attrs_type<L2NormalizeAttrs>()
569 570 571
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
572
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
573 574
.add_type_rel("Identity", IdentityRel);

575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600
// Dropout
TVM_REGISTER_NODE_TYPE(DropoutAttrs);

bool DropoutRel(const Array<Type>& types,
                int num_inputs,
                const Attrs& attrs,
                const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) return false;

  // dropout returns the original tensor with dropout applied
  // and a mask tensor (1.0 where element not dropped, 0.0 where dropped)
  auto ret_type = TensorTypeNode::make(data->shape, data->dtype);
  reporter->Assign(types[1], TupleTypeNode::make(Array<Type>({ret_type, ret_type})));
  return true;
}

Expr MakeDropout(Expr data, double rate) {
  auto attrs = make_node<DropoutAttrs>();
  attrs->rate = rate;
  static const Op& op = Op::Get("nn.dropout");
  return CallNode::make(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.dropout")
601
.set_body_typed(MakeDropout);
602 603 604 605 606 607 608 609

RELAY_REGISTER_OP("nn.dropout")
.describe(R"code(Applies the dropout operation to the input array.

During training, each element of the input is set to zero with probability ``p``.
The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input unchanged.

)code" TVM_ADD_FILELINE)
610
.set_attrs_type<DropoutAttrs>()
611 612 613
.set_num_inputs(1)
.add_argument("data", "Tensor", "Input to which dropout will be applied.")
.set_support_level(1)
614
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
.add_type_rel("Dropout", DropoutRel);

// batch_norm
TVM_REGISTER_NODE_TYPE(BatchNormAttrs);

bool BatchNormRel(const Array<Type>& types,
                  int num_inputs,
                  const Attrs& attrs,
                  const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 6);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) return false;

  const BatchNormAttrs* param = attrs.as<BatchNormAttrs>();

  // axis of -1 means use the last dimension
  CHECK(param->axis >= -1 && param->axis < (int)data->shape.size());
  int axis = (param->axis != -1) ? param->axis : data->shape.size() - 1;
633
  auto axis_size = data->shape[axis];
634 635

  // if we are using beta and gamma, they need to be of shape (dim,)
636 637 638 639
  reporter->Assign(types[1], TensorTypeNode::make({axis_size}, data->dtype));
  reporter->Assign(types[2], TensorTypeNode::make({axis_size}, data->dtype));
  reporter->Assign(types[3], TensorTypeNode::make({axis_size}, data->dtype));
  reporter->Assign(types[4], TensorTypeNode::make({axis_size}, data->dtype));
640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664

  // output is a tuple of the normed data (same shape as input), new running mean,
  // and new running average (the latter two are both vectors of length dim)
  std::vector<Type> fields;
  auto vec_ty = TensorTypeNode::make(Array<IndexExpr>({data->shape[axis]}),
                                     data->dtype);
  fields.push_back(TensorTypeNode::make(data->shape, data->dtype));
  fields.push_back(vec_ty);
  fields.push_back(vec_ty);
  reporter->Assign(types[5], TupleTypeNode::make(Array<Type>(fields)));
  return true;
}

Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var,
                   int axis, double epsilon, bool center, bool scale) {
  auto attrs = make_node<BatchNormAttrs>();
  attrs->axis = axis;
  attrs->epsilon = epsilon;
  attrs->center = center;
  attrs->scale = scale;
  static const Op& op = Op::Get("nn.batch_norm");
  return CallNode::make(op, {data, gamma, beta, moving_mean, moving_var}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.batch_norm")
665
.set_body_typed(MakeBatchNorm);
666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703

RELAY_REGISTER_OP("nn.batch_norm")
.describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014).
Normalizes the input at each batch, i.e. applies a transformation
that maintains the mean activation close to 0 and the activation
standard deviation close to 1.

.. math::

  data\_mean[i] = mean(data[:,i,:,...]) \\
  data\_var[i] = var(data[:,i,:,...])

Then compute the normalized output, which has the same shape as input, as following:

.. math::

  out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} \
* gamma[i] + beta[i]

Both *mean* and *var* returns a scalar by treating the input as a vector.

Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` have shape *(k,)*.

Besides the inputs and the outputs, this operator accepts two auxiliary
states, ``moving_mean`` and ``moving_var``, which are *k*-length
vectors. They are global statistics for the whole dataset, which are updated
by::

  moving_mean = moving_mean * momentum + data_mean * (1 - momentum)
  moving_var = moving_var * momentum + data_var * (1 - momentum)

The parameter ``axis`` specifies which axis of the input shape denotes
the 'channel' (separately normalized groups).  The default is 1.  Specifying -1 sets the channel
axis to be the last item in the input shape.

.. note::
    This operator can be optimized away for inference.
)code" TVM_ADD_FILELINE)
704
.set_attrs_type<BatchNormAttrs>()
705 706 707 708 709 710 711 712 713
.set_num_inputs(5)
.add_argument("data", "Tensor", "Input to which batch_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.add_argument("moving_mean", "Tensor", "Running mean of input.")
.add_argument("moving_var", "Tensor", "Running variance of input.")
.set_support_level(1)
.add_type_rel("BatchNorm", BatchNormRel);

714

715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775
// instance_norm
TVM_REGISTER_NODE_TYPE(InstanceNormAttrs);

bool InstanceNormRel(const Array<Type>& types,
                     int num_inputs,
                     const Attrs& attrs,
                     const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 4);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) return false;
  const InstanceNormAttrs* param = attrs.as<InstanceNormAttrs>();
  int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
  CHECK(axis >= 0 && axis < (int)data->shape.size());
  reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype));
  reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype));
  reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype));

  return true;
}

Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon,
                      bool center, bool scale) {
  auto attrs = make_node<InstanceNormAttrs>();
  attrs->axis = axis;
  attrs->epsilon = epsilon;
  attrs->center = center;
  attrs->scale = scale;
  static const Op& op = Op::Get("nn.instance_norm");
  return CallNode::make(op, {data, gamma, beta}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.instance_norm")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
    runtime::detail::unpack_call<Expr, 7>(MakeInstanceNorm, args, rv);
  });

RELAY_REGISTER_OP("nn.instance_norm")
.describe(R"code(Instance Normalization (Ulyanov and et al., 2016)
Applies instance normalization to the n-dimensional input array.

.. math::

    out = \frac{data - mean(data)}{\sqrt{var(data)+\epsilon}}
        * gamma + beta

The instance normalization is similar to batch normalization, but unlike
batch normalization, the mean and var are calculated per-dimension
separately for each object(instance) in a mini-batch, not over a batch.
And the same normalization is applied both at test and train time.

Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
have shape *(k,)*.

The parameter ``axis`` specifies which axis of the input shape denotes
the 'channel'.  The default is 1. Specifying -1 sets the channel axis
to be the last item in the input shape.

.. note::

    This operator can be optimized away for inference.
)code" TVM_ADD_FILELINE)
776
.set_attrs_type<InstanceNormAttrs>()
777 778 779 780 781 782 783 784
.set_num_inputs(3)
.add_argument("data", "Tensor", "Input to which instance_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_support_level(1)
.add_type_rel("InstanceNorm", InstanceNormRel);


785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823
// layer_norm
TVM_REGISTER_NODE_TYPE(LayerNormAttrs);

bool LayerNormRel(const Array<Type>& types,
                  int num_inputs,
                  const Attrs& attrs,
                  const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 4);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) return false;
  const LayerNormAttrs* param = attrs.as<LayerNormAttrs>();
  int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
  CHECK(axis >= 0 && axis < (int)data->shape.size());
  reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype));
  reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype));
  reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype));

  return true;
}

Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon,
                   bool center, bool scale) {
  auto attrs = make_node<LayerNormAttrs>();
  attrs->axis = axis;
  attrs->epsilon = epsilon;
  attrs->center = center;
  attrs->scale = scale;
  static const Op& op = Op::Get("nn.layer_norm");
  return CallNode::make(op, {data, gamma, beta}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.layer_norm")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
    runtime::detail::unpack_call<Expr, 7>(MakeLayerNorm, args, rv);
  });

RELAY_REGISTER_OP("nn.layer_norm")
.describe(R"code(
)code" TVM_ADD_FILELINE)
824
.set_attrs_type<LayerNormAttrs>()
825 826 827 828 829 830 831
.set_num_inputs(3)
.add_argument("data", "Tensor", "Input to which layer_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_support_level(1)
.add_type_rel("LayerNorm", LayerNormRel);

832 833 834 835 836 837 838 839 840
// relay.nn.batch_matmul
bool BatchMatmulRel(const Array<Type>& types,
                    int num_inputs,
                    const Attrs& attrs,
                    const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 3);
  const auto* x = types[0].as<TensorTypeNode>();
  const auto* y = types[1].as<TensorTypeNode>();
  if (x == nullptr || y == nullptr) return false;
841
  CHECK(x->shape.size() == 3 && y->shape.size() == 3);
842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868
  CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
      << "BatchDot: batch dimension doesn't match, "
      << " x shape=" << x->shape
      << ", y shape=" << y->shape;
  CHECK(reporter->AssertEQ(x->shape[2], y->shape[2]))
      << "BatchDot: shapes of x and y is inconsistent, "
      << " x shape=" << x->shape
      << ", y shape=" << y->shape;

  Array<tvm::Expr> oshape = x->shape;
  oshape.Set(2, y->shape[1]);

  // assign output type
  reporter->Assign(types[2], TensorTypeNode::make(oshape, x->dtype));
  return true;
}


// Positional relay function to create batch_matmul operator used by frontend FFI.
Expr MakeBatchMatmul(Expr x,
                     Expr y) {
  static const Op& op = Op::Get("nn.batch_matmul");
  return CallNode::make(op, {x, y}, Attrs(), {});
}


TVM_REGISTER_API("relay.op.nn._make.batch_matmul")
869
.set_body_typed(MakeBatchMatmul);
870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891


RELAY_REGISTER_OP("nn.batch_matmul")
.describe(R"code(Computes matrix multiplication of `x` and `y` when `x` and `y`
are data in batch.

.. math::

  batch\_matmul(x, y)[i, :, :] = matmul(x[i, :, :], y[i, :, :]^T)

- **x**: `(b, m, k)`
- **y**: `(b, n, k)`
- **out**: `(b, m, n)`.

)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("x", "3D Tensor", "First input.")
.add_argument("y", "3D Tensor", "Second input.")
.set_support_level(10)
.add_type_rel("BatchMatmul", BatchMatmulRel);


892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917
// relay.nn.cross_entropy
bool CrossEntropyRel(const Array<Type>& types,
                    int num_inputs,
                    const Attrs& attrs,
                    const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 3);
  const auto* x = types[0].as<TensorTypeNode>();
  const auto* y = types[1].as<TensorTypeNode>();
  if (x == nullptr || y == nullptr) return false;
  CHECK(x->shape.size() == 2 && y->shape.size() == 2)
    << "CrossEntropy: shapes of x and y is inconsistent, "
    << "x shape = " << x->shape << ", "
    << "y shape = " << y->shape;
  CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
    << "CrossEntropy: shapes of x and y is inconsistent, "
    << "x shape = " << x->shape << ", "
    << "y shape = " << y->shape;
  CHECK(reporter->AssertEQ(x->shape[1], y->shape[1]))
    << "CrossEntropy: shapes of x and y is inconsistent, "
    << "x shape = " << x->shape << ", "
    << "y shape = " << y->shape;
  // assign output type
  reporter->Assign(types[2], TensorTypeNode::make({}, x->dtype));
  return true;
}

918
// Positional relay function to create cross_entropy operator used by frontend FFI.
919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940
Expr MakeCrossEntropy(Expr predictions, Expr targets) {
  static const Op& op = Op::Get("nn.cross_entropy");
  return CallNode::make(op, {predictions, targets}, Attrs(), {});
}


TVM_REGISTER_API("relay.op.nn._make.cross_entropy")
.set_body_typed(MakeCrossEntropy);


RELAY_REGISTER_OP("nn.cross_entropy")
.describe(R"code(
Computes cross entropy given predictions and targets.
Do log on the data - do not accept logits.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("x", "1D Tensor", "Predictions.")
.add_argument("y", "1D Tensor", "Targets.")
.set_support_level(10)
.add_type_rel("CrossEntropy", CrossEntropyRel);


941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962
// Positional relay function to create cross_entropy_with_logits operator used by frontend FFI.
Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) {
  static const Op& op = Op::Get("nn.cross_entropy_with_logits");
  return CallNode::make(op, {predictions, targets}, Attrs(), {});
}


TVM_REGISTER_API("relay.op.nn._make.cross_entropy_with_logits")
.set_body_typed(MakeCrossEntropyWithLogits);


RELAY_REGISTER_OP("nn.cross_entropy_with_logits")
.describe(R"code(
Computes cross entropy given predictions and targets.
Accept logits.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("x", "1D Tensor", "Predictions.")
.add_argument("y", "1D Tensor", "Targets.")
.set_support_level(10)
.add_type_rel("CrossEntropy", CrossEntropyRel);

963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079
// Depth to space and space to depth
TVM_REGISTER_NODE_TYPE(SubPixelAttrs);

bool DepthToSpaceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                     const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) return false;

  static const Layout kNCHW("NCHW");

  const SubPixelAttrs* param = attrs.as<SubPixelAttrs>();
  CHECK(param != nullptr);
  const int block_size = param->block_size;
  const Layout in_layout(param->layout);
  auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
  CHECK(layout_converter.defined())
      << "DepthToSpace only support input layouts that are convertible from NCHW."
      << " But got " << in_layout;

  auto oshape = layout_converter.ForwardShape(data->shape);
  oshape.Set(1, indexdiv(oshape[1], (block_size * block_size)));
  oshape.Set(2, oshape[2] * block_size);
  oshape.Set(3, oshape[3] * block_size);

  // Assign output type
  reporter->Assign(types[1],
                   TensorTypeNode::make(layout_converter.BackwardShape(oshape), data->dtype));

  return true;
}

// Positional relay function to create DepthToSpace operator
// used by frontend FFI
Expr MakeDepthToSpace(Expr data, int block_size, std::string layout, std::string mode) {
  auto attrs = make_node<SubPixelAttrs>();
  attrs->block_size = block_size;
  attrs->layout = std::move(layout);
  attrs->mode = std::move(mode);
  static const Op& op = Op::Get("nn.depth_to_space");
  return CallNode::make(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.depth_to_space").set_body_typed(MakeDepthToSpace);

RELAY_REGISTER_OP("nn.depth_to_space")
    .describe(R"code(Rearrange input channels into spatial pixels.

- **data**: data is a 4D array of shape
            (batch, in_channels, in_height, in_width) for NCHW

- **out**: Output is a 4D array of shape
           (batch, in_channels / block_size * block_size, in_height * block_size, in_width * block_size) for NCHW.

)code" TVM_ADD_FILELINE)
    .set_attrs_type<SubPixelAttrs>()
    .set_num_inputs(1)
    .add_argument("data", "Tensor", "The input tensor")
    .set_support_level(5)
    .add_type_rel("DepthToSpace", DepthToSpaceRel);

bool SpaceToDepthRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                     const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) return false;

  static const Layout kNCHW("NCHW");

  const SubPixelAttrs* param = attrs.as<SubPixelAttrs>();
  CHECK(param != nullptr);
  const int block_size = param->block_size;
  const Layout in_layout(param->layout);
  auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
  CHECK(layout_converter.defined())
      << "SpaceToDepth only support input layouts that are convertible from NCHW."
      << " But got " << in_layout;

  auto oshape = layout_converter.ForwardShape(data->shape);
  oshape.Set(1, oshape[1] * (block_size * block_size));
  oshape.Set(2, indexdiv(oshape[2], block_size));
  oshape.Set(3, indexdiv(oshape[3], block_size));

  // Assign output type
  reporter->Assign(types[1],
                   TensorTypeNode::make(layout_converter.BackwardShape(oshape), data->dtype));

  return true;
}

// Positional relay function to create SpaceToDepth operator
// used by frontend FFI
Expr MakeSpaceToDepth(Expr data, int block_size, std::string layout) {
  auto attrs = make_node<SubPixelAttrs>();
  attrs->block_size = block_size;
  attrs->layout = std::move(layout);
  static const Op& op = Op::Get("nn.space_to_depth");
  return CallNode::make(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.space_to_depth").set_body_typed(MakeSpaceToDepth);

RELAY_REGISTER_OP("nn.space_to_depth")
    .describe(R"code(Rearrange spatial pixels into new output channels.

- **data**: data is a 4D array of shape
            (batch, in_channels, in_height, in_width) for NCHW

- **out**: Output is a 4D array of shape
           (batch, in_channels * block_size * block_size, in_height / block_size, in_width / block_size) for NCHW.

)code" TVM_ADD_FILELINE)
    .set_attrs_type<SubPixelAttrs>()
    .set_num_inputs(1)
    .add_argument("data", "Tensor", "The input tensor")
    .set_support_level(5)
    .add_type_rel("SpaceToDepth", SpaceToDepthRel);
1080

1081 1082
}  // namespace relay
}  // namespace tvm