ir_visitor.cc 5.91 KB
Newer Older
tqchen committed
1 2 3 4 5 6 7 8 9 10 11 12 13
/*!
 *  Copyright (c) 2016 by Contributors
 * \file ir_visitor.cc
 */
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <unordered_set>

namespace tvm {
namespace ir {
// visitor to implement apply
class IRApplyVisit : public IRVisitor {
 public:
14
  explicit IRApplyVisit(std::function<void(const NodeRef&)> f) : f_(f) {}
tqchen committed
15

16
  void Visit(const NodeRef& node) final {
tqchen committed
17 18
    if (visited_.count(node.get()) != 0) return;
    visited_.insert(node.get());
tqchen committed
19
    IRVisitor::Visit(node);
tqchen committed
20 21 22 23
    f_(node);
  }

 private:
24
  std::function<void(const NodeRef&)> f_;
tqchen committed
25 26
  std::unordered_set<const Node*> visited_;
};
tqchen committed
27

tqchen committed
28

29
void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit) {
tqchen committed
30 31 32
  IRApplyVisit(fvisit).Visit(node);
}

tqchen committed
33 34 35 36
IRVisitor::FVisit& IRVisitor::vtable() {  // NOLINT(*)
  static FVisit inst; return inst;
}

tqchen committed
37
inline void VisitArray(const Array<Expr>& arr, IRVisitor* v) {
tqchen committed
38
  for (size_t i = 0; i < arr.size(); i++) {
tqchen committed
39
    v->Visit(arr[i]);
tqchen committed
40 41 42
  }
}

tqchen committed
43 44 45
inline void VisitRDom(const Array<IterVar>& rdom, IRVisitor* v) {
  for (size_t i = 0; i < rdom.size(); i++) {
    Range r = rdom[i]->dom;
tqchen committed
46 47
    v->Visit(r->min);
    v->Visit(r->extent);
tqchen committed
48 49 50
  }
}

51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
void IRVisitor::Visit_(const Variable* op) {}

void IRVisitor::Visit_(const LetStmt *op) {
  this->Visit(op->value);
  this->Visit(op->body);
}

void IRVisitor::Visit_(const AttrStmt* op) {
  this->Visit(op->value);
  this->Visit(op->body);
}

void IRVisitor::Visit_(const For *op) {
  IRVisitor* v = this;
  v->Visit(op->min);
  v->Visit(op->extent);
  v->Visit(op->body);
}

void IRVisitor::Visit_(const Allocate *op) {
  IRVisitor* v = this;
  for (size_t i = 0; i < op->extents.size(); i++) {
    v->Visit(op->extents[i]);
  }
  v->Visit(op->body);
  v->Visit(op->condition);
  if (op->new_expr.defined()) {
    v->Visit(op->new_expr);
  }
}

void IRVisitor::Visit_(const Load *op) {
  this->Visit(op->index);
84
  this->Visit(op->predicate);
85 86 87 88 89
}

void IRVisitor::Visit_(const Store *op) {
  this->Visit(op->value);
  this->Visit(op->index);
90
  this->Visit(op->predicate);
91 92
}

93 94 95 96 97 98 99 100
void IRVisitor::Visit_(const IfThenElse *op) {
  this->Visit(op->condition);
  this->Visit(op->then_case);
  if (op->else_case.defined()) {
    this->Visit(op->else_case);
  }
}

101 102 103 104 105 106 107 108 109 110
void IRVisitor::Visit_(const Let *op) {
  this->Visit(op->value);
  this->Visit(op->body);
}

void IRVisitor::Visit_(const Free* op) {}

void IRVisitor::Visit_(const Call *op) {
  VisitArray(op->args, this);
}
tqchen committed
111

112 113 114 115 116
#define DEFINE_BINOP_VISIT_(OP)                     \
  void IRVisitor::Visit_(const OP* op) {            \
    this->Visit(op->a);                             \
    this->Visit(op->b);                             \
  }
tqchen committed
117

118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
DEFINE_BINOP_VISIT_(Add)
DEFINE_BINOP_VISIT_(Sub)
DEFINE_BINOP_VISIT_(Mul)
DEFINE_BINOP_VISIT_(Div)
DEFINE_BINOP_VISIT_(Mod)
DEFINE_BINOP_VISIT_(Min)
DEFINE_BINOP_VISIT_(Max)
DEFINE_BINOP_VISIT_(EQ)
DEFINE_BINOP_VISIT_(NE)
DEFINE_BINOP_VISIT_(LT)
DEFINE_BINOP_VISIT_(LE)
DEFINE_BINOP_VISIT_(GT)
DEFINE_BINOP_VISIT_(GE)
DEFINE_BINOP_VISIT_(And)
DEFINE_BINOP_VISIT_(Or)

void IRVisitor::Visit_(const Reduce* op) {
  VisitRDom(op->axis, this);
136
  VisitArray(op->source, this);
137
}
tqchen committed
138

139 140
void IRVisitor::Visit_(const Cast* op) {
  this->Visit(op->value);
tqchen committed
141 142
}

143 144 145
void IRVisitor::Visit_(const Not* op) {
  this->Visit(op->a);
}
tqchen committed
146

147 148 149 150 151
void IRVisitor::Visit_(const Select* op) {
  this->Visit(op->condition);
  this->Visit(op->true_value);
  this->Visit(op->false_value);
}
tqchen committed
152

153 154 155 156 157 158 159 160 161 162 163 164
void IRVisitor::Visit_(const Ramp *op) {
  this->Visit(op->base);
  this->Visit(op->stride);
}

void IRVisitor::Visit_(const Broadcast *op) {
  this->Visit(op->value);
}

void IRVisitor::Visit_(const AssertStmt *op) {
  this->Visit(op->condition);
  this->Visit(op->message);
165
  this->Visit(op->body);
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
}

void IRVisitor::Visit_(const ProducerConsumer *op) {
  this->Visit(op->body);
}

void IRVisitor::Visit_(const Provide *op) {
  VisitArray(op->args, this);
  this->Visit(op->value);
}

void IRVisitor::Visit_(const Realize *op) {
  for (size_t i = 0; i < op->bounds.size(); i++) {
    this->Visit(op->bounds[i]->min);
    this->Visit(op->bounds[i]->extent);
  }

  this->Visit(op->body);
  this->Visit(op->condition);
}

187 188 189 190 191 192 193
void IRVisitor::Visit_(const Prefetch *op) {
  for (size_t i = 0; i < op->bounds.size(); i++) {
    this->Visit(op->bounds[i]->min);
    this->Visit(op->bounds[i]->extent);
  }
}

194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
void IRVisitor::Visit_(const Block *op) {
  this->Visit(op->first);
  this->Visit(op->rest);
}

void IRVisitor::Visit_(const Evaluate *op) {
  this->Visit(op->value);
}

#define DEFINE_OP_NO_VISIT_(OP)                     \
  void IRVisitor::Visit_(const OP* op) {}

DEFINE_OP_NO_VISIT_(IntImm)
DEFINE_OP_NO_VISIT_(UIntImm)
DEFINE_OP_NO_VISIT_(FloatImm)
DEFINE_OP_NO_VISIT_(StringImm)

#define DISPATCH_TO_VISIT(OP)                       \
  set_dispatch<OP>([](const OP* op, IRVisitor* v) { \
      v->Visit_(op);                                \
    })

TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.DISPATCH_TO_VISIT(Variable)
.DISPATCH_TO_VISIT(LetStmt)
.DISPATCH_TO_VISIT(AttrStmt)
.DISPATCH_TO_VISIT(IfThenElse)
.DISPATCH_TO_VISIT(For)
.DISPATCH_TO_VISIT(Allocate)
.DISPATCH_TO_VISIT(Load)
.DISPATCH_TO_VISIT(Store)
.DISPATCH_TO_VISIT(Let)
.DISPATCH_TO_VISIT(Free)
.DISPATCH_TO_VISIT(Call)
.DISPATCH_TO_VISIT(Add)
.DISPATCH_TO_VISIT(Sub)
.DISPATCH_TO_VISIT(Mul)
.DISPATCH_TO_VISIT(Div)
.DISPATCH_TO_VISIT(Mod)
.DISPATCH_TO_VISIT(Min)
.DISPATCH_TO_VISIT(Max)
.DISPATCH_TO_VISIT(EQ)
.DISPATCH_TO_VISIT(NE)
.DISPATCH_TO_VISIT(LT)
.DISPATCH_TO_VISIT(LE)
.DISPATCH_TO_VISIT(GT)
.DISPATCH_TO_VISIT(GE)
.DISPATCH_TO_VISIT(And)
.DISPATCH_TO_VISIT(Or)
.DISPATCH_TO_VISIT(Reduce)
.DISPATCH_TO_VISIT(Cast)
.DISPATCH_TO_VISIT(Not)
.DISPATCH_TO_VISIT(Select)
.DISPATCH_TO_VISIT(Ramp)
.DISPATCH_TO_VISIT(Broadcast)
.DISPATCH_TO_VISIT(AssertStmt)
.DISPATCH_TO_VISIT(ProducerConsumer)
.DISPATCH_TO_VISIT(Provide)
.DISPATCH_TO_VISIT(Realize)
.DISPATCH_TO_VISIT(Block)
.DISPATCH_TO_VISIT(Evaluate)
.DISPATCH_TO_VISIT(IntImm)
.DISPATCH_TO_VISIT(UIntImm)
.DISPATCH_TO_VISIT(FloatImm)
258 259
.DISPATCH_TO_VISIT(StringImm)
.DISPATCH_TO_VISIT(Prefetch);
tqchen committed
260 261 262

}  // namespace ir
}  // namespace tvm