pattern_util.h 19.2 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21
/*!
 *
22
 * \file tvm/relay/_pattern_util.h
23 24 25
 * \brief Header of internal operator functions
 *  These can be used for writing passes.
 */
26 27
#ifndef TVM_RELAY_TRANSFORMS_PATTERN_UTIL_H_
#define TVM_RELAY_TRANSFORMS_PATTERN_UTIL_H_
28

29
#include <builtin_fp16.h>
30
#include <tvm/node/structural_equal.h>
31
#include <tvm/tir/data_layout.h>
32 33
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
34
#include <tvm/relay/analysis.h>
35
#include <tvm/relay/attrs/nn.h>
36
#include <tvm/relay/attrs/transform.h>
37
#include <tvm/relay/attrs/reduce.h>
38 39
#include <tvm/relay/op_attr_types.h>

40
#include <string>
41
#include <vector>
42
#include <utility>
43

44 45 46 47 48

namespace tvm {
namespace relay {

/*!
49 50 51 52
 * \brief Dispatch DataType to the C++ data type
 *  during runtime.
 */
#define TVM_DTYPE_DISPATCH(type, DType, ...)            \
53
  if (type == DataType::Float(64)) {                              \
54 55
    typedef double DType;                               \
    {__VA_ARGS__}                                       \
56
  } else if (type == DataType::Float(32)) {                       \
57 58
    typedef float DType;                                \
    {__VA_ARGS__}                                       \
59
  } else if (type == DataType::Float(16)) {                       \
60 61
    typedef uint16_t DType;                             \
    {__VA_ARGS__}                                       \
62
  } else if (type == DataType::Int(64)) {                         \
63 64
    typedef int64_t DType;                              \
    {__VA_ARGS__}                                       \
65
  } else if (type == DataType::Int(32)) {                         \
66 67
    typedef int32_t DType;                              \
    {__VA_ARGS__}                                       \
68
  } else if (type == DataType::Int(16)) {                         \
69 70
    typedef int16_t DType;                              \
    {__VA_ARGS__}                                       \
71
  } else if (type == DataType::Int(8)) {                          \
72 73
    typedef int8_t DType;                               \
    {__VA_ARGS__}                                       \
74
  } else if (type == DataType::UInt(64)) {                        \
75 76
    typedef uint64_t DType;                             \
    {__VA_ARGS__}                                       \
77
  } else if (type == DataType::UInt(32)) {                        \
78 79
    typedef uint32_t DType;                             \
    {__VA_ARGS__}                                       \
80
  } else if (type == DataType::UInt(16)) {                        \
81 82
    typedef uint16_t DType;                             \
    {__VA_ARGS__}                                       \
83
  } else if (type == DataType::UInt(8)) {                         \
84 85 86 87 88 89 90
    typedef uint8_t DType;                              \
    {__VA_ARGS__}                                       \
  } else {                                              \
    LOG(FATAL) << "unknown data type " << type;         \
  }

/*!
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
 * \brief Try to match lhs and rhs via broadcasting rule, such that:
 *
 * rhs matches the dimension of lhs specified by lhs_axes
 * rhs's value equals 1 on rest of dimensions.
 *
 * \param tlhs The type of left operand (data)
 * \param trhs The type right operand (bias)
 * \param lhs_axes The axes on lhs to match.
 * \param rhs_value A squeezed version of rhs which only contains matched dimension.
 * \return Whether match is successful.
 */
inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs,
                                     const TensorTypeNode* trhs,
                                     const Array<Integer>& lhs_axes,
                                     Expr* rhs_value = nullptr) {
  if (tlhs->shape.size() < trhs->shape.size()) return false;
107
  StructuralEqual equal;
108 109 110
  size_t base = tlhs->shape.size() - trhs->shape.size();
  size_t j = 0;

111
  ObjectPtr<SqueezeAttrs> squeeze_attrs;
112
  if (rhs_value != nullptr) {
113
    squeeze_attrs = make_object<SqueezeAttrs>();
114 115 116 117 118 119 120 121 122
  }

  for (size_t i = 0; i < tlhs->shape.size(); ++i) {
    if (j < lhs_axes.size() && i == static_cast<size_t>(lhs_axes[j]->value)) {
      if (i < base || !equal(tlhs->shape[i], trhs->shape[i - base])) {
        return false;
      }
      ++j;
    } else if (i >= base) {
123
      if (!tir::is_const_int(trhs->shape[i - base], 1)) {
124 125 126 127 128 129 130 131 132
        return false;
      }
      if (rhs_value != nullptr) {
        squeeze_attrs->axis.push_back(static_cast<int>(i - base));
      }
    }
  }
  if (rhs_value != nullptr && squeeze_attrs->axis.size() != 0) {
    static const Op& squeeze_op = Op::Get("squeeze");
133
    *rhs_value = Call(squeeze_op, {rhs_value[0]}, Attrs(squeeze_attrs), {});
134 135 136 137 138 139 140 141 142 143 144
  }
  return true;
}

/*!
 * \brief Expand 1D Tensor to match axis.
 *
 * The result bias can be used to add or multiply to
 * the target Tensor on the specified axis via broadcasting rule.
 *
 * \param bias The bias.
145
 * \param target_ndim Target dimension.
146 147 148 149 150 151 152 153 154 155
 * \param axes The axis on the output we want to match on.
 */
inline Expr ExpandBiasToMatchAxis(Expr bias,
                                  int target_ndim,
                                  const Array<Integer>& axes) {
  static const Op& expand_dims = Op::Get("expand_dims");
  for (size_t i = axes.size(); i != 0; --i) {
    if (i == axes.size()) {
      int64_t num_pad_axis = target_ndim - axes[i - 1]->value - 1;
      if (num_pad_axis > 0) {
156
        auto attrs = make_object<ExpandDimsAttrs>();
157 158
        attrs->axis = i;
        attrs->num_newaxis = static_cast<int>(num_pad_axis);
159
        bias = Call(expand_dims, {bias}, Attrs(attrs), {});
160 161 162 163 164
      }
    } else {
      int64_t diff = axes[i]->value - axes[i - 1]->value;
      CHECK_GE(diff, 0L);
      if (diff > 0) {
165
        auto attrs = make_object<ExpandDimsAttrs>();
166 167
        attrs->axis = i;
        attrs->num_newaxis = static_cast<int>(diff);
168
        bias = Call(expand_dims, {bias}, Attrs(attrs), {});
169 170 171 172 173 174
      }
    }
  }
  return bias;
}

175 176 177 178 179 180 181 182 183
/*!
 * \brief Check if the call is depthwise conv2d.
 *
 * \param call The conv2d call.
 * \param param The conv2d attributes.
 * \return Whether it is depthwise_conv2d.
 */
inline bool IsDepthwiseConv2D(const Call& call,
                              const Conv2DAttrs* param,
184
                              const Layout& kernel_layout) {
185
  static const Layout kOIHW("OIHW");
186
  const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIHW);
187
  auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape);
188 189
  return tir::is_const_int(wshape[0], param->groups) &&
      tir::is_const_int(wshape[1], 1);
190 191
}

192 193 194 195 196 197 198 199
/*!
 * \brief Get super-dimension of output channels of conv2d
 * \param call The conv2d call.
 * \return Super-dimension size of output channels of conv2d.
 */
inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
    auto param = call->attrs.as<Conv2DAttrs>();
    auto tweight = call->args[1]->type_as<TensorTypeNode>();
200
    auto index = param->kernel_layout.find('O');
201
    CHECK_NE(index, std::string::npos);
202
    auto channels = tir::as_const_int(tweight->shape[index]);
203 204
    return *channels;
}
205

206
/*!
207 208 209 210 211 212 213
 * \brief Is single value tensor (scalar).
 * \param expr The expr.
 * \return True if single value tensor.
 */
inline bool IsScalar(const Expr& expr) {
  if (auto tensor_type = expr->checked_type().as<TensorTypeNode>()) {
    for (auto dim_index_expr : tensor_type->shape) {
214
      if (auto dim_index = dim_index_expr.as<IntImmNode>()) {
215 216 217 218 219 220 221 222 223 224 225 226 227 228
        if (dim_index->value != 1) {
          return false;
        }
      } else {
        return false;
      }
    }
  } else {
    return false;
  }
  return true;
}

/*!
229 230 231 232 233 234 235 236 237 238 239 240 241
 * \brief Check if expr is a const scalar.
 * \param expr The expr.
 * \return True if const scalar.
 */
inline bool IsConstScalar(const Expr& expr) {
  const auto* const_expr = expr.as<ConstantNode>();
  if (const_expr) {
    return const_expr->is_scalar();
  }
  return false;
}

/*!
242 243 244 245 246 247
 * \brief Create a Constant with a scalar
 *
 * \param dtype The data type.
 * \param value The value of the scalar.
 * \return A Constant.
 */
248
template <typename T>
249
inline Constant MakeConstantScalar(DataType dtype, T value) {
250
  runtime::NDArray arr = runtime::NDArray::Empty({}, dtype, {kDLCPU, 0});
251
  TVM_DTYPE_DISPATCH(dtype, DType, {
252
    if (dtype == DataType::Float(16)) {
253 254 255
      // convert to float16
      // storage is uint16_t
      *static_cast<DType*>(arr->data) =
256
          __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value));
257 258 259
    } else {
      *static_cast<DType*>(arr->data) = value;
    }
260
  })
261
  return Constant(arr);
262 263
}

264
/*!
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
 * \brief Create a Constant with a tensor.
 *
 * \param dtype The data type.
 * \param value The vector of the tensor values.
 * \return A Constant.
 */
template <typename T>
static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> shape,
                                          std::vector<T> value) {
  runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0});
  TVM_DTYPE_DISPATCH(dtype, DType, {
    for (size_t i = 0; i < value.size(); i++) {
      if (dtype == DataType::Float(16)) {
        // convert to float16
        // storage is uint16_t
        // Similar handling as that in MakeConstantScalar
        *(static_cast<DType*>(arr->data) + i) =
            __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
                static_cast<float>(value[i]));
      } else {
        *(static_cast<DType*>(arr->data) + i) = value[i];
      }
    }
  })
289
  return Constant(arr);
290 291 292
}

/*!
293 294 295 296 297 298 299 300 301 302 303
 * \brief Check if two expressions are equal scalars.
 * \param a The expression to be checked.
 * \param b The expression to be checked
 * \return Whether two expressions are equal scalars.
 */
inline bool IsEqualScalar(const Expr& a, const Expr& b) {
  const auto* constant_a = a.as<ConstantNode>();
  const auto* constant_b = b.as<ConstantNode>();
  if (!constant_a || !constant_b || !constant_a->is_scalar() || !constant_b->is_scalar()) {
    return false;
  }
304
  return tvm::StructuralEqual()(a, b);
305 306
}

307
inline Expr GetField(Expr t, size_t i) {
308
  return TupleGetItem(t, i);
309 310 311
}

inline Expr Pair(Expr l, Expr r) {
312
  return Tuple({l, r});
313 314 315 316
}

inline Expr Exp(Expr e) {
  static const Op& op = Op::Get("exp");
317
  return Call(op, {e});
318 319
}

320 321
inline Expr FastExp(Expr e) {
  static const Op& op = Op::Get("fast_exp");
322
  return Call(op, {e});
323 324
}

325 326 327 328 329
inline Expr FastErf(Expr e) {
  static const Op& op = Op::Get("fast_erf");
  return Call(op, {e});
}

330 331
inline Expr FastTanh(Expr e) {
  static const Op& op = Op::Get("fast_tanh");
332
  return Call(op, {e});
333 334
}

335 336
inline Expr Log(Expr e) {
  static const Op& op = Op::Get("log");
337
  return Call(op, {e});
338
}
339 340 341 342 343 344 345 346 347
/*!
 * \brief Get an immediate scalar from a Constant expr.
 *
 * \param expr The Constant expr.
 * \return A scalar with type T.
 */
template <typename T>
T GetScalarFromConstant(Expr expr) {
  const auto* n = expr.as<ConstantNode>();
348
  CHECK(n) << "Expr must be a constant expr - " << AsText(expr, false);
349 350 351 352 353 354
  CHECK(n->is_scalar());
  return static_cast<T*>(n->data->data)[0];
}

inline Expr Cast(Expr x, DataType dtype) {
  static const Op& op = Op::Get("cast");
355
  auto attrs = make_object<CastAttrs>();
356
  attrs->dtype = dtype;
357
  return Call(op, {x}, Attrs(attrs), {});
358
}
359 360 361

inline Expr Negative(Expr x) {
  static const Op& op = Op::Get("negative");
362
  return Call(op, {x}, Attrs(), {});
363 364 365 366 367
}


inline Expr Sqrt(Expr x) {
  static const Op& op = Op::Get("sqrt");
368
  return Call(op, {x}, Attrs(), {});
369 370 371
}


372 373
inline Expr Relu(Expr x) {
  static const Op& op = Op::Get("nn.relu");
374
  return Call(op, {x}, Attrs(), {});
375 376 377 378 379
}


inline Expr Round(Expr x) {
  static const Op& op = Op::Get("round");
380
  return Call(op, {x}, Attrs(), {});
381 382 383 384 385
}


inline Expr Clip(Expr x, double a_min, double a_max) {
  static const Op& op = Op::Get("clip");
386
  auto attrs = make_object<ClipAttrs>();
387 388
  attrs->a_min = a_min;
  attrs->a_max = a_max;
389
  return Call(op, {x}, Attrs(attrs), {});
390 391 392
}


393 394
inline Expr Add(Expr lhs, Expr rhs) {
  static const Op& op = Op::Get("add");
395
  return Call(op, {lhs, rhs}, Attrs(), {});
396 397 398
}


eqy committed
399
inline Expr Subtract(Expr lhs, Expr rhs) {
400
  static const Op& op = Op::Get("subtract");
401
  return Call(op, {lhs, rhs}, Attrs(), {});
402 403 404
}


405 406
inline Expr Multiply(Expr lhs, Expr rhs) {
  static const Op& op = Op::Get("multiply");
407
  return Call(op, {lhs, rhs}, Attrs(), {});
408 409
}

410

411 412
inline Expr Divide(Expr lhs, Expr rhs) {
  static const Op& op = Op::Get("divide");
413
  return Call(op, {lhs, rhs}, Attrs(), {});
414 415
}

416 417
inline Expr Maximum(Expr lhs, Expr rhs) {
  static const Op& op = Op::Get("maximum");
418
  return Call(op, {lhs, rhs}, Attrs(), {});
419 420
}

421
inline Expr ZerosLike(Expr e) {
422
  static const Op& op = Op::Get("zeros_like");
423
  return Call(op, {e});
424 425
}

426
inline Expr Zeros(Array<IndexExpr> shape, DataType dtype) {
427
  auto attrs = make_object<InitOpAttrs>();
428 429 430
  attrs->shape = std::move(shape);
  attrs->dtype = std::move(dtype);
  static const Op& op = Op::Get("zeros");
431
  return Call(op, {}, Attrs(attrs), {});
432 433
}

434
inline Expr OnesLike(Expr e) {
435
  static const Op& op = Op::Get("ones_like");
436
  return Call(op, {e});
437
}
438

439 440
inline Expr CollapseSumLike(Expr e) {
  static const Op& op = Op::Get("collapse_sum_like");
441
  return Call(op, {e});
442 443
}

444 445
inline Expr Power(Expr lhs, Expr rhs) {
  static const Op& op = Op::Get("power");
446
  return Call(op, {lhs, rhs}, Attrs(), {});
447 448 449 450 451
}


inline Expr RightShift(Expr x, Expr nbit) {
  static const Op& op = Op::Get("right_shift");
452
  return Call(op, {x, nbit}, Attrs(), {});
453 454 455 456 457
}


inline Expr LeftShift(Expr x, Expr nbit) {
  static const Op& op = Op::Get("left_shift");
458
  return Call(op, {x, nbit}, Attrs(), {});
459 460 461
}


462 463
inline Expr ReshapeLike(Expr lhs, Expr rhs) {
  static const Op& op = Op::Get("reshape_like");
464
  return Call(op, {lhs, rhs}, Attrs(), {});
465 466
}

467 468 469

inline Expr Copy(Expr data) {
  static const Op& op = Op::Get("copy");
470
  return Call(op, {data}, Attrs(), {});
471 472 473
}


474
inline Expr Mean(Expr data, Array<Integer> axis, bool keepdims, bool exclude) {
475
  auto attrs = make_object<ReduceAttrs>();
476 477 478 479
  attrs->axis = std::move(axis);
  attrs->keepdims = keepdims;
  attrs->exclude = exclude;
  static const Op& op = Op::Get("mean");
480
  return Call(op, {data}, Attrs(attrs), {});
481 482 483
}

inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude) {
484
  auto attrs = make_object<ReduceAttrs>();
485 486 487 488
  attrs->axis = std::move(axis);
  attrs->keepdims = keepdims;
  attrs->exclude = exclude;
  static const Op& op = Op::Get("variance");
489
  return Call(op, {data, mean}, Attrs(attrs), {});
490 491 492
}


493 494
static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) {
  static const Op& op = Op::Get("where");
495
  return Call(op, {condition, x, y});
496 497 498 499
}

static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) {
  static const Op& op = Op::Get("greater_equal");
500
  return Call(op, {lhs, rhs}, Attrs(), {});
501 502 503 504 505
}

static inline Expr Full(Expr fill_value,
                        Array<IndexExpr> shape,
                        DataType dtype) {
506
  auto attrs = make_object<InitOpAttrs>();
507 508 509
  attrs->shape = std::move(shape);
  attrs->dtype = std::move(dtype);
  static const Op& op = Op::Get("full");
510
  return Call(op, {fill_value}, Attrs(attrs), {});
511 512
}

513 514 515 516
static inline Expr Conv2D(Expr data, Expr weight, Array<IndexExpr> strides,
                          Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups,
                          IndexExpr channels, Array<IndexExpr> kernel_size, std::string data_layout,
                          std::string kernel_layout, std::string out_layout, DataType out_dtype) {
517
  auto attrs = make_object<Conv2DAttrs>();
518 519 520 521 522 523 524 525 526 527 528
  attrs->strides = std::move(strides);
  attrs->padding = std::move(padding);
  attrs->dilation = std::move(dilation);
  attrs->groups = groups;
  attrs->channels = std::move(channels);
  attrs->kernel_size = std::move(kernel_size);
  attrs->data_layout = std::move(data_layout);
  attrs->kernel_layout = std::move(kernel_layout);
  attrs->out_layout = std::move(out_layout);
  attrs->out_dtype = std::move(out_dtype);
  static const Op& op = Op::Get("nn.conv2d");
529
  return Call(op, {data, weight}, Attrs(attrs), {});
shoubhik committed
530 531 532 533 534 535
}

static inline Expr Dense(Expr data,
                         Expr weight,
                         IndexExpr units,
                         DataType out_dtype) {
536
  auto attrs = make_object<DenseAttrs>();
shoubhik committed
537 538 539
  attrs->units = units;
  attrs->out_dtype = out_dtype;
  static const Op& op = Op::Get("nn.dense");
540
  return Call(op, {data, weight}, Attrs(attrs), {});
541 542 543
}

static inline Expr Sum(Expr data, Array<Integer> axis, bool keepdims, bool exclude) {
544
  auto attrs = make_object<ReduceAttrs>();
545 546 547 548
  attrs->axis = std::move(axis);
  attrs->keepdims = keepdims;
  attrs->exclude = exclude;
  static const Op& op = Op::Get("sum");
549
  return Call(op, {data}, Attrs(attrs), {});
550 551 552
}

static inline Expr Reshape(Expr data, Array<Integer> newshape) {
553
  auto attrs = make_object<ReshapeAttrs>();
554 555 556
  attrs->newshape = std::move(newshape);
  attrs->reverse = false;
  static const Op& op = Op::Get("reshape");
557
  return Call(op, {data}, Attrs(attrs), {});
558 559 560 561 562
}

static inline Expr AvgPool2D(Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
                             Array<IndexExpr> padding, std::string layout, bool ceil_mode,
                             bool count_include_pad) {
563
  auto attrs = make_object<AvgPool2DAttrs>();
564 565 566 567 568 569 570
  attrs->pool_size = std::move(pool_size);
  attrs->strides = std::move(strides);
  attrs->padding = std::move(padding);
  attrs->layout = std::move(layout);
  attrs->ceil_mode = ceil_mode;
  attrs->count_include_pad = count_include_pad;
  static const Op& op = Op::Get("nn.avg_pool2d");
571
  return Call(op, {data}, Attrs(attrs), {});
572 573
}

574 575
static inline Expr Pad(Expr data, Array<Array<IndexExpr>> pad_width, double pad_value,
                       std::string pad_mode) {
576
  auto attrs = make_object<PadAttrs>();
577 578
  attrs->pad_value = pad_value;
  attrs->pad_width = std::move(pad_width);
579
  attrs->pad_mode = std::move(pad_mode);
580
  static const Op& op = Op::Get("nn.pad");
581
  return Call(op, {data}, Attrs(attrs), {});
582 583 584
}

static inline Expr Tile(Expr data, Array<Integer> reps) {
585
  auto attrs = make_object<TileAttrs>();
586 587
  attrs->reps = reps;
  static const Op& op = Op::Get("tile");
588
  return Call(op, {data}, Attrs(attrs), {});
589 590
}

591 592
Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape);

593 594
Expr MakeConcatenate(Expr data, int axis);

595 596
Expr MakeRepeat(Expr data, int repeats, int axis);

597 598
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);

599 600
Expr MakeStack(Expr data, int axis);

601
Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis);
602 603 604 605 606

Expr MakeSqueeze(Expr data, Array<Integer> axis);

Expr MakeExpandDims(Expr data, int axis, int num_newaxis);

607 608
Expr MakeLayoutTransform(Expr data, std::string src_layout, std::string dst_layout);

609 610
Expr StopFusion(Expr data);

611
Expr CastHint(Expr data, DataType dtype);
612

613 614
}  // namespace relay
}  // namespace tvm
615
#endif  // TVM_RELAY_TRANSFORMS_PATTERN_UTIL_H_