/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file tvm/arithmetic/pattern_match.h
 *
 * \brief Internal tool for expression-template based pattern matching.
 *
 * It helps to simplify pattern matching and rewrites.
 * All the patterns are generated via expression template during compile time,
 * so the result code should be as efficient as manually written pattern match code.
 *
 * The code below shows how to use the pattern matcher.
 *
 * \code
 *
 *  // max(x + z, y + z) => max(x, y) + z
 *  arith::PVar<Expr> x, y, z;
 *
 *  // The following code tries to match the declared pattern.
 *  // Match will fill the result of match into PVar if successful.
 *  // Note that z occurs twice in the pattern,
 *  // an equality check is performed to ensure each occurance of z
 *  // is equivalent to each other.
 *  if (max(x + z, y + z).Match(expr)) {
 *    // Eval evaluates a pattern with the current matched value.
 *    // The filled value is valid until the next call to Match.
 *    return (max(x, y) + z).Eval();
 *  }
 *
 *  tvm::Var tx, ty;
 *  arith::PVar<Integer> c;
 *  arith::PVar<Var> v;
 *  // We can match integer and Var, both of which are
 *  // special case container of Expr
 *  CHECK((v * c).Match(tx * 3));
 *  CHECK_EQ(c.Eval()->value, 3);
 *  // cannot match c to ty
 *  CHECK(!(v * c).Match(tx * ty));
 *
 * \endcode
 *
 * \note The pattern matcher is not threadsafe,
 *       do not use the same PVar in multiple threads.
 *
 *       Please be aware that the filled value in a PVar
 *       can be overriden in the next call to Match.
 */
#ifndef TVM_ARITHMETIC_PATTERN_MATCH_H_
#define TVM_ARITHMETIC_PATTERN_MATCH_H_

#include <tvm/ir_pass.h>
#include <tuple>
#include "const_fold.h"

namespace tvm {
namespace arith {
/*!
 * \brief Base class of all the patterns.
 *
 * There are two major member functions supported by each pattern.
 * - Match: checks if value matches the pattern.
 * - Eval: construct a new value based on matched values in PVar.
 *
 * We use curiously recurring template pattern to construct
 * expression templates.
 *
 * \tparam Derived The type of the derived class.
 */
template<typename Derived>
class Pattern {
 public:
  /*!
   * \brief Nested storage type in the expression.
   *
   *  Depending on the Derived class,
   *  Nested can be Derived (nest by value) or
   *  const Derived& (nest by reference).
   *
   *  The trick of Nested typedef originates from Eigen.
   *
   * \note We use nest by value for intermediate expressions,
   *       and nest by reference for PVars.
   */
  using Nested = Derived;
  /*!
   * \brief Check if value matches the current pattern.
   *
   * This call also populates the PVars with matched value.
   * The values in PVars are valid until the next call to Match.
   *
   * \return whether value matches the pattern.
   */
  template<typename NodeType>
  bool Match(const NodeType& value) const {
    derived().InitMatch_();
    return derived().Match_(value);
  }
  /*! \return Derived instance of current class. */
  const Derived& derived() const {
    return *static_cast<const Derived*>(this);
  }
};

/*!
 * \brief Default deep equality checker
 * \tparam T the comparison point.
 */
template<typename T>
class PEqualChecker {
 public:
  bool operator()(const T& lhs, const T& rhs) const {
    return lhs == rhs;
  }
};

template<>
class PEqualChecker<Expr> {
 public:
  bool operator()(const Expr& lhs, const Expr& rhs) const {
    if (lhs.same_as(rhs)) return true;
    return ir::Equal(lhs, rhs);
  }
};

template<>
class PEqualChecker<Integer> {
 public:
  bool operator()(const Integer& lhs, const Integer& rhs) const {
    return lhs->value == rhs->value;
  }
};

template<>
class PEqualChecker<Var> {
 public:
  bool operator()(const Var& lhs, const Var& rhs) const {
    return lhs.same_as(rhs);
  }
};

/*!
 * \brief Pattern variable container.
 *
 * PVar is used as a "hole" in the pattern that can be matched.
 *
 * \tparam T the type of the hole.
 *
 * \note PVar is not thread safe.
 *       Do not use the same PVar in multiple threads.
 */
template<typename T>
class PVar : public Pattern<PVar<T> > {
 public:
  // Store PVars by reference in the expression.
  using Nested = const PVar<T>&;

  void InitMatch_() const {
    filled_ = false;
  }

  bool Match_(const T& value) const {
    if (!filled_) {
      value_ = value;
      filled_ = true;
      return true;
    } else {
      return PEqualChecker<T>()(value_, value);
    }
  }

  template<typename NodeRefType,
           typename = typename std::enable_if<
             std::is_base_of<NodeRefType, T>::value>::type>
  bool Match_(const NodeRefType& value) const {
    if (const auto* ptr = value.template as<typename T::ContainerType>()) {
      return Match_(GetRef<T>(ptr));
    } else {
      return false;
    }
  }

  T Eval() const {
    CHECK(filled_);
    return value_;
  }

 protected:
  /*! \brief The matched value */
  mutable T value_;
  /*! \brief whether the variable has been filled */
  mutable bool filled_{false};
};

/*!
 * \brief Constant Pattern variable container.
 *
 * \tparam T the type of the hole.
 */
template<typename T>
class PConst : public Pattern<PConst<T> > {
 public:
  PConst(T value)  // NOLINT(*)
      : value_(value) {}

  void InitMatch_() const {}

  bool Match_(const T& value) const {
    return PEqualChecker<T>()(value_, value);
  }

  T Eval() const {
    return value_;
  }

 private:
  const T value_;
};

/*!
 * \brief Pattern binary expression.
 * \tparam NodeType The AST node type.
 * \tparam TA The pattern type of the first operand.
 * \tparam TB The pattern type of the second operand.
 */
template<typename NodeType, typename TA, typename TB>
class PBinaryExpr :
      public Pattern<PBinaryExpr<NodeType, TA, TB> > {
 public:
  PBinaryExpr(const TA& a, const TB& b) : a_(a), b_(b) {}

  void InitMatch_() const {
    a_.InitMatch_();
    b_.InitMatch_();
  }

  bool Match_(const NodeRef& node) const {
    if (const NodeType* ptr = node.as<NodeType>()) {
      if (!a_.Match_(ptr->a)) return false;
      if (!b_.Match_(ptr->b)) return false;
      return true;
    } else {
      return false;
    }
  }

  Expr Eval() const {
    Expr lhs = a_.Eval();
    Expr rhs = b_.Eval();
    Expr ret = TryConstFold<NodeType>(lhs, rhs);
    if (ret.defined()) return ret;
    return NodeType::make(lhs, rhs);
  }

 private:
  typename TA::Nested a_;
  typename TB::Nested b_;
};

template<typename TA>
class PConstWithTypeLike :
      public Pattern<PConstWithTypeLike<TA> > {
 public:
  PConstWithTypeLike(const TA& ref, int64_t value)
      : ref_(ref), value_(value) {}

  void InitMatch_() const {}

  bool Match_(const NodeRef& node) const {
    if (const ir::IntImm* ptr = node.as<ir::IntImm>()) {
      return ptr->value == value_;
    } else {
      return false;
    }
  }

  Expr Eval() const {
    return make_const(ref_.Eval().type(), value_);
  }

 private:
  typename TA::Nested ref_;
  int64_t value_;
};


#define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep)     \
  template<typename TA, typename TB>                                \
  inline PBinaryExpr<NodeName, TA, TB>                              \
  FuncName(const Pattern<TA>& a, const Pattern<TB>& b) {            \
    CheckStep;                                                      \
    return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
  }                                                                 \
  template<typename TA>                                             \
  inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> >         \
  FuncName(const Pattern<TA>& a, int64_t b) {                       \
    CheckStep;                                                      \
    return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b));     \
  }                                                                 \
  template<typename TA>                                             \
  inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA>          \
  FuncName(int64_t b, const Pattern<TA>& a) {                       \
    CheckStep;                                                      \
    return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a);     \
  }

#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
  TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, )


// raise ambiguity error for operator overload of / and %
TVM_PATTERN_BINARY_OP_EX(operator/, ir::Div, DivAmbiguityError(a));
TVM_PATTERN_BINARY_OP_EX(operator%, ir::Mod, DivAmbiguityError(a));

// arithmetic expressions
TVM_PATTERN_BINARY_OP(operator+, ir::Add);
TVM_PATTERN_BINARY_OP(operator-, ir::Sub);
TVM_PATTERN_BINARY_OP(operator*, ir::Mul);
TVM_PATTERN_BINARY_OP(min, ir::Min);
TVM_PATTERN_BINARY_OP(max, ir::Max);
TVM_PATTERN_BINARY_OP(div, ir::Div);
TVM_PATTERN_BINARY_OP(truncdiv, ir::Div);
TVM_PATTERN_BINARY_OP(truncmod, ir::Mod);
TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv);
TVM_PATTERN_BINARY_OP(floormod, ir::FloorMod);

// logical expressions
TVM_PATTERN_BINARY_OP(operator>, ir::GT);
TVM_PATTERN_BINARY_OP(operator>=, ir::GE);
TVM_PATTERN_BINARY_OP(operator<, ir::LT);
TVM_PATTERN_BINARY_OP(operator<=, ir::LE);
TVM_PATTERN_BINARY_OP(operator==, ir::EQ);
TVM_PATTERN_BINARY_OP(operator!=, ir::NE);
TVM_PATTERN_BINARY_OP(operator&&, ir::And);
TVM_PATTERN_BINARY_OP(operator||, ir::Or);

/*!
 * \brief Pattern not expression.
 * \tparam TA The pattern type of the true operand.
 */
template<typename TA>
class PNotExpr : public Pattern<PNotExpr<TA> > {
 public:
  explicit PNotExpr(const TA& value)
      : value_(value) {}

  void InitMatch_() const {
    value_.InitMatch_();
  }

  bool Match_(const NodeRef& node) const {
    if (const ir::Not* ptr = node.as<ir::Not>()) {
      if (!value_.Match_(ptr->a)) return false;
      return true;
    } else {
      return false;
    }
  }

  Expr Eval() const {
    return ir::Not::make(value_.Eval());
  }

 private:
  typename TA::Nested value_;
};

template<typename TA>
inline PNotExpr<TA> operator!(const Pattern<TA>& value) {
  return PNotExpr<TA>(value.derived());
}

// select
/*!
 * \brief Pattern select expression.
 * \tparam TCond The pattern type of the condition.
 * \tparam TA The pattern type of the true operand.
 * \tparam TB The pattern type of the false operand.
 */
template<typename TCond, typename TA, typename TB>
class PSelectExpr :
      public Pattern<PSelectExpr<TCond, TA, TB> > {
 public:
  PSelectExpr(const TCond& condition,
              const TA& true_value,
              const TB& false_value)
      : condition_(condition),
        true_value_(true_value),
        false_value_(false_value) {}

  void InitMatch_() const {
    condition_.InitMatch_();
    true_value_.InitMatch_();
    false_value_.InitMatch_();
  }

  bool Match_(const NodeRef& node) const {
    if (const ir::Select* ptr = node.as<ir::Select>()) {
      if (!condition_.Match_(ptr->condition)) return false;
      if (!true_value_.Match_(ptr->true_value)) return false;
      if (!false_value_.Match_(ptr->false_value)) return false;
      return true;
    } else {
      return false;
    }
  }

  Expr Eval() const {
    return ir::Select::make(
        condition_.Eval(), true_value_.Eval(), false_value_.Eval());
  }

 private:
  typename TCond::Nested condition_;
  typename TA::Nested true_value_;
  typename TB::Nested false_value_;
};

/*!
 * \brief Construct a select pattern.
 *
 * \param condition The condition expression.
 * \param true_value The value when condition is true.
 * \param true_value The value when condition is false.
 *
 * \return The result pattern.
 *
 * \tparam TCond The pattern type of the condition.
 * \tparam TA The pattern type of the true operand.
 * \tparam TB The pattern type of the false operand.
 */
template<typename TCond, typename TA, typename TB>
inline PSelectExpr<TCond, TA, TB>
select(const Pattern<TCond>& condition,
       const Pattern<TA>& true_value,
       const Pattern<TB>& false_value) {
  return PSelectExpr<TCond, TA, TB>(
      condition.derived(), true_value.derived(), false_value.derived());
}

/*!
 * \brief Pattern cast expression.
 * \tparam DType The Pattern type of dtype.
 * \tparam TA The pattern type of the first operand.
 */
template<typename DType, typename TA>
class PCastExpr :
      public Pattern<PCastExpr<DType, TA> > {
 public:
  PCastExpr(const DType& dtype, const TA& value)
      : dtype_(dtype), value_(value) {
  }

  void InitMatch_() const {
    dtype_.InitMatch_();
    value_.InitMatch_();
  }

  bool Match_(const NodeRef& node) const {
    if (const ir::Cast* ptr = node.as<ir::Cast>()) {
      if (!dtype_.Match_(ptr->type)) return false;
      if (!value_.Match_(ptr->value)) return false;
      return true;
    } else {
      return false;
    }
  }

  Expr Eval() const {
    return ir::Cast::make(dtype_.Eval(), value_.Eval());
  }

 private:
  typename DType::Nested dtype_;
  typename TA::Nested value_;
};

/*!
 * \brief Construct a cast pattern.
 *
 * \param dtype The target data type, can be PVar<Type> or PConst<Type>.
 * \param value The input type.
 *
 * \return The result pattern.
 *
 * \tparam DType The pattern type of type.
 * \tparam TA The pattern type of value.
 */
template<typename DType, typename TA>
inline PCastExpr<DType, TA>
cast(const Pattern<DType>& dtype, const Pattern<TA>& value) {
  return PCastExpr<DType, TA>(dtype.derived(), value.derived());
}

/*!
 * \brief Pattern ramp expression.
 * \tparam TBase The pattern type of the base.
 * \tparam TStride The pattern type of the stride.
 * \tparam TLanes The pattern type of the lanes.
 */
template<typename TBase, typename TStride, typename TLanes>
class PRampExpr :
      public Pattern<PRampExpr<TBase, TStride, TLanes> > {
 public:
  PRampExpr(const TBase& base,
            const TStride& stride,
            const TLanes& lanes)
      : base_(base), stride_(stride), lanes_(lanes) {
  }

  void InitMatch_() const {
    base_.InitMatch_();
    stride_.InitMatch_();
    lanes_.InitMatch_();
  }

  bool Match_(const NodeRef& node) const {
    if (const ir::Ramp* ptr = node.as<ir::Ramp>()) {
      if (!base_.Match_(ptr->base)) return false;
      if (!stride_.Match_(ptr->stride)) return false;
      if (!lanes_.Match_(ptr->lanes)) return false;
      return true;
    } else {
      return false;
    }
  }

  Expr Eval() const {
    return ir::Ramp::make(base_.Eval(), stride_.Eval(), lanes_.Eval());
  }

 private:
  typename TBase::Nested base_;
  typename TStride::Nested stride_;
  typename TLanes::Nested lanes_;
};

/*!
 * \brief Construct a ramp pattern.
 *
 * \param base The base pattern.
 * \param stride The stride pattern.
 * \param lanes The lanes pattern.
 *
 * \return The result pattern.
 *
 * \tparam TBase The pattern type of the base.
 * \tparam TStride The pattern type of the stride.
 * \tparam TLanes The pattern type of the lanes.
 */
template<typename TBase, typename TStride, typename TLanes>
inline PRampExpr<TBase, TStride, TLanes>
ramp(const Pattern<TBase>& base,
     const Pattern<TStride>& stride,
     const Pattern<TLanes>& lanes) {
  return PRampExpr<TBase, TStride, TLanes>(
      base.derived(), stride.derived(), lanes.derived());
}

/*!
 * \brief Pattern broadcast expression.
 * \tparam TA The pattern type of the value.
 * \tparam TLanes The pattern type of the lanes.
 */
template<typename TA, typename TLanes>
class PBroadcastExpr :
      public Pattern<PBroadcastExpr<TA, TLanes> > {
 public:
  PBroadcastExpr(const TA& value,
                 const TLanes& lanes)
      : value_(value), lanes_(lanes) {
  }

  void InitMatch_() const {
    value_.InitMatch_();
    lanes_.InitMatch_();
  }

  bool Match_(const NodeRef& node) const {
    if (const ir::Broadcast* ptr = node.as<ir::Broadcast>()) {
      if (!value_.Match_(ptr->value)) return false;
      if (!lanes_.Match_(ptr->lanes)) return false;
      return true;
    } else {
      return false;
    }
  }

  Expr Eval() const {
    return ir::Broadcast::make(value_.Eval(), lanes_.Eval());
  }

 private:
  typename TA::Nested value_;
  typename TLanes::Nested lanes_;
};

/*!
 * \brief Construct a broadcast pattern.
 *
 * \param value The value pattern.
 * \param lanes The lanes pattern.
 *
 * \return The result pattern.
 *
 * \tparam TA The pattern type of the value.
 * \tparam TLanes The pattern type of the lanes.
 */
template<typename TA, typename TLanes>
inline PBroadcastExpr<TA, TLanes>
broadcast(const Pattern<TA>& value, const Pattern<TLanes>& lanes) {
  return PBroadcastExpr<TA, TLanes>(value.derived(), lanes.derived());
}

// internal namespace
namespace detail {
// implementation details for  CallExpr
template<bool stop, std::size_t I, typename F>
struct tuple_for_each_dispatcher {
  template<typename TTuple>
  static void run(F& f, const TTuple& tuple) { // NOLINT(*)
    f(I, std::get<I>(tuple));
    tuple_for_each_dispatcher<
      (I + 1) == std::tuple_size<TTuple>::value, (I + 1), F>
        ::run(f, tuple);
  }
};

template<std::size_t I, typename F>
struct tuple_for_each_dispatcher<true, I, F> {
  template<typename TTuple>
  static void run(F& f, const TTuple& tuple) {} // NOLINT(*)
};

template<typename F, typename TTuple>
inline void tuple_for_each(F& f, const TTuple& tuple) {  // NOLINT(*)
  tuple_for_each_dispatcher<std::tuple_size<TTuple>::value == 0, 0, F>
      ::run(f, tuple);
}

struct PCallExprInitMatchFunctor {
  template<typename T>
  void operator()(size_t i, const T& pattern) const {
    pattern.InitMatch_();
  }
};

struct PCallExprMatchFunctor {
  const ir::Call* call_;
  bool matched_{true};

  explicit PCallExprMatchFunctor(const ir::Call* call)
      : call_(call) {}

  template<typename T>
  void operator()(size_t i, const T& pattern) {
    matched_ = matched_ && pattern.Match_(call_->args[i]);
  }
};

struct PCallExprEvalArgsFunctor {
  Array<Expr> args_;

  template<typename T>
  void operator()(size_t i, const T& pattern) {
    args_.push_back(pattern.Eval());
  }
};
}  // namespace detail

/*!
 * \brief Pattern CallExpr expression.
 * \tparam Op The operator functor class.
 * \tparam TArgs The arguments.
 * \note Op functor contains the name of the function and
 *          the implementation of Eval.
 */
template<typename Op, typename ...TArgs>
class PCallExpr :
      public Pattern<PCallExpr<Op, TArgs...> > {
 public:
  explicit PCallExpr(const TArgs&... args)
      : args_(args...) {
  }

  void InitMatch_() const {
    detail::PCallExprInitMatchFunctor finit;
    detail::tuple_for_each(finit, args_);
  }

  bool Match_(const NodeRef& node) const {
    if (const ir::Call* ptr = node.as<ir::Call>()) {
      if (ptr->args.size() != sizeof...(TArgs)) return false;
      if (ptr->name != Op::kName) return false;
      detail::PCallExprMatchFunctor fmatch(ptr);
      detail::tuple_for_each(fmatch, args_);
      return fmatch.matched_;
    } else {
      return false;
    }
  }

  Expr Eval() const {
    detail::PCallExprEvalArgsFunctor feval_args;
    detail::tuple_for_each(feval_args, args_);
    return Op::Eval(feval_args.args_);
  }

 private:
  std::tuple<typename TArgs::Nested...> args_;
};

// arithemetic intrinsics
#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr)        \
  struct OpName {                                                     \
    static Expr Eval(Array<Expr> args) {                              \
      return ir::Call::make(args[0].type(), kName, args,              \
                            ir::Call::PureIntrinsic);                 \
    }                                                                 \
    static constexpr const char* kName = IntrinStr;                   \
  };                                                                  \
  template<typename TA, typename TB>                                  \
  inline PCallExpr<OpName, TA, TB>                                    \
  FuncName(const Pattern<TA>& a, const Pattern<TB>& b) {              \
    return PCallExpr<OpName, TA, TB>(a.derived(), b.derived());             \
  }

TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left");
TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, "shift_right");
TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, "bitwise_and");
TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, "bitwise_or");
TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor");

// unary intrinsics
#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr)         \
  struct OpName {                                                     \
    static Expr Eval(Array<Expr> args) {                              \
      return ir::Call::make(args[0].type(), kName, args,              \
                            ir::Call::PureIntrinsic);                 \
    }                                                                 \
    static constexpr const char* kName = IntrinStr;                   \
  };                                                                  \
  template<typename TA>                                               \
  inline PCallExpr<OpName, TA>                                        \
  FuncName(const Pattern<TA>& a) {                                    \
    return PCallExpr<OpName, TA>(a.derived());                           \
  }

TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not");

// if_then_else
struct PIfThenElseOp {
  static Expr Eval(Array<Expr> args) {
    return ir::Call::make(
        args[1].type(), kName, args,
        ir::Call::PureIntrinsic);
  }
  static constexpr const char* kName = "tvm_if_then_else";
};

/*!
 * \brief Construct a if_then_else pattern.
 *
 * \param cond The condition expression.
 * \param true_value The value when condition is true.
 * \param true_value The value when condition is false.
 *
 * \return The result pattern.
 *
 * \tparam TCond The pattern type of the condition.
 * \tparam TA The pattern type of the true operand.
 * \tparam TB The pattern type of the false operand.
 */
template<typename TCond, typename TA, typename TB>
inline PCallExpr<PIfThenElseOp, TCond, TA, TB>
if_then_else(const Pattern<TCond>& cond,
             const Pattern<TA>& true_value,
             const Pattern<TB>& false_value) {
  return PCallExpr<PIfThenElseOp, TCond, TA, TB>(
      cond.derived(), true_value.derived(), false_value.derived());
}

}  // namespace arith
}  // namespace tvm
#endif  // TVM_ARITHMETIC_PATTERN_MATCH_H_