Commit 5c07413c by Ziheng Jiang Committed by Tianqi Chen

[PASS] Change IRVisitor interfaces to function override (#42)

* [PASS] Change IRVisitor interfaces to function override

* [PASS] Change IRMutator interfaces to overloadable function
parent b8f0ec50
...@@ -16,7 +16,8 @@ namespace ir { ...@@ -16,7 +16,8 @@ namespace ir {
/*! /*!
* \brief a base class for mutator to iterative mutate the IR * \brief a base class for mutator to iterative mutate the IR
* *
* This IRMutator is implemented via IRFunctor instead of Visitor Pattern. * This IRMutator is implemented via Visitor Pattern.
* Also you can implement via IRFunctor.
* This enables easy extensions of possible new Node. * This enables easy extensions of possible new Node.
* It also makes changing return types easier. * It also makes changing return types easier.
* *
...@@ -54,20 +55,91 @@ class IRMutator { ...@@ -54,20 +55,91 @@ class IRMutator {
static FMutateStmt& vtable_stmt(); // NOLINT(*) static FMutateStmt& vtable_stmt(); // NOLINT(*)
// Set of overloadable functions // Set of overloadable functions
// The underscore allows Mutate not to be shadowed by inheritance // The underscore allows Mutate not to be shadowed by inheritance
virtual Stmt Mutate_(const Variable* op, const Stmt& s);
virtual Stmt Mutate_(const LetStmt* op, const Stmt& s); virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s); virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
virtual Stmt Mutate_(const For* op, const Stmt& s); virtual Stmt Mutate_(const For* op, const Stmt& s);
virtual Stmt Mutate_(const Provide* op, const Stmt& s);
virtual Stmt Mutate_(const Allocate* op, const Stmt& s); virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
virtual Stmt Mutate_(const Realize* op, const Stmt& s); virtual Stmt Mutate_(const Load* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s); virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Let* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s); virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s); virtual Stmt Mutate_(const Call* op, const Stmt& s);
virtual Stmt Mutate_(const Add* op, const Stmt& e);
virtual Stmt Mutate_(const Sub* op, const Stmt& e);
virtual Stmt Mutate_(const Mul* op, const Stmt& e);
virtual Stmt Mutate_(const Div* op, const Stmt& e);
virtual Stmt Mutate_(const Mod* op, const Stmt& e);
virtual Stmt Mutate_(const Min* op, const Stmt& e);
virtual Stmt Mutate_(const Max* op, const Stmt& e);
virtual Stmt Mutate_(const EQ* op, const Stmt& e);
virtual Stmt Mutate_(const NE* op, const Stmt& e);
virtual Stmt Mutate_(const LT* op, const Stmt& e);
virtual Stmt Mutate_(const LE* op, const Stmt& e);
virtual Stmt Mutate_(const GT* op, const Stmt& e);
virtual Stmt Mutate_(const GE* op, const Stmt& e);
virtual Stmt Mutate_(const And* op, const Stmt& e);
virtual Stmt Mutate_(const Or* op, const Stmt& e);
virtual Stmt Mutate_(const Reduce* op, const Stmt& s);
virtual Stmt Mutate_(const Cast* op, const Stmt& s);
virtual Stmt Mutate_(const Not* op, const Stmt& s);
virtual Stmt Mutate_(const Select* op, const Stmt& s);
virtual Stmt Mutate_(const Ramp* op, const Stmt& s);
virtual Stmt Mutate_(const Broadcast* op, const Stmt& e);
virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e);
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e);
virtual Stmt Mutate_(const Provide* op, const Stmt& e);
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s); virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Expr Mutate_(const Call* op, const Expr& e); virtual Stmt Mutate_(const Evaluate* op, const Stmt& e);
virtual Expr Mutate_(const Load* op, const Expr& s); virtual Stmt Mutate_(const IntImm* op, const Stmt& e);
virtual Stmt Mutate_(const UIntImm* op, const Stmt& e);
virtual Stmt Mutate_(const FloatImm* op, const Stmt& e);
virtual Stmt Mutate_(const StringImm* op, const Stmt& e);
virtual Expr Mutate_(const Variable* op, const Expr& e); virtual Expr Mutate_(const Variable* op, const Expr& e);
virtual Expr Mutate_(const LetStmt* op, const Expr& e);
virtual Expr Mutate_(const AttrStmt* op, const Expr& e);
virtual Expr Mutate_(const IfThenElse* op, const Expr& e);
virtual Expr Mutate_(const For* op, const Expr& e);
virtual Expr Mutate_(const Allocate* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& e);
virtual Expr Mutate_(const Store* op, const Expr& e);
virtual Expr Mutate_(const Let* op, const Expr& e); virtual Expr Mutate_(const Let* op, const Expr& e);
virtual Expr Mutate_(const Free* op, const Expr& e);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Add* op, const Expr& e);
virtual Expr Mutate_(const Sub* op, const Expr& e);
virtual Expr Mutate_(const Mul* op, const Expr& e);
virtual Expr Mutate_(const Div* op, const Expr& e);
virtual Expr Mutate_(const Mod* op, const Expr& e);
virtual Expr Mutate_(const Min* op, const Expr& e);
virtual Expr Mutate_(const Max* op, const Expr& e);
virtual Expr Mutate_(const EQ* op, const Expr& e);
virtual Expr Mutate_(const NE* op, const Expr& e);
virtual Expr Mutate_(const LT* op, const Expr& e);
virtual Expr Mutate_(const LE* op, const Expr& e);
virtual Expr Mutate_(const GT* op, const Expr& e);
virtual Expr Mutate_(const GE* op, const Expr& e);
virtual Expr Mutate_(const And* op, const Expr& e);
virtual Expr Mutate_(const Or* op, const Expr& e);
virtual Expr Mutate_(const Reduce* op, const Expr& e);
virtual Expr Mutate_(const Cast* op, const Expr& e);
virtual Expr Mutate_(const Not* op, const Expr& e);
virtual Expr Mutate_(const Select* op, const Expr& e);
virtual Expr Mutate_(const Ramp* op, const Expr& e);
virtual Expr Mutate_(const Broadcast* op, const Expr& e);
virtual Expr Mutate_(const AssertStmt* op, const Expr& e);
virtual Expr Mutate_(const ProducerConsumer* op, const Expr& e);
virtual Expr Mutate_(const Provide* op, const Expr& e);
virtual Expr Mutate_(const Realize* op, const Expr& e);
virtual Expr Mutate_(const Block* op, const Expr& e);
virtual Expr Mutate_(const Evaluate* op, const Expr& e);
virtual Expr Mutate_(const IntImm* op, const Expr& e);
virtual Expr Mutate_(const UIntImm* op, const Expr& e);
virtual Expr Mutate_(const FloatImm* op, const Expr& e);
virtual Expr Mutate_(const StringImm* op, const Expr& e);
}; };
/*! /*!
......
...@@ -36,16 +36,47 @@ class IRVisitor { ...@@ -36,16 +36,47 @@ class IRVisitor {
static FVisit& vtable(); static FVisit& vtable();
// overloadable visit function. // overloadable visit function.
virtual void Visit_(const Variable* op); virtual void Visit_(const Variable* op);
virtual void Visit_(const AttrStmt* op);
virtual void Visit_(const LetStmt* op); virtual void Visit_(const LetStmt* op);
virtual void Visit_(const AttrStmt* op);
virtual void Visit_(const IfThenElse* op);
virtual void Visit_(const For* op); virtual void Visit_(const For* op);
virtual void Visit_(const Allocate* op); virtual void Visit_(const Allocate* op);
virtual void Visit_(const IfThenElse* op);
virtual void Visit_(const Load* op); virtual void Visit_(const Load* op);
virtual void Visit_(const Store* op); virtual void Visit_(const Store* op);
virtual void Visit_(const Let* op); virtual void Visit_(const Let* op);
virtual void Visit_(const Free* op); virtual void Visit_(const Free* op);
virtual void Visit_(const Call* op); virtual void Visit_(const Call* op);
virtual void Visit_(const Add* op);
virtual void Visit_(const Sub* op);
virtual void Visit_(const Mul* op);
virtual void Visit_(const Div* op);
virtual void Visit_(const Mod* op);
virtual void Visit_(const Min* op);
virtual void Visit_(const Max* op);
virtual void Visit_(const EQ* op);
virtual void Visit_(const NE* op);
virtual void Visit_(const LT* op);
virtual void Visit_(const LE* op);
virtual void Visit_(const GT* op);
virtual void Visit_(const GE* op);
virtual void Visit_(const And* op);
virtual void Visit_(const Or* op);
virtual void Visit_(const Reduce* op);
virtual void Visit_(const Cast* op);
virtual void Visit_(const Not* op);
virtual void Visit_(const Select* op);
virtual void Visit_(const Ramp* op);
virtual void Visit_(const Broadcast* op);
virtual void Visit_(const AssertStmt* op);
virtual void Visit_(const ProducerConsumer* op);
virtual void Visit_(const Provide* op);
virtual void Visit_(const Realize* op);
virtual void Visit_(const Block* op);
virtual void Visit_(const Evaluate* op);
virtual void Visit_(const IntImm* op);
virtual void Visit_(const UIntImm* op);
virtual void Visit_(const FloatImm* op);
virtual void Visit_(const StringImm* op);
}; };
/*! /*!
......
...@@ -16,11 +16,6 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*) ...@@ -16,11 +16,6 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*)
static FMutateStmt inst; return inst; static FMutateStmt inst; return inst;
} }
// const expr
inline Expr ReturnSelfExpr(const NodeRef&, const Expr& e, IRMutator*) {
return e;
}
inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) { inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
std::vector<Expr> new_arr(arr.size()); std::vector<Expr> new_arr(arr.size());
bool changed = false; bool changed = false;
...@@ -58,47 +53,33 @@ inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) { ...@@ -58,47 +53,33 @@ inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) {
} }
} }
// Mutate Stmt
#define DISPATCH_TO_MUTATE_STMT(OP) \ #define DISPATCH_TO_MUTATE_STMT(OP) \
set_dispatch<OP>([](const OP* op, const Stmt& s, IRMutator* m) { \ set_dispatch<OP>([](const OP* op, const Stmt& s, IRMutator* m) { \
return m->Mutate_(op, s); \ return m->Mutate_(op, s); \
}) })
#define DISPATCH_TO_MUTATE_EXPR(OP) \ Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) {
set_dispatch<OP>([](const OP* op, const Expr& e, IRMutator* m) { \
return m->Mutate_(op, e); \
})
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.DISPATCH_TO_MUTATE_STMT(LetStmt)
.DISPATCH_TO_MUTATE_STMT(AttrStmt)
.DISPATCH_TO_MUTATE_STMT(Provide)
.DISPATCH_TO_MUTATE_STMT(Realize)
.DISPATCH_TO_MUTATE_STMT(Store)
.DISPATCH_TO_MUTATE_STMT(IfThenElse)
.DISPATCH_TO_MUTATE_STMT(For)
.DISPATCH_TO_MUTATE_STMT(Allocate)
.DISPATCH_TO_MUTATE_STMT(Block)
.DISPATCH_TO_MUTATE_STMT(Free);
Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) {
Expr value = this->Mutate(op->value); Expr value = this->Mutate(op->value);
Stmt body = this->Mutate(op->body); Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) && if (value.same_as(op->value) &&
body.same_as(op->body)) { body.same_as(op->body)) {
return s; return s;
} else { } else {
return LetStmt::make(op->var, value, body); return AttrStmt::make(op->node, op->type_key, value, body);
} }
} }
Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) { Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) {
Expr value = this->Mutate(op->value); Expr value = this->Mutate(op->value);
Stmt body = this->Mutate(op->body); Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) && if (value.same_as(op->value) &&
body.same_as(op->body)) { body.same_as(op->body)) {
return s; return s;
} else { } else {
return AttrStmt::make(op->node, op->type_key, value, body); return LetStmt::make(op->var, value, body);
} }
} }
...@@ -143,6 +124,36 @@ Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) { ...@@ -143,6 +124,36 @@ Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) {
} }
} }
Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Stmt then_case = this->Mutate(op->then_case);
Stmt else_case;
if (else_case.defined()) {
else_case = this->Mutate(op->else_case);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(condition, then_case, else_case);
}
}
Stmt IRMutator::Mutate_(const Load *op, const Stmt& s) {
return s;
}
Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Expr index = this->Mutate(op->index);
if (value.same_as(op->value) && index.same_as(op->index)) {
return s;
} else {
return Store::make(op->buffer_var, value, index);
}
}
Stmt IRMutator::Mutate_(const Provide* op, const Stmt& s) { Stmt IRMutator::Mutate_(const Provide* op, const Stmt& s) {
auto new_args = MutateArray(op->args, this); auto new_args = MutateArray(op->args, this);
auto new_value = this->Mutate(op->value); auto new_value = this->Mutate(op->value);
...@@ -183,63 +194,137 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) { ...@@ -183,63 +194,137 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
} }
} }
Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) { Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) {
Expr value = this->Mutate(op->value); Stmt first = this->Mutate(op->first);
Expr index = this->Mutate(op->index); Stmt rest = this->Mutate(op->rest);
if (value.same_as(op->value) && index.same_as(op->index)) { if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return s; return s;
} else { } else {
return Store::make(op->buffer_var, value, index); return Block::make(first, rest);
} }
} }
Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) { Stmt IRMutator::Mutate_(const AssertStmt *op, const Stmt& s) {
return s;
}
Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
Expr condition = this->Mutate(op->condition); Expr condition = this->Mutate(op->condition);
Stmt then_case = this->Mutate(op->then_case); Expr message = this->Mutate(op->message);
Stmt else_case;
if (else_case.defined()) { if (condition.same_as(op->condition) && message.same_as(op->message)) {
else_case = this->Mutate(op->else_case);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s; return s;
} else { } else {
return IfThenElse::make(condition, then_case, else_case); return AssertStmt::make(condition, message);
} }
} }
Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) { Stmt IRMutator::Mutate_(const ProducerConsumer *op, const Stmt& s) {
Stmt first = this->Mutate(op->first); Stmt body = this->Mutate(op->body);
Stmt rest = this->Mutate(op->rest); if (body.same_as(op->body)) {
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return s; return s;
} else { } else {
return Block::make(first, rest); return ProducerConsumer::make(op->func, op->is_producer, body);
} }
} }
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) {
.DISPATCH_TO_MUTATE_EXPR(Call) Expr v = this->Mutate(op->value);
.DISPATCH_TO_MUTATE_EXPR(Let) if (v.same_as(op->value)) {
.DISPATCH_TO_MUTATE_EXPR(Load) return s;
.DISPATCH_TO_MUTATE_EXPR(Variable);
Expr IRMutator::Mutate_(const Call* op, const Expr& e) {
auto new_args = MutateArray(op->args, this);
if (op->args.same_as(new_args)) {
return e;
} else { } else {
return Call::make(op->type, op->name, new_args, op->call_type, return Evaluate::make(v);
op->func, op->value_index);
} }
} }
#define DEFINE_OP_RETURN_SELF_STMT_MUTATE_(OP) \
Stmt IRMutator::Mutate_(const OP *op, const Stmt& s) { \
return s; \
}
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Variable)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Let)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Free)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Call)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Add)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Sub)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Mul)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Div)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Mod)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Min)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Max)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(EQ)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(NE)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(LT)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(LE)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(GT)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(GE)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(And)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Or)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Reduce)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Cast)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Not)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Select)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Ramp)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Broadcast)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(IntImm)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(UIntImm)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(FloatImm)
DEFINE_OP_RETURN_SELF_STMT_MUTATE_(StringImm)
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.DISPATCH_TO_MUTATE_STMT(Variable)
.DISPATCH_TO_MUTATE_STMT(LetStmt)
.DISPATCH_TO_MUTATE_STMT(AttrStmt)
.DISPATCH_TO_MUTATE_STMT(IfThenElse)
.DISPATCH_TO_MUTATE_STMT(For)
.DISPATCH_TO_MUTATE_STMT(Allocate)
.DISPATCH_TO_MUTATE_STMT(Load)
.DISPATCH_TO_MUTATE_STMT(Store)
.DISPATCH_TO_MUTATE_STMT(Let)
.DISPATCH_TO_MUTATE_STMT(Free)
.DISPATCH_TO_MUTATE_STMT(Call)
.DISPATCH_TO_MUTATE_STMT(Add)
.DISPATCH_TO_MUTATE_STMT(Sub)
.DISPATCH_TO_MUTATE_STMT(Mul)
.DISPATCH_TO_MUTATE_STMT(Div)
.DISPATCH_TO_MUTATE_STMT(Mod)
.DISPATCH_TO_MUTATE_STMT(Min)
.DISPATCH_TO_MUTATE_STMT(Max)
.DISPATCH_TO_MUTATE_STMT(EQ)
.DISPATCH_TO_MUTATE_STMT(NE)
.DISPATCH_TO_MUTATE_STMT(LT)
.DISPATCH_TO_MUTATE_STMT(LE)
.DISPATCH_TO_MUTATE_STMT(GT)
.DISPATCH_TO_MUTATE_STMT(GE)
.DISPATCH_TO_MUTATE_STMT(And)
.DISPATCH_TO_MUTATE_STMT(Or)
.DISPATCH_TO_MUTATE_STMT(Reduce)
.DISPATCH_TO_MUTATE_STMT(Cast)
.DISPATCH_TO_MUTATE_STMT(Not)
.DISPATCH_TO_MUTATE_STMT(Select)
.DISPATCH_TO_MUTATE_STMT(Ramp)
.DISPATCH_TO_MUTATE_STMT(Broadcast)
.DISPATCH_TO_MUTATE_STMT(AssertStmt)
.DISPATCH_TO_MUTATE_STMT(ProducerConsumer)
.DISPATCH_TO_MUTATE_STMT(Provide)
.DISPATCH_TO_MUTATE_STMT(Realize)
.DISPATCH_TO_MUTATE_STMT(Block)
.DISPATCH_TO_MUTATE_STMT(Evaluate)
.DISPATCH_TO_MUTATE_STMT(IntImm)
.DISPATCH_TO_MUTATE_STMT(UIntImm)
.DISPATCH_TO_MUTATE_STMT(FloatImm)
.DISPATCH_TO_MUTATE_STMT(StringImm);
// Mutate Expr
#define DISPATCH_TO_MUTATE_EXPR(OP) \
set_dispatch<OP>([](const OP* op, const Expr& e, IRMutator* m) { \
return m->Mutate_(op, e); \
})
Expr IRMutator::Mutate_(const Variable *op, const Expr& e) {
return e;
}
Expr IRMutator::Mutate_(const Load *op, const Expr& e) { Expr IRMutator::Mutate_(const Load *op, const Expr& e) {
Expr index = this->Mutate(op->index); Expr index = this->Mutate(op->index);
if (index.same_as(op->index)) { if (index.same_as(op->index)) {
...@@ -249,11 +334,6 @@ Expr IRMutator::Mutate_(const Load *op, const Expr& e) { ...@@ -249,11 +334,6 @@ Expr IRMutator::Mutate_(const Load *op, const Expr& e) {
} }
} }
Expr IRMutator::Mutate_(const Variable *op, const Expr& e) {
return e;
}
Expr IRMutator::Mutate_(const Let *op, const Expr& e) { Expr IRMutator::Mutate_(const Let *op, const Expr& e) {
Expr value = this->Mutate(op->value); Expr value = this->Mutate(op->value);
Expr body = this->Mutate(op->body); Expr body = this->Mutate(op->body);
...@@ -265,130 +345,172 @@ Expr IRMutator::Mutate_(const Let *op, const Expr& e) { ...@@ -265,130 +345,172 @@ Expr IRMutator::Mutate_(const Let *op, const Expr& e) {
} }
} }
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) Expr IRMutator::Mutate_(const Call* op, const Expr& e) {
.set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) { auto new_args = MutateArray(op->args, this);
Array<IterVar> new_axis = MutateIterVarArr(op->axis, m); if (op->args.same_as(new_args)) {
Expr new_source = m->Mutate(op->source); return e;
if (op->axis.same_as(new_axis) && } else {
op->source.same_as(new_source)) { return Call::make(op->type, op->name, new_args, op->call_type,
return e; op->func, op->value_index);
} else { }
return Reduce::make(op->op, new_source, new_axis); }
}
});
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) #define DEFINE_BIOP_EXPR_MUTATE_(OP) \
.set_dispatch<IntImm>(ReturnSelfExpr) Expr IRMutator::Mutate_(const OP* op, const Expr& e) { \
.set_dispatch<UIntImm>(ReturnSelfExpr) Expr a = this->Mutate(op->a); \
.set_dispatch<FloatImm>(ReturnSelfExpr) Expr b = this->Mutate(op->b); \
.set_dispatch<StringImm>(ReturnSelfExpr); if (a.same_as(op->a) && \
b.same_as(op->b)) { \
return e; \
} else { \
return OP::make(a, b); \
} \
}
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) DEFINE_BIOP_EXPR_MUTATE_(Add)
.set_dispatch<Cast>([](const Cast* op, const Expr& e, IRMutator* m) { DEFINE_BIOP_EXPR_MUTATE_(Sub)
Expr value = m->Mutate(op->value); DEFINE_BIOP_EXPR_MUTATE_(Mul)
if (value.same_as(op->value)) { DEFINE_BIOP_EXPR_MUTATE_(Div)
return e; DEFINE_BIOP_EXPR_MUTATE_(Mod)
} else { DEFINE_BIOP_EXPR_MUTATE_(Min)
return Cast::make(op->type, value); DEFINE_BIOP_EXPR_MUTATE_(Max)
} DEFINE_BIOP_EXPR_MUTATE_(EQ)
}); DEFINE_BIOP_EXPR_MUTATE_(NE)
DEFINE_BIOP_EXPR_MUTATE_(LT)
// binary operator DEFINE_BIOP_EXPR_MUTATE_(LE)
template<typename T> DEFINE_BIOP_EXPR_MUTATE_(GT)
inline Expr Binary(const T* op, const Expr& e, IRMutator* m) { DEFINE_BIOP_EXPR_MUTATE_(GE)
Expr a = m->Mutate(op->a); DEFINE_BIOP_EXPR_MUTATE_(And)
Expr b = m->Mutate(op->b); DEFINE_BIOP_EXPR_MUTATE_(Or)
if (a.same_as(op->a) &&
b.same_as(op->b)) { Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) {
Array<IterVar> new_axis = MutateIterVarArr(op->axis, this);
Expr new_source = this->Mutate(op->source);
if (op->axis.same_as(new_axis) &&
op->source.same_as(new_source)) {
return e; return e;
} else { } else {
return T::make(a, b); return Reduce::make(op->op, new_source, new_axis);
} }
} }
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) Expr IRMutator::Mutate_(const Cast *op, const Expr& e) {
.set_dispatch<Add>(Binary<Add>) Expr value = this->Mutate(op->value);
.set_dispatch<Sub>(Binary<Sub>) if (value.same_as(op->value)) {
.set_dispatch<Mul>(Binary<Mul>) return e;
.set_dispatch<Div>(Binary<Div>) } else {
.set_dispatch<Mod>(Binary<Mod>) return Cast::make(op->type, value);
.set_dispatch<Min>(Binary<Min>) }
.set_dispatch<Max>(Binary<Max>) }
.set_dispatch<EQ>(Binary<EQ>)
.set_dispatch<NE>(Binary<NE>) Expr IRMutator::Mutate_(const Not *op, const Expr& e) {
.set_dispatch<LT>(Binary<LT>) Expr a = this->Mutate(op->a);
.set_dispatch<LE>(Binary<LE>) if (a.same_as(op->a)) {
.set_dispatch<GT>(Binary<GT>) return e;
.set_dispatch<GE>(Binary<GE>) } else {
.set_dispatch<And>(Binary<And>) return Not::make(a);
.set_dispatch<Or>(Binary<Or>); }
}
Expr IRMutator::Mutate_(const Select *op, const Expr& e) {
Expr cond = this->Mutate(op->condition);
Expr t = this->Mutate(op->true_value);
Expr f = this->Mutate(op->false_value);
if (cond.same_as(op->condition) &&
t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return e;
} else {
return Select::make(cond, t, f);
}
}
Expr IRMutator::Mutate_(const Ramp *op, const Expr& e) {
Expr base = this->Mutate(op->base);
Expr stride = this->Mutate(op->stride);
if (base.same_as(op->base) &&
stride.same_as(op->stride)) {
return e;
} else {
return Ramp::make(base, stride, op->lanes);
}
}
Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) {
Expr value = this->Mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
return Broadcast::make(value, op->lanes);
}
}
#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \
Expr IRMutator::Mutate_(const OP *op, const Expr& e) { \
return e; \
}
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(LetStmt)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(AttrStmt)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(For)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IfThenElse)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Allocate)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Store)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Free)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(AssertStmt)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(ProducerConsumer)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Provide)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Realize)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Block)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Evaluate)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm)
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Not>([](const Not* op, const Expr& e, IRMutator* m) { .DISPATCH_TO_MUTATE_EXPR(Variable)
Expr a = m->Mutate(op->a); .DISPATCH_TO_MUTATE_EXPR(LetStmt)
if (a.same_as(op->a)) { .DISPATCH_TO_MUTATE_EXPR(AttrStmt)
return e; .DISPATCH_TO_MUTATE_EXPR(IfThenElse)
} else { .DISPATCH_TO_MUTATE_EXPR(For)
return Not::make(a); .DISPATCH_TO_MUTATE_EXPR(Allocate)
} .DISPATCH_TO_MUTATE_EXPR(Load)
}) .DISPATCH_TO_MUTATE_EXPR(Store)
.set_dispatch<Select>([](const Select *op, const Expr& e, IRMutator* m) { .DISPATCH_TO_MUTATE_EXPR(Let)
Expr cond = m->Mutate(op->condition); .DISPATCH_TO_MUTATE_EXPR(Free)
Expr t = m->Mutate(op->true_value); .DISPATCH_TO_MUTATE_EXPR(Call)
Expr f = m->Mutate(op->false_value); .DISPATCH_TO_MUTATE_EXPR(Add)
if (cond.same_as(op->condition) && .DISPATCH_TO_MUTATE_EXPR(Sub)
t.same_as(op->true_value) && .DISPATCH_TO_MUTATE_EXPR(Mul)
f.same_as(op->false_value)) { .DISPATCH_TO_MUTATE_EXPR(Div)
return e; .DISPATCH_TO_MUTATE_EXPR(Mod)
} else { .DISPATCH_TO_MUTATE_EXPR(Min)
return Select::make(cond, t, f); .DISPATCH_TO_MUTATE_EXPR(Max)
} .DISPATCH_TO_MUTATE_EXPR(EQ)
}) .DISPATCH_TO_MUTATE_EXPR(NE)
.set_dispatch<Ramp>([](const Ramp *op, const Expr& e, IRMutator* m) { .DISPATCH_TO_MUTATE_EXPR(LT)
Expr base = m->Mutate(op->base); .DISPATCH_TO_MUTATE_EXPR(LE)
Expr stride = m->Mutate(op->stride); .DISPATCH_TO_MUTATE_EXPR(GT)
if (base.same_as(op->base) && .DISPATCH_TO_MUTATE_EXPR(GE)
stride.same_as(op->stride)) { .DISPATCH_TO_MUTATE_EXPR(And)
return e; .DISPATCH_TO_MUTATE_EXPR(Or)
} else { .DISPATCH_TO_MUTATE_EXPR(Reduce)
return Ramp::make(base, stride, op->lanes); .DISPATCH_TO_MUTATE_EXPR(Cast)
} .DISPATCH_TO_MUTATE_EXPR(Not)
}) .DISPATCH_TO_MUTATE_EXPR(Select)
.set_dispatch<Broadcast>([](const Broadcast *op, const Expr& e, IRMutator* m) { .DISPATCH_TO_MUTATE_EXPR(Ramp)
Expr value = m->Mutate(op->value); .DISPATCH_TO_MUTATE_EXPR(Broadcast)
if (value.same_as(op->value)) { .DISPATCH_TO_MUTATE_EXPR(AssertStmt)
return e; .DISPATCH_TO_MUTATE_EXPR(ProducerConsumer)
} else { .DISPATCH_TO_MUTATE_EXPR(Provide)
return Broadcast::make(value, op->lanes); .DISPATCH_TO_MUTATE_EXPR(Realize)
} .DISPATCH_TO_MUTATE_EXPR(Block)
}); .DISPATCH_TO_MUTATE_EXPR(Evaluate)
.DISPATCH_TO_MUTATE_EXPR(IntImm)
.DISPATCH_TO_MUTATE_EXPR(UIntImm)
.DISPATCH_TO_MUTATE_EXPR(FloatImm)
.DISPATCH_TO_MUTATE_EXPR(StringImm);
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.set_dispatch<AssertStmt>([](const AssertStmt *op, const Stmt& s, IRMutator* m) {
Expr condition = m->Mutate(op->condition);
Expr message = m->Mutate(op->message);
if (condition.same_as(op->condition) && message.same_as(op->message)) {
return s;
} else {
return AssertStmt::make(condition, message);
}
})
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, const Stmt& s, IRMutator* m) {
Stmt body = m->Mutate(op->body);
if (body.same_as(op->body)) {
return s;
} else {
return ProducerConsumer::make(op->func, op->is_producer, body);
}
})
.set_dispatch<Evaluate>([](const Evaluate *op, const Stmt& s, IRMutator* m) {
Expr v = m->Mutate(op->value);
if (v.same_as(op->value)) {
return s;
} else {
return Evaluate::make(v);
}
});
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -34,9 +34,6 @@ IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*) ...@@ -34,9 +34,6 @@ IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*)
static FVisit inst; return inst; static FVisit inst; return inst;
} }
void NoOp(const NodeRef& n, IRVisitor* v) {
}
inline void VisitArray(const Array<Expr>& arr, IRVisitor* v) { inline void VisitArray(const Array<Expr>& arr, IRVisitor* v) {
for (size_t i = 0; i < arr.size(); i++) { for (size_t i = 0; i < arr.size(); i++) {
v->Visit(arr[i]); v->Visit(arr[i]);
...@@ -51,24 +48,6 @@ inline void VisitRDom(const Array<IterVar>& rdom, IRVisitor* v) { ...@@ -51,24 +48,6 @@ inline void VisitRDom(const Array<IterVar>& rdom, IRVisitor* v) {
} }
} }
#define DISPATCH_TO_VISIT(OP) \
set_dispatch<OP>([](const OP* op, IRVisitor* v) { \
v->Visit_(op); \
})
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.DISPATCH_TO_VISIT(Variable)
.DISPATCH_TO_VISIT(LetStmt)
.DISPATCH_TO_VISIT(AttrStmt)
.DISPATCH_TO_VISIT(IfThenElse)
.DISPATCH_TO_VISIT(For)
.DISPATCH_TO_VISIT(Allocate)
.DISPATCH_TO_VISIT(Load)
.DISPATCH_TO_VISIT(Store)
.DISPATCH_TO_VISIT(Let)
.DISPATCH_TO_VISIT(Call)
.DISPATCH_TO_VISIT(Free);
void IRVisitor::Visit_(const Variable* op) {} void IRVisitor::Visit_(const Variable* op) {}
void IRVisitor::Visit_(const LetStmt *op) { void IRVisitor::Visit_(const LetStmt *op) {
...@@ -128,91 +107,146 @@ void IRVisitor::Visit_(const Call *op) { ...@@ -128,91 +107,146 @@ void IRVisitor::Visit_(const Call *op) {
VisitArray(op->args, this); VisitArray(op->args, this);
} }
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) #define DEFINE_BINOP_VISIT_(OP) \
.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) { void IRVisitor::Visit_(const OP* op) { \
VisitRDom(op->axis, v); this->Visit(op->a); \
v->Visit(op->source); this->Visit(op->b); \
}) }
.set_dispatch<IntImm>(NoOp)
.set_dispatch<UIntImm>(NoOp)
.set_dispatch<FloatImm>(NoOp)
.set_dispatch<StringImm>(NoOp);
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) DEFINE_BINOP_VISIT_(Add)
.set_dispatch<Cast>([](const Cast* op, IRVisitor* v) { DEFINE_BINOP_VISIT_(Sub)
v->Visit(op->value); DEFINE_BINOP_VISIT_(Mul)
}); DEFINE_BINOP_VISIT_(Div)
DEFINE_BINOP_VISIT_(Mod)
DEFINE_BINOP_VISIT_(Min)
DEFINE_BINOP_VISIT_(Max)
DEFINE_BINOP_VISIT_(EQ)
DEFINE_BINOP_VISIT_(NE)
DEFINE_BINOP_VISIT_(LT)
DEFINE_BINOP_VISIT_(LE)
DEFINE_BINOP_VISIT_(GT)
DEFINE_BINOP_VISIT_(GE)
DEFINE_BINOP_VISIT_(And)
DEFINE_BINOP_VISIT_(Or)
void IRVisitor::Visit_(const Reduce* op) {
VisitRDom(op->axis, this);
this->Visit(op->source);
}
// binary operator void IRVisitor::Visit_(const Cast* op) {
template<typename T> this->Visit(op->value);
inline void Binary(const T* op, IRVisitor* v) {
v->Visit(op->a);
v->Visit(op->b);
} }
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) void IRVisitor::Visit_(const Not* op) {
.set_dispatch<Add>(Binary<Add>) this->Visit(op->a);
.set_dispatch<Sub>(Binary<Sub>) }
.set_dispatch<Mul>(Binary<Mul>)
.set_dispatch<Div>(Binary<Div>)
.set_dispatch<Mod>(Binary<Mod>)
.set_dispatch<Min>(Binary<Min>)
.set_dispatch<Max>(Binary<Max>)
.set_dispatch<EQ>(Binary<EQ>)
.set_dispatch<NE>(Binary<NE>)
.set_dispatch<LT>(Binary<LT>)
.set_dispatch<LE>(Binary<LE>)
.set_dispatch<GT>(Binary<GT>)
.set_dispatch<GE>(Binary<GE>)
.set_dispatch<And>(Binary<And>)
.set_dispatch<Or>(Binary<Or>);
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) void IRVisitor::Visit_(const Select* op) {
.set_dispatch<Not>([](const Not* op, IRVisitor* v) { this->Visit(op->condition);
v->Visit(op->a); this->Visit(op->true_value);
}) this->Visit(op->false_value);
.set_dispatch<Select>([](const Select *op, IRVisitor* v) { }
v->Visit(op->condition);
v->Visit(op->true_value);
v->Visit(op->false_value);
})
.set_dispatch<Ramp>([](const Ramp *op, IRVisitor* v) {
v->Visit(op->base);
v->Visit(op->stride);
})
.set_dispatch<Broadcast>([](const Broadcast *op, IRVisitor* v) {
v->Visit(op->value);
});
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) void IRVisitor::Visit_(const Ramp *op) {
.set_dispatch<AssertStmt>([](const AssertStmt *op, IRVisitor* v) { this->Visit(op->base);
v->Visit(op->condition); this->Visit(op->stride);
v->Visit(op->message); }
})
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, IRVisitor* v) { void IRVisitor::Visit_(const Broadcast *op) {
v->Visit(op->body); this->Visit(op->value);
}) }
.set_dispatch<Provide>([](const Provide *op, IRVisitor* v) {
VisitArray(op->args, v); void IRVisitor::Visit_(const AssertStmt *op) {
v->Visit(op->value); this->Visit(op->condition);
}) this->Visit(op->message);
.set_dispatch<Realize>([](const Realize *op, IRVisitor* v) { }
void IRVisitor::Visit_(const ProducerConsumer *op) {
this->Visit(op->body);
}
void IRVisitor::Visit_(const Provide *op) {
VisitArray(op->args, this);
this->Visit(op->value);
}
void IRVisitor::Visit_(const Realize *op) {
// Mutate the bounds // Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) { for (size_t i = 0; i < op->bounds.size(); i++) {
v->Visit(op->bounds[i]->min); this->Visit(op->bounds[i]->min);
v->Visit(op->bounds[i]->extent); this->Visit(op->bounds[i]->extent);
} }
v->Visit(op->body); this->Visit(op->body);
v->Visit(op->condition); this->Visit(op->condition);
}) }
.set_dispatch<Block>([](const Block *op, IRVisitor* v) {
v->Visit(op->first); void IRVisitor::Visit_(const Block *op) {
v->Visit(op->rest); this->Visit(op->first);
}) this->Visit(op->rest);
.set_dispatch<Evaluate>([](const Evaluate *op, IRVisitor* v) { }
v->Visit(op->value);
}); void IRVisitor::Visit_(const Evaluate *op) {
this->Visit(op->value);
}
#define DEFINE_OP_NO_VISIT_(OP) \
void IRVisitor::Visit_(const OP* op) {}
DEFINE_OP_NO_VISIT_(IntImm)
DEFINE_OP_NO_VISIT_(UIntImm)
DEFINE_OP_NO_VISIT_(FloatImm)
DEFINE_OP_NO_VISIT_(StringImm)
#define DISPATCH_TO_VISIT(OP) \
set_dispatch<OP>([](const OP* op, IRVisitor* v) { \
v->Visit_(op); \
})
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.DISPATCH_TO_VISIT(Variable)
.DISPATCH_TO_VISIT(LetStmt)
.DISPATCH_TO_VISIT(AttrStmt)
.DISPATCH_TO_VISIT(IfThenElse)
.DISPATCH_TO_VISIT(For)
.DISPATCH_TO_VISIT(Allocate)
.DISPATCH_TO_VISIT(Load)
.DISPATCH_TO_VISIT(Store)
.DISPATCH_TO_VISIT(Let)
.DISPATCH_TO_VISIT(Free)
.DISPATCH_TO_VISIT(Call)
.DISPATCH_TO_VISIT(Add)
.DISPATCH_TO_VISIT(Sub)
.DISPATCH_TO_VISIT(Mul)
.DISPATCH_TO_VISIT(Div)
.DISPATCH_TO_VISIT(Mod)
.DISPATCH_TO_VISIT(Min)
.DISPATCH_TO_VISIT(Max)
.DISPATCH_TO_VISIT(EQ)
.DISPATCH_TO_VISIT(NE)
.DISPATCH_TO_VISIT(LT)
.DISPATCH_TO_VISIT(LE)
.DISPATCH_TO_VISIT(GT)
.DISPATCH_TO_VISIT(GE)
.DISPATCH_TO_VISIT(And)
.DISPATCH_TO_VISIT(Or)
.DISPATCH_TO_VISIT(Reduce)
.DISPATCH_TO_VISIT(Cast)
.DISPATCH_TO_VISIT(Not)
.DISPATCH_TO_VISIT(Select)
.DISPATCH_TO_VISIT(Ramp)
.DISPATCH_TO_VISIT(Broadcast)
.DISPATCH_TO_VISIT(AssertStmt)
.DISPATCH_TO_VISIT(ProducerConsumer)
.DISPATCH_TO_VISIT(Provide)
.DISPATCH_TO_VISIT(Realize)
.DISPATCH_TO_VISIT(Block)
.DISPATCH_TO_VISIT(Evaluate)
.DISPATCH_TO_VISIT(IntImm)
.DISPATCH_TO_VISIT(UIntImm)
.DISPATCH_TO_VISIT(FloatImm)
.DISPATCH_TO_VISIT(StringImm);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment