pattern_match.h 22.9 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
/*!
 * \file tvm/arithmetic/pattern_match.h
 *
 * \brief Internal tool for expression-template based pattern matching.
 *
 * It helps to simplify pattern matching and rewrites.
 * All the patterns are generated via expression template during compile time,
 * so the result code should be as efficient as manually written pattern match code.
 *
 * The code below shows how to use the pattern matcher.
 *
 * \code
 *
 *  // max(x + z, y + z) => max(x, y) + z
 *  arith::PVar<Expr> x, y, z;
 *
 *  // The following code tries to match the declared pattern.
 *  // Match will fill the result of match into PVar if successful.
 *  // Note that z occurs twice in the pattern,
 *  // an equality check is performed to ensure each occurance of z
 *  // is equivalent to each other.
 *  if (max(x + z, y + z).Match(expr)) {
 *    // Eval evaluates a pattern with the current matched value.
 *    // The filled value is valid until the next call to Match.
 *    return (max(x, y) + z).Eval();
 *  }
46
 *
47
 *  tvm::tir::Var tx, ty;
48
 *  arith::PVar<IntImm> c;
49 50 51 52 53 54 55 56
 *  arith::PVar<Var> v;
 *  // We can match integer and Var, both of which are
 *  // special case container of Expr
 *  CHECK((v * c).Match(tx * 3));
 *  CHECK_EQ(c.Eval()->value, 3);
 *  // cannot match c to ty
 *  CHECK(!(v * c).Match(tx * ty));
 *
57 58 59 60 61 62 63 64
 * \endcode
 *
 * \note The pattern matcher is not threadsafe,
 *       do not use the same PVar in multiple threads.
 *
 *       Please be aware that the filled value in a PVar
 *       can be overriden in the next call to Match.
 */
65 66
#ifndef TVM_ARITH_PATTERN_MATCH_H_
#define TVM_ARITH_PATTERN_MATCH_H_
67

68
#include <tvm/tir/ir_pass.h>
69
#include <tuple>
70
#include "const_fold.h"
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133

namespace tvm {
namespace arith {
/*!
 * \brief Base class of all the patterns.
 *
 * There are two major member functions supported by each pattern.
 * - Match: checks if value matches the pattern.
 * - Eval: construct a new value based on matched values in PVar.
 *
 * We use curiously recurring template pattern to construct
 * expression templates.
 *
 * \tparam Derived The type of the derived class.
 */
template<typename Derived>
class Pattern {
 public:
  /*!
   * \brief Nested storage type in the expression.
   *
   *  Depending on the Derived class,
   *  Nested can be Derived (nest by value) or
   *  const Derived& (nest by reference).
   *
   *  The trick of Nested typedef originates from Eigen.
   *
   * \note We use nest by value for intermediate expressions,
   *       and nest by reference for PVars.
   */
  using Nested = Derived;
  /*!
   * \brief Check if value matches the current pattern.
   *
   * This call also populates the PVars with matched value.
   * The values in PVars are valid until the next call to Match.
   *
   * \return whether value matches the pattern.
   */
  template<typename NodeType>
  bool Match(const NodeType& value) const {
    derived().InitMatch_();
    return derived().Match_(value);
  }
  /*! \return Derived instance of current class. */
  const Derived& derived() const {
    return *static_cast<const Derived*>(this);
  }
};

/*!
 * \brief Default deep equality checker
 * \tparam T the comparison point.
 */
template<typename T>
class PEqualChecker {
 public:
  bool operator()(const T& lhs, const T& rhs) const {
    return lhs == rhs;
  }
};

template<>
134
class PEqualChecker<PrimExpr> {
135
 public:
136
  bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
137
    if (lhs.same_as(rhs)) return true;
138
    return tir::Equal(lhs, rhs);
139 140 141
  }
};

142
template<>
143
class PEqualChecker<IntImm> {
144
 public:
145
  bool operator()(const IntImm& lhs, const IntImm& rhs) const {
146 147 148 149 150 151 152 153 154 155 156 157
    return lhs->value == rhs->value;
  }
};

template<>
class PEqualChecker<Var> {
 public:
  bool operator()(const Var& lhs, const Var& rhs) const {
    return lhs.same_as(rhs);
  }
};

158 159 160 161 162 163 164 165 166 167 168 169 170 171
/*!
 * \brief Pattern variable container.
 *
 * PVar is used as a "hole" in the pattern that can be matched.
 *
 * \tparam T the type of the hole.
 *
 * \note PVar is not thread safe.
 *       Do not use the same PVar in multiple threads.
 */
template<typename T>
class PVar : public Pattern<PVar<T> > {
 public:
  // Store PVars by reference in the expression.
172
  using Nested = const PVar<T>&;
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187

  void InitMatch_() const {
    filled_ = false;
  }

  bool Match_(const T& value) const {
    if (!filled_) {
      value_ = value;
      filled_ = true;
      return true;
    } else {
      return PEqualChecker<T>()(value_, value);
    }
  }

188 189 190 191 192 193 194 195 196 197 198
  template<typename NodeRefType,
           typename = typename std::enable_if<
             std::is_base_of<NodeRefType, T>::value>::type>
  bool Match_(const NodeRefType& value) const {
    if (const auto* ptr = value.template as<typename T::ContainerType>()) {
      return Match_(GetRef<T>(ptr));
    } else {
      return false;
    }
  }

199 200 201 202 203
  T Eval() const {
    CHECK(filled_);
    return value_;
  }

204
 protected:
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
  /*! \brief The matched value */
  mutable T value_;
  /*! \brief whether the variable has been filled */
  mutable bool filled_{false};
};

/*!
 * \brief Constant Pattern variable container.
 *
 * \tparam T the type of the hole.
 */
template<typename T>
class PConst : public Pattern<PConst<T> > {
 public:
  PConst(T value)  // NOLINT(*)
      : value_(value) {}

  void InitMatch_() const {}

  bool Match_(const T& value) const {
    return PEqualChecker<T>()(value_, value);
  }

  T Eval() const {
    return value_;
  }
231

232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
 private:
  const T value_;
};

/*!
 * \brief Pattern binary expression.
 * \tparam NodeType The AST node type.
 * \tparam TA The pattern type of the first operand.
 * \tparam TB The pattern type of the second operand.
 */
template<typename NodeType, typename TA, typename TB>
class PBinaryExpr :
      public Pattern<PBinaryExpr<NodeType, TA, TB> > {
 public:
  PBinaryExpr(const TA& a, const TB& b) : a_(a), b_(b) {}

  void InitMatch_() const {
    a_.InitMatch_();
    b_.InitMatch_();
  }

253
  bool Match_(const ObjectRef& node) const {
254 255 256 257 258 259 260 261 262
    if (const NodeType* ptr = node.as<NodeType>()) {
      if (!a_.Match_(ptr->a)) return false;
      if (!b_.Match_(ptr->b)) return false;
      return true;
    } else {
      return false;
    }
  }

263 264 265 266
  PrimExpr Eval() const {
    PrimExpr lhs = a_.Eval();
    PrimExpr rhs = b_.Eval();
    PrimExpr ret = TryConstFold<NodeType>(lhs, rhs);
267 268
    if (ret.defined()) return ret;
    return NodeType::make(lhs, rhs);
269 270 271 272 273 274 275
  }

 private:
  typename TA::Nested a_;
  typename TB::Nested b_;
};

276 277 278 279 280 281 282 283 284
template<typename TA>
class PConstWithTypeLike :
      public Pattern<PConstWithTypeLike<TA> > {
 public:
  PConstWithTypeLike(const TA& ref, int64_t value)
      : ref_(ref), value_(value) {}

  void InitMatch_() const {}

285
  bool Match_(const ObjectRef& node) const {
286
    if (const tir::IntImmNode* ptr = node.as<tir::IntImmNode>()) {
287 288 289 290 291 292
      return ptr->value == value_;
    } else {
      return false;
    }
  }

293
  PrimExpr Eval() const {
294
    return tir::make_const(ref_.Eval().dtype(), value_);
295 296 297 298 299 300 301
  }

 private:
  typename TA::Nested ref_;
  int64_t value_;
};

302

303
#define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep)     \
304 305 306
  template<typename TA, typename TB>                                \
  inline PBinaryExpr<NodeName, TA, TB>                              \
  FuncName(const Pattern<TA>& a, const Pattern<TB>& b) {            \
307
    CheckStep;                                                      \
308
    return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
309 310 311 312
  }                                                                 \
  template<typename TA>                                             \
  inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> >         \
  FuncName(const Pattern<TA>& a, int64_t b) {                       \
313
    CheckStep;                                                      \
314 315 316 317 318
    return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b));     \
  }                                                                 \
  template<typename TA>                                             \
  inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA>          \
  FuncName(int64_t b, const Pattern<TA>& a) {                       \
319
    CheckStep;                                                      \
320
    return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a);     \
321 322
  }

323 324 325 326 327
#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
  TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, )


// raise ambiguity error for operator overload of / and %
328 329
TVM_PATTERN_BINARY_OP_EX(operator/, tir::DivNode, DivAmbiguityError(a));
TVM_PATTERN_BINARY_OP_EX(operator%, tir::ModNode, DivAmbiguityError(a));
330

331
// arithmetic expressions
332 333 334 335 336 337 338 339 340 341
TVM_PATTERN_BINARY_OP(operator+, tir::AddNode);
TVM_PATTERN_BINARY_OP(operator-, tir::SubNode);
TVM_PATTERN_BINARY_OP(operator*, tir::MulNode);
TVM_PATTERN_BINARY_OP(min, tir::MinNode);
TVM_PATTERN_BINARY_OP(max, tir::MaxNode);
TVM_PATTERN_BINARY_OP(div, tir::DivNode);
TVM_PATTERN_BINARY_OP(truncdiv, tir::DivNode);
TVM_PATTERN_BINARY_OP(truncmod, tir::ModNode);
TVM_PATTERN_BINARY_OP(floordiv, tir::FloorDivNode);
TVM_PATTERN_BINARY_OP(floormod, tir::FloorModNode);
342 343

// logical expressions
344 345 346 347 348 349 350 351
TVM_PATTERN_BINARY_OP(operator>, tir::GTNode);
TVM_PATTERN_BINARY_OP(operator>=, tir::GENode);
TVM_PATTERN_BINARY_OP(operator<, tir::LTNode);
TVM_PATTERN_BINARY_OP(operator<=, tir::LENode);
TVM_PATTERN_BINARY_OP(operator==, tir::EQNode);
TVM_PATTERN_BINARY_OP(operator!=, tir::NENode);
TVM_PATTERN_BINARY_OP(operator&&, tir::AndNode);
TVM_PATTERN_BINARY_OP(operator||, tir::OrNode);
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366

/*!
 * \brief Pattern not expression.
 * \tparam TA The pattern type of the true operand.
 */
template<typename TA>
class PNotExpr : public Pattern<PNotExpr<TA> > {
 public:
  explicit PNotExpr(const TA& value)
      : value_(value) {}

  void InitMatch_() const {
    value_.InitMatch_();
  }

367
  bool Match_(const ObjectRef& node) const {
368
    if (const tir::NotNode* ptr = node.as<tir::NotNode>()) {
369 370 371 372 373 374 375
      if (!value_.Match_(ptr->a)) return false;
      return true;
    } else {
      return false;
    }
  }

376
  PrimExpr Eval() const {
377
    return tir::NotNode::make(value_.Eval());
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
  }

 private:
  typename TA::Nested value_;
};

template<typename TA>
inline PNotExpr<TA> operator!(const Pattern<TA>& value) {
  return PNotExpr<TA>(value.derived());
}

// select
/*!
 * \brief Pattern select expression.
 * \tparam TCond The pattern type of the condition.
 * \tparam TA The pattern type of the true operand.
 * \tparam TB The pattern type of the false operand.
 */
template<typename TCond, typename TA, typename TB>
class PSelectExpr :
      public Pattern<PSelectExpr<TCond, TA, TB> > {
 public:
  PSelectExpr(const TCond& condition,
              const TA& true_value,
              const TB& false_value)
      : condition_(condition),
        true_value_(true_value),
        false_value_(false_value) {}

  void InitMatch_() const {
    condition_.InitMatch_();
    true_value_.InitMatch_();
    false_value_.InitMatch_();
  }

413
  bool Match_(const ObjectRef& node) const {
414
    if (const tir::SelectNode* ptr = node.as<tir::SelectNode>()) {
415 416 417 418 419 420 421 422 423
      if (!condition_.Match_(ptr->condition)) return false;
      if (!true_value_.Match_(ptr->true_value)) return false;
      if (!false_value_.Match_(ptr->false_value)) return false;
      return true;
    } else {
      return false;
    }
  }

424
  PrimExpr Eval() const {
425
    return tir::SelectNode::make(
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
        condition_.Eval(), true_value_.Eval(), false_value_.Eval());
  }

 private:
  typename TCond::Nested condition_;
  typename TA::Nested true_value_;
  typename TB::Nested false_value_;
};

/*!
 * \brief Construct a select pattern.
 *
 * \param condition The condition expression.
 * \param true_value The value when condition is true.
 * \param true_value The value when condition is false.
 *
 * \return The result pattern.
 *
 * \tparam TCond The pattern type of the condition.
 * \tparam TA The pattern type of the true operand.
 * \tparam TB The pattern type of the false operand.
 */
template<typename TCond, typename TA, typename TB>
inline PSelectExpr<TCond, TA, TB>
select(const Pattern<TCond>& condition,
       const Pattern<TA>& true_value,
       const Pattern<TB>& false_value) {
  return PSelectExpr<TCond, TA, TB>(
      condition.derived(), true_value.derived(), false_value.derived());
}

/*!
 * \brief Pattern cast expression.
 * \tparam DType The Pattern type of dtype.
 * \tparam TA The pattern type of the first operand.
 */
template<typename DType, typename TA>
class PCastExpr :
      public Pattern<PCastExpr<DType, TA> > {
 public:
  PCastExpr(const DType& dtype, const TA& value)
      : dtype_(dtype), value_(value) {
  }

  void InitMatch_() const {
    dtype_.InitMatch_();
    value_.InitMatch_();
  }

475
  bool Match_(const ObjectRef& node) const {
476
    if (const tir::CastNode* ptr = node.as<tir::CastNode>()) {
477
      if (!dtype_.Match_(ptr->dtype)) return false;
478 479 480 481 482 483 484
      if (!value_.Match_(ptr->value)) return false;
      return true;
    } else {
      return false;
    }
  }

485
  PrimExpr Eval() const {
486
    return tir::CastNode::make(dtype_.Eval(), value_.Eval());
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532
  }

 private:
  typename DType::Nested dtype_;
  typename TA::Nested value_;
};

/*!
 * \brief Construct a cast pattern.
 *
 * \param dtype The target data type, can be PVar<Type> or PConst<Type>.
 * \param value The input type.
 *
 * \return The result pattern.
 *
 * \tparam DType The pattern type of type.
 * \tparam TA The pattern type of value.
 */
template<typename DType, typename TA>
inline PCastExpr<DType, TA>
cast(const Pattern<DType>& dtype, const Pattern<TA>& value) {
  return PCastExpr<DType, TA>(dtype.derived(), value.derived());
}

/*!
 * \brief Pattern ramp expression.
 * \tparam TBase The pattern type of the base.
 * \tparam TStride The pattern type of the stride.
 * \tparam TLanes The pattern type of the lanes.
 */
template<typename TBase, typename TStride, typename TLanes>
class PRampExpr :
      public Pattern<PRampExpr<TBase, TStride, TLanes> > {
 public:
  PRampExpr(const TBase& base,
            const TStride& stride,
            const TLanes& lanes)
      : base_(base), stride_(stride), lanes_(lanes) {
  }

  void InitMatch_() const {
    base_.InitMatch_();
    stride_.InitMatch_();
    lanes_.InitMatch_();
  }

533
  bool Match_(const ObjectRef& node) const {
534
    if (const tir::RampNode* ptr = node.as<tir::RampNode>()) {
535 536 537 538 539 540 541 542 543
      if (!base_.Match_(ptr->base)) return false;
      if (!stride_.Match_(ptr->stride)) return false;
      if (!lanes_.Match_(ptr->lanes)) return false;
      return true;
    } else {
      return false;
    }
  }

544
  PrimExpr Eval() const {
545
    return tir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval());
546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594
  }

 private:
  typename TBase::Nested base_;
  typename TStride::Nested stride_;
  typename TLanes::Nested lanes_;
};

/*!
 * \brief Construct a ramp pattern.
 *
 * \param base The base pattern.
 * \param stride The stride pattern.
 * \param lanes The lanes pattern.
 *
 * \return The result pattern.
 *
 * \tparam TBase The pattern type of the base.
 * \tparam TStride The pattern type of the stride.
 * \tparam TLanes The pattern type of the lanes.
 */
template<typename TBase, typename TStride, typename TLanes>
inline PRampExpr<TBase, TStride, TLanes>
ramp(const Pattern<TBase>& base,
     const Pattern<TStride>& stride,
     const Pattern<TLanes>& lanes) {
  return PRampExpr<TBase, TStride, TLanes>(
      base.derived(), stride.derived(), lanes.derived());
}

/*!
 * \brief Pattern broadcast expression.
 * \tparam TA The pattern type of the value.
 * \tparam TLanes The pattern type of the lanes.
 */
template<typename TA, typename TLanes>
class PBroadcastExpr :
      public Pattern<PBroadcastExpr<TA, TLanes> > {
 public:
  PBroadcastExpr(const TA& value,
                 const TLanes& lanes)
      : value_(value), lanes_(lanes) {
  }

  void InitMatch_() const {
    value_.InitMatch_();
    lanes_.InitMatch_();
  }

595
  bool Match_(const ObjectRef& node) const {
596
    if (const tir::BroadcastNode* ptr = node.as<tir::BroadcastNode>()) {
597 598 599 600 601 602 603 604
      if (!value_.Match_(ptr->value)) return false;
      if (!lanes_.Match_(ptr->lanes)) return false;
      return true;
    } else {
      return false;
    }
  }

605
  PrimExpr Eval() const {
606
    return tir::BroadcastNode::make(value_.Eval(), lanes_.Eval());
607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664
  }

 private:
  typename TA::Nested value_;
  typename TLanes::Nested lanes_;
};

/*!
 * \brief Construct a broadcast pattern.
 *
 * \param value The value pattern.
 * \param lanes The lanes pattern.
 *
 * \return The result pattern.
 *
 * \tparam TA The pattern type of the value.
 * \tparam TLanes The pattern type of the lanes.
 */
template<typename TA, typename TLanes>
inline PBroadcastExpr<TA, TLanes>
broadcast(const Pattern<TA>& value, const Pattern<TLanes>& lanes) {
  return PBroadcastExpr<TA, TLanes>(value.derived(), lanes.derived());
}

// internal namespace
namespace detail {
// implementation details for  CallExpr
template<bool stop, std::size_t I, typename F>
struct tuple_for_each_dispatcher {
  template<typename TTuple>
  static void run(F& f, const TTuple& tuple) { // NOLINT(*)
    f(I, std::get<I>(tuple));
    tuple_for_each_dispatcher<
      (I + 1) == std::tuple_size<TTuple>::value, (I + 1), F>
        ::run(f, tuple);
  }
};

template<std::size_t I, typename F>
struct tuple_for_each_dispatcher<true, I, F> {
  template<typename TTuple>
  static void run(F& f, const TTuple& tuple) {} // NOLINT(*)
};

template<typename F, typename TTuple>
inline void tuple_for_each(F& f, const TTuple& tuple) {  // NOLINT(*)
  tuple_for_each_dispatcher<std::tuple_size<TTuple>::value == 0, 0, F>
      ::run(f, tuple);
}

struct PCallExprInitMatchFunctor {
  template<typename T>
  void operator()(size_t i, const T& pattern) const {
    pattern.InitMatch_();
  }
};

struct PCallExprMatchFunctor {
665
  const tir::CallNode* call_;
666 667
  bool matched_{true};

668
  explicit PCallExprMatchFunctor(const tir::CallNode* call)
669 670 671 672 673 674 675 676 677
      : call_(call) {}

  template<typename T>
  void operator()(size_t i, const T& pattern) {
    matched_ = matched_ && pattern.Match_(call_->args[i]);
  }
};

struct PCallExprEvalArgsFunctor {
678
  Array<PrimExpr> args_;
679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706

  template<typename T>
  void operator()(size_t i, const T& pattern) {
    args_.push_back(pattern.Eval());
  }
};
}  // namespace detail

/*!
 * \brief Pattern CallExpr expression.
 * \tparam Op The operator functor class.
 * \tparam TArgs The arguments.
 * \note Op functor contains the name of the function and
 *          the implementation of Eval.
 */
template<typename Op, typename ...TArgs>
class PCallExpr :
      public Pattern<PCallExpr<Op, TArgs...> > {
 public:
  explicit PCallExpr(const TArgs&... args)
      : args_(args...) {
  }

  void InitMatch_() const {
    detail::PCallExprInitMatchFunctor finit;
    detail::tuple_for_each(finit, args_);
  }

707
  bool Match_(const ObjectRef& node) const {
708
    if (const tir::CallNode* ptr = node.as<tir::CallNode>()) {
709 710 711 712 713 714 715 716 717 718
      if (ptr->args.size() != sizeof...(TArgs)) return false;
      if (ptr->name != Op::kName) return false;
      detail::PCallExprMatchFunctor fmatch(ptr);
      detail::tuple_for_each(fmatch, args_);
      return fmatch.matched_;
    } else {
      return false;
    }
  }

719
  PrimExpr Eval() const {
720 721 722 723 724 725 726 727 728 729
    detail::PCallExprEvalArgsFunctor feval_args;
    detail::tuple_for_each(feval_args, args_);
    return Op::Eval(feval_args.args_);
  }

 private:
  std::tuple<typename TArgs::Nested...> args_;
};

// arithemetic intrinsics
730 731
#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr)          \
  struct OpName {                                                       \
732
    static PrimExpr Eval(Array<PrimExpr> args) {                                \
733 734
      return tir::CallNode::make(args[0].dtype(), kName, args,           \
                                tir::CallNode::PureIntrinsic);           \
735 736 737 738 739 740 741
    }                                                                   \
    static constexpr const char* kName = IntrinStr;                     \
  };                                                                    \
  template<typename TA, typename TB>                                    \
  inline PCallExpr<OpName, TA, TB>                                      \
  FuncName(const Pattern<TA>& a, const Pattern<TB>& b) {                \
    return PCallExpr<OpName, TA, TB>(a.derived(), b.derived());         \
742 743 744 745 746 747 748 749 750
  }

TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left");
TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, "shift_right");
TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, "bitwise_and");
TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, "bitwise_or");
TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor");

// unary intrinsics
751 752
#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr)           \
  struct OpName {                                                       \
753
    static PrimExpr Eval(Array<PrimExpr> args) {                                \
754 755
      return tir::CallNode::make(args[0].dtype(), kName, args,           \
                                tir::CallNode::PureIntrinsic);           \
756 757 758 759 760 761 762
    }                                                                   \
    static constexpr const char* kName = IntrinStr;                     \
  };                                                                    \
  template<typename TA>                                                 \
  inline PCallExpr<OpName, TA>                                          \
  FuncName(const Pattern<TA>& a) {                                      \
    return PCallExpr<OpName, TA>(a.derived());                          \
763 764 765 766 767 768
  }

TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not");

// if_then_else
struct PIfThenElseOp {
769
  static PrimExpr Eval(Array<PrimExpr> args) {
770
    return tir::CallNode::make(
771
        args[1].dtype(), kName, args,
772
        tir::CallNode::PureIntrinsic);
773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800
  }
  static constexpr const char* kName = "tvm_if_then_else";
};

/*!
 * \brief Construct a if_then_else pattern.
 *
 * \param cond The condition expression.
 * \param true_value The value when condition is true.
 * \param true_value The value when condition is false.
 *
 * \return The result pattern.
 *
 * \tparam TCond The pattern type of the condition.
 * \tparam TA The pattern type of the true operand.
 * \tparam TB The pattern type of the false operand.
 */
template<typename TCond, typename TA, typename TB>
inline PCallExpr<PIfThenElseOp, TCond, TA, TB>
if_then_else(const Pattern<TCond>& cond,
             const Pattern<TA>& true_value,
             const Pattern<TB>& false_value) {
  return PCallExpr<PIfThenElseOp, TCond, TA, TB>(
      cond.derived(), true_value.derived(), false_value.derived());
}

}  // namespace arith
}  // namespace tvm
801
#endif  // TVM_ARITH_PATTERN_MATCH_H_