ir_visitor.cc 5.23 KB
Newer Older
tqchen committed
1 2 3 4 5 6 7 8 9 10 11 12 13
/*!
 *  Copyright (c) 2016 by Contributors
 * \file ir_visitor.cc
 */
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <unordered_set>

namespace tvm {
namespace ir {
// visitor to implement apply
class IRApplyVisit : public IRVisitor {
 public:
14
  explicit IRApplyVisit(std::function<void(const NodeRef&)> f) : f_(f) {}
tqchen committed
15

16
  void Visit(const NodeRef& node) final {
tqchen committed
17 18
    if (visited_.count(node.get()) != 0) return;
    visited_.insert(node.get());
tqchen committed
19
    IRVisitor::Visit(node);
tqchen committed
20 21 22 23
    f_(node);
  }

 private:
24
  std::function<void(const NodeRef&)> f_;
tqchen committed
25 26
  std::unordered_set<const Node*> visited_;
};
tqchen committed
27

tqchen committed
28

29
void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit) {
tqchen committed
30 31 32
  IRApplyVisit(fvisit).Visit(node);
}

tqchen committed
33 34 35 36
IRVisitor::FVisit& IRVisitor::vtable() {  // NOLINT(*)
  static FVisit inst; return inst;
}

37
void NoOp(const NodeRef& n, IRVisitor* v) {
tqchen committed
38 39
}

tqchen committed
40
inline void VisitArray(const Array<Expr>& arr, IRVisitor* v) {
tqchen committed
41
  for (size_t i = 0; i < arr.size(); i++) {
tqchen committed
42
    v->Visit(arr[i]);
tqchen committed
43 44 45
  }
}

tqchen committed
46 47 48
inline void VisitRDom(const Array<IterVar>& rdom, IRVisitor* v) {
  for (size_t i = 0; i < rdom.size(); i++) {
    Range r = rdom[i]->dom;
tqchen committed
49 50
    v->Visit(r->min);
    v->Visit(r->extent);
tqchen committed
51 52 53
  }
}

54 55 56 57
#define DISPATCH_TO_VISIT(OP)                       \
  set_dispatch<OP>([](const OP* op, IRVisitor* v) { \
      v->Visit_(op);                                \
    })
tqchen committed
58 59

TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
60 61
.DISPATCH_TO_VISIT(Variable)
.DISPATCH_TO_VISIT(LetStmt)
62 63
.DISPATCH_TO_VISIT(AttrStmt)
.DISPATCH_TO_VISIT(IfThenElse)
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
.DISPATCH_TO_VISIT(For)
.DISPATCH_TO_VISIT(Allocate)
.DISPATCH_TO_VISIT(Load)
.DISPATCH_TO_VISIT(Store)
.DISPATCH_TO_VISIT(Let)
.DISPATCH_TO_VISIT(Call)
.DISPATCH_TO_VISIT(Free);

void IRVisitor::Visit_(const Variable* op) {}

void IRVisitor::Visit_(const LetStmt *op) {
  this->Visit(op->value);
  this->Visit(op->body);
}

void IRVisitor::Visit_(const AttrStmt* op) {
  this->Visit(op->value);
  this->Visit(op->body);
}

void IRVisitor::Visit_(const For *op) {
  IRVisitor* v = this;
  v->Visit(op->min);
  v->Visit(op->extent);
  v->Visit(op->body);
}

void IRVisitor::Visit_(const Allocate *op) {
  IRVisitor* v = this;
  for (size_t i = 0; i < op->extents.size(); i++) {
    v->Visit(op->extents[i]);
  }
  v->Visit(op->body);
  v->Visit(op->condition);
  if (op->new_expr.defined()) {
    v->Visit(op->new_expr);
  }
}

void IRVisitor::Visit_(const Load *op) {
  this->Visit(op->index);
}

void IRVisitor::Visit_(const Store *op) {
  this->Visit(op->value);
  this->Visit(op->index);
}

112 113 114 115 116 117 118 119
void IRVisitor::Visit_(const IfThenElse *op) {
  this->Visit(op->condition);
  this->Visit(op->then_case);
  if (op->else_case.defined()) {
    this->Visit(op->else_case);
  }
}

120 121 122 123 124 125 126 127 128 129
void IRVisitor::Visit_(const Let *op) {
  this->Visit(op->value);
  this->Visit(op->body);
}

void IRVisitor::Visit_(const Free* op) {}

void IRVisitor::Visit_(const Call *op) {
  VisitArray(op->args, this);
}
tqchen committed
130 131

TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
132
.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
133
    VisitRDom(op->axis, v);
134 135
    v->Visit(op->source);
  })
tqchen committed
136 137 138
.set_dispatch<IntImm>(NoOp)
.set_dispatch<UIntImm>(NoOp)
.set_dispatch<FloatImm>(NoOp)
139
.set_dispatch<StringImm>(NoOp);
tqchen committed
140 141 142

TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Cast>([](const Cast* op, IRVisitor* v) {
tqchen committed
143
    v->Visit(op->value);
tqchen committed
144 145 146 147 148
  });

// binary operator
template<typename T>
inline void Binary(const T* op, IRVisitor* v) {
tqchen committed
149 150
  v->Visit(op->a);
  v->Visit(op->b);
tqchen committed
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
}

TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Add>(Binary<Add>)
.set_dispatch<Sub>(Binary<Sub>)
.set_dispatch<Mul>(Binary<Mul>)
.set_dispatch<Div>(Binary<Div>)
.set_dispatch<Mod>(Binary<Mod>)
.set_dispatch<Min>(Binary<Min>)
.set_dispatch<Max>(Binary<Max>)
.set_dispatch<EQ>(Binary<EQ>)
.set_dispatch<NE>(Binary<NE>)
.set_dispatch<LT>(Binary<LT>)
.set_dispatch<LE>(Binary<LE>)
.set_dispatch<GT>(Binary<GT>)
.set_dispatch<GE>(Binary<GE>)
.set_dispatch<And>(Binary<And>)
.set_dispatch<Or>(Binary<Or>);

TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Not>([](const Not* op, IRVisitor* v) {
tqchen committed
172
    v->Visit(op->a);
tqchen committed
173 174
  })
.set_dispatch<Select>([](const Select *op, IRVisitor* v) {
tqchen committed
175 176 177
    v->Visit(op->condition);
    v->Visit(op->true_value);
    v->Visit(op->false_value);
tqchen committed
178 179
  })
.set_dispatch<Ramp>([](const Ramp *op, IRVisitor* v) {
tqchen committed
180 181
    v->Visit(op->base);
    v->Visit(op->stride);
tqchen committed
182 183
  })
.set_dispatch<Broadcast>([](const Broadcast *op, IRVisitor* v) {
tqchen committed
184
    v->Visit(op->value);
tqchen committed
185 186 187 188
  });

TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<AssertStmt>([](const AssertStmt *op, IRVisitor* v) {
tqchen committed
189 190
    v->Visit(op->condition);
    v->Visit(op->message);
tqchen committed
191 192
  })
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, IRVisitor* v) {
tqchen committed
193
    v->Visit(op->body);
tqchen committed
194 195 196
  })
.set_dispatch<Provide>([](const Provide *op, IRVisitor* v) {
    VisitArray(op->args, v);
197
    v->Visit(op->value);
tqchen committed
198 199 200 201
  })
.set_dispatch<Realize>([](const Realize *op, IRVisitor* v) {
    // Mutate the bounds
    for (size_t i = 0; i < op->bounds.size(); i++) {
tqchen committed
202 203
      v->Visit(op->bounds[i]->min);
      v->Visit(op->bounds[i]->extent);
tqchen committed
204 205
    }

tqchen committed
206 207
    v->Visit(op->body);
    v->Visit(op->condition);
tqchen committed
208 209
  })
.set_dispatch<Block>([](const Block *op, IRVisitor* v) {
tqchen committed
210 211
    v->Visit(op->first);
    v->Visit(op->rest);
tqchen committed
212 213
  })
.set_dispatch<Evaluate>([](const Evaluate *op, IRVisitor* v) {
tqchen committed
214
    v->Visit(op->value);
tqchen committed
215 216 217 218
  });

}  // namespace ir
}  // namespace tvm