/*!
 *  Copyright (c) 2017 by Contributors
 * \file ir_functor_ext.h
 * \brief More powerful Visitor that allows define function signatures.
 */
#ifndef TVM_IR_FUNCTOR_EXT_H_
#define TVM_IR_FUNCTOR_EXT_H_

#include <tvm/ir_functor.h>
#include "./ir.h"

namespace tvm {
namespace ir {

/*!
 * \brief A dynamical functor that dispatches on in the first Expr argument.
 *  You can use this as a more powerful Visitor, since it allows you to
 *  define function signatures of Visit Function.
 *
 *  This helps you to avoid to book-keep return value of Visitor via state,
 *  which can cause bugs easily when state is incorrectly maintained.
 *
 * \code
 *  // A functor that set variable to b. and calculate results.
 *  class MyExprFunctor
 *    : public ir::ExprFunctor<int(const Expr&, int)> {
 *   public:
 *    int VisitExpr_(const Variable* op, int b) final {
 *     return b;
 *    }
 *    int VisitExpr_(const IntImm* op, int b) final {
 *      return op->value;
 *    }
 *    int VisitExpr_(const Add* op, int b) final {
 *     return Visit(op->a, b) + Visit(op->b, b);
 *    }
 *  };
 *  MyExprFunctor f;
 *  Var x("x");
 *  CHECK_EQ(f(x + 1, 2), 3);
 * \endcode
 *
 * \note Why do we need this more powerful Functor:
 *
 *  We often need to implement a transformer tasks.
 *  Say we want to take Expr and transform it to some analysis result,
 *  This easily be done incorrectly using plain Visitor. See IRVisitor's
 *  document for possible error cases.
 *
 * \tparam FType function signiture
 *  This type if only defined for FType with function signiture R(const Expr&, Args...)
 */
template<typename FType>
class ExprFunctor;
/*!
 * \brief Same as ExprFunctor except it is applied on statements
 * \tparam FType The function signature.
 */
template<typename FType>
class StmtFunctor;

// functions to be overriden.
#define EXPR_FUNCTOR_DEFAULT {                                      \
    return VisitExprDefault_(op, std::forward<Args>(args)...);      \
  }
#define STMT_FUNCTOR_DEFAULT {                                      \
    return VisitStmtDefault_(op, std::forward<Args>(args)...);      \
}

#define IR_EXPR_FUNCTOR_DISPATCH(OP)                                    \
  vtable.template set_dispatch<OP>(                                     \
      [](const NodeRef& n, TSelf* self, Args... args) {                 \
        return self->VisitExpr_(static_cast<const OP*>(n.node_.get()),  \
                                std::forward<Args>(args)...);           \
      });                                                               \

#define IR_STMT_FUNCTOR_DISPATCH(OP)                                    \
  vtable.template set_dispatch<OP>(                                     \
      [](const NodeRef& n, TSelf* self, Args... args) {                 \
        return self->VisitStmt_(static_cast<const OP*>(n.node_.get()),  \
                                std::forward<Args>(args)...);           \
      });                                                               \

template<typename R, typename ...Args>
class ExprFunctor<R(const Expr& n, Args...)> {
 private:
  using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
  using FType = IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;

 public:
  /*! \brief the result type of this functor */
  using result_type = R;
  /*! \brief virtual destructor */
  virtual ~ExprFunctor() {}
  /*!
   * \brief Same as call.
   * \param n The expression node.
   * \param args Additional arguments.
   * \return The result of the call
   */
  R operator()(const Expr& n, Args... args) {
    return VisitExpr(n, std::forward<Args>(args)...);
  }
  /*!
   * \brief The functor call.
   * \param n The expression node.
   * \param args Additional arguments.
   * \return The result of the call
   */
  virtual R VisitExpr(const Expr& n, Args... args) {
    static FType vtable = InitVTable();
    return vtable(n, this, std::forward<Args>(args)...);
  }
  // Functions that can be overriden by subclass
  virtual R VisitExpr_(const Variable* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Load* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Let* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Call* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Add* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Sub* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Mul* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Div* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Mod* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Max* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const NE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const LT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const LE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const GT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const GE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const And* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Or* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Reduce* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Cast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Not* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Broadcast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const Shuffle* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const IntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExprDefault_(const Node* op, Args ...) {
    LOG(FATAL) << "Do not have a default for " << op->type_key();
    return R();
  }

 private:
  // initialize the vtable.
  static FType InitVTable() {
    FType vtable;
    // Set dispatch
    IR_EXPR_FUNCTOR_DISPATCH(Variable);
    IR_EXPR_FUNCTOR_DISPATCH(Load);
    IR_EXPR_FUNCTOR_DISPATCH(Let);
    IR_EXPR_FUNCTOR_DISPATCH(Call);
    IR_EXPR_FUNCTOR_DISPATCH(Add);
    IR_EXPR_FUNCTOR_DISPATCH(Sub);
    IR_EXPR_FUNCTOR_DISPATCH(Mul);
    IR_EXPR_FUNCTOR_DISPATCH(Div);
    IR_EXPR_FUNCTOR_DISPATCH(Mod);
    IR_EXPR_FUNCTOR_DISPATCH(Min);
    IR_EXPR_FUNCTOR_DISPATCH(Max);
    IR_EXPR_FUNCTOR_DISPATCH(EQ);
    IR_EXPR_FUNCTOR_DISPATCH(NE);
    IR_EXPR_FUNCTOR_DISPATCH(LT);
    IR_EXPR_FUNCTOR_DISPATCH(LE);
    IR_EXPR_FUNCTOR_DISPATCH(GT);
    IR_EXPR_FUNCTOR_DISPATCH(GE);
    IR_EXPR_FUNCTOR_DISPATCH(And);
    IR_EXPR_FUNCTOR_DISPATCH(Or);
    IR_EXPR_FUNCTOR_DISPATCH(Reduce);
    IR_EXPR_FUNCTOR_DISPATCH(Cast);
    IR_EXPR_FUNCTOR_DISPATCH(Not);
    IR_EXPR_FUNCTOR_DISPATCH(Select);
    IR_EXPR_FUNCTOR_DISPATCH(Ramp);
    IR_EXPR_FUNCTOR_DISPATCH(Broadcast);
    IR_EXPR_FUNCTOR_DISPATCH(IntImm);
    IR_EXPR_FUNCTOR_DISPATCH(UIntImm);
    IR_EXPR_FUNCTOR_DISPATCH(FloatImm);
    IR_EXPR_FUNCTOR_DISPATCH(StringImm);
    return vtable;
  }
};

template<typename R, typename ...Args>
class StmtFunctor<R(const Stmt& n, Args... args)> {
 private:
  using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>;
  using FType = IRFunctor<R(const NodeRef& n, TSelf* self, Args... args)>;

 public:
  /*! \brief the result type of this functor */
  using result_type = R;
  /*! \brief virtual destructor */
  virtual ~StmtFunctor() {}
  /*!
   * \brief Same as call.
   * \param n The stmt node.
   * \param args Additional arguments.
   * \return The result of the call
   */
  R operator()(const Stmt& n, Args... args) {
    return VisitStmt(n, std::forward<Args>(args)...);
  }
  /*!
   * \brief The functor call.
   * \param n The stmt node.
   * \param args Additional arguments.
   * \return The result of the call
   */
  virtual R VisitStmt(const Stmt& n, Args... args) {
    static FType vtable = InitVTable();
    return vtable(n, this, std::forward<Args>(args)...);
  }
  // Functions that can be overriden by subclass
  virtual R VisitStmt_(const LetStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
  virtual R VisitStmt_(const AttrStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
  virtual R VisitStmt_(const IfThenElse* op, Args... args) STMT_FUNCTOR_DEFAULT;
  virtual R VisitStmt_(const For* op, Args... args) STMT_FUNCTOR_DEFAULT;
  virtual R VisitStmt_(const Allocate* op, Args... args) STMT_FUNCTOR_DEFAULT;
  virtual R VisitStmt_(const Store* op, Args... args) STMT_FUNCTOR_DEFAULT;
  virtual R VisitStmt_(const Free* op, Args... args) STMT_FUNCTOR_DEFAULT;
  virtual R VisitStmt_(const AssertStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
  virtual R VisitStmt_(const ProducerConsumer* op, Args... args) STMT_FUNCTOR_DEFAULT;
  virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT;
  virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT;
  virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT;
  virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT;
  virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
  virtual R VisitStmtDefault_(const Node* op, Args ...) {
    LOG(FATAL) << "Do not have a default for " << op->type_key();
    return R();
  }

 private:
  // initialize the vtable.
  static FType InitVTable() {
    FType vtable;
    IR_STMT_FUNCTOR_DISPATCH(LetStmt);
    IR_STMT_FUNCTOR_DISPATCH(AttrStmt);
    IR_STMT_FUNCTOR_DISPATCH(IfThenElse);
    IR_STMT_FUNCTOR_DISPATCH(For);
    IR_STMT_FUNCTOR_DISPATCH(Allocate);
    IR_STMT_FUNCTOR_DISPATCH(Store);
    IR_STMT_FUNCTOR_DISPATCH(Free);
    IR_STMT_FUNCTOR_DISPATCH(AssertStmt);
    IR_STMT_FUNCTOR_DISPATCH(ProducerConsumer);
    IR_STMT_FUNCTOR_DISPATCH(Provide);
    IR_STMT_FUNCTOR_DISPATCH(Realize);
    IR_STMT_FUNCTOR_DISPATCH(Prefetch);
    IR_STMT_FUNCTOR_DISPATCH(Block);
    IR_STMT_FUNCTOR_DISPATCH(Evaluate);
    return vtable;
  }
};

#undef IR_STMT_FUNCTOR_DISPATCH
#undef IR_EXPR_FUNCTOR_DISPATCH
#undef EXPR_FUNCTOR_DEFAULT
#undef STMT_FUNCTOR_DEFAULT

}  // namespace ir
}  // namespace tvm
#endif  // TVM_IR_FUNCTOR_EXT_H_