transform.cc 90.1 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 * \file transform.cc
 * \brief Transform operators.
 */
#include <tvm/relay/op.h>
25
#include <tvm/ir/error.h>
26
#include <tvm/relay/attrs/transform.h>
27 28 29
#include <tvm/tir/op.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/data_layout.h>
30
#include <tvm/runtime/packed_func.h>
31
#include <topi/transform.h>
32
#include <topi/elemwise.h>
33 34
#include <topi/broadcast.h>
#include <topi/reduction.h>
35
#include <topi/nn.h>
36
#include <vector>
37
#include "../op_common.h"
38
#include "../../../arith/compute_expr.h"
39 40
#include "../../transforms/infer_layout_util.h"
#include "../../transforms/pattern_util.h"
41
#include "transform.h"
42 43 44

namespace tvm {
namespace relay {
45
using tir::IntImmNode;
46

47 48
// relay.cast
TVM_REGISTER_NODE_TYPE(CastAttrs);
49

50 51 52 53 54 55 56 57 58 59 60 61 62
bool CastRel(const Array<Type>& types,
             int num_inputs,
             const Attrs& attrs,
             const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) {
    CHECK(types[0].as<IncompleteTypeNode>())
        << "cast: expect input type to be TensorType but get "
        << types[0];
    return false;
  }
  const auto* param = attrs.as<CastAttrs>();
63
  reporter->Assign(types[1], TensorType(
64 65 66 67
      data->shape, param->dtype));
  return true;
}

68
Array<te::Tensor> CastCompute(const Attrs& attrs,
69 70
                              const Array<te::Tensor>& inputs,
                              const Type& out_type) {
71 72 73 74 75 76
  const CastAttrs *param = attrs.as<CastAttrs>();
  CHECK(param != nullptr);
  DataType dtype = param->dtype;
  return { topi::cast(inputs[0], dtype) };
}

77 78
Expr MakeCast(Expr data,
              DataType dtype) {
79
  auto attrs = make_object<CastAttrs>();
80 81
  attrs->dtype = dtype;
  static const Op& op = Op::Get("cast");
82
  return Call(op, {data}, Attrs(attrs), {});
83 84
}

85
TVM_REGISTER_GLOBAL("relay.ir.cast")
86
.set_body_typed(MakeCast);
87 88 89 90 91 92

RELAY_REGISTER_OP("cast")
.describe(R"code(Cast the data into a new data type.

)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
93
.set_attrs_type<CastAttrs>()
94 95
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
96 97
.add_type_rel("Cast", CastRel)
.set_attr<FTVMCompute>("FTVMCompute", CastCompute)
98 99
.set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
100

101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121

// relay.cast_like
bool CastLikeRel(const Array<Type>& types,
                 int num_inputs,
                 const Attrs& attrs,
                 const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 3);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) {
    CHECK(types[0].as<IncompleteTypeNode>())
        << "cast: expect input type to be TensorType but get "
        << types[0];
    return false;
  }
  const auto* dtype_like = types[1].as<TensorTypeNode>();
  if (dtype_like == nullptr) {
    CHECK(types[1].as<IncompleteTypeNode>())
        << "cast: expect input type to be TensorType but get "
        << types[1];
    return false;
  }
122
  reporter->Assign(types[2], TensorType(data->shape, dtype_like->dtype));
123 124 125 126
  return true;
}


127
Array<te::Tensor> CastLikeCompute(const Attrs& attrs,
128 129
                                  const Array<te::Tensor>& inputs,
                                  const Type& out_type) {
130 131 132 133 134 135 136
  return { topi::cast(inputs[0], inputs[1]->dtype) };
}


Expr MakeCastLike(Expr data,
                  Expr dtype_like) {
  static const Op& op = Op::Get("cast_like");
137
  return Call(op, {data, dtype_like}, Attrs(), {});
138 139 140
}


141
TVM_REGISTER_GLOBAL("relay.ir.cast_like")
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
.set_body_typed(MakeCastLike);

RELAY_REGISTER_OP("cast_like")
.describe(R"code(Cast the data into the type of another tensor.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("dtype_like", "Tensor", "The tensor to cast to.")
.set_support_level(3)
.add_type_rel("CastLike", CastLikeRel)
.set_attr<FTVMCompute>("FTVMCompute", CastLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);


157 158 159
Array<te::Tensor> ReinterpretCompute(const Attrs& attrs,
                                     const Array<te::Tensor>& inputs,
                                     const Type& out_type) {
160 161 162 163 164 165 166
  const CastAttrs* param = attrs.as<CastAttrs>();
  CHECK(param != nullptr);
  DataType dtype = param->dtype;
  return {topi::reinterpret(inputs[0], dtype)};
}

Expr MakeReinterpret(Expr data, DataType dtype) {
167
  auto attrs = make_object<CastAttrs>();
168 169
  attrs->dtype = dtype;
  static const Op& op = Op::Get("reinterpret");
170
  return Call(op, {data}, Attrs(attrs), {});
171 172
}

173
TVM_REGISTER_GLOBAL("relay._make.reinterpret").set_body([](const TVMArgs& args, TVMRetValue* rv) {
174 175 176 177
  runtime::detail::unpack_call<Expr, 2>(MakeReinterpret, args, rv);
});

RELAY_REGISTER_OP("reinterpret")
178
.describe(R"code(Reinterpret the data into a new data type.
179
)code" TVM_ADD_FILELINE)
180 181 182 183 184 185 186 187
.set_num_inputs(1)
.set_attrs_type<CastAttrs>()
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Reinterpret", CastRel)
.set_attr<FTVMCompute>("FTVMCompute", ReinterpretCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
188

189
// relay.expand_dims
190 191 192 193 194 195
TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);

bool ExpandDimsRel(const Array<Type>& types,
                   int num_inputs,
                   const Attrs& attrs,
                   const TypeReporter& reporter) {
196
  // `types` contains: [data, result]
197 198 199
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) {
200 201 202
    CHECK(types[0].as<IncompleteTypeNode>())
        << "expand_dims: expect input type to be TensorType but get "
        << types[0];
203 204
    return false;
  }
205
  const auto* param = attrs.as<ExpandDimsAttrs>();
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
  const int ndim = static_cast<int>(data->shape.size());
  const int axis = param->axis;
  const int num_newaxis = param->num_newaxis;
  CHECK(num_newaxis >= 0)
    << "expand_dims only accepts `num_newaxis >= 0`"
    << ", but got num_newaxis = " << num_newaxis;
  CHECK(-ndim - 1 <= axis && axis <= ndim)
    << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]"
    << ", but got axis = " << axis
    << ", and data.ndim = " << ndim;
  const int pivot = axis < 0 ? ndim + axis + 1 : axis;
  std::vector<IndexExpr> oshape;
  oshape.reserve(ndim + num_newaxis);
  for (int i = 0; i < pivot; ++i) {
    oshape.emplace_back(data->shape[i]);
  }
  for (int i = 0; i < num_newaxis; ++i) {
    oshape.emplace_back(1);
  }
  for (int i = pivot; i < ndim; ++i) {
    oshape.emplace_back(data->shape[i]);
  }
228
  reporter->Assign(types[1], TensorType(oshape, data->dtype));
229 230 231
  return true;
}

232
Array<te::Tensor> ExpandDimsCompute(const Attrs& attrs,
233 234
                                    const Array<te::Tensor>& inputs,
                                    const Type& out_type) {
235 236 237 238 239
  const ExpandDimsAttrs *param = attrs.as<ExpandDimsAttrs>();
  CHECK(param != nullptr);
  return { topi::expand_dims(inputs[0], param->axis, param->num_newaxis) };
}

240 241 242
Expr MakeExpandDims(Expr data,
                    int axis,
                    int num_newaxis) {
243
  auto attrs = make_object<ExpandDimsAttrs>();
244 245 246
  attrs->axis = axis;
  attrs->num_newaxis = num_newaxis;
  static const Op& op = Op::Get("expand_dims");
247
  return Call(op, {data}, Attrs(attrs), {});
248 249
}

250
TVM_REGISTER_GLOBAL("relay.op._make.expand_dims")
251
.set_body_typed(MakeExpandDims);
252 253 254 255 256 257 258 259

RELAY_REGISTER_OP("expand_dims")
.describe(R"code(Insert `num_newaxis` axises at the position given by `axis`

- **data**: The input data to the operator.

)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
260
.set_attrs_type<ExpandDimsAttrs>()
261 262
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
263 264 265
.add_type_rel("ExpandDims", ExpandDimsRel)
.set_attr<FTVMCompute>("FTVMCompute", ExpandDimsCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
266

267
// relay.concatenate
268 269
TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);

270
Array<te::Tensor> ConcatenateCompute(const Attrs& attrs,
271 272
                                     const Array<te::Tensor>& inputs,
                                     const Type& out_type) {
273 274 275 276 277
  const ConcatenateAttrs *param = attrs.as<ConcatenateAttrs>();
  CHECK(param != nullptr);
  return { topi::concatenate(inputs, param->axis) };
}

278 279
Expr MakeConcatenate(Expr data,
                     int axis) {
280
  auto attrs = make_object<ConcatenateAttrs>();
281 282
  attrs->axis = axis;
  static const Op& op = Op::Get("concatenate");
283
  return Call(op, {data}, Attrs(attrs), {});
284 285
}

286
TVM_REGISTER_GLOBAL("relay.op._make.concatenate")
287
.set_body_typed(MakeConcatenate);
288 289 290 291 292 293 294 295 296

RELAY_REGISTER_OP("concatenate")
.describe(R"code(Concatenate the input tensors along the given axis.

- **data** : A list of tensors.

- **axis** : The axis along which the tensors are concatenated.

)code" TVM_ADD_FILELINE)
297
.set_attrs_type<ConcatenateAttrs>()
298 299 300
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input list of tensors.")
.set_support_level(1)
301
.add_type_rel("Concatenate", ConcatenateRel<ConcatenateAttrs>)
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConcatenateLayout)
.set_attr<FTVMCompute>("FTVMCompute", ConcatenateCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

TVM_REGISTER_NODE_TYPE(StackAttrs);

bool StackRel(const Array<Type>& types,
              int num_inputs,
              const Attrs& attrs,
              const TypeReporter& reporter) {
  // types: [data, result]
  CHECK_EQ(types.size(), 2);
  const auto* tensor_tuple = types[0].as<TupleTypeNode>();
  if (tensor_tuple == nullptr) {
    CHECK(types[0].as<IncompleteTypeNode>())
        << "cast: expect input type to be TupleType but get "
        << types[0];
    return false;
  }
  const auto* param = attrs.as<StackAttrs>();
  const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
  const int ndim = static_cast<int>(first->shape.size());
324 325 326 327 328 329 330 331 332 333

  // Sanity check: axis
  int axis = param->axis;
  CHECK(-ndim <= axis && axis < ndim)
    << "stack only accepts `axis` in [-ndim, ndim)"
    << ", but got axis = " << axis
    << ", and ndim = " << ndim;
  axis = axis < 0 ? ndim + axis + 1: axis;

  // Sanity check: ndim and dtype.
334 335 336 337 338 339 340
  const DataType dtype = first->dtype;
  for (const Type& ele : tensor_tuple->fields) {
    const auto& e = Downcast<TensorType>(ele);
    int e_ndim = static_cast<int>(e->shape.size());
    const DataType& e_dtype = e->dtype;
    CHECK_EQ(e_ndim, ndim) << "relay.stack requires all tensors have the same ndim";
    CHECK_EQ(e_dtype, dtype) << "relay.stack requires all tensors have the same dtype";
341 342 343
    for (size_t j = 0; j < first->shape.size(); ++j) {
      if (j == static_cast<size_t>(axis)) continue;
      if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue;
344
      throw Error("relay.stack requires all tensors have the same shape "
345 346
                         "on non-stacking axes");
    }
347
  }
348

349 350 351 352 353 354 355 356 357 358 359
  // Calculate shape
  std::vector<IndexExpr> oshape;
  oshape.reserve(ndim + 1);
  const int stack_dim = static_cast<int>(tensor_tuple->fields.size());
  for (int i = 0; i < axis; ++i) {
    oshape.emplace_back(first->shape[i]);
  }
  oshape.emplace_back(stack_dim);
  for (int i = axis; i < ndim; ++i) {
    oshape.emplace_back(first->shape[i]);
  }
360
  reporter->Assign(types[1], TensorType(oshape, dtype));
361 362 363
  return true;
}

364
Array<te::Tensor> StackCompute(const Attrs& attrs,
365 366
                               const Array<te::Tensor>& inputs,
                               const Type& out_type) {
367 368 369 370 371 372 373
  const StackAttrs *param = attrs.as<StackAttrs>();
  CHECK(param != nullptr);
  return { topi::stack(inputs, param->axis) };
}

Expr MakeStack(Expr data,
               int axis) {
374
  auto attrs = make_object<StackAttrs>();
375 376
  attrs->axis = axis;
  static const Op& op = Op::Get("stack");
377
  return Call(op, {data}, Attrs(attrs), {});
378 379
}

380
TVM_REGISTER_GLOBAL("relay.op._make.stack")
381
.set_body_typed(MakeStack);
382 383 384 385 386 387 388 389 390

RELAY_REGISTER_OP("stack")
.describe(R"code(Stack the input tensors along the given axis.

- **data** : A list of tensors.

- **axis** : The axis along which the tensors are stacked.

)code" TVM_ADD_FILELINE)
391
.set_attrs_type<StackAttrs>()
392 393
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input list of tensors.")
394
.set_support_level(3)
395 396 397
.add_type_rel("Stack", StackRel)
.set_attr<FTVMCompute>("FTVMCompute", StackCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
398 399

/* relay.transpose */
400
TVM_REGISTER_NODE_TYPE(TransposeAttrs);
401 402 403 404 405 406 407 408 409

bool TransposeRel(const Array<Type>& types,
                  int num_inputs,
                  const Attrs& attrs,
                  const TypeReporter& reporter) {
  // types: [data, result]
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) {
410 411 412
    CHECK(types[0].as<IncompleteTypeNode>())
        << "transpose: expect input type to be TensorType but get "
        << types[0];
413 414 415 416
    return false;
  }
  const auto* param = attrs.as<TransposeAttrs>();
  const int ndim = data->shape.size();
417
  const Array<Integer>& axes = param->axes;
418
  // check dimension match
419
  CHECK(!axes.defined() || static_cast<int>(axes.size()) == ndim)
420 421 422 423 424
    << "Dimension mismatch: axes has " << axes.size() << " elements"
    << ", but data.ndim = " << ndim;
  // construct int_axes
  std::vector<int> int_axes;
  int_axes.reserve(ndim);
425 426
  // used not defined to check if it is None.
  if (!axes.defined()) {
427 428 429 430 431
    for (int i = ndim - 1; i >= 0; --i) {
      int_axes.push_back(i);
    }
  } else {
    std::vector<int> axis_used(ndim, 0);
432 433
    for (const Integer& e : axes) {
      int64_t axis = e;
434 435 436 437 438 439 440 441 442
      // sanity check for axis and ndim
      CHECK(-ndim <= axis && axis < ndim)
        << "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)"
        << ", but got axis = " << axis
        << ", and data.ndim = " << ndim;
      axis = axis < 0 ? axis + ndim : axis;
      // sanity check for duplication
      CHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis;
      axis_used[axis] = 1;
443
      int_axes.push_back(static_cast<int>(axis));
444 445 446 447 448 449 450
    }
  }
  std::vector<IndexExpr> oshape;
  oshape.reserve(ndim);
  for (int axis : int_axes) {
    oshape.push_back(data->shape[axis]);
  }
451
  reporter->Assign(types[1], TensorType(oshape, data->dtype));
452 453 454
  return true;
}

455
Array<te::Tensor> TransposeCompute(const Attrs& attrs,
456 457
                                   const Array<te::Tensor>& inputs,
                                   const Type& out_type) {
458 459
  const auto* param = attrs.as<TransposeAttrs>();
  CHECK(param != nullptr);
460
  return Array<te::Tensor>{ topi::transpose(inputs[0], param->axes) };
461 462
}

463
Expr MakeTranspose(Expr data,
464
                   Array<Integer> axes) {
465
  auto attrs = make_object<TransposeAttrs>();
466 467
  attrs->axes = std::move(axes);
  static const Op& op = Op::Get("transpose");
468
  return Call(op, {data}, Attrs(attrs), {});
469 470
}

471
TVM_REGISTER_GLOBAL("relay.op._make.transpose")
472
.set_body_typed(MakeTranspose);
473 474 475 476 477 478 479 480 481 482

RELAY_REGISTER_OP("transpose")
.describe(R"code(Permutes the dimensions of an array.

- **data**: The input data to the operator.

- **axes**: The target axes order, reverse order if not specified.

)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
483
.set_attrs_type<TransposeAttrs>()
484 485
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
486 487 488
.add_type_rel("Transpose", TransposeRel)
.set_attr<FTVMCompute>("FTVMCompute", TransposeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
489 490

/* relay.reshape */
491 492
TVM_REGISTER_NODE_TYPE(ReshapeAttrs);

493 494 495 496 497 498 499 500
bool ReshapeRel(const Array<Type>& types,
                int num_inputs,
                const Attrs& attrs,
                const TypeReporter& reporter) {
  // types: [data, result]
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) {
501 502 503
    CHECK(types[0].as<IncompleteTypeNode>())
        << "reshape: expect input type to be TensorType but get "
        << types[0];
504 505
    return false;
  }
506

507
  const auto* param = attrs.as<ReshapeAttrs>();
508 509 510 511 512 513 514 515 516
  Array<IndexExpr> data_shape;
  Array<Integer> newshape;
  if (param->reverse) {
    data_shape.assign(data->shape.rbegin(), data->shape.rend());
    newshape.assign(param->newshape.rbegin(), param->newshape.rend());
  } else {
    data_shape = data->shape;
    newshape = param->newshape;
  }
517
  Array<IndexExpr> oshape;
518 519
  std::unordered_set<size_t> used_input_dims;
  std::unordered_set<size_t> used_output_dims;
520 521 522
  size_t src_idx = 0;
  int infer_idx = -1;

523 524
  for (size_t i = 0; i < newshape.size(); ++i) {
    int svalue = newshape[i]->value;
525 526
    // special flag handling for shape inference.
    if (svalue > 0) {
527
      oshape.push_back(newshape[i]);
528 529 530
      ++src_idx;
    } else if (svalue == 0) {
      // keep same
531
      CHECK_LT(src_idx, data_shape.size());
532 533
      used_input_dims.insert(src_idx);
      used_output_dims.insert(oshape.size());
534
      oshape.push_back(data_shape[src_idx++]);
535 536 537 538 539 540 541 542 543
    } else if (svalue == -1) {
      // inference based on rest
      CHECK_LT(infer_idx, 0)
          << "One and only one dim can be inferred";
      infer_idx = i;
      oshape.push_back(1);
      ++src_idx;
    } else if (svalue == -2) {
      // copy all remaining dims from source
544
      while (src_idx < data_shape.size()) {
545 546
        used_input_dims.insert(src_idx);
        used_output_dims.insert(oshape.size());
547
        oshape.push_back(data_shape[src_idx++]);
548 549 550
      }
    } else if (svalue == -3) {
      // merge two dims from source
551
      CHECK_LT(src_idx + 1, data_shape.size());
552
      used_input_dims.insert(src_idx);
553
      IndexExpr d1 = data_shape[src_idx++];
554
      used_input_dims.insert(src_idx);
555
      IndexExpr d2 = data_shape[src_idx++];
556
      used_output_dims.insert(oshape.size());
557 558 559 560 561
      if (d1.as<Any>() || d2.as<Any>()) {
        oshape.push_back(Any::make());
      } else {
        oshape.push_back(d1 * d2);
      }
562 563 564
    } else if (svalue == -4) {
      // split the source dim s into two dims
      // read the left dim and then the right dim (either can be -1)
565 566
      CHECK_LT(i + 2, newshape.size());
      CHECK_LT(src_idx, data_shape.size());
567
      used_input_dims.insert(src_idx);
568 569 570
      IndexExpr d0 = data_shape[src_idx++];
      Integer d1 = newshape[++i];
      Integer d2 = newshape[++i];
571 572 573
      if (d1->value == -1) {
        CHECK(d2->value != -1)
            << "Split dims cannot both be -1.";
574 575 576 577
        used_output_dims.insert(oshape.size());
        if (d0.as<Any>()) {
          oshape.push_back(Any::make());
        } else {
578
          oshape.push_back(indexdiv(d0, d2));
579 580
        }
        used_output_dims.insert(oshape.size());
581 582
        oshape.push_back(d2);
      } else {
583
        used_output_dims.insert(oshape.size());
584
        oshape.push_back(d1);
585
        used_output_dims.insert(oshape.size());
586
        if (d2->value == -1) {
587 588 589
          if (d0.as<Any>()) {
            oshape.push_back(Any::make());
          } else {
590
            oshape.push_back(indexdiv(d0, d1));
591
          }
592 593 594
        } else {
          oshape.push_back(d2);
        }
595
      }
596 597
    } else {
      CHECK(false) << "Unsupported special value: " << svalue;
598 599 600 601
    }
  }

  if (infer_idx >= 0) {
602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621
    IndexExpr infer_dim = 1;
    for (size_t i = 0; i < data_shape.size(); ++i) {
      if (used_input_dims.count(i) != 0) {
        continue;
      }
      if (data_shape[i].as<Any>()) {
        infer_dim = Any::make();
        break;
      }
      infer_dim *= data_shape[i];
    }
    if (!infer_dim.as<Any>()) {
      for (size_t i = 0; i < oshape.size(); ++i) {
        if (used_output_dims.count(i) != 0) {
          continue;
        }
        if (oshape[i].as<Any>()) {
          infer_dim = Any::make();
          break;
        }
622
        infer_dim = indexdiv(infer_dim, oshape[i]);
623 624 625
      }
    }
    oshape.Set(infer_idx, infer_dim);
626
  }
627 628

  if (param->reverse) {
629
    reporter->Assign(types[1], TensorType(
630 631
        Array<IndexExpr>(oshape.rbegin(), oshape.rend()), data->dtype));
  } else {
632
    reporter->Assign(types[1], TensorType(oshape, data->dtype));
633
  }
634 635 636
  return true;
}

637
Array<te::Tensor> ReshapeCompute(const Attrs& attrs,
638 639
                                 const Array<te::Tensor>& inputs,
                                 const Type& out_type) {
640 641
  const auto* out_ttype = out_type.as<TensorTypeNode>();
  CHECK(out_ttype != nullptr);
642 643
  Array<IndexExpr> newshape;
  for (auto val : out_ttype->shape) {
644 645
    if (val->IsInstance<tir::AnyNode>()) {
      newshape.push_back(val.as<tir::AnyNode>()->ToVar());
646 647 648 649 650
    } else {
      newshape.push_back(val);
    }
  }
  return { topi::reshape(inputs[0], newshape) };
651 652
}

653
Expr MakeReshape(Expr data,
654
                 Array<Integer> newshape) {
655
  auto attrs = make_object<ReshapeAttrs>();
656
  attrs->newshape = std::move(newshape);
657
  attrs->reverse = false;
658
  static const Op& op = Op::Get("reshape");
659
  return Call(op, {data}, Attrs(attrs), {});
660 661
}

662
TVM_REGISTER_GLOBAL("relay.op._make.reshape")
663
.set_body_typed(MakeReshape);
664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716

RELAY_REGISTER_OP("reshape")
.describe(R"code(Reshapes the input array.

Example::

To give user more convenience in without doing manual shape inference,
some dimensions of the shape can take special values from the set {0, -1, -2, -3, -4}.
The significance of each is explained below:

- ``0``  copy this dimension from the input to the output shape.

Example::

- data.shape = (2,3,4), newshape = (4,0,2), result.shape = (4,3,2)
- data.shape = (2,3,4), newshape = (2,0,0), result.shape = (2,3,4)

- ``-1`` infers the dimension of the output shape by using the remainder of the input dimensions
keeping the size of the new array same as that of the input array.
At most one dimension of shape can be -1.

Example::

- data.shape = (2,3,4), newshape = (6,1,-1), result.shape = (6,1,4)
- data.shape = (2,3,4), newshape = (3,-1,8), result.shape = (3,1,8)
- data.shape = (2,3,4), newshape = (-1,), result.shape = (24,)

- ``-2`` copy all/remainder of the input dimensions to the output shape.

Example::

- data.shape = (2,3,4), newshape = (-2,), result.shape = (2,3,4)
- data.shape = (2,3,4), newshape = (2,-2), result.shape = (2,3,4)
- data.shape = (2,3,4), newshape = (-2,1,1), result.shape = (2,3,4,1,1)

- ``-3`` use the product of two consecutive dimensions of the input shape as the output dimension.

Example::

- data.shape = (2,3,4), newshape = (-3,4), result.shape = (6,4)
- data.shape = (2,3,4,5), newshape = (-3,-3), result.shape = (6,20)
- data.shape = (2,3,4), newshape = (0,-3), result.shape = (2,12)
- data.shape = (2,3,4), newshape = (-3,-2), result.shape = (6,4)

- ``-4`` split one dimension of the input into two dimensions passed subsequent to -4 in shape (can contain -1).

Example::

- data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape =(1,2,3,4)
- data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4)

)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
717
.set_attrs_type<ReshapeAttrs>()
718 719
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
720
.add_type_rel("Reshape", ReshapeRel)
721 722
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
723

Siju committed
724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745

/*!
* \brief ReshapeLikeRel User defined type constraint function.
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return False if the relation has not been resolved, it might be resolved later.
*  True if this relation has been resolved.
*/
bool ReshapeLikeRel(const Array<Type>& types,
                    int num_inputs,
                    const Attrs& attrs,
                    const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 3);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) {
    return false;
  }
  const auto* reshape_like = types[1].as<TensorTypeNode>();
  if (reshape_like == nullptr) {
    return false;
  }
746 747 748
  // Only check When input data has static shape.
  bool is_static_shape = true;
  for (size_t i = 0; i < data->shape.size(); ++i) {
749
    if (!data->shape[i].as<IntImmNode>()) {
750 751 752 753 754 755 756 757
      is_static_shape = false;
      break;
    }
  }
  if (is_static_shape) {
    CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size()))
      << "Reshape inputs size should be compatible.";
  }
758
  reporter->Assign(types[2], TensorType(reshape_like->shape, data->dtype));
Siju committed
759 760 761 762 763 764 765
  return true;
}


Expr MakeReshapeLike(Expr data,
                     Expr shape_like) {
  static const Op& op = Op::Get("reshape_like");
766
  return Call(op, {data, shape_like}, Attrs(), {});
Siju committed
767 768 769
}


770
TVM_REGISTER_GLOBAL("relay.op._make.reshape_like")
771
.set_body_typed(MakeReshapeLike);
Siju committed
772 773 774 775 776 777 778 779 780 781 782 783 784


RELAY_REGISTER_OP("reshape_like")
.describe(R"code(Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
.. note::
    Sizes for both array should be compatible.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("shape_like", "Tensor", "Shape tensor.")
.set_support_level(3)
785
.add_type_rel("ReshapeLike", ReshapeLikeRel)
786 787
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
Siju committed
788

789 790 791 792 793 794 795 796 797 798 799 800
// ArgWhere
bool ArgWhereRel(const Array<Type>& types,
                 int num_inputs,
                 const Attrs& attrs,
                 const TypeReporter& reporter) {
  CHECK_EQ(num_inputs, 1);
  auto tt = types[0].as<TensorTypeNode>();
  CHECK(tt != nullptr);
  const auto& input_shape = tt->shape;
  const auto& input_rank = input_shape.size();
  std::vector<IndexExpr> result_shape;
  result_shape.push_back(Any::make());
801
  result_shape.push_back(IntImm(DataType::Int(32), input_rank));
802
  reporter->Assign(types[1], TensorType(result_shape, DataType::Int(32)));
803 804 805
  return true;
}

806
TVM_REGISTER_GLOBAL("relay.op._make.argwhere")
807
.set_body_typed([](Expr data) {
808
  static const Op& op = Op::Get("argwhere");
809
  return Call(op, {data}, Attrs(), {});
810 811 812 813 814 815 816 817 818 819 820
});

RELAY_REGISTER_OP("argwhere")
.describe(R"doc(Find the indices of elements of a tensor that are
non-zero)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("condition", "Tensor", "The input condition tensor.")
.add_type_rel("ArgWhere", ArgWhereRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_support_level(10);
Siju committed
821

Siva committed
822 823 824 825 826 827 828 829 830 831 832 833 834
// Take
TVM_REGISTER_NODE_TYPE(TakeAttrs);

bool TakeRel(const Array<Type>& types,
             int num_inputs,
             const Attrs& attrs,
             const TypeReporter& reporter) {
  // `types` contains: [data, indices, result]
  CHECK_EQ(types.size(), 3);
  const auto* data = types[0].as<TensorTypeNode>();
  CHECK(data != nullptr);
  const auto* indices = types[1].as<TensorTypeNode>();
  CHECK(indices != nullptr);
835
  CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
Siva committed
836 837 838 839
  const auto param = attrs.as<TakeAttrs>();
  CHECK(param != nullptr);

  if (!param->axis.defined()) {
840
    std::vector<IndexExpr> oshape(indices->shape.begin(), indices->shape.end());
841
    reporter->Assign(types[2], TensorType(oshape, data->dtype));
Siva committed
842 843 844 845 846 847
    return true;
  }

  std::vector<IndexExpr> oshape;
  const auto ndim_data = static_cast<int>(data->shape.size());
  const auto ndim_indices = static_cast<int>(indices->shape.size());
848
  int axis = static_cast<int>(param->axis->value);
Siva committed
849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864
  if (axis < 0) axis += ndim_data;
  CHECK_LE(axis, ndim_data)
    << "axis should be with in data shape"
    << ", but got = " << axis;

  oshape.reserve(ndim_data - 1 + ndim_indices);
  for (int i = 0; i < axis; ++i) {
    oshape.emplace_back(data->shape[i]);
  }
  for (int i = 0; i < ndim_indices; ++i) {
    oshape.emplace_back(indices->shape[i]);
  }
  for (int i = axis+1; i < ndim_data; ++i) {
    oshape.emplace_back(data->shape[i]);
  }

865
  reporter->Assign(types[2], TensorType(oshape, data->dtype));
Siva committed
866 867 868
  return true;
}

869
Array<te::Tensor> TakeCompute(const Attrs& attrs,
870 871
                              const Array<te::Tensor>& inputs,
                              const Type& out_type) {
872 873 874
  const auto* param = attrs.as<TakeAttrs>();
  CHECK(param != nullptr);
  if (!param->axis.defined()) {
875
    return Array<te::Tensor>{ topi::take(inputs[0], inputs[1], param->mode) };
876
  } else {
877
    return Array<te::Tensor>{ topi::take(inputs[0], inputs[1], param->axis, param->mode) };
878 879 880
  }
}

Siva committed
881 882
Expr MakeTake(Expr data,
              Expr indices,
883 884
              Integer axis,
              std::string mode) {
885
  auto attrs = make_object<TakeAttrs>();
886
  attrs->axis = std::move(axis);
887
  attrs->mode = std::move(mode);
Siva committed
888
  static const Op& op = Op::Get("take");
889
  return Call(op, {data, indices}, Attrs(attrs), {});
Siva committed
890 891
}

892
TVM_REGISTER_GLOBAL("relay.op._make.take")
893
.set_body_typed(MakeTake);
Siva committed
894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917

RELAY_REGISTER_OP("take")
.describe(R"code(Take elements from an array along an axis.

When axis is not None, this function does the same thing as 'fancy' indexing
(indexing arrays using arrays); however, it can be easier to use if you need
elements along a given axis.

**Note** that when axis is none the flattened input array is used.

Examples::

  a = [[ 1, 2],
       [ 3, 4]]
  indices = [3, 0, 2]
  take(a, indices) = [ 4, 1, 3]

  a = [[ 1., 2.],
       [ 3., 4.]]
  indices = [1, 0]
  take(a, indices, axis=1) = [[ 2., 1.],
                              [ 4., 3.]]

)code" TVM_ADD_FILELINE)
918
.set_attrs_type<TakeAttrs>()
Siva committed
919 920 921
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices tensor.")
922
.set_support_level(3)
923 924 925 926
.add_type_rel("Take", TakeRel)
.set_attr<FTVMCompute>("FTVMCompute", TakeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

Siva committed
927

928
// Init ops
929
TVM_REGISTER_NODE_TYPE(InitOpAttrs);
930 931 932 933 934 935

bool FullRel(const Array<Type>& types,
             int num_inputs,
             const Attrs& attrs,
             const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 2);
936
  const InitOpAttrs* param = attrs.as<InitOpAttrs>();
937 938 939 940 941 942 943 944 945 946 947 948 949 950
  const auto* fill_value = types[0].as<TensorTypeNode>();
  if (fill_value == nullptr) {
    return false;
  }

  DataType out_dtype = param->dtype;
  if (out_dtype.bits() == 0) {
    out_dtype = fill_value->dtype;
  }

  CHECK_EQ(fill_value->shape.size(), 0)
    << "Fill value should be a scalar but has dimension "
    << fill_value->shape.size() << ".";

951
  reporter->Assign(types[1], TensorType(param->shape, out_dtype));
952 953 954
  return true;
}

955
Array<te::Tensor> FullCompute(const Attrs& attrs,
956 957
                              const Array<te::Tensor>& inputs,
                              const Type& out_type) {
958 959 960 961
  const auto* out_ttype = out_type.as<TensorTypeNode>();
  return { topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]()) };
}

962 963 964
Expr MakeFull(Expr fill_value,
              Array<IndexExpr> shape,
              DataType dtype) {
965
  auto attrs = make_object<InitOpAttrs>();
966 967 968
  attrs->shape = std::move(shape);
  attrs->dtype = std::move(dtype);
  static const Op& op = Op::Get("full");
969
  return Call(op, {fill_value}, Attrs(attrs), {});
970 971
}

972
TVM_REGISTER_GLOBAL("relay.op._make.full")
973
.set_body_typed(MakeFull);
974 975 976 977 978

RELAY_REGISTER_OP("full")
.describe(R"code(Fill array with scalar value.

)code" TVM_ADD_FILELINE)
979
.set_attrs_type<InitOpAttrs>()
980 981 982
.set_num_inputs(1)
.add_argument("fill_value", "double", "The value to fill.")
.set_support_level(3)
983 984 985
.add_type_rel("Full", FullRel)
.set_attr<FTVMCompute>("FTVMCompute", FullCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise);
986

987 988 989 990 991 992 993
bool InitOpRel(const Array<Type>& types,
               int num_inputs,
               const Attrs& attrs,
               const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 1);
  const InitOpAttrs* param = attrs.as<InitOpAttrs>();

994
  reporter->Assign(types[0], TensorType(param->shape, param->dtype));
995 996 997 998 999
  return true;
}

Expr MakeZeros(Array<IndexExpr> shape,
               DataType dtype) {
1000
  auto attrs = make_object<InitOpAttrs>();
1001 1002 1003
  attrs->shape = std::move(shape);
  attrs->dtype = std::move(dtype);
  static const Op& op = Op::Get("zeros");
1004
  return Call(op, {}, Attrs(attrs), {});
1005 1006
}

1007
TVM_REGISTER_GLOBAL("relay.op._make.zeros")
1008
.set_body_typed(MakeZeros);
1009 1010 1011 1012 1013

RELAY_REGISTER_OP("zeros")
.describe(R"code(Fill array with zeros.

)code" TVM_ADD_FILELINE)
1014
.set_attrs_type<InitOpAttrs>()
1015 1016 1017 1018 1019 1020
.set_num_inputs(0)
.set_support_level(3)
.add_type_rel("InitOp", InitOpRel);

Expr MakeOnes(Array<IndexExpr> shape,
              DataType dtype) {
1021
  auto attrs = make_object<InitOpAttrs>();
1022 1023 1024
  attrs->shape = std::move(shape);
  attrs->dtype = std::move(dtype);
  static const Op& op = Op::Get("ones");
1025
  return Call(op, {}, Attrs(attrs), {});
1026 1027
}

1028
TVM_REGISTER_GLOBAL("relay.op._make.ones")
1029
.set_body_typed(MakeOnes);
1030 1031 1032 1033 1034

RELAY_REGISTER_OP("ones")
.describe(R"code(Fill array with ones.

)code" TVM_ADD_FILELINE)
1035
.set_attrs_type<InitOpAttrs>()
1036 1037 1038 1039
.set_num_inputs(0)
.set_support_level(3)
.add_type_rel("InitOp", InitOpRel);

1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057
bool FullLikeRel(const Array<Type>& types,
                 int num_inputs,
                 const Attrs& attrs,
                 const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 3);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) {
    return false;
  }
  const auto* fill_value = types[1].as<TensorTypeNode>();
  if (fill_value == nullptr) {
    return false;
  }

  CHECK_EQ(fill_value->shape.size(), 0)
    << "The fill value should be a scalar but here it has dimension "
    << fill_value->shape.size() << ".";

1058
  reporter->Assign(types[2], TensorType(data->shape, data->dtype));
1059 1060 1061
  return true;
}

1062
Array<te::Tensor> FullLikeCompute(const Attrs& attrs,
1063 1064
                                  const Array<te::Tensor>& inputs,
                                  const Type& out_type) {
1065 1066 1067
  return { topi::full_like(inputs[0], inputs[1]()) };
}

1068 1069 1070
Expr MakeFullLike(Expr data,
                  Expr fill_value) {
  static const Op& op = Op::Get("full_like");
1071
  return Call(op, {data, fill_value}, Attrs(), {});
1072 1073
}

1074
TVM_REGISTER_GLOBAL("relay.op._make.full_like")
1075
.set_body_typed(MakeFullLike);
1076 1077 1078 1079 1080 1081 1082 1083 1084 1085

RELAY_REGISTER_OP("full_like")
.describe(R"code(Return an scalar value array with the same shape
and type as the input array.

)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("fill_value", "double", "Scalar value to fill.")
.set_support_level(3)
1086 1087 1088
.add_type_rel("FullLike", FullLikeRel)
.set_attr<FTVMCompute>("FTVMCompute", FullLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise);
1089

1090 1091 1092
// arange operator
TVM_REGISTER_NODE_TYPE(ArangeAttrs);

1093
double ToScalar(const runtime::NDArray& array) {
1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124
  if (array->dtype.code == kDLInt) {
    if (array->dtype.bits == 8) {
      return reinterpret_cast<int8_t*>(array->data)[0];
    } else if (array->dtype.bits == 16) {
      return reinterpret_cast<int16_t*>(array->data)[0];
    } else if (array->dtype.bits == 32) {
      return reinterpret_cast<int32_t*>(array->data)[0];
    } else if (array->dtype.bits == 64) {
      return reinterpret_cast<int64_t*>(array->data)[0];
    }
  } else if (array->dtype.code == kDLUInt) {
    if (array->dtype.bits == 8) {
      return reinterpret_cast<uint8_t*>(array->data)[0];
    } else if (array->dtype.bits == 16) {
      return reinterpret_cast<uint16_t*>(array->data)[0];
    } else if (array->dtype.bits == 32) {
      return reinterpret_cast<uint32_t*>(array->data)[0];
    } else if (array->dtype.bits == 64) {
      return reinterpret_cast<uint64_t*>(array->data)[0];
    }
  } else if (array->dtype.code == kDLFloat) {
#if (__ARM_FP16_FORMAT_IEEE == 1)
    if (array->dtype.bits == 16) {
      return reinterpret_cast<__fp16*>(array->data)[0];
    }
#endif
    if (array->dtype.bits == 32) {
      return reinterpret_cast<float*>(array->data)[0];
    } else if (array->dtype.bits == 64) {
      return reinterpret_cast<double*>(array->data)[0];
    }
1125
  }
1126
  LOG(FATAL) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype);
1127 1128
  // make compiler happy
  return -std::numeric_limits<double>::infinity();
1129 1130
}

1131 1132
bool ArangeRel(const Array<Type>& types,
               int num_inputs,
1133
               const Attrs& raw_attrs,
1134
               const TypeReporter& reporter) {
1135 1136 1137 1138 1139 1140
  CHECK_EQ(types.size(), 4);
  const ArangeAttrs* attrs = raw_attrs.as<ArangeAttrs>();
  const ConstantNode *cstart, *cstop, *cstep;

  reporter->Assign(types[0], types[1]);
  reporter->Assign(types[1], types[2]);
1141
  reporter->Assign(types[2], TensorType({}, attrs->dtype));
1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152

  if ((cstart = attrs->start.as<ConstantNode>()) &&
      (cstop = attrs->stop.as<ConstantNode>()) &&
      (cstep = attrs->step.as<ConstantNode>())) {
    double start = ToScalar(cstart->data);
    double stop = ToScalar(cstop->data);
    double step = ToScalar(cstep->data);
    int32_t num_elem = static_cast<int32_t>(std::ceil((stop - start) / step));
    CHECK_GT(num_elem, 0)
        << "Invalid arange attributes (start, stop, step): " << attrs->start
        << ", " << attrs->stop << ", " << attrs->step;
1153
    reporter->Assign(types[3], TensorType({num_elem}, attrs->dtype));
1154 1155
    return true;
  } else {
1156
    reporter->Assign(types[3], TensorType({Any::make()}, attrs->dtype));
1157
    return true;
1158
  }
1159 1160
}

1161 1162 1163
inline te::Tensor DynamicArange(const te::Tensor& start,
                                 const te::Tensor& stop,
                                 const te::Tensor& step,
1164 1165 1166
                                 tvm::DataType dtype,
                                 std::string name = "tensor",
                                 std::string tag = topi::kInjective) {
1167
  tvm::PrimExpr num_elem = tvm::tir::Var("num_elem");
1168
  return te::compute({num_elem}, [&](const Array<tvm::tir::Var>& indices) {
1169 1170
    return tvm::cast(dtype, start[0] + step[0] * indices[0]);
  }, name, tag);
1171 1172
}

1173
Array<te::Tensor> ArangeCompute(const Attrs& attrs,
1174 1175
                                const Array<te::Tensor>& inputs,
                                const Type& out_type) {
1176
  const ArangeAttrs* param = attrs.as<ArangeAttrs>();
1177 1178 1179
  te::Tensor start = inputs[0];
  te::Tensor stop =  inputs[1];
  te::Tensor step = inputs[2];
1180
  return { DynamicArange(start, stop, step, param->dtype) };
1181 1182
}

1183 1184 1185
Expr MakeArange(Expr start,
                Expr stop,
                Expr step,
1186
                DataType dtype) {
1187
  auto attrs = make_object<ArangeAttrs>();
1188 1189 1190 1191
  attrs->start = start;
  attrs->stop = stop;
  attrs->step = step;
  attrs->dtype = dtype;
1192
  static const Op& op = Op::Get("arange");
1193
  return Call(op, {start, stop, step}, Attrs(attrs), {});
1194 1195
}

1196
TVM_REGISTER_GLOBAL("relay.op._make.arange")
1197
.set_body_typed(MakeArange);
1198

1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211
// An issue with the existing design is that we require dependency
// to type the operator precisely.
//
// Supporting this in general is challenging so we duplicate the
// secondary arguments as args and attributes.
//
// In this way reify the arguments at both the value and type level.
//
// In the case our arguments are constant we can immediately recover
// the type of arange.
//
// In general I think we should avoid this pattern, and introduce
// a secondary shape analysis to recover more precise information.
1212 1213 1214 1215
RELAY_REGISTER_OP("arange")
.describe(R"code(Returns evenly spaced values within a given interval.

)code" TVM_ADD_FILELINE)
1216
.set_attrs_type<ArangeAttrs>()
1217
.set_num_inputs(3)
1218 1219 1220
.set_support_level(3)
.add_type_rel("Arange", ArangeRel)
.set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
1221 1222
// TODO(@icemelon): Change arange to kOpaque because FuseOps doesn't consider dynamic shape
.set_attr<TOpPattern>("TOpPattern", kOpaque)
1223
.set_attr<AnyCodegenStrategy>("AnyCodegenStrategy", kVariableDimensions);
1224

1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261
// repeat operator
TVM_REGISTER_NODE_TYPE(RepeatAttrs);

bool RepeatRel(const Array<Type>& types,
               int num_inputs,
               const Attrs& attrs,
               const TypeReporter& reporter) {
  // `types` contains: [data, result]
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) {
    CHECK(types[0].as<IncompleteTypeNode>())
        << "repeat: expect input type to be TensorType but get "
        << types[0];
    return false;
  }
  const auto* param = attrs.as<RepeatAttrs>();
  const int ndim = static_cast<int>(data->shape.size());
  const int repeats = param->repeats;
  const int axis = param->axis;
  CHECK(repeats >= 1)
    << "repeat only accepts `repeats >= 1`"
    << ", but got repeats = " << repeats;
  CHECK(-ndim - 1 <= axis && axis <= ndim)
    << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
    << ", but got axis = " << axis
    << ", and data.ndim = " << ndim;
  const int pivot = axis < 0 ? ndim + axis : axis;
  std::vector<IndexExpr> oshape;
  oshape.reserve(ndim + repeats);
  for (int i = 0; i < pivot; ++i) {
    oshape.emplace_back(data->shape[i]);
  }
  oshape.emplace_back(data->shape[pivot] * repeats);
  for (int i = pivot + 1; i < ndim; ++i) {
    oshape.emplace_back(data->shape[i]);
  }
1262
  reporter->Assign(types[1], TensorType(oshape, data->dtype));
1263 1264 1265
  return true;
}

1266
Array<te::Tensor> RepeatCompute(const Attrs& attrs,
1267 1268
                                const Array<te::Tensor>& inputs,
                                const Type& out_type) {
1269 1270 1271 1272 1273 1274
  const RepeatAttrs *param = attrs.as<RepeatAttrs>();
  CHECK(param != nullptr);
  return { topi::repeat(inputs[0], param->repeats, param->axis) };
}

Expr MakeRepeat(Expr data,
1275 1276
                int repeats,
                int axis) {
1277
  auto attrs = make_object<RepeatAttrs>();
1278 1279 1280
  attrs->repeats = repeats;
  attrs->axis = axis;
  static const Op& op = Op::Get("repeat");
1281
  return Call(op, {data}, Attrs(attrs), {});
1282 1283
}

1284
TVM_REGISTER_GLOBAL("relay.op._make.repeat")
1285
.set_body_typed(MakeRepeat);
1286 1287 1288 1289 1290 1291 1292 1293

RELAY_REGISTER_OP("repeat")
.describe(R"code(Repeat elements of an array `repeats` times along axis `axis`

- **data**: The input data to the operator.

)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
1294
.set_attrs_type<RepeatAttrs>()
1295
.add_argument("data", "Tensor", "The input tensor.")
1296
.set_support_level(3)
1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320
.add_type_rel("Repeat", RepeatRel)
.set_attr<FTVMCompute>("FTVMCompute", RepeatCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);

// tile operator
TVM_REGISTER_NODE_TYPE(TileAttrs);

bool TileRel(const Array<Type>& types,
             int num_inputs,
             const Attrs& attrs,
             const TypeReporter& reporter) {
  // `types` contains: [data, result]
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) {
    CHECK(types[0].as<IncompleteTypeNode>())
        << "tile: expect input type to be TensorType but get "
        << types[0];
    return false;
  }
  const auto* param = attrs.as<TileAttrs>();
  const size_t ndim = data->shape.size();
  const Array<Integer>& reps = param->reps;
  // check dimension match
1321
  CHECK(reps.defined())
1322 1323
    << "repetition array is not defined. data.ndim = " << ndim;
  const size_t rndim = reps.size();
1324
  for (size_t i = 0; i < rndim; ++i) {
1325
    if (const tvm::tir::IntImmNode* val = reps[i].as<tvm::tir::IntImmNode>()) {
1326 1327 1328 1329
      CHECK_GT(val->value, 0)
          << "Tile reps value should always be larger than 0, but get: " << val->value;
    }
  }
1330 1331 1332 1333 1334 1335 1336 1337
  size_t tndim = (ndim > rndim) ? ndim : rndim;
  // re-construct data shape or reps shape
  std::vector<IndexExpr> data_shape;
  std::vector<IndexExpr> reps_shape;
  data_shape.reserve(tndim);
  reps_shape.reserve(tndim);
  if (ndim == rndim) {
    for (size_t i = 0; i < tndim; ++i) {
1338 1339
      data_shape.emplace_back(data->shape[i]);
      reps_shape.emplace_back(reps[i]);
1340 1341
    }
  } else if (ndim > rndim) {
1342 1343 1344 1345 1346 1347 1348 1349 1350
    for (size_t i = 0; i < ndim; ++i) {
      data_shape.emplace_back(data->shape[i]);
    }
    for (size_t i = 0; i < (ndim - rndim); ++i) {
      reps_shape.emplace_back(1);
    }
    for (size_t i = 0; i < rndim; ++i) {
      reps_shape.emplace_back(reps[i]);
    }
1351
  } else {
1352 1353 1354 1355 1356 1357 1358 1359 1360
    for (size_t i = 0; i < rndim; ++i) {
      reps_shape.emplace_back(reps[i]);
    }
    for (size_t i = 0; i < (rndim - ndim); ++i) {
      data_shape.emplace_back(1);
    }
    for (size_t i = 0; i < ndim; ++i) {
      data_shape.emplace_back(data->shape[i]);
    }
1361 1362 1363 1364
  }
  std::vector<IndexExpr> oshape;
  oshape.reserve(tndim);
  for (size_t i = 0; i < tndim; ++i) {
1365
    // Save Any if it is dynamic shape
1366
    if (!data_shape[i].as<IntImmNode>()) {
1367 1368 1369 1370
      oshape.emplace_back(Any::make());
    } else {
      oshape.emplace_back(data_shape[i] * reps_shape[i]);
    }
1371
  }
1372
  reporter->Assign(types[1], TensorType(oshape, data->dtype));
1373 1374 1375
  return true;
}

1376
Array<te::Tensor> TileCompute(const Attrs& attrs,
1377 1378
                              const Array<te::Tensor>& inputs,
                              const Type& out_type) {
1379 1380 1381 1382 1383 1384 1385
  const TileAttrs *param = attrs.as<TileAttrs>();
  CHECK(param != nullptr);
  return { topi::tile(inputs[0], param->reps) };
}

Expr MakeTile(Expr data,
              Array<Integer> reps) {
1386
  auto attrs = make_object<TileAttrs>();
1387 1388
  attrs->reps = reps;
  static const Op& op = Op::Get("tile");
1389
  return Call(op, {data}, Attrs(attrs), {});
1390 1391
}

1392
TVM_REGISTER_GLOBAL("relay.op._make.tile")
1393
.set_body_typed(MakeTile);
1394 1395 1396 1397 1398 1399 1400 1401

RELAY_REGISTER_OP("tile")
.describe(R"code(Repeat the whole array multiple times.

- **data**: The input data to the operator.

)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
1402
.set_attrs_type<TileAttrs>()
1403
.add_argument("data", "Tensor", "The input tensor.")
1404
.set_support_level(3)
1405 1406 1407 1408
.add_type_rel("Tile", TileRel)
.set_attr<FTVMCompute>("FTVMCompute", TileCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);

1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435
// reverse operator
TVM_REGISTER_NODE_TYPE(ReverseAttrs);

bool ReverseRel(const Array<Type>& types,
               int num_inputs,
               const Attrs& attrs,
               const TypeReporter& reporter) {
  // `types` contains: [data, result]
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) {
    CHECK(types[0].as<IncompleteTypeNode>())
        << "reverse: expect input type to be TensorType but get "
        << types[0];
    return false;
  }
  const auto* param = attrs.as<ReverseAttrs>();
  const int ndim = static_cast<int>(data->shape.size());
  const int axis = param->axis;
  CHECK(-ndim <= axis && axis < ndim)
    << "reverse only accepts `axis` in [-data.ndim, data.ndim - 1]"
    << ", but got axis = " << axis
    << ", and data.ndim = " << ndim;
  reporter->Assign(types[1], types[0]);
  return true;
}

1436
Array<te::Tensor> ReverseCompute(const Attrs& attrs,
1437 1438
                                 const Array<te::Tensor>& inputs,
                                 const Type& out_type) {
1439 1440 1441 1442 1443 1444 1445
  const ReverseAttrs *param = attrs.as<ReverseAttrs>();
  CHECK(param != nullptr);
  return { topi::flip(inputs[0], param->axis) };
}

Expr MakeReverse(Expr data,
                 int axis) {
1446
  auto attrs = make_object<ReverseAttrs>();
1447 1448
  attrs->axis = axis;
  static const Op& op = Op::Get("reverse");
1449
  return Call(op, {data}, Attrs(attrs), {});
1450 1451
}

1452
TVM_REGISTER_GLOBAL("relay.op._make.reverse")
1453
.set_body_typed(MakeReverse);
1454 1455 1456 1457 1458 1459 1460 1461

RELAY_REGISTER_OP("reverse")
.describe(R"code(Reverses the order of elements along given `axis` while preserving array shape.

- **data**: The input data to the operator.

)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
1462
.set_attrs_type<ReverseAttrs>()
1463 1464 1465 1466 1467 1468
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Reverse", ReverseRel)
.set_attr<FTVMCompute>("FTVMCompute", ReverseCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

Zhi committed
1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493
// where operator
bool WhereRel(const Array<Type>& types,
              int num_inputs,
              const Attrs& attrs,
              const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 4U);
  const auto* condition = types[0].as<TensorTypeNode>();
  const auto* x = types[1].as<TensorTypeNode>();
  const auto* y = types[2].as<TensorTypeNode>();
  CHECK(condition != nullptr && x != nullptr && y != nullptr);

  const auto& cond_shape = condition->shape;
  const auto& x_shape = x->shape;
  const auto& y_shape = y->shape;
  CHECK(x_shape.size() == y_shape.size()) << "x and y must have the same size";

  if (cond_shape.size() != x_shape.size()) {
    CHECK_EQ(cond_shape.size(), 1)
        << "Shape of condition " << condition->shape
        << " must be either equal to x or has dimension of 1.";
  }
  for (size_t i = 0; i < x_shape.size(); i++) {
    CHECK(reporter->AssertEQ(x_shape[i], y_shape[i]))
        << "x and y must have the same shape: " << x_shape << " vs " << y_shape;

1494 1495 1496 1497
    if (i < cond_shape.size()) {
        CHECK(reporter->AssertEQ(cond_shape[i], x_shape[i]))
        << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape;
    }
Zhi committed
1498
  }
1499
  reporter->Assign(types[3], TensorType(x_shape, x->dtype));
Zhi committed
1500 1501 1502 1503 1504 1505
  return true;
}

// Positional relay function to create where operator.
Expr MakeWhere(const Expr& condition, const Expr& x, const Expr& y) {
  static const Op& op = Op::Get("where");
1506
  return Call(op, {condition, x, y});
Zhi committed
1507 1508
}

1509
Array<te::Tensor> WhereCompute(const Attrs& attrs,
1510 1511
                               const Array<te::Tensor>& inputs,
                               const Type& out_type) {
1512 1513 1514
  return { topi::where(inputs[0], inputs[1], inputs[2]) };
}

1515
TVM_REGISTER_GLOBAL("relay.op._make.where")
1516
.set_body_typed(MakeWhere);
Zhi committed
1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551

RELAY_REGISTER_OP("where")
.describe(R"code(
Return the elements, either from x or y, depending on the condition.

Given three ndarrays, condition, x, and y, return an ndarray with the elements
from x or y, depending on the elements from condition are true or false.
x and y must have the same shape. If condition has the same shape as x,
each element in the output array is from x if the corresponding element
in the condition is true, and from y if false.

If condition does not have the same shape as x, it must be a 1D array whose
size is the same as x’s first dimension size. Each row of the output array
is from x’s row if the corresponding element from condition is true, and
from y’s row if false.

Note that all non-zero values are interpreted as True in condition.

Examples::

  x = [[1, 2], [3, 4]]
  y = [[5, 6], [7, 8]]
  cond = [[0, 1], [-1, 0]]
  where(cond, x, y) = [[5, 2], [3, 8]]


  cond = [1, 0]
  where(cond, x, y) = [[1, 2], [7, 8]]

)code" TVM_ADD_FILELINE)
.add_argument("condition", "Tensor", "Condition array")
.add_argument("x", "Tensor", "First array to be selected")
.add_argument("y", "Tensor", "Second array to be selected")
.set_num_inputs(3)
.set_support_level(4)
1552 1553 1554
.add_type_rel("Where", WhereRel)
.set_attr<FTVMCompute>("FTVMCompute", WhereCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
Zhi committed
1555

1556 1557 1558 1559

// Squeeze
TVM_REGISTER_NODE_TYPE(SqueezeAttrs);

1560
Expr MakeSqueeze(Expr data,
1561
                 Array<Integer> axis) {
1562
  auto attrs = make_object<SqueezeAttrs>();
1563
  attrs->axis = std::move(axis);
1564
  static const Op& op = Op::Get("squeeze");
1565
  return Call(op, {data}, Attrs(attrs), {});
1566 1567
}

1568
TVM_REGISTER_GLOBAL("relay.op._make.squeeze")
1569
.set_body_typed(MakeSqueeze);
1570

1571

1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583
bool SqueezeRel(const Array<Type>& types,
                int num_inputs,
                const Attrs& attrs,
                const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) {
    return false;
  }
  const auto* param = attrs.as<SqueezeAttrs>();
  CHECK(param != nullptr);
  std::vector<IndexExpr> result_shape;
1584 1585
  // if axes is None, squeeze all axes of dimension 1
  if (!param->axis.defined()) {
1586
    for (const auto& e : data->shape) {
1587
      if (!e.as<IntImmNode>()) {
1588 1589
        LOG(FATAL) << "axis needs to be defined for dynamic input.";
      }
1590
      const int64_t* axis_ptr = tir::as_const_int(e);
1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601
      CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete";
      if (*axis_ptr != 1) {
        result_shape.push_back(e);
      }
    }
  } else {
    // pair up original shape with a boolean which control whether it will be in the final shape.
    std::vector<std::pair<IndexExpr, bool> > original_shape;
    for (const auto& e : data->shape) {
      original_shape.push_back(std::pair<IndexExpr, bool>(e, true));
    }
1602
    for (const auto& e : param->axis) {
1603 1604 1605 1606 1607 1608 1609
      int64_t axis_val = e->value;
      if (axis_val < 0) {
        axis_val += static_cast<int64_t>(original_shape.size());
      }
      CHECK_GE(axis_val, 0);
      CHECK_LT(axis_val, original_shape.size());
      original_shape.at(axis_val).second = false;
1610
    }
1611
    for (const auto& p : original_shape) {
1612 1613 1614
      if (p.second) {
        result_shape.push_back(p.first);
      } else {
1615
        const int64_t* axis_ptr = tir::as_const_int(p.first);
1616 1617 1618 1619 1620
        CHECK(axis_ptr != nullptr) << "cannot get concrete shape of input tensor";
        CHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not equal to 1";
      }
    }
  }
1621
  reporter->Assign(types[1], TensorType(result_shape, data->dtype));
1622 1623 1624
  return true;
}

1625
Array<te::Tensor> SqueezeCompute(const Attrs& attrs,
1626 1627
                                 const Array<te::Tensor>& inputs,
                                 const Type& out_type) {
1628 1629 1630 1631 1632 1633
  const SqueezeAttrs *param = attrs.as<SqueezeAttrs>();
  CHECK(param != nullptr);
  return { topi::squeeze(inputs[0], param->axis) };
}


1634 1635 1636 1637 1638 1639 1640
RELAY_REGISTER_OP("squeeze")
.describe(R"code(Squeeze the input tensor at the dimensions given by axes

- **data**: The input data to the operator.

)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
1641
.set_attrs_type<SqueezeAttrs>()
1642 1643
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
1644 1645 1646 1647
.add_type_rel("Squeeze", SqueezeRel)
.set_attr<FTVMCompute>("FTVMCompute", SqueezeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

1648

1649 1650 1651 1652 1653 1654 1655
// CollapseSumLike: <A, B> -> B where BroadCast(A, B) = A
bool CollapseSumLikeRel(const Array<Type>& types,
                        int num_inputs,
                        const Attrs& attrs,
                        const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 3);
  reporter->Assign(types[2], types[1]);
1656
  return BroadcastRel({types[0], types[1], types[0]}, 2, Attrs(), reporter);
1657 1658 1659 1660 1661
}

Expr MakeCollapseSumLike(Expr data,
                         Expr collapse_type) {
  static const Op& op = Op::Get("collapse_sum_like");
1662
  return Call(op, {data, collapse_type}, Attrs(), {});
1663 1664
}

1665
Array<te::Tensor> CollapseSumLikeCompute(const Attrs& attrs,
1666 1667
                                         const Array<te::Tensor>& inputs,
                                         const Type& out_type) {
1668 1669 1670 1671 1672
  const auto* out_ttype = out_type.as<TensorTypeNode>();
  CHECK(out_ttype != nullptr);
  return { topi::collapse_sum(inputs[0], out_ttype->shape) };
}

1673
TVM_REGISTER_GLOBAL("relay.op._make.collapse_sum_like")
1674
.set_body_typed(MakeCollapseSumLike);
1675 1676 1677 1678 1679 1680 1681 1682

RELAY_REGISTER_OP("collapse_sum_like")
.describe(R"code(Collapse the first input to match the shape of the second input.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("collapse_type", "Tensor", "Provide the type to collapse to.")
.set_support_level(10)
1683 1684 1685
.add_type_rel("CollapseSumLike", CollapseSumLikeRel)
.set_attr<FTVMCompute>("FTVMCompute", CollapseSumLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
1686

1687 1688 1689 1690 1691 1692 1693 1694 1695 1696
// BroadCastTo: <A, B> -> B where BroadCast(A, B) = B
bool BroadCastToRel(const Array<Type>& types,
                    int num_inputs,
                    const Attrs& attrs,
                    const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 2);
  auto ioattrs = attrs.as<InitOpAttrs>();
  CHECK(ioattrs);
  auto intt = types[0].as<TensorTypeNode>();
  if (intt == nullptr) { return false; }
1697
  auto type = TensorType(ioattrs->shape, intt->dtype);
1698
  reporter->Assign(types[1], type);
1699
  return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter);
1700 1701 1702 1703
}

Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape) {
  static const Op& op = Op::Get("broadcast_to");
1704
  auto attrs = make_object<InitOpAttrs>();
1705
  attrs->shape = std::move(shape);
1706
  return Call(op, {data}, Attrs(attrs), {});
1707 1708
}

1709
Array<te::Tensor> BroadCastToCompute(const Attrs& attrs,
1710 1711
                                     const Array<te::Tensor>& inputs,
                                     const Type& out_type) {
1712 1713 1714 1715 1716
  auto ioattrs = attrs.as<InitOpAttrs>();
  CHECK(ioattrs != nullptr);
  return { topi::broadcast_to(inputs[0], ioattrs->shape) };
}

1717
TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to")
1718
.set_body_typed(MakeBroadCastTo);
1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729

RELAY_REGISTER_OP("broadcast_to")
.describe(R"code(Broadcast the first input to match the shape argument.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(4)
.add_type_rel("BroadCastTo", BroadCastToRel)
.set_attr<FTVMCompute>("FTVMCompute", BroadCastToCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);

1730 1731 1732 1733 1734 1735 1736
// BroadCastToLike: <A, B> -> B where BroadCast(A, B) = B
bool BroadCastToLikeRel(const Array<Type>& types,
                        int num_inputs,
                        const Attrs& attrs,
                        const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 3);
  reporter->Assign(types[2], types[1]);
1737
  return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter);
1738 1739 1740 1741 1742
}

Expr MakeBroadCastToLike(Expr data,
                         Expr broadcast_type) {
  static const Op& op = Op::Get("broadcast_to_like");
1743
  return Call(op, {data, broadcast_type}, Attrs(), {});
1744 1745
}

1746
Array<te::Tensor> BroadCastToLikeCompute(const Attrs& attrs,
1747 1748
                                         const Array<te::Tensor>& inputs,
                                         const Type& out_type) {
1749 1750 1751 1752 1753
  const auto* out_ttype = out_type.as<TensorTypeNode>();
  CHECK(out_ttype != nullptr);
  return { topi::broadcast_to(inputs[0], out_ttype->shape) };
}

1754
TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to_like")
1755
.set_body_typed(MakeBroadCastToLike);
1756 1757 1758 1759 1760 1761 1762 1763

RELAY_REGISTER_OP("broadcast_to_like")
.describe(R"code(Broadcast the first input to match the shape of the second input.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.")
.set_support_level(10)
1764 1765 1766
.add_type_rel("BroadCastToLike", BroadCastToLikeRel)
.set_attr<FTVMCompute>("FTVMCompute", BroadCastToLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
1767

1768

1769 1770 1771
// Adapter function to make int array.
Array<Integer> GetIntArray(Array<IndexExpr> arr) {
  for (size_t i = 0; i < arr.size(); ++i) {
1772
    CHECK(!arr[i].defined() || arr[i].as<IntImmNode>())
1773 1774
      << "Expect an int array";
  }
1775
  return Downcast<Array<Integer> >(arr);
1776 1777 1778
}


1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849
// strided_slice
TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
bool StridedSliceRel(const Array<Type>& types,
                     int num_inputs,
                     const Attrs& attrs,
                     const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) return false;

  const StridedSliceAttrs *param = attrs.as<StridedSliceAttrs>();
  CHECK(param != nullptr);

  auto dshape = data->shape;
  auto num_axis = dshape.size();

  std::vector<int64_t> stride_vec;
  for (Integer i : param->strides) {
    CHECK(i.defined());
    stride_vec.push_back(i->value);
  }
  for (size_t i = stride_vec.size(); i < num_axis; ++i) {
    stride_vec.push_back(1);
  }
  const int64_t max_range = std::numeric_limits<int64_t>::max();

  std::vector<int64_t> begin_vec;
  for (size_t i = 0; i < param->begin.size(); ++i) {
    if (!param->begin[i].defined()) {
      // value=None
      begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
    } else {
      begin_vec.push_back(param->begin[i]->value);
    }
  }
  for (size_t i = begin_vec.size(); i < num_axis; ++i) {
    begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
  }

  std::vector<int64_t> end_vec;
  for (size_t i = 0; i < param->end.size(); ++i) {
    // allow end to be None
    if (!param->end[i].defined()) {
      end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
    } else {
      end_vec.push_back(param->end[i]->value);
    }
  }
  for (size_t i = end_vec.size(); i < num_axis; ++i) {
    end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
  }

  std::vector<IndexExpr> oshape(dshape.size());
  for (size_t i = 0; i < num_axis; ++i) {
    int64_t stride_v = stride_vec[i];
    int64_t begin_v = begin_vec[i];
    int64_t end_v = end_vec[i];

    if ((stride_v == 1 &&
         begin_v == 0 &&
         end_v == max_range) ||
        (stride_v == -1 &&
         begin_v == max_range &&
         end_v == 0)) {
      // Quick path, do not slice this dimension.
      oshape[i] = dshape[i];
      continue;
    }
    // Normal path, require the shape to be concrete integer.
    // Require concrete integer as symbolic inference of min/max
    // can get complicated and not very helpful.
1850
    const int64_t* p_dim_size = tir::as_const_int(dshape[i]);
1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873
    CHECK(p_dim_size)
        << "strided_slice requires sliced dimension to be concrete int";
    int64_t dim_size = p_dim_size[0];
    begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v;
    end_v = (end_v < 0) ? dim_size + end_v : end_v;

    int64_t slice_range, step;
    if (stride_v < 0) {
      if (end_v < -1) end_v = -1;
      CHECK_LT(end_v, begin_v)
          << "strided_slice get empty slice at axis " << i;
      begin_v = std::min(dim_size - 1, begin_v);
      slice_range = begin_v - end_v;
      step = -stride_v;
    } else {
      if (begin_v < 0) begin_v = 0;
      CHECK_GE(stride_v, 0);
      CHECK_LT(begin_v, end_v)
          << "strided_slice get empty slice at axis " << i;
      end_v = std::min(dim_size, end_v);
      slice_range = end_v - begin_v;
      step = stride_v;
    }
1874
    oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step);
1875
  }
1876
  reporter->Assign(types[1], TensorType(oshape, data->dtype));
1877 1878 1879 1880
  return true;
}


1881 1882 1883 1884
Array<Array<Layout> > StridedSliceInferCorrectLayout(
    const Attrs& attrs,
    const Array<Layout>& new_in_layouts,
    const Array<Layout>& old_in_layouts,
1885 1886 1887 1888 1889 1890 1891 1892
    const Array<tvm::relay::Type>& old_in_types) {

  Array<Array<IndexExpr>> old_in_shapes;
  for (auto old_in_t : old_in_types) {
    CHECK(old_in_t.as<TensorTypeNode>());
    old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
  }

1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928
  CHECK(old_in_layouts.defined());
  CHECK_EQ(old_in_layouts.size(), 1);
  CHECK(old_in_shapes.defined());
  CHECK_EQ(old_in_shapes.size(), 1);

  auto layout = old_in_layouts[0];
  if (layout.defined() && new_in_layouts.defined()) {
    CHECK_EQ(new_in_layouts.size(), 1);
    auto new_layout = new_in_layouts[0];
    auto shape = old_in_shapes[0];

    // NOTE: Discard "const" qualifier here.
    auto *params = const_cast<StridedSliceAttrs*>(attrs.as<StridedSliceAttrs>());

    Array<Integer> new_begin, new_end;

    for (size_t i = 0; i < params->begin.size(); i++) {
      const LayoutAxis& axis = layout[i];
      if (!axis.IsPrimal()) {
        // original layout that contains splitted axes is not supported
        return {{Layout::Undef()}, {Layout::Undef()}};
      }
      auto factor = new_layout.FactorOf(axis);
      if (factor == -1) {
        new_begin.push_back(params->begin[i]);
        new_end.push_back(params->end[i]);
      } else {
        if (params->strides.defined() && i < params->strides.size()) {
          auto stride = params->strides[i];
          // arbitrary stride is not supported
          if (stride.defined() && stride->value != 1) {
            return {{Layout::Undef()}, {Layout::Undef()}};
          }
        }
        int64_t begin = params->begin[i].defined() ? params->begin[i]->value : 0;
        int64_t end = params->end[i].defined() ? params->end[i]->value :
1929
            shape[i].as<IntImmNode>()->value;
1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945
        if (begin % factor || end % factor) {
          // transform to original layout
          return {{Layout::Undef()}, {Layout::Undef()}};
        }
        new_begin.push_back(tvm::Integer(begin / factor));
        new_end.push_back(tvm::Integer(end / factor));
      }
    }
    layout = new_layout;
    params->begin = new_begin;
    params->end = new_end;
  }
  return {{layout}, {layout}};
}


1946 1947 1948 1949 1950
// Positional relay function to create StridedSlice operator used by frontend FFI.
Expr MakeStridedSlice(Expr data,
                      Array<Integer> begin,
                      Array<Integer> end,
                      Array<Integer> strides) {
1951
  auto attrs = make_object<StridedSliceAttrs>();
1952 1953 1954 1955
  attrs->begin = std::move(begin);
  attrs->end = std::move(end);
  attrs->strides = std::move(strides);
  static const Op& op = Op::Get("strided_slice");
1956
  return Call(op, {data}, Attrs(attrs), {});
1957 1958
}

1959
Array<te::Tensor> StridedSliceCompute(const Attrs& attrs,
1960 1961
                                      const Array<te::Tensor>& inputs,
                                      const Type& out_type) {
1962 1963
  const StridedSliceAttrs *param = attrs.as<StridedSliceAttrs>();
  CHECK(param != nullptr);
1964
  return Array<te::Tensor>{
1965 1966 1967 1968 1969
    topi::strided_slice(inputs[0], param->begin, param->end, param->strides)
  };
}


1970
TVM_REGISTER_GLOBAL("relay.op._make.strided_slice")
1971
.set_body_typed(MakeStridedSlice);
1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000


RELAY_REGISTER_OP("strided_slice")
    .describe(R"code(Strided slice of an array.

Examples::

  x = [[  1.,   4.,   7.,  10.],
       [  2.,   5.,   8.,  11.],
       [  3.,   6.,   9.,  12.]]

  strided_slice(x, begin=[0, 1], end=[2, 4], stride=[1, 1]) = [[ 4.,  7.,  10.],
                                                               [ 5.,  8.,  11.]]

  x = [[[ 1.,  2.],
        [ 3.,  4.]],

       [[ 5.,  6.],
        [ 7.,  8.]]]

  strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1.,  2.],
                                                 [ 3.,  4.]],

                                                [[ 5.,  6.],
                                                 [ 7.,  8.]]]
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(4)
2001
.set_attrs_type<StridedSliceAttrs>()
2002 2003
.add_type_rel("StridedSlice", StridedSliceRel)
.set_attr<FTVMCompute>("FTVMCompute", StridedSliceCompute)
2004 2005
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", StridedSliceInferCorrectLayout);
2006

2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022
// strided_set
bool StridedSetRel(const Array<Type>& types,
                   int num_inputs,
                   const Attrs& attrs,
                   const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 6);
  reporter->Assign(types[5], types[0]);
  return true;
}

Expr MakeStridedSet(Expr data,
                    Expr v,
                    Expr begin,
                    Expr end,
                    Expr strides) {
  static const Op& op = Op::Get("strided_set");
2023
  return Call(op, {data, v, begin, end, strides}, {});
2024 2025
}

2026
TVM_REGISTER_GLOBAL("relay.op._make.strided_set")
2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054
.set_body_typed(MakeStridedSet);


RELAY_REGISTER_OP("strided_set")
  .describe(R"code(Strided set of an array.
Example::

  x = [[  1.,   4.,   7.,  10.],
       [  2.,   5.,   8.,  11.],
       [  3.,   6.,   9.,  12.]]

  v = [[ 11., 22., 33.]
       [ 44., 55., 66.]]

  strided_set(x, v, begin=[0, 1], end=[2, 4], stride=[1, 1]) = \
      [[  1.,  11.,  22.,  33.],
       [  2.,  44.,  55.,  66.],
       [  3.,   6.,   9.,  12.]]
)code" TVM_ADD_FILELINE)
.set_num_inputs(5)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("v", "Tensor", "The data to set.")
.add_argument("begin", "Tensor", "Indices for the start of the slice.")
.add_argument("end", "Tensor", "Indices indicating the end of the slice.")
.add_argument("strides", "Tensor", "The strides values.")
.set_support_level(4)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.add_type_rel("StridedSet", StridedSetRel);
2055

2056
// relay.split
Siva committed
2057 2058 2059 2060 2061 2062 2063 2064 2065
TVM_REGISTER_NODE_TYPE(SplitAttrs);

bool SplitRel(const Array<Type>& types,
              int num_inputs,
              const Attrs& attrs,
              const TypeReporter& reporter) {
  // `types` contains: [data, result]
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
2066
  if (data == nullptr) return false;
Siva committed
2067 2068 2069 2070 2071 2072 2073 2074 2075
  CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty";
  const auto param = attrs.as<SplitAttrs>();
  CHECK(param != nullptr);
  auto axis = param->axis;
  if (axis < 0) {
    axis += data->shape.size();
  }
  CHECK_LT(axis, data->shape.size())
    << "axis should be within the input dimension range.";
2076
  CHECK_GE(axis, 0)
Siva committed
2077 2078
    << "axis should be within the input dimension range.";

2079
  if (const IntImmNode* sections = param->indices_or_sections.as<IntImmNode>()) {
2080
    CHECK(reporter->Assert(indexmod(data->shape[axis],
2081
                                    sections->value) == tir::make_zero(DataType::Int(64))))
Siva committed
2082 2083 2084
        << "indices_or_sections need to be able to divide input.shape[axis]";
    std::vector<Type> fields;
    for (int i = 0; i < sections->value; ++i) {
2085
        std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
2086
        oshape[axis] = indexdiv(oshape[axis], sections->value);
2087
        auto vec_type = TensorType(oshape, data->dtype);
Siva committed
2088 2089
        fields.push_back(vec_type);
    }
2090
    reporter->Assign(types[1], TupleType(Array<Type>(fields)));
Siva committed
2091 2092
  } else {
    auto indices = param->indices_or_sections.as<ArrayNode>()->data;
2093
    auto begin = IndexExpr(tir::make_zero(DataType::Int(32)));
Siva committed
2094
    std::vector<Type> fields;
2095
    for (unsigned int i = 0; i < indices.size(); ++i) {
2096
      CHECK(reporter->Assert(Downcast<IndexExpr>(indices[i]) > begin))
Siva committed
2097
          << "indices_or_sections need to be a sorted ascending list";
2098
      std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
2099 2100
      oshape[axis] = Downcast<IndexExpr>(indices[i]) - begin;
      begin = Downcast<IndexExpr>(indices[i]);
2101
      auto vec_type = TensorType(oshape, data->dtype);
Siva committed
2102 2103 2104 2105
      fields.push_back(vec_type);
    }
    CHECK(reporter->Assert(begin < data->shape[axis]))
        << "The sum of sections must match the input.shape[axis]";
2106
    std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
Siva committed
2107
    oshape[axis] = data->shape[axis] - begin;
2108
    auto vec_type = TensorType(oshape, data->dtype);
Siva committed
2109
    fields.push_back(vec_type);
2110
    reporter->Assign(types[1], TupleType(Array<Type>(fields)));
Siva committed
2111 2112 2113 2114
  }
  return true;
}

2115
Array<te::Tensor> SplitCompute(const Attrs& attrs,
2116 2117
                               const Array<te::Tensor>& inputs,
                               const Type& out_type) {
2118 2119 2120
  const auto param = attrs.as<SplitAttrs>();
  CHECK(param != nullptr);

2121
  if (const IntImmNode* sections = param->indices_or_sections.as<IntImmNode>()) {
2122
    int64_t num_sections = sections->value;
2123
    return Array<te::Tensor>{
2124 2125 2126
      topi::split_sections(inputs[0], num_sections, param->axis) };
  } else {
    auto indices = Downcast<Array<Integer> >(param->indices_or_sections);
2127
    return Array<te::Tensor>{ topi::split(inputs[0], indices, param->axis) };
2128 2129 2130
  }
}

Siva committed
2131
Expr MakeSplit(Expr data,
2132
               ObjectRef indices_or_sections,
Siva committed
2133
               int axis) {
2134
  auto attrs = make_object<SplitAttrs>();
Siva committed
2135 2136 2137
  attrs->axis = axis;
  attrs->indices_or_sections = std::move(indices_or_sections);
  static const Op& op = Op::Get("split");
2138
  return Call(op, {data}, Attrs(attrs), {});
Siva committed
2139 2140
}

2141
TVM_REGISTER_GLOBAL("relay.op._make.split")
Siva committed
2142 2143
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
    if (args.type_codes[1] == kDLInt) {
2144 2145 2146 2147 2148 2149
      // Note: we change it from Int(64) to Int(32) for now as
      // combine_parallel_dense will transform the graph with Int(32).
      // More invetigation is needs to check which one we should use.
      *rv = MakeSplit(args[0],
                      tir::make_const(DataType::Int(32), static_cast<int>(args[1])),
                      args[2]);
Siva committed
2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165
    } else {
      *rv = MakeSplit(args[0], args[1], args[2]);
    }
});

RELAY_REGISTER_OP("split")
.describe(R"code(Splits an array along a particular axis into multiple sub-arrays.

Indices or sections to split into. Accepts an int or a tuple
If indices_or_sections is an integer, the input will be divided equally
along given axis. If such a split is not possible, an error is raised.

If indices_or_sections is a tuple of sorted integers,
the entries indicate where along axis the array is split.

)code" TVM_ADD_FILELINE)
2166
.set_attrs_type<SplitAttrs>()
Siva committed
2167 2168 2169
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
2170 2171 2172
.add_type_rel("Split", SplitRel)
.set_attr<FTVMCompute>("FTVMCompute", SplitCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
Siva committed
2173

2174

2175
// relay.slice_like
2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203
TVM_REGISTER_NODE_TYPE(SliceLikeAttrs);

/*!
* \brief SliceLikeRel User defined type constraint function.
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return False if the relation has not been resolved, it might be resolved later.
*  True if this relation has been resolved.
*/
bool SliceLikeRel(const Array<Type>& types,
                  int num_inputs,
                  const Attrs& attrs,
                  const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 3);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) {
    return false;
  }

  const auto* target = types[1].as<TensorTypeNode>();
  if (target == nullptr) {
    return false;
  }

  const auto param = attrs.as<SliceLikeAttrs>();
  CHECK(param != nullptr);

2204 2205 2206
  const Array<IndexExpr>& dshape = data->shape;
  const Array<IndexExpr>& target_shape = target->shape;
  std::vector<IndexExpr> oshape(dshape.begin(), dshape.end());
2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233

  if (!param->axes.defined()) {
    for (size_t i = 0; i < dshape.size(); ++i) {
      if (i < target_shape.size()) {
        oshape[i] = target_shape[i];
        CHECK(reporter->Assert(oshape[i] <= dshape[i]))
          << "End index of axis " << i << " exceeds input shape: "
          << oshape[i] << " vs " << dshape[i];
      }
    }
  } else {
    CHECK(param->axes.size() != 0) << "Axes cannot be empty.";
    for (Integer val : param->axes) {
      int axis = val->value;
      if (axis < 0) {
        axis += dshape.size();
      }
      CHECK(axis < static_cast<int>(target_shape.size()))
        << "Axis " << axis << " exceeds dimension "
        << target_shape.size() << " of target_shape.";
      oshape[axis] = target_shape[axis];
      CHECK(reporter->Assert(oshape[axis] <= dshape[axis]))
        << "End index of axis " << axis << " exceeds input shape: "
        << oshape[axis] << " vs " << dshape[axis];
    }
  }

2234
  reporter->Assign(types[2], TensorType(oshape, data->dtype));
2235 2236 2237 2238 2239 2240 2241
  return true;
}


Expr MakeSliceLike(Expr data,
                   Expr shape_like,
                   Array<Integer> axes) {
2242
  auto attrs = make_object<SliceLikeAttrs>();
2243 2244
  attrs->axes = std::move(axes);
  static const Op& op = Op::Get("slice_like");
2245
  return Call(op, {data, shape_like}, Attrs(attrs), {});
2246 2247
}

2248
Array<te::Tensor> SliceLikeCompute(const Attrs& attrs,
2249 2250
                                   const Array<te::Tensor>& inputs,
                                   const Type& out_type) {
2251
  const auto* param = attrs.as<SliceLikeAttrs>();
2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284
  CHECK(param != nullptr);
  Array<IndexExpr> src_shape = inputs[0]->shape;
  Array<IndexExpr> target_shape = inputs[1]->shape;
  Array<IndexExpr> begin_idx, end_idx, strides;
  for (size_t i = 0; i < src_shape.size(); ++i) {
    begin_idx.push_back(0);
    strides.push_back(1);
  }
  end_idx = Array<IndexExpr>(src_shape);
  if (!param->axes.defined()) {
    for (size_t i = 0; i < src_shape.size(); ++i) {
      if (i < target_shape.size()) {
        end_idx.Set(i, target_shape[i]);
        CHECK_LE(topi::GetConstInt(end_idx[i]),
                 topi::GetConstInt(src_shape[i]))
          << "End index of axis " << i << " exceeds input shape: "
          << topi::GetConstInt(end_idx[i]) << " vs "
          << topi::GetConstInt(src_shape[i]);
      }
    }
  } else {
    for (int axis : param->axes) {
      if (axis < 0) {
        axis = static_cast<int>(src_shape.size()) + axis;
      }
      end_idx.Set(axis, target_shape[axis]);
      CHECK_LE(topi::GetConstInt(end_idx[axis]),
               topi::GetConstInt(src_shape[axis]))
        << "End index of axis " << axis << " exceeds input shape: "
        << topi::GetConstInt(end_idx[axis]) << " vs "
        << topi::GetConstInt(src_shape[axis]);
    }
  }
2285
  return Array<te::Tensor>{
2286 2287 2288 2289 2290 2291 2292 2293
    topi::strided_slice(inputs[0],
                        GetIntArray(begin_idx),
                        GetIntArray(end_idx),
                        GetIntArray(strides))
  };
}


2294
TVM_REGISTER_GLOBAL("relay.op._make.slice_like")
2295
.set_body_typed(MakeSliceLike);
2296 2297 2298 2299 2300


RELAY_REGISTER_OP("slice_like")
.describe(R"code(Slice the first input respect to the second input.
)code" TVM_ADD_FILELINE)
2301
.set_attrs_type<SliceLikeAttrs>()
2302 2303 2304 2305 2306
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("shape_like", "Tensor", "Shape tensor.")
.set_support_level(10)
.add_type_rel("SliceLike", SliceLikeRel)
2307 2308
.set_attr<FTVMCompute>("FTVMCompute", SliceLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
2309

2310
// relay.layout_transform
2311 2312
TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);

2313
Array<te::Tensor> LayoutTransformCompute(const Attrs& attrs,
2314 2315
                                         const Array<te::Tensor>& inputs,
                                         const Type& out_type) {
2316
  const auto* param = attrs.as<LayoutTransformAttrs>();
2317
  CHECK(param != nullptr);
2318
  return Array<te::Tensor>{
2319
    topi::layout_transform(inputs[0], param->src_layout, param->dst_layout)
2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335
  };
}

bool LayoutTransformRel(const Array<Type>& types,
                        int num_inputs,
                        const Attrs& attrs,
                        const TypeReporter& reporter) {
  const auto* data = types[0].as<TensorTypeNode>();
  CHECK(data != nullptr);
  const LayoutTransformAttrs* params = attrs.as<LayoutTransformAttrs>();

  Layout src_layout(params->src_layout);
  Layout dst_layout(params->dst_layout);

  CHECK(src_layout.defined() && dst_layout.defined())
    << "cannot convert from/to undefined layout";
2336

2337
  auto layout_converter = tir::BijectiveLayout(src_layout, dst_layout);
2338
  CHECK(layout_converter.defined())
2339 2340
    << "cannot convert from " << params->src_layout << " to " << params->dst_layout;

2341
  const auto& out_shape = layout_converter.ForwardShape(data->shape);
2342
  reporter->Assign(types[1], TensorType(out_shape, data->dtype));
2343 2344 2345 2346 2347 2348
  return true;
}

Expr MakeLayoutTransform(Expr data,
                         std::string src_layout,
                         std::string dst_layout) {
2349
  auto attrs = make_object<LayoutTransformAttrs>();
2350 2351 2352
  attrs->src_layout = std::move(src_layout);
  attrs->dst_layout = std::move(dst_layout);
  static const Op& op = Op::Get("layout_transform");
2353
  return Call(op, {data}, Attrs(attrs), {});
2354 2355
}

2356
TVM_REGISTER_GLOBAL("relay.op._make.layout_transform")
2357
.set_body_typed(MakeLayoutTransform);
2358 2359 2360 2361 2362 2363 2364 2365

RELAY_REGISTER_OP("layout_transform")
.describe(R"code(Transform the input data layout.

For transforming from NCHW to N16cHWC, the `__layout_transform__` operator reshapes
the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]

)code" TVM_ADD_FILELINE)
2366
.set_attrs_type<LayoutTransformAttrs>()
2367 2368 2369 2370 2371 2372
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.add_type_rel("layout_transform", LayoutTransformRel)
.set_support_level(5)
.set_attr<FTVMCompute>("FTVMCompute", LayoutTransformCompute);

2373 2374 2375 2376

/* relay._contrib_reverse_reshape */
Expr MakeReverseReshape(Expr data,
                        Array<Integer> newshape) {
2377
  auto attrs = make_object<ReshapeAttrs>();
2378 2379 2380
  attrs->newshape = std::move(newshape);
  attrs->reverse = true;
  static const Op& op = Op::Get("_contrib_reverse_reshape");
2381
  return Call(op, {data}, Attrs(attrs), {});
2382 2383
}

2384
TVM_REGISTER_GLOBAL("relay.op._make._contrib_reverse_reshape")
2385
.set_body_typed(MakeReverseReshape);
2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401

RELAY_REGISTER_OP("_contrib_reverse_reshape")
.describe(R"code(Reshapes the input array where the special values are inferred from
right to left.

Example::

The special values have the same semantics as reshape. The difference is that
special values are inferred from right to left. It can be explained in the
example below::

- data.shape = (10,5,4), newshape = (-1,0), reshape results in (40,5)
- data.shape = (10,5,4), newshape = (-1,0), reverse_reshape results in (40,5)

)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
2402
.set_attrs_type<ReshapeAttrs>()
2403 2404 2405 2406 2407 2408
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(10)
.add_type_rel("Reshape", ReshapeRel)
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430
// gather_nd operator
bool GatherNDRel(const Array<Type>& types,
                 int num_inputs,
                 const Attrs& attrs,
                 const TypeReporter& reporter) {
  // `types` contains: [data, indices, result]
  CHECK_EQ(types.size(), 3);
  const auto* data = types[0].as<TensorTypeNode>();
  const auto* indices = types[1].as<TensorTypeNode>();
  if (data == nullptr) {
    CHECK(types[0].as<IncompleteTypeNode>())
        << "GatherND: expect input data type to be TensorType but get "
        << types[0];
    return false;
  }
  if (indices == nullptr) {
    CHECK(types[1].as<IncompleteTypeNode>())
        << "GatherND: expect indices type to be TensorType but get "
        << types[1];
    return false;
  }
  const size_t ndim = data->shape.size();
2431
  const IntImmNode* mdim = indices->shape[0].as<IntImmNode>();
2432 2433 2434 2435 2436 2437 2438 2439 2440
  const size_t kdim = indices->shape.size() - 1;
  CHECK(size_t(mdim->value) <= ndim)
        << "GatherND: indices shape does satisfy.";

  Array<IndexExpr> oshape;
  for (size_t i = 1; i < kdim + 1; ++i)
      oshape.push_back(indices->shape[i]);
  for (size_t i = mdim->value; i < ndim; ++i)
      oshape.push_back(data->shape[i]);
2441
  reporter->Assign(types[2], TensorType(oshape, data->dtype));
2442 2443 2444
  return true;
}

2445
Array<te::Tensor> GatherNDCompute(const Attrs& attrs,
2446 2447
                                  const Array<te::Tensor>& inputs,
                                  const Type& out_type) {
2448 2449 2450 2451 2452 2453
  return { topi::gather_nd(inputs[0], inputs[1]) };
}

Expr MakeGatherND(Expr data,
                  Expr indices) {
  static const Op& op = Op::Get("gather_nd");
2454
  return Call(op, {data, indices}, {});
2455 2456
}

2457
TVM_REGISTER_GLOBAL("relay.op._make.gather_nd")
2458
.set_body_typed(MakeGatherND);
2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475

RELAY_REGISTER_OP("gather_nd")
.describe(R"code(Gather elements or slices from data and store to
                 a tensor whose shape is defined by indices.

Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with
shape (M, Y_0, ..., Y_{K-1}), the output will have shape
(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N,
output shape will simply be (Y_0, ..., Y_{K-1}).
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("GatherND", GatherNDRel)
.set_attr<FTVMCompute>("FTVMCompute", GatherNDCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492
// relay.sequence_mask
TVM_REGISTER_NODE_TYPE(SequenceMaskAttrs);

bool SequenceMaskRel(const Array<Type>& types,
                     int num_inputs,
                     const Attrs& attrs,
                     const TypeReporter& reporter) {
  // `types` contains: [data, valid_length, result]
  CHECK_EQ(types.size(), 3);
  const auto* data = types[0].as<TensorTypeNode>();
  const auto* valid_length = types[1].as<TensorTypeNode>();
  CHECK(data);
  CHECK(valid_length);
  const auto param = attrs.as<SequenceMaskAttrs>();
  Array<IndexExpr> valid_length_shape;
  CHECK(param->axis == 0 || param->axis == 1);
  valid_length_shape.push_back(data->shape[1 - param->axis]);
2493
  reporter->Assign(types[1], TensorType(valid_length_shape, valid_length->dtype));
2494 2495 2496 2497
  reporter->Assign(types[2], types[0]);
  return true;
}

2498
Array<te::Tensor> SequenceMaskCompute(const Attrs& attrs,
2499 2500
                                      const Array<te::Tensor>& inputs,
                                      const Type& out_type) {
2501 2502
  const auto* param = attrs.as<SequenceMaskAttrs>();
  CHECK(param != nullptr);
2503
  return Array<te::Tensor>{
2504
    topi::sequence_mask(inputs[0], inputs[1], param->mask_value, param->axis) };
2505 2506 2507 2508 2509 2510
}

Expr MakeSequenceMask(Expr data,
                      Expr valid_length,
                      double mask_value,
                      int axis) {
2511
  auto attrs = make_object<SequenceMaskAttrs>();
2512 2513 2514
  attrs->mask_value = std::move(mask_value);
  attrs->axis = std::move(axis);
  static const Op& op = Op::Get("sequence_mask");
2515
  return Call(op, {data, valid_length}, Attrs(attrs), {});
2516 2517
}

2518
TVM_REGISTER_GLOBAL("relay.op._make.sequence_mask")
2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561 2562 2563 2564 2565 2566 2567 2568 2569
.set_body_typed(MakeSequenceMask);

RELAY_REGISTER_OP("sequence_mask")
.describe(R"code(Sets all elements outside the expected length of the sequence to a constant value.

This function takes an n-dimensional input array of the form [MAX_LENGTH, batch_size, ...] or
[batch_size, MAX_LENGTH, ...] and returns an array of the same shape.

`axis` means the axis of the length dimension and can only be 0 or 1. If axis is 0,
the data must have shape [MAX_LENGTH, batch_size, ...]. Otherwise (axis=1), the data must have
shape [batch_size, MAX_LENGTH, ...].

`valid_length` gives the length of each sequence. `valid_length` should be
a 1D int array with positive ints and has dimension [batch_size,].

Examples::

  x = [[[  1.,   2.,   3.],
        [  4.,   5.,   6.]],

       [[  7.,   8.,   9.],
        [ 10.,  11.,  12.]],

       [[ 13.,  14.,   15.],
        [ 16.,  17.,   18.]]]

  // valid_length [1, 1] means only the first block of each batch will be kept
  // and other blocks are masked with default mask value = 0
  sequence_mask(x, valid_length=[1, 1]) =
       [[[  1.,   2.,   3.],
         [  4.,   5.,   6.]],

        [[  0.,   0.,   0.],
         [  0.,   0.,   0.]],

        [[  0.,   0.,   0.],
         [  0.,   0.,   0.]]]

  // valid_length [2, 3] means the first 2 blocks of the 1st batch will be kept
  // and the first 3 blocks of the 2nd batch will be kept
  // the masked values are set to be the specified mask value = 0.1
  sequence_mask(x, valid_length=[2, 3], mask_value=0.1) =
       [[[  1.,   2.,   3.],
         [  4.,   5.,   6.]],

        [[  7.,   8.,   9.],
         [  10.,  11.,  12.]],

        [[  0.1,  0.1,  0.1],
         [  16.,  17.,  18.]]]
)code" TVM_ADD_FILELINE)
2570
.set_attrs_type<SequenceMaskAttrs>()
2571 2572 2573 2574 2575 2576 2577 2578
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("valid_length", "Tensor", "The real (valid) length of each sequence.")
.set_support_level(10)
.add_type_rel("SequenceMask", SequenceMaskRel)
.set_attr<FTVMCompute>("FTVMCompute", SequenceMaskCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

2579 2580 2581 2582 2583 2584 2585 2586 2587 2588 2589 2590 2591 2592 2593 2594 2595 2596 2597 2598 2599 2600 2601 2602 2603 2604 2605
// relay.one_hot
TVM_REGISTER_NODE_TYPE(OneHotAttrs);

bool OneHotRel(const Array<Type>& types,
               int num_inputs,
               const Attrs& attrs,
               const TypeReporter& reporter) {
  // `types` contains: [indices, on_value, off_value, result]
  CHECK_EQ(types.size(), 4);
  const auto* indices = types[0].as<TensorTypeNode>();
  CHECK(indices);

  const auto param = attrs.as<OneHotAttrs>();
  CHECK_GT(param->depth, 0);

  Array<IndexExpr> oshape;
  int ndim = indices->shape.size() + 1;
  int indices_index = 0;
  int true_axis = (param->axis == -1) ? indices->shape.size() : param->axis;
  for (int i = 0; i < ndim; i++) {
    if (i == true_axis) {
      oshape.push_back(Integer(param->depth));
    } else {
      oshape.push_back(indices->shape[indices_index++]);
    }
  }

2606
  reporter->Assign(types[3], TensorType(oshape, param->dtype));
2607 2608 2609
  return true;
}

2610
Array<te::Tensor> OneHotCompute(const Attrs& attrs,
2611 2612
                                const Array<te::Tensor>& inputs,
                                const Type& out_type) {
2613 2614
  const auto* param = attrs.as<OneHotAttrs>();
  CHECK(param != nullptr);
2615
  return Array<te::Tensor> {
2616 2617 2618 2619 2620 2621 2622 2623 2624 2625 2626 2627 2628 2629 2630
    topi::one_hot(inputs[0],
                  inputs[1](),
                  inputs[2](),
                  param->depth,
                  param->axis,
                  param->dtype)
  };
}

Expr MakeOneHot(Expr indices,
                Expr on_value,
                Expr off_value,
                int depth,
                int axis,
                DataType dtype) {
2631
  auto attrs = make_object<OneHotAttrs>();
2632 2633 2634 2635
  attrs->depth = std::move(depth);
  attrs->axis = axis;
  attrs->dtype = dtype;
  static const Op& op = Op::Get("one_hot");
2636
  return Call(op, {indices, on_value, off_value}, Attrs(attrs), {});
2637 2638
}

2639
TVM_REGISTER_GLOBAL("relay.op._make.one_hot")
2640 2641 2642
.set_body_typed(MakeOneHot);

RELAY_REGISTER_OP("one_hot")
2643
.describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1,
2644 2645 2646 2647 2648 2649 2650 2651 2652
    other locations take value 0. Final dimension is <indices dimensions> x depth.

    **indices** Locations to set to 1.

    **on_value** Value to fill at indices.

    **off_value** Value to fill at all other positions besides indices.

    **depth** Depth of the one-hot dimension.
2653

2654
    **axis** Axis to fill.
2655

2656
    **dtype**)code" TVM_ADD_FILELINE)
2657
.set_attrs_type<OneHotAttrs>()
2658 2659 2660 2661 2662 2663 2664 2665 2666
.set_num_inputs(3)
.add_argument("indices", "Tensor", "Locations to set to on_value.")
.add_argument("on_value", "Expr", "Value to fill at indices.")
.add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.")
.set_support_level(10)
.add_type_rel("OneHot", OneHotRel)
.set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
.set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);

2667 2668 2669 2670 2671 2672 2673 2674 2675 2676 2677 2678 2679 2680 2681 2682 2683 2684 2685 2686 2687 2688 2689 2690 2691 2692 2693 2694 2695 2696 2697 2698 2699 2700 2701 2702 2703 2704 2705 2706 2707 2708 2709 2710 2711 2712 2713 2714 2715 2716
/* relay.unravel_index */
bool UnRavelIndexRel(const Array<Type>& types,
                     int num_inputs,
                     const Attrs& attrs,
                     const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 3);

  const auto* indices = types[0].as<TensorTypeNode>();
  if (indices == nullptr) {
    CHECK(types[0].as<IncompleteTypeNode>())
        << "unravel_index: expect input type to be TensorType but get "
        << types[0];
    return false;
  }
  CHECK(indices->dtype.is_int())
      << "indices of unravel_index must be tensor of integer";

  const auto* shape = types[1].as<TensorTypeNode>();
  if (shape == nullptr) {
    CHECK(types[1].as<IncompleteTypeNode>())
        << "unravel_index: expect input type to be TensorType but get "
        << types[1];
    return false;
  }
  CHECK(indices->dtype.is_int())
      << "shape of unravel_index must be tensor of integer";

  Array<IndexExpr> indices_shape;
  Array<IndexExpr> shape_shape;
  indices_shape = indices->shape;
  shape_shape = shape->shape;

  Array<IndexExpr> oshape;
  oshape.push_back(shape_shape[0]);
  if (indices_shape.size() != 0) {
    oshape.push_back(indices_shape[0]);
  }
  reporter->Assign(types[2], TensorType(oshape, indices->dtype));
  return true;
}

Array<te::Tensor> UnRavelIndexCompute(const Attrs& attrs,
                                      const Array<te::Tensor>& inputs,
                                      const Type& out_type) {
  return Array<te::Tensor>{topi::unravel_index(inputs[0], inputs[1])};
}

Expr MakeUnRavelIndex(Expr data,
                      Expr shape) {
  static const Op& op = Op::Get("unravel_index");
2717
  return Call(op, {data, shape}, Attrs(), {});
2718 2719 2720 2721 2722 2723 2724 2725 2726 2727 2728 2729 2730 2731 2732 2733 2734
}

TVM_REGISTER_GLOBAL("relay.op._make.unravel_index")
.set_body_typed(MakeUnRavelIndex);

RELAY_REGISTER_OP("unravel_index")
.describe(R"code(Converts a flat index or array of flat indices into a tuple of coordinate arrays.

Example::
  -  unravel_index([22, 41, 37], (7, 6)) = [[3, 6, 6], [4, 5, 1]]
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.set_support_level(3)
.add_type_rel("UnRavelIndexRel", UnRavelIndexRel)
.set_attr<FTVMCompute>("FTVMCompute", UnRavelIndexCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

2735 2736
}  // namespace relay
}  // namespace tvm