pattern_match.h 22 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
/*!
 *  Copyright (c) 2019 by Contributors
 * \file tvm/arithmetic/pattern_match.h
 *
 * \brief Internal tool for expression-template based pattern matching.
 *
 * It helps to simplify pattern matching and rewrites.
 * All the patterns are generated via expression template during compile time,
 * so the result code should be as efficient as manually written pattern match code.
 *
 * The code below shows how to use the pattern matcher.
 *
 * \code
 *
 *  // max(x + z, y + z) => max(x, y) + z
 *  arith::PVar<Expr> x, y, z;
 *
 *  // The following code tries to match the declared pattern.
 *  // Match will fill the result of match into PVar if successful.
 *  // Note that z occurs twice in the pattern,
 *  // an equality check is performed to ensure each occurance of z
 *  // is equivalent to each other.
 *  if (max(x + z, y + z).Match(expr)) {
 *    // Eval evaluates a pattern with the current matched value.
 *    // The filled value is valid until the next call to Match.
 *    return (max(x, y) + z).Eval();
 *  }
47 48 49 50 51 52 53 54 55 56 57
 *
 *  tvm::Var tx, ty;
 *  arith::PVar<Integer> c;
 *  arith::PVar<Var> v;
 *  // We can match integer and Var, both of which are
 *  // special case container of Expr
 *  CHECK((v * c).Match(tx * 3));
 *  CHECK_EQ(c.Eval()->value, 3);
 *  // cannot match c to ty
 *  CHECK(!(v * c).Match(tx * ty));
 *
58 59 60 61 62 63 64 65 66 67 68 69 70
 * \endcode
 *
 * \note The pattern matcher is not threadsafe,
 *       do not use the same PVar in multiple threads.
 *
 *       Please be aware that the filled value in a PVar
 *       can be overriden in the next call to Match.
 */
#ifndef TVM_ARITHMETIC_PATTERN_MATCH_H_
#define TVM_ARITHMETIC_PATTERN_MATCH_H_

#include <tvm/ir_pass.h>
#include <tuple>
71
#include "const_fold.h"
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142

namespace tvm {
namespace arith {
/*!
 * \brief Base class of all the patterns.
 *
 * There are two major member functions supported by each pattern.
 * - Match: checks if value matches the pattern.
 * - Eval: construct a new value based on matched values in PVar.
 *
 * We use curiously recurring template pattern to construct
 * expression templates.
 *
 * \tparam Derived The type of the derived class.
 */
template<typename Derived>
class Pattern {
 public:
  /*!
   * \brief Nested storage type in the expression.
   *
   *  Depending on the Derived class,
   *  Nested can be Derived (nest by value) or
   *  const Derived& (nest by reference).
   *
   *  The trick of Nested typedef originates from Eigen.
   *
   * \note We use nest by value for intermediate expressions,
   *       and nest by reference for PVars.
   */
  using Nested = Derived;
  /*!
   * \brief Check if value matches the current pattern.
   *
   * This call also populates the PVars with matched value.
   * The values in PVars are valid until the next call to Match.
   *
   * \return whether value matches the pattern.
   */
  template<typename NodeType>
  bool Match(const NodeType& value) const {
    derived().InitMatch_();
    return derived().Match_(value);
  }
  /*! \return Derived instance of current class. */
  const Derived& derived() const {
    return *static_cast<const Derived*>(this);
  }
};

/*!
 * \brief Default deep equality checker
 * \tparam T the comparison point.
 */
template<typename T>
class PEqualChecker {
 public:
  bool operator()(const T& lhs, const T& rhs) const {
    return lhs == rhs;
  }
};

template<>
class PEqualChecker<Expr> {
 public:
  bool operator()(const Expr& lhs, const Expr& rhs) const {
    if (lhs.same_as(rhs)) return true;
    return ir::Equal(lhs, rhs);
  }
};

143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
template<>
class PEqualChecker<Integer> {
 public:
  bool operator()(const Integer& lhs, const Integer& rhs) const {
    return lhs->value == rhs->value;
  }
};

template<>
class PEqualChecker<Var> {
 public:
  bool operator()(const Var& lhs, const Var& rhs) const {
    return lhs.same_as(rhs);
  }
};

159 160 161 162 163 164 165 166 167 168 169 170 171 172
/*!
 * \brief Pattern variable container.
 *
 * PVar is used as a "hole" in the pattern that can be matched.
 *
 * \tparam T the type of the hole.
 *
 * \note PVar is not thread safe.
 *       Do not use the same PVar in multiple threads.
 */
template<typename T>
class PVar : public Pattern<PVar<T> > {
 public:
  // Store PVars by reference in the expression.
173
  using Nested = const PVar<T>&;
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188

  void InitMatch_() const {
    filled_ = false;
  }

  bool Match_(const T& value) const {
    if (!filled_) {
      value_ = value;
      filled_ = true;
      return true;
    } else {
      return PEqualChecker<T>()(value_, value);
    }
  }

189 190 191 192 193 194 195 196 197 198 199
  template<typename NodeRefType,
           typename = typename std::enable_if<
             std::is_base_of<NodeRefType, T>::value>::type>
  bool Match_(const NodeRefType& value) const {
    if (const auto* ptr = value.template as<typename T::ContainerType>()) {
      return Match_(GetRef<T>(ptr));
    } else {
      return false;
    }
  }

200 201 202 203 204
  T Eval() const {
    CHECK(filled_);
    return value_;
  }

205
 protected:
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
  /*! \brief The matched value */
  mutable T value_;
  /*! \brief whether the variable has been filled */
  mutable bool filled_{false};
};

/*!
 * \brief Constant Pattern variable container.
 *
 * \tparam T the type of the hole.
 */
template<typename T>
class PConst : public Pattern<PConst<T> > {
 public:
  PConst(T value)  // NOLINT(*)
      : value_(value) {}

  void InitMatch_() const {}

  bool Match_(const T& value) const {
    return PEqualChecker<T>()(value_, value);
  }

  T Eval() const {
    return value_;
  }
232

233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
 private:
  const T value_;
};

/*!
 * \brief Pattern binary expression.
 * \tparam NodeType The AST node type.
 * \tparam TA The pattern type of the first operand.
 * \tparam TB The pattern type of the second operand.
 */
template<typename NodeType, typename TA, typename TB>
class PBinaryExpr :
      public Pattern<PBinaryExpr<NodeType, TA, TB> > {
 public:
  PBinaryExpr(const TA& a, const TB& b) : a_(a), b_(b) {}

  void InitMatch_() const {
    a_.InitMatch_();
    b_.InitMatch_();
  }

  bool Match_(const NodeRef& node) const {
    if (const NodeType* ptr = node.as<NodeType>()) {
      if (!a_.Match_(ptr->a)) return false;
      if (!b_.Match_(ptr->b)) return false;
      return true;
    } else {
      return false;
    }
  }

  Expr Eval() const {
265 266 267 268 269
    Expr lhs = a_.Eval();
    Expr rhs = b_.Eval();
    Expr ret = TryConstFold<NodeType>(lhs, rhs);
    if (ret.defined()) return ret;
    return NodeType::make(lhs, rhs);
270 271 272 273 274 275 276
  }

 private:
  typename TA::Nested a_;
  typename TB::Nested b_;
};

277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
template<typename TA>
class PConstWithTypeLike :
      public Pattern<PConstWithTypeLike<TA> > {
 public:
  PConstWithTypeLike(const TA& ref, int64_t value)
      : ref_(ref), value_(value) {}

  void InitMatch_() const {}

  bool Match_(const NodeRef& node) const {
    if (const ir::IntImm* ptr = node.as<ir::IntImm>()) {
      return ptr->value == value_;
    } else {
      return false;
    }
  }

  Expr Eval() const {
    return make_const(ref_.Eval().type(), value_);
  }

 private:
  typename TA::Nested ref_;
  int64_t value_;
};

303

304 305 306 307
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName)                   \
  template<typename TA, typename TB>                                \
  inline PBinaryExpr<NodeName, TA, TB>                              \
  FuncName(const Pattern<TA>& a, const Pattern<TB>& b) {            \
308
    return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
309 310 311 312 313 314 315 316 317 318
  }                                                                 \
  template<typename TA>                                             \
  inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> >         \
  FuncName(const Pattern<TA>& a, int64_t b) {                       \
    return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b));     \
  }                                                                 \
  template<typename TA>                                             \
  inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA>          \
  FuncName(int64_t b, const Pattern<TA>& a) {                       \
    return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a);     \
319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788
  }

// arithmetic expressions
TVM_PATTERN_BINARY_OP(operator+, ir::Add);
TVM_PATTERN_BINARY_OP(operator-, ir::Sub);
TVM_PATTERN_BINARY_OP(operator*, ir::Mul);
TVM_PATTERN_BINARY_OP(operator/, ir::Div);
TVM_PATTERN_BINARY_OP(operator%, ir::Mod);
TVM_PATTERN_BINARY_OP(min, ir::Min);
TVM_PATTERN_BINARY_OP(max, ir::Max);

// logical expressions
TVM_PATTERN_BINARY_OP(operator>, ir::GT);
TVM_PATTERN_BINARY_OP(operator>=, ir::GE);
TVM_PATTERN_BINARY_OP(operator<, ir::LT);
TVM_PATTERN_BINARY_OP(operator<=, ir::LE);
TVM_PATTERN_BINARY_OP(operator==, ir::EQ);
TVM_PATTERN_BINARY_OP(operator!=, ir::NE);
TVM_PATTERN_BINARY_OP(operator&&, ir::And);
TVM_PATTERN_BINARY_OP(operator||, ir::Or);

/*!
 * \brief Pattern not expression.
 * \tparam TA The pattern type of the true operand.
 */
template<typename TA>
class PNotExpr : public Pattern<PNotExpr<TA> > {
 public:
  explicit PNotExpr(const TA& value)
      : value_(value) {}

  void InitMatch_() const {
    value_.InitMatch_();
  }

  bool Match_(const NodeRef& node) const {
    if (const ir::Not* ptr = node.as<ir::Not>()) {
      if (!value_.Match_(ptr->a)) return false;
      return true;
    } else {
      return false;
    }
  }

  Expr Eval() const {
    return ir::Not::make(value_.Eval());
  }

 private:
  typename TA::Nested value_;
};

template<typename TA>
inline PNotExpr<TA> operator!(const Pattern<TA>& value) {
  return PNotExpr<TA>(value.derived());
}

// select
/*!
 * \brief Pattern select expression.
 * \tparam TCond The pattern type of the condition.
 * \tparam TA The pattern type of the true operand.
 * \tparam TB The pattern type of the false operand.
 */
template<typename TCond, typename TA, typename TB>
class PSelectExpr :
      public Pattern<PSelectExpr<TCond, TA, TB> > {
 public:
  PSelectExpr(const TCond& condition,
              const TA& true_value,
              const TB& false_value)
      : condition_(condition),
        true_value_(true_value),
        false_value_(false_value) {}

  void InitMatch_() const {
    condition_.InitMatch_();
    true_value_.InitMatch_();
    false_value_.InitMatch_();
  }

  bool Match_(const NodeRef& node) const {
    if (const ir::Select* ptr = node.as<ir::Select>()) {
      if (!condition_.Match_(ptr->condition)) return false;
      if (!true_value_.Match_(ptr->true_value)) return false;
      if (!false_value_.Match_(ptr->false_value)) return false;
      return true;
    } else {
      return false;
    }
  }

  Expr Eval() const {
    return ir::Select::make(
        condition_.Eval(), true_value_.Eval(), false_value_.Eval());
  }

 private:
  typename TCond::Nested condition_;
  typename TA::Nested true_value_;
  typename TB::Nested false_value_;
};

/*!
 * \brief Construct a select pattern.
 *
 * \param condition The condition expression.
 * \param true_value The value when condition is true.
 * \param true_value The value when condition is false.
 *
 * \return The result pattern.
 *
 * \tparam TCond The pattern type of the condition.
 * \tparam TA The pattern type of the true operand.
 * \tparam TB The pattern type of the false operand.
 */
template<typename TCond, typename TA, typename TB>
inline PSelectExpr<TCond, TA, TB>
select(const Pattern<TCond>& condition,
       const Pattern<TA>& true_value,
       const Pattern<TB>& false_value) {
  return PSelectExpr<TCond, TA, TB>(
      condition.derived(), true_value.derived(), false_value.derived());
}

/*!
 * \brief Pattern cast expression.
 * \tparam DType The Pattern type of dtype.
 * \tparam TA The pattern type of the first operand.
 */
template<typename DType, typename TA>
class PCastExpr :
      public Pattern<PCastExpr<DType, TA> > {
 public:
  PCastExpr(const DType& dtype, const TA& value)
      : dtype_(dtype), value_(value) {
  }

  void InitMatch_() const {
    dtype_.InitMatch_();
    value_.InitMatch_();
  }

  bool Match_(const NodeRef& node) const {
    if (const ir::Cast* ptr = node.as<ir::Cast>()) {
      if (!dtype_.Match_(ptr->type)) return false;
      if (!value_.Match_(ptr->value)) return false;
      return true;
    } else {
      return false;
    }
  }

  Expr Eval() const {
    return ir::Cast::make(dtype_.Eval(), value_.Eval());
  }

 private:
  typename DType::Nested dtype_;
  typename TA::Nested value_;
};

/*!
 * \brief Construct a cast pattern.
 *
 * \param dtype The target data type, can be PVar<Type> or PConst<Type>.
 * \param value The input type.
 *
 * \return The result pattern.
 *
 * \tparam DType The pattern type of type.
 * \tparam TA The pattern type of value.
 */
template<typename DType, typename TA>
inline PCastExpr<DType, TA>
cast(const Pattern<DType>& dtype, const Pattern<TA>& value) {
  return PCastExpr<DType, TA>(dtype.derived(), value.derived());
}

/*!
 * \brief Pattern ramp expression.
 * \tparam TBase The pattern type of the base.
 * \tparam TStride The pattern type of the stride.
 * \tparam TLanes The pattern type of the lanes.
 */
template<typename TBase, typename TStride, typename TLanes>
class PRampExpr :
      public Pattern<PRampExpr<TBase, TStride, TLanes> > {
 public:
  PRampExpr(const TBase& base,
            const TStride& stride,
            const TLanes& lanes)
      : base_(base), stride_(stride), lanes_(lanes) {
  }

  void InitMatch_() const {
    base_.InitMatch_();
    stride_.InitMatch_();
    lanes_.InitMatch_();
  }

  bool Match_(const NodeRef& node) const {
    if (const ir::Ramp* ptr = node.as<ir::Ramp>()) {
      if (!base_.Match_(ptr->base)) return false;
      if (!stride_.Match_(ptr->stride)) return false;
      if (!lanes_.Match_(ptr->lanes)) return false;
      return true;
    } else {
      return false;
    }
  }

  Expr Eval() const {
    return ir::Ramp::make(base_.Eval(), stride_.Eval(), lanes_.Eval());
  }

 private:
  typename TBase::Nested base_;
  typename TStride::Nested stride_;
  typename TLanes::Nested lanes_;
};

/*!
 * \brief Construct a ramp pattern.
 *
 * \param base The base pattern.
 * \param stride The stride pattern.
 * \param lanes The lanes pattern.
 *
 * \return The result pattern.
 *
 * \tparam TBase The pattern type of the base.
 * \tparam TStride The pattern type of the stride.
 * \tparam TLanes The pattern type of the lanes.
 */
template<typename TBase, typename TStride, typename TLanes>
inline PRampExpr<TBase, TStride, TLanes>
ramp(const Pattern<TBase>& base,
     const Pattern<TStride>& stride,
     const Pattern<TLanes>& lanes) {
  return PRampExpr<TBase, TStride, TLanes>(
      base.derived(), stride.derived(), lanes.derived());
}

/*!
 * \brief Pattern broadcast expression.
 * \tparam TA The pattern type of the value.
 * \tparam TLanes The pattern type of the lanes.
 */
template<typename TA, typename TLanes>
class PBroadcastExpr :
      public Pattern<PBroadcastExpr<TA, TLanes> > {
 public:
  PBroadcastExpr(const TA& value,
                 const TLanes& lanes)
      : value_(value), lanes_(lanes) {
  }

  void InitMatch_() const {
    value_.InitMatch_();
    lanes_.InitMatch_();
  }

  bool Match_(const NodeRef& node) const {
    if (const ir::Broadcast* ptr = node.as<ir::Broadcast>()) {
      if (!value_.Match_(ptr->value)) return false;
      if (!lanes_.Match_(ptr->lanes)) return false;
      return true;
    } else {
      return false;
    }
  }

  Expr Eval() const {
    return ir::Broadcast::make(value_.Eval(), lanes_.Eval());
  }

 private:
  typename TA::Nested value_;
  typename TLanes::Nested lanes_;
};

/*!
 * \brief Construct a broadcast pattern.
 *
 * \param value The value pattern.
 * \param lanes The lanes pattern.
 *
 * \return The result pattern.
 *
 * \tparam TA The pattern type of the value.
 * \tparam TLanes The pattern type of the lanes.
 */
template<typename TA, typename TLanes>
inline PBroadcastExpr<TA, TLanes>
broadcast(const Pattern<TA>& value, const Pattern<TLanes>& lanes) {
  return PBroadcastExpr<TA, TLanes>(value.derived(), lanes.derived());
}

// internal namespace
namespace detail {
// implementation details for  CallExpr
template<bool stop, std::size_t I, typename F>
struct tuple_for_each_dispatcher {
  template<typename TTuple>
  static void run(F& f, const TTuple& tuple) { // NOLINT(*)
    f(I, std::get<I>(tuple));
    tuple_for_each_dispatcher<
      (I + 1) == std::tuple_size<TTuple>::value, (I + 1), F>
        ::run(f, tuple);
  }
};

template<std::size_t I, typename F>
struct tuple_for_each_dispatcher<true, I, F> {
  template<typename TTuple>
  static void run(F& f, const TTuple& tuple) {} // NOLINT(*)
};

template<typename F, typename TTuple>
inline void tuple_for_each(F& f, const TTuple& tuple) {  // NOLINT(*)
  tuple_for_each_dispatcher<std::tuple_size<TTuple>::value == 0, 0, F>
      ::run(f, tuple);
}

struct PCallExprInitMatchFunctor {
  template<typename T>
  void operator()(size_t i, const T& pattern) const {
    pattern.InitMatch_();
  }
};

struct PCallExprMatchFunctor {
  const ir::Call* call_;
  bool matched_{true};

  explicit PCallExprMatchFunctor(const ir::Call* call)
      : call_(call) {}

  template<typename T>
  void operator()(size_t i, const T& pattern) {
    matched_ = matched_ && pattern.Match_(call_->args[i]);
  }
};

struct PCallExprEvalArgsFunctor {
  Array<Expr> args_;

  template<typename T>
  void operator()(size_t i, const T& pattern) {
    args_.push_back(pattern.Eval());
  }
};
}  // namespace detail

/*!
 * \brief Pattern CallExpr expression.
 * \tparam Op The operator functor class.
 * \tparam TArgs The arguments.
 * \note Op functor contains the name of the function and
 *          the implementation of Eval.
 */
template<typename Op, typename ...TArgs>
class PCallExpr :
      public Pattern<PCallExpr<Op, TArgs...> > {
 public:
  explicit PCallExpr(const TArgs&... args)
      : args_(args...) {
  }

  void InitMatch_() const {
    detail::PCallExprInitMatchFunctor finit;
    detail::tuple_for_each(finit, args_);
  }

  bool Match_(const NodeRef& node) const {
    if (const ir::Call* ptr = node.as<ir::Call>()) {
      if (ptr->args.size() != sizeof...(TArgs)) return false;
      if (ptr->name != Op::kName) return false;
      detail::PCallExprMatchFunctor fmatch(ptr);
      detail::tuple_for_each(fmatch, args_);
      return fmatch.matched_;
    } else {
      return false;
    }
  }

  Expr Eval() const {
    detail::PCallExprEvalArgsFunctor feval_args;
    detail::tuple_for_each(feval_args, args_);
    return Op::Eval(feval_args.args_);
  }

 private:
  std::tuple<typename TArgs::Nested...> args_;
};

// arithemetic intrinsics
#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr)        \
  struct OpName {                                                     \
    static Expr Eval(Array<Expr> args) {                              \
      return ir::Call::make(args[0].type(), kName, args,              \
                            ir::Call::PureIntrinsic);                 \
    }                                                                 \
    static constexpr const char* kName = IntrinStr;                   \
  };                                                                  \
  template<typename TA, typename TB>                                  \
  inline PCallExpr<OpName, TA, TB>                                    \
  FuncName(const Pattern<TA>& a, const Pattern<TB>& b) {              \
    return PCallExpr<OpName, TA, TB>(a.derived(), b.derived());             \
  }

TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left");
TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, "shift_right");
TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, "bitwise_and");
TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, "bitwise_or");
TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor");

// unary intrinsics
#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr)         \
  struct OpName {                                                     \
    static Expr Eval(Array<Expr> args) {                              \
      return ir::Call::make(args[0].type(), kName, args,              \
                            ir::Call::PureIntrinsic);                 \
    }                                                                 \
    static constexpr const char* kName = IntrinStr;                   \
  };                                                                  \
  template<typename TA>                                               \
  inline PCallExpr<OpName, TA>                                        \
  FuncName(const Pattern<TA>& a) {                                    \
    return PCallExpr<OpName, TA>(a.derived());                           \
  }

TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not");

// if_then_else
struct PIfThenElseOp {
  static Expr Eval(Array<Expr> args) {
    return ir::Call::make(
        args[1].type(), kName, args,
        ir::Call::PureIntrinsic);
  }
  static constexpr const char* kName = "tvm_if_then_else";
};

/*!
 * \brief Construct a if_then_else pattern.
 *
 * \param cond The condition expression.
 * \param true_value The value when condition is true.
 * \param true_value The value when condition is false.
 *
 * \return The result pattern.
 *
 * \tparam TCond The pattern type of the condition.
 * \tparam TA The pattern type of the true operand.
 * \tparam TB The pattern type of the false operand.
 */
template<typename TCond, typename TA, typename TB>
inline PCallExpr<PIfThenElseOp, TCond, TA, TB>
if_then_else(const Pattern<TCond>& cond,
             const Pattern<TA>& true_value,
             const Pattern<TB>& false_value) {
  return PCallExpr<PIfThenElseOp, TCond, TA, TB>(
      cond.derived(), true_value.derived(), false_value.derived());
}

}  // namespace arith
}  // namespace tvm
#endif  // TVM_ARITHMETIC_PATTERN_MATCH_H_