Commit 5c07413c by Ziheng Jiang Committed by Tianqi Chen

[PASS] Change IRVisitor interfaces to function override (#42)

* [PASS] Change IRVisitor interfaces to function override

* [PASS] Change IRMutator interfaces to overloadable function
parent b8f0ec50
......@@ -16,7 +16,8 @@ namespace ir {
/*!
* \brief a base class for mutator to iterative mutate the IR
*
* This IRMutator is implemented via IRFunctor instead of Visitor Pattern.
* This IRMutator is implemented via Visitor Pattern.
* Also you can implement via IRFunctor.
* This enables easy extensions of possible new Node.
* It also makes changing return types easier.
*
......@@ -54,20 +55,91 @@ class IRMutator {
static FMutateStmt& vtable_stmt(); // NOLINT(*)
// Set of overloadable functions
// The underscore allows Mutate not to be shadowed by inheritance
virtual Stmt Mutate_(const Variable* op, const Stmt& s);
virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
virtual Stmt Mutate_(const For* op, const Stmt& s);
virtual Stmt Mutate_(const Provide* op, const Stmt& s);
virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Load* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Let* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
virtual Stmt Mutate_(const Call* op, const Stmt& s);
virtual Stmt Mutate_(const Add* op, const Stmt& e);
virtual Stmt Mutate_(const Sub* op, const Stmt& e);
virtual Stmt Mutate_(const Mul* op, const Stmt& e);
virtual Stmt Mutate_(const Div* op, const Stmt& e);
virtual Stmt Mutate_(const Mod* op, const Stmt& e);
virtual Stmt Mutate_(const Min* op, const Stmt& e);
virtual Stmt Mutate_(const Max* op, const Stmt& e);
virtual Stmt Mutate_(const EQ* op, const Stmt& e);
virtual Stmt Mutate_(const NE* op, const Stmt& e);
virtual Stmt Mutate_(const LT* op, const Stmt& e);
virtual Stmt Mutate_(const LE* op, const Stmt& e);
virtual Stmt Mutate_(const GT* op, const Stmt& e);
virtual Stmt Mutate_(const GE* op, const Stmt& e);
virtual Stmt Mutate_(const And* op, const Stmt& e);
virtual Stmt Mutate_(const Or* op, const Stmt& e);
virtual Stmt Mutate_(const Reduce* op, const Stmt& s);
virtual Stmt Mutate_(const Cast* op, const Stmt& s);
virtual Stmt Mutate_(const Not* op, const Stmt& s);
virtual Stmt Mutate_(const Select* op, const Stmt& s);
virtual Stmt Mutate_(const Ramp* op, const Stmt& s);
virtual Stmt Mutate_(const Broadcast* op, const Stmt& e);
virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e);
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e);
virtual Stmt Mutate_(const Provide* op, const Stmt& e);
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& s);
virtual Stmt Mutate_(const Evaluate* op, const Stmt& e);
virtual Stmt Mutate_(const IntImm* op, const Stmt& e);
virtual Stmt Mutate_(const UIntImm* op, const Stmt& e);
virtual Stmt Mutate_(const FloatImm* op, const Stmt& e);
virtual Stmt Mutate_(const StringImm* op, const Stmt& e);
virtual Expr Mutate_(const Variable* op, const Expr& e);
virtual Expr Mutate_(const LetStmt* op, const Expr& e);
virtual Expr Mutate_(const AttrStmt* op, const Expr& e);
virtual Expr Mutate_(const IfThenElse* op, const Expr& e);
virtual Expr Mutate_(const For* op, const Expr& e);
virtual Expr Mutate_(const Allocate* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& e);
virtual Expr Mutate_(const Store* op, const Expr& e);
virtual Expr Mutate_(const Let* op, const Expr& e);
virtual Expr Mutate_(const Free* op, const Expr& e);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Add* op, const Expr& e);
virtual Expr Mutate_(const Sub* op, const Expr& e);
virtual Expr Mutate_(const Mul* op, const Expr& e);
virtual Expr Mutate_(const Div* op, const Expr& e);
virtual Expr Mutate_(const Mod* op, const Expr& e);
virtual Expr Mutate_(const Min* op, const Expr& e);
virtual Expr Mutate_(const Max* op, const Expr& e);
virtual Expr Mutate_(const EQ* op, const Expr& e);
virtual Expr Mutate_(const NE* op, const Expr& e);
virtual Expr Mutate_(const LT* op, const Expr& e);
virtual Expr Mutate_(const LE* op, const Expr& e);
virtual Expr Mutate_(const GT* op, const Expr& e);
virtual Expr Mutate_(const GE* op, const Expr& e);
virtual Expr Mutate_(const And* op, const Expr& e);
virtual Expr Mutate_(const Or* op, const Expr& e);
virtual Expr Mutate_(const Reduce* op, const Expr& e);
virtual Expr Mutate_(const Cast* op, const Expr& e);
virtual Expr Mutate_(const Not* op, const Expr& e);
virtual Expr Mutate_(const Select* op, const Expr& e);
virtual Expr Mutate_(const Ramp* op, const Expr& e);
virtual Expr Mutate_(const Broadcast* op, const Expr& e);
virtual Expr Mutate_(const AssertStmt* op, const Expr& e);
virtual Expr Mutate_(const ProducerConsumer* op, const Expr& e);
virtual Expr Mutate_(const Provide* op, const Expr& e);
virtual Expr Mutate_(const Realize* op, const Expr& e);
virtual Expr Mutate_(const Block* op, const Expr& e);
virtual Expr Mutate_(const Evaluate* op, const Expr& e);
virtual Expr Mutate_(const IntImm* op, const Expr& e);
virtual Expr Mutate_(const UIntImm* op, const Expr& e);
virtual Expr Mutate_(const FloatImm* op, const Expr& e);
virtual Expr Mutate_(const StringImm* op, const Expr& e);
};
/*!
......
......@@ -36,16 +36,47 @@ class IRVisitor {
static FVisit& vtable();
// overloadable visit function.
virtual void Visit_(const Variable* op);
virtual void Visit_(const AttrStmt* op);
virtual void Visit_(const LetStmt* op);
virtual void Visit_(const AttrStmt* op);
virtual void Visit_(const IfThenElse* op);
virtual void Visit_(const For* op);
virtual void Visit_(const Allocate* op);
virtual void Visit_(const IfThenElse* op);
virtual void Visit_(const Load* op);
virtual void Visit_(const Store* op);
virtual void Visit_(const Let* op);
virtual void Visit_(const Free* op);
virtual void Visit_(const Call* op);
virtual void Visit_(const Add* op);
virtual void Visit_(const Sub* op);
virtual void Visit_(const Mul* op);
virtual void Visit_(const Div* op);
virtual void Visit_(const Mod* op);
virtual void Visit_(const Min* op);
virtual void Visit_(const Max* op);
virtual void Visit_(const EQ* op);
virtual void Visit_(const NE* op);
virtual void Visit_(const LT* op);
virtual void Visit_(const LE* op);
virtual void Visit_(const GT* op);
virtual void Visit_(const GE* op);
virtual void Visit_(const And* op);
virtual void Visit_(const Or* op);
virtual void Visit_(const Reduce* op);
virtual void Visit_(const Cast* op);
virtual void Visit_(const Not* op);
virtual void Visit_(const Select* op);
virtual void Visit_(const Ramp* op);
virtual void Visit_(const Broadcast* op);
virtual void Visit_(const AssertStmt* op);
virtual void Visit_(const ProducerConsumer* op);
virtual void Visit_(const Provide* op);
virtual void Visit_(const Realize* op);
virtual void Visit_(const Block* op);
virtual void Visit_(const Evaluate* op);
virtual void Visit_(const IntImm* op);
virtual void Visit_(const UIntImm* op);
virtual void Visit_(const FloatImm* op);
virtual void Visit_(const StringImm* op);
};
/*!
......
......@@ -34,9 +34,6 @@ IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*)
static FVisit inst; return inst;
}
void NoOp(const NodeRef& n, IRVisitor* v) {
}
inline void VisitArray(const Array<Expr>& arr, IRVisitor* v) {
for (size_t i = 0; i < arr.size(); i++) {
v->Visit(arr[i]);
......@@ -51,24 +48,6 @@ inline void VisitRDom(const Array<IterVar>& rdom, IRVisitor* v) {
}
}
#define DISPATCH_TO_VISIT(OP) \
set_dispatch<OP>([](const OP* op, IRVisitor* v) { \
v->Visit_(op); \
})
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.DISPATCH_TO_VISIT(Variable)
.DISPATCH_TO_VISIT(LetStmt)
.DISPATCH_TO_VISIT(AttrStmt)
.DISPATCH_TO_VISIT(IfThenElse)
.DISPATCH_TO_VISIT(For)
.DISPATCH_TO_VISIT(Allocate)
.DISPATCH_TO_VISIT(Load)
.DISPATCH_TO_VISIT(Store)
.DISPATCH_TO_VISIT(Let)
.DISPATCH_TO_VISIT(Call)
.DISPATCH_TO_VISIT(Free);
void IRVisitor::Visit_(const Variable* op) {}
void IRVisitor::Visit_(const LetStmt *op) {
......@@ -128,91 +107,146 @@ void IRVisitor::Visit_(const Call *op) {
VisitArray(op->args, this);
}
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
VisitRDom(op->axis, v);
v->Visit(op->source);
})
.set_dispatch<IntImm>(NoOp)
.set_dispatch<UIntImm>(NoOp)
.set_dispatch<FloatImm>(NoOp)
.set_dispatch<StringImm>(NoOp);
#define DEFINE_BINOP_VISIT_(OP) \
void IRVisitor::Visit_(const OP* op) { \
this->Visit(op->a); \
this->Visit(op->b); \
}
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Cast>([](const Cast* op, IRVisitor* v) {
v->Visit(op->value);
});
DEFINE_BINOP_VISIT_(Add)
DEFINE_BINOP_VISIT_(Sub)
DEFINE_BINOP_VISIT_(Mul)
DEFINE_BINOP_VISIT_(Div)
DEFINE_BINOP_VISIT_(Mod)
DEFINE_BINOP_VISIT_(Min)
DEFINE_BINOP_VISIT_(Max)
DEFINE_BINOP_VISIT_(EQ)
DEFINE_BINOP_VISIT_(NE)
DEFINE_BINOP_VISIT_(LT)
DEFINE_BINOP_VISIT_(LE)
DEFINE_BINOP_VISIT_(GT)
DEFINE_BINOP_VISIT_(GE)
DEFINE_BINOP_VISIT_(And)
DEFINE_BINOP_VISIT_(Or)
void IRVisitor::Visit_(const Reduce* op) {
VisitRDom(op->axis, this);
this->Visit(op->source);
}
// binary operator
template<typename T>
inline void Binary(const T* op, IRVisitor* v) {
v->Visit(op->a);
v->Visit(op->b);
void IRVisitor::Visit_(const Cast* op) {
this->Visit(op->value);
}
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Add>(Binary<Add>)
.set_dispatch<Sub>(Binary<Sub>)
.set_dispatch<Mul>(Binary<Mul>)
.set_dispatch<Div>(Binary<Div>)
.set_dispatch<Mod>(Binary<Mod>)
.set_dispatch<Min>(Binary<Min>)
.set_dispatch<Max>(Binary<Max>)
.set_dispatch<EQ>(Binary<EQ>)
.set_dispatch<NE>(Binary<NE>)
.set_dispatch<LT>(Binary<LT>)
.set_dispatch<LE>(Binary<LE>)
.set_dispatch<GT>(Binary<GT>)
.set_dispatch<GE>(Binary<GE>)
.set_dispatch<And>(Binary<And>)
.set_dispatch<Or>(Binary<Or>);
void IRVisitor::Visit_(const Not* op) {
this->Visit(op->a);
}
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Not>([](const Not* op, IRVisitor* v) {
v->Visit(op->a);
})
.set_dispatch<Select>([](const Select *op, IRVisitor* v) {
v->Visit(op->condition);
v->Visit(op->true_value);
v->Visit(op->false_value);
})
.set_dispatch<Ramp>([](const Ramp *op, IRVisitor* v) {
v->Visit(op->base);
v->Visit(op->stride);
})
.set_dispatch<Broadcast>([](const Broadcast *op, IRVisitor* v) {
v->Visit(op->value);
});
void IRVisitor::Visit_(const Select* op) {
this->Visit(op->condition);
this->Visit(op->true_value);
this->Visit(op->false_value);
}
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<AssertStmt>([](const AssertStmt *op, IRVisitor* v) {
v->Visit(op->condition);
v->Visit(op->message);
})
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, IRVisitor* v) {
v->Visit(op->body);
})
.set_dispatch<Provide>([](const Provide *op, IRVisitor* v) {
VisitArray(op->args, v);
v->Visit(op->value);
})
.set_dispatch<Realize>([](const Realize *op, IRVisitor* v) {
void IRVisitor::Visit_(const Ramp *op) {
this->Visit(op->base);
this->Visit(op->stride);
}
void IRVisitor::Visit_(const Broadcast *op) {
this->Visit(op->value);
}
void IRVisitor::Visit_(const AssertStmt *op) {
this->Visit(op->condition);
this->Visit(op->message);
}
void IRVisitor::Visit_(const ProducerConsumer *op) {
this->Visit(op->body);
}
void IRVisitor::Visit_(const Provide *op) {
VisitArray(op->args, this);
this->Visit(op->value);
}
void IRVisitor::Visit_(const Realize *op) {
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
v->Visit(op->bounds[i]->min);
v->Visit(op->bounds[i]->extent);
}
v->Visit(op->body);
v->Visit(op->condition);
})
.set_dispatch<Block>([](const Block *op, IRVisitor* v) {
v->Visit(op->first);
v->Visit(op->rest);
})
.set_dispatch<Evaluate>([](const Evaluate *op, IRVisitor* v) {
v->Visit(op->value);
});
for (size_t i = 0; i < op->bounds.size(); i++) {
this->Visit(op->bounds[i]->min);
this->Visit(op->bounds[i]->extent);
}
this->Visit(op->body);
this->Visit(op->condition);
}
void IRVisitor::Visit_(const Block *op) {
this->Visit(op->first);
this->Visit(op->rest);
}
void IRVisitor::Visit_(const Evaluate *op) {
this->Visit(op->value);
}
#define DEFINE_OP_NO_VISIT_(OP) \
void IRVisitor::Visit_(const OP* op) {}
DEFINE_OP_NO_VISIT_(IntImm)
DEFINE_OP_NO_VISIT_(UIntImm)
DEFINE_OP_NO_VISIT_(FloatImm)
DEFINE_OP_NO_VISIT_(StringImm)
#define DISPATCH_TO_VISIT(OP) \
set_dispatch<OP>([](const OP* op, IRVisitor* v) { \
v->Visit_(op); \
})
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.DISPATCH_TO_VISIT(Variable)
.DISPATCH_TO_VISIT(LetStmt)
.DISPATCH_TO_VISIT(AttrStmt)
.DISPATCH_TO_VISIT(IfThenElse)
.DISPATCH_TO_VISIT(For)
.DISPATCH_TO_VISIT(Allocate)
.DISPATCH_TO_VISIT(Load)
.DISPATCH_TO_VISIT(Store)
.DISPATCH_TO_VISIT(Let)
.DISPATCH_TO_VISIT(Free)
.DISPATCH_TO_VISIT(Call)
.DISPATCH_TO_VISIT(Add)
.DISPATCH_TO_VISIT(Sub)
.DISPATCH_TO_VISIT(Mul)
.DISPATCH_TO_VISIT(Div)
.DISPATCH_TO_VISIT(Mod)
.DISPATCH_TO_VISIT(Min)
.DISPATCH_TO_VISIT(Max)
.DISPATCH_TO_VISIT(EQ)
.DISPATCH_TO_VISIT(NE)
.DISPATCH_TO_VISIT(LT)
.DISPATCH_TO_VISIT(LE)
.DISPATCH_TO_VISIT(GT)
.DISPATCH_TO_VISIT(GE)
.DISPATCH_TO_VISIT(And)
.DISPATCH_TO_VISIT(Or)
.DISPATCH_TO_VISIT(Reduce)
.DISPATCH_TO_VISIT(Cast)
.DISPATCH_TO_VISIT(Not)
.DISPATCH_TO_VISIT(Select)
.DISPATCH_TO_VISIT(Ramp)
.DISPATCH_TO_VISIT(Broadcast)
.DISPATCH_TO_VISIT(AssertStmt)
.DISPATCH_TO_VISIT(ProducerConsumer)
.DISPATCH_TO_VISIT(Provide)
.DISPATCH_TO_VISIT(Realize)
.DISPATCH_TO_VISIT(Block)
.DISPATCH_TO_VISIT(Evaluate)
.DISPATCH_TO_VISIT(IntImm)
.DISPATCH_TO_VISIT(UIntImm)
.DISPATCH_TO_VISIT(FloatImm)
.DISPATCH_TO_VISIT(StringImm);
} // namespace ir
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment