ir_visitor.h 5.25 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

tqchen committed
20
/*!
tqchen committed
21
 * \file tvm/ir_visitor.h
tqchen committed
22 23 24 25 26
 * \brief Visitor to quickly visit IR trees
 */
#ifndef TVM_IR_VISITOR_H_
#define TVM_IR_VISITOR_H_

27
#include "ir.h"
28
#include "tvm/node/ir_functor.h"
tqchen committed
29 30 31 32 33 34 35 36

namespace tvm {
namespace ir {

/*!
 * \brief a base class for visitor to iterative traverse the IR
 *
 *  This IRVisitor is implemented via IRFunctor
37
 *  This enables extensions of possible new Node.
tqchen committed
38
 *
39 40 41
 * \sa ExprFunctor, StmtFunctor, PostOrderVisit
 *
 * \note If you need to return values during Visit:
ziheng committed
42
 *  - If it is mutation of the IR, use IRMutator
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
 *  - If you want to return other things, consider use ExprFunctor/StmtFunctor
 *  - Watch out for possible bug pattern if you use IRVisitor to simulate returns.
 *
 * \code
 *
 * // This is an example code to show cases for traps in IRVisitor
 * // The use case is to count number of Variables in the ir tree.
 * class MyCounter : public IRVisitor {
 *  public:
 *   int Count(const NodeRef& n) {
 *     ret_ = 0;
 *     this->Visit(n);
 *     return ret_;
 *   }
 *   void Visit_(const Variable* op) final {
 *     ret_ = 1;
 *   }
 *   void Visit_(const Add* op) final {
 *     ret_ = count(op->a) + count(op->b);
 *   }

 *  private:
 *   int ret_;
 * };
 * MyCounter counter;
 * Var x("x");
 * // this returns 2
 * CHECK_EQ(counter.Count(x + x), 2);
 * // Think what is the result of the following count
 * counter.count(Max::make(x, x));
 * // The result is actually 1
 * // This is because Visit is not overriden for Max
 * // so it simply calls Visit for the left and right children
 * // and because Count is not called, ret_ is not cleared.
 * // There can also be cases where ret_ is forgetten to be set.
 *
 * // These traps may not happen if we program carefully
 * // But it is recommended to use ExprFunctor, which allows direct
 * // return the value, this helps us to avoid such problems.
 *
83
 * \endcode
tqchen committed
84
 */
85
class TVM_DLL IRVisitor {
tqchen committed
86 87 88 89
 public:
  /*!
   * \brief recursively visit an IR node
   */
90
  virtual void Visit(const NodeRef& node) {
tqchen committed
91 92 93 94 95 96
    static const FVisit& f = vtable();
    if (node.defined()) f(node, this);
  }
  /*! \brief destructor */
  virtual ~IRVisitor() {}
  /*! \brief functor type of visitor */
97
  using FVisit = IRFunctor<void(const NodeRef&, IRVisitor*)>;
tqchen committed
98 99
  /*! \return internal vtable*/
  static FVisit& vtable();
100 101 102
  // overloadable visit function.
  virtual void Visit_(const Variable* op);
  virtual void Visit_(const LetStmt* op);
103 104
  virtual void Visit_(const AttrStmt* op);
  virtual void Visit_(const IfThenElse* op);
105 106 107 108 109 110 111
  virtual void Visit_(const For* op);
  virtual void Visit_(const Allocate* op);
  virtual void Visit_(const Load* op);
  virtual void Visit_(const Store* op);
  virtual void Visit_(const Let* op);
  virtual void Visit_(const Free* op);
  virtual void Visit_(const Call* op);
112 113 114 115 116
  virtual void Visit_(const Add* op);
  virtual void Visit_(const Sub* op);
  virtual void Visit_(const Mul* op);
  virtual void Visit_(const Div* op);
  virtual void Visit_(const Mod* op);
117 118
  virtual void Visit_(const FloorDiv* op);
  virtual void Visit_(const FloorMod* op);
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
  virtual void Visit_(const Min* op);
  virtual void Visit_(const Max* op);
  virtual void Visit_(const EQ* op);
  virtual void Visit_(const NE* op);
  virtual void Visit_(const LT* op);
  virtual void Visit_(const LE* op);
  virtual void Visit_(const GT* op);
  virtual void Visit_(const GE* op);
  virtual void Visit_(const And* op);
  virtual void Visit_(const Or* op);
  virtual void Visit_(const Reduce* op);
  virtual void Visit_(const Cast* op);
  virtual void Visit_(const Not* op);
  virtual void Visit_(const Select* op);
  virtual void Visit_(const Ramp* op);
134
  virtual void Visit_(const Shuffle* op);
135 136 137 138 139
  virtual void Visit_(const Broadcast* op);
  virtual void Visit_(const AssertStmt* op);
  virtual void Visit_(const ProducerConsumer* op);
  virtual void Visit_(const Provide* op);
  virtual void Visit_(const Realize* op);
140
  virtual void Visit_(const Prefetch* op);
141 142 143 144 145 146
  virtual void Visit_(const Block* op);
  virtual void Visit_(const Evaluate* op);
  virtual void Visit_(const IntImm* op);
  virtual void Visit_(const UIntImm* op);
  virtual void Visit_(const FloatImm* op);
  virtual void Visit_(const StringImm* op);
tqchen committed
147 148 149 150
};

/*!
 * \brief recursively visit the ir in post DFS order node, apply fvisit
ziheng committed
151
 * Each node is guaranteed to be visited only once.
tqchen committed
152 153 154
 * \param node The ir to be visited.
 * \param fvisit The visitor function to be applied.
 */
155
TVM_DLL void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit);
tqchen committed
156 157 158 159 160

}  // namespace ir
}  // namespace tvm

#endif  // TVM_IR_VISITOR_H_