/*!
 *  Copyright (c) 2016 by Contributors
 * \file ir_visitor.h
 * \brief Visitor to quickly visit IR trees
 */
#ifndef TVM_IR_VISITOR_H_
#define TVM_IR_VISITOR_H_

#include <tvm/ir_functor.h>
#include "./ir.h"

namespace tvm {
namespace ir {

/*!
 * \brief a base class for visitor to iterative traverse the IR
 *
 *  This IRVisitor is implemented via IRFunctor
 *  This enables extensions of possible new Node.
 *
 * \sa ExprFunctor, StmtFunctor, PostOrderVisit
 *
 * \note If you need to return values during Visit:
 *  - If it is mutation of the IR, use IRMutator
 *  - If you want to return other things, consider use ExprFunctor/StmtFunctor
 *  - Watch out for possible bug pattern if you use IRVisitor to simulate returns.
 *
 * \code
 *
 * // This is an example code to show cases for traps in IRVisitor
 * // The use case is to count number of Variables in the ir tree.
 * class MyCounter : public IRVisitor {
 *  public:
 *   int Count(const NodeRef& n) {
 *     ret_ = 0;
 *     this->Visit(n);
 *     return ret_;
 *   }
 *   void Visit_(const Variable* op) final {
 *     ret_ = 1;
 *   }
 *   void Visit_(const Add* op) final {
 *     ret_ = count(op->a) + count(op->b);
 *   }

 *  private:
 *   int ret_;
 * };
 * MyCounter counter;
 * Var x("x");
 * // this returns 2
 * CHECK_EQ(counter.Count(x + x), 2);
 * // Think what is the result of the following count
 * counter.count(Max::make(x, x));
 * // The result is actually 1
 * // This is because Visit is not overriden for Max
 * // so it simply calls Visit for the left and right children
 * // and because Count is not called, ret_ is not cleared.
 * // There can also be cases where ret_ is forgetten to be set.
 *
 * // These traps may not happen if we program carefully
 * // But it is recommended to use ExprFunctor, which allows direct
 * // return the value, this helps us to avoid such problems.
 *
 * \endcode
 */
class IRVisitor {
 public:
  /*!
   * \brief recursively visit an IR node
   */
  virtual void Visit(const NodeRef& node) {
    static const FVisit& f = vtable();
    if (node.defined()) f(node, this);
  }
  /*! \brief destructor */
  virtual ~IRVisitor() {}
  /*! \brief functor type of visitor */
  using FVisit = IRFunctor<void(const NodeRef&, IRVisitor*)>;
  /*! \return internal vtable*/
  static FVisit& vtable();
  // overloadable visit function.
  virtual void Visit_(const Variable* op);
  virtual void Visit_(const LetStmt* op);
  virtual void Visit_(const AttrStmt* op);
  virtual void Visit_(const IfThenElse* op);
  virtual void Visit_(const For* op);
  virtual void Visit_(const Allocate* op);
  virtual void Visit_(const Load* op);
  virtual void Visit_(const Store* op);
  virtual void Visit_(const Let* op);
  virtual void Visit_(const Free* op);
  virtual void Visit_(const Call* op);
  virtual void Visit_(const Add* op);
  virtual void Visit_(const Sub* op);
  virtual void Visit_(const Mul* op);
  virtual void Visit_(const Div* op);
  virtual void Visit_(const Mod* op);
  virtual void Visit_(const Min* op);
  virtual void Visit_(const Max* op);
  virtual void Visit_(const EQ* op);
  virtual void Visit_(const NE* op);
  virtual void Visit_(const LT* op);
  virtual void Visit_(const LE* op);
  virtual void Visit_(const GT* op);
  virtual void Visit_(const GE* op);
  virtual void Visit_(const And* op);
  virtual void Visit_(const Or* op);
  virtual void Visit_(const Reduce* op);
  virtual void Visit_(const Cast* op);
  virtual void Visit_(const Not* op);
  virtual void Visit_(const Select* op);
  virtual void Visit_(const Ramp* op);
  virtual void Visit_(const Broadcast* op);
  virtual void Visit_(const AssertStmt* op);
  virtual void Visit_(const ProducerConsumer* op);
  virtual void Visit_(const Provide* op);
  virtual void Visit_(const Realize* op);
  virtual void Visit_(const Prefetch* op);
  virtual void Visit_(const Block* op);
  virtual void Visit_(const Evaluate* op);
  virtual void Visit_(const IntImm* op);
  virtual void Visit_(const UIntImm* op);
  virtual void Visit_(const FloatImm* op);
  virtual void Visit_(const StringImm* op);
};

/*!
 * \brief recursively visit the ir in post DFS order node, apply fvisit
 * Each node is guaranteed to be visited only once.
 * \param node The ir to be visited.
 * \param fvisit The visitor function to be applied.
 */
void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit);

}  // namespace ir
}  // namespace tvm

#endif  // TVM_IR_VISITOR_H_