ir_visitor.h 4.38 KB
Newer Older
tqchen committed
1 2
/*!
 *  Copyright (c) 2016 by Contributors
tqchen committed
3
 * \file tvm/ir_visitor.h
tqchen committed
4 5 6 7 8
 * \brief Visitor to quickly visit IR trees
 */
#ifndef TVM_IR_VISITOR_H_
#define TVM_IR_VISITOR_H_

9
#include "ir.h"
10
#include "tvm/node/ir_functor.h"
tqchen committed
11 12 13 14 15 16 17 18

namespace tvm {
namespace ir {

/*!
 * \brief a base class for visitor to iterative traverse the IR
 *
 *  This IRVisitor is implemented via IRFunctor
19
 *  This enables extensions of possible new Node.
tqchen committed
20
 *
21 22 23
 * \sa ExprFunctor, StmtFunctor, PostOrderVisit
 *
 * \note If you need to return values during Visit:
ziheng committed
24
 *  - If it is mutation of the IR, use IRMutator
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
 *  - If you want to return other things, consider use ExprFunctor/StmtFunctor
 *  - Watch out for possible bug pattern if you use IRVisitor to simulate returns.
 *
 * \code
 *
 * // This is an example code to show cases for traps in IRVisitor
 * // The use case is to count number of Variables in the ir tree.
 * class MyCounter : public IRVisitor {
 *  public:
 *   int Count(const NodeRef& n) {
 *     ret_ = 0;
 *     this->Visit(n);
 *     return ret_;
 *   }
 *   void Visit_(const Variable* op) final {
 *     ret_ = 1;
 *   }
 *   void Visit_(const Add* op) final {
 *     ret_ = count(op->a) + count(op->b);
 *   }

 *  private:
 *   int ret_;
 * };
 * MyCounter counter;
 * Var x("x");
 * // this returns 2
 * CHECK_EQ(counter.Count(x + x), 2);
 * // Think what is the result of the following count
 * counter.count(Max::make(x, x));
 * // The result is actually 1
 * // This is because Visit is not overriden for Max
 * // so it simply calls Visit for the left and right children
 * // and because Count is not called, ret_ is not cleared.
 * // There can also be cases where ret_ is forgetten to be set.
 *
 * // These traps may not happen if we program carefully
 * // But it is recommended to use ExprFunctor, which allows direct
 * // return the value, this helps us to avoid such problems.
 *
65
 * \endcode
tqchen committed
66
 */
67
class TVM_DLL IRVisitor {
tqchen committed
68 69 70 71
 public:
  /*!
   * \brief recursively visit an IR node
   */
72
  virtual void Visit(const NodeRef& node) {
tqchen committed
73 74 75 76 77 78
    static const FVisit& f = vtable();
    if (node.defined()) f(node, this);
  }
  /*! \brief destructor */
  virtual ~IRVisitor() {}
  /*! \brief functor type of visitor */
79
  using FVisit = IRFunctor<void(const NodeRef&, IRVisitor*)>;
tqchen committed
80 81
  /*! \return internal vtable*/
  static FVisit& vtable();
82 83 84
  // overloadable visit function.
  virtual void Visit_(const Variable* op);
  virtual void Visit_(const LetStmt* op);
85 86
  virtual void Visit_(const AttrStmt* op);
  virtual void Visit_(const IfThenElse* op);
87 88 89 90 91 92 93
  virtual void Visit_(const For* op);
  virtual void Visit_(const Allocate* op);
  virtual void Visit_(const Load* op);
  virtual void Visit_(const Store* op);
  virtual void Visit_(const Let* op);
  virtual void Visit_(const Free* op);
  virtual void Visit_(const Call* op);
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
  virtual void Visit_(const Add* op);
  virtual void Visit_(const Sub* op);
  virtual void Visit_(const Mul* op);
  virtual void Visit_(const Div* op);
  virtual void Visit_(const Mod* op);
  virtual void Visit_(const Min* op);
  virtual void Visit_(const Max* op);
  virtual void Visit_(const EQ* op);
  virtual void Visit_(const NE* op);
  virtual void Visit_(const LT* op);
  virtual void Visit_(const LE* op);
  virtual void Visit_(const GT* op);
  virtual void Visit_(const GE* op);
  virtual void Visit_(const And* op);
  virtual void Visit_(const Or* op);
  virtual void Visit_(const Reduce* op);
  virtual void Visit_(const Cast* op);
  virtual void Visit_(const Not* op);
  virtual void Visit_(const Select* op);
  virtual void Visit_(const Ramp* op);
  virtual void Visit_(const Broadcast* op);
  virtual void Visit_(const AssertStmt* op);
  virtual void Visit_(const ProducerConsumer* op);
  virtual void Visit_(const Provide* op);
  virtual void Visit_(const Realize* op);
119
  virtual void Visit_(const Prefetch* op);
120 121 122 123 124 125
  virtual void Visit_(const Block* op);
  virtual void Visit_(const Evaluate* op);
  virtual void Visit_(const IntImm* op);
  virtual void Visit_(const UIntImm* op);
  virtual void Visit_(const FloatImm* op);
  virtual void Visit_(const StringImm* op);
tqchen committed
126 127 128 129
};

/*!
 * \brief recursively visit the ir in post DFS order node, apply fvisit
ziheng committed
130
 * Each node is guaranteed to be visited only once.
tqchen committed
131 132 133
 * \param node The ir to be visited.
 * \param fvisit The visitor function to be applied.
 */
134
TVM_DLL void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit);
tqchen committed
135 136 137 138 139

}  // namespace ir
}  // namespace tvm

#endif  // TVM_IR_VISITOR_H_