Commit a4ea0f4b by Siyuan Feng Committed by Tianqi Chen

[IR] fix style in ir_mutator and ir_visitor (#4561)

parent 44cb1054
...@@ -45,7 +45,7 @@ class IRTransformer final : public IRMutator { ...@@ -45,7 +45,7 @@ class IRTransformer final : public IRMutator {
} }
private: private:
template<typename T> template <typename T>
T MutateInternal(T node) { T MutateInternal(T node) {
if (only_enable_.size() && if (only_enable_.size() &&
!only_enable_.count(node->type_index())) { !only_enable_.count(node->type_index())) {
...@@ -89,11 +89,11 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*) ...@@ -89,11 +89,11 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*)
static FMutateStmt inst; return inst; static FMutateStmt inst; return inst;
} }
inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) { inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator* m) {
return UpdateArray(arr, [&m] (const Expr& e) { return m->Mutate(e); }); return UpdateArray(arr, [&m](const Expr& e) { return m->Mutate(e); });
} }
inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) { inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator* m) {
std::vector<IterVar> new_dom(rdom.size()); std::vector<IterVar> new_dom(rdom.size());
bool changed = false; bool changed = false;
for (size_t i = 0; i < rdom.size(); i++) { for (size_t i = 0; i < rdom.size(); i++) {
...@@ -133,7 +133,7 @@ Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) { ...@@ -133,7 +133,7 @@ Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) {
} }
} }
Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) { Stmt IRMutator::Mutate_(const LetStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value); Expr value = this->Mutate(op->value);
Stmt body = this->Mutate(op->body); Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) && if (value.same_as(op->value) &&
...@@ -144,7 +144,7 @@ Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) { ...@@ -144,7 +144,7 @@ Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) {
} }
} }
Stmt IRMutator::Mutate_(const For *op, const Stmt& s) { Stmt IRMutator::Mutate_(const For* op, const Stmt& s) {
Expr min = this->Mutate(op->min); Expr min = this->Mutate(op->min);
Expr extent = this->Mutate(op->extent); Expr extent = this->Mutate(op->extent);
Stmt body = this->Mutate(op->body); Stmt body = this->Mutate(op->body);
...@@ -185,7 +185,7 @@ Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) { ...@@ -185,7 +185,7 @@ Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) {
} }
} }
Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) { Stmt IRMutator::Mutate_(const IfThenElse* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition); Expr condition = this->Mutate(op->condition);
Stmt then_case = this->Mutate(op->then_case); Stmt then_case = this->Mutate(op->then_case);
Stmt else_case; Stmt else_case;
...@@ -201,7 +201,7 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) { ...@@ -201,7 +201,7 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
} }
} }
Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) { Stmt IRMutator::Mutate_(const Store* op, const Stmt& s) {
Expr value = this->Mutate(op->value); Expr value = this->Mutate(op->value);
Expr index = this->Mutate(op->index); Expr index = this->Mutate(op->index);
Expr pred = this->Mutate(op->predicate); Expr pred = this->Mutate(op->predicate);
...@@ -233,7 +233,7 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) { ...@@ -233,7 +233,7 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
Expr old_extent = op->bounds[i]->extent; Expr old_extent = op->bounds[i]->extent;
Expr new_min = m->Mutate(old_min); Expr new_min = m->Mutate(old_min);
Expr new_extent = m->Mutate(old_extent); Expr new_extent = m->Mutate(old_extent);
if (!new_min.same_as(old_min)) bounds_changed = true; if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_extent.same_as(old_extent)) bounds_changed = true; if (!new_extent.same_as(old_extent)) bounds_changed = true;
new_bounds.push_back( new_bounds.push_back(
Range::make_by_min_extent(new_min, new_extent)); Range::make_by_min_extent(new_min, new_extent));
...@@ -263,7 +263,7 @@ Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) { ...@@ -263,7 +263,7 @@ Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) {
Expr old_extent = op->bounds[i]->extent; Expr old_extent = op->bounds[i]->extent;
Expr new_min = m->Mutate(old_min); Expr new_min = m->Mutate(old_min);
Expr new_extent = m->Mutate(old_extent); Expr new_extent = m->Mutate(old_extent);
if (!new_min.same_as(old_min)) bounds_changed = true; if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_extent.same_as(old_extent)) bounds_changed = true; if (!new_extent.same_as(old_extent)) bounds_changed = true;
new_bounds.push_back( new_bounds.push_back(
Range::make_by_min_extent(new_min, new_extent)); Range::make_by_min_extent(new_min, new_extent));
...@@ -288,7 +288,7 @@ Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) { ...@@ -288,7 +288,7 @@ Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) {
} }
} }
Stmt IRMutator::Mutate_(const AssertStmt *op, const Stmt& s) { Stmt IRMutator::Mutate_(const AssertStmt* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition); Expr condition = this->Mutate(op->condition);
Expr message = this->Mutate(op->message); Expr message = this->Mutate(op->message);
Stmt body = this->Mutate(op->body); Stmt body = this->Mutate(op->body);
...@@ -302,7 +302,7 @@ Stmt IRMutator::Mutate_(const AssertStmt *op, const Stmt& s) { ...@@ -302,7 +302,7 @@ Stmt IRMutator::Mutate_(const AssertStmt *op, const Stmt& s) {
} }
} }
Stmt IRMutator::Mutate_(const ProducerConsumer *op, const Stmt& s) { Stmt IRMutator::Mutate_(const ProducerConsumer* op, const Stmt& s) {
Stmt body = this->Mutate(op->body); Stmt body = this->Mutate(op->body);
if (body.same_as(op->body)) { if (body.same_as(op->body)) {
return s; return s;
...@@ -311,7 +311,7 @@ Stmt IRMutator::Mutate_(const ProducerConsumer *op, const Stmt& s) { ...@@ -311,7 +311,7 @@ Stmt IRMutator::Mutate_(const ProducerConsumer *op, const Stmt& s) {
} }
} }
Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) { Stmt IRMutator::Mutate_(const Evaluate* op, const Stmt& s) {
Expr v = this->Mutate(op->value); Expr v = this->Mutate(op->value);
if (v.same_as(op->value)) { if (v.same_as(op->value)) {
return s; return s;
...@@ -320,7 +320,7 @@ Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) { ...@@ -320,7 +320,7 @@ Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) {
} }
} }
Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) { Stmt IRMutator::Mutate_(const Free* op, const Stmt& s) {
return s; return s;
} }
...@@ -348,11 +348,11 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -348,11 +348,11 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
return m->Mutate_(static_cast<const OP*>(node.get()), e); \ return m->Mutate_(static_cast<const OP*>(node.get()), e); \
}) })
Expr IRMutator::Mutate_(const Variable *op, const Expr& e) { Expr IRMutator::Mutate_(const Variable* op, const Expr& e) {
return e; return e;
} }
Expr IRMutator::Mutate_(const Load *op, const Expr& e) { Expr IRMutator::Mutate_(const Load* op, const Expr& e) {
Expr index = this->Mutate(op->index); Expr index = this->Mutate(op->index);
Expr pred = this->Mutate(op->predicate); Expr pred = this->Mutate(op->predicate);
if (index.same_as(op->index) && pred.same_as(op->predicate)) { if (index.same_as(op->index) && pred.same_as(op->predicate)) {
...@@ -362,7 +362,7 @@ Expr IRMutator::Mutate_(const Load *op, const Expr& e) { ...@@ -362,7 +362,7 @@ Expr IRMutator::Mutate_(const Load *op, const Expr& e) {
} }
} }
Expr IRMutator::Mutate_(const Let *op, const Expr& e) { Expr IRMutator::Mutate_(const Let* op, const Expr& e) {
Expr value = this->Mutate(op->value); Expr value = this->Mutate(op->value);
Expr body = this->Mutate(op->body); Expr body = this->Mutate(op->body);
if (value.same_as(op->value) && if (value.same_as(op->value) &&
...@@ -413,8 +413,8 @@ DEFINE_BIOP_EXPR_MUTATE_(GE) ...@@ -413,8 +413,8 @@ DEFINE_BIOP_EXPR_MUTATE_(GE)
DEFINE_BIOP_EXPR_MUTATE_(And) DEFINE_BIOP_EXPR_MUTATE_(And)
DEFINE_BIOP_EXPR_MUTATE_(Or) DEFINE_BIOP_EXPR_MUTATE_(Or)
Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) { Expr IRMutator::Mutate_(const Reduce* op, const Expr& e) {
Array<IterVar> new_axis = MutateIterVarArr(op->axis, this); Array<IterVar> new_axis = MutateIterVarArr(op->axis, this);
Array<Expr> new_source = MutateArray(op->source, this); Array<Expr> new_source = MutateArray(op->source, this);
Expr new_cond = this->Mutate(op->condition); Expr new_cond = this->Mutate(op->condition);
if (op->axis.same_as(new_axis) && if (op->axis.same_as(new_axis) &&
...@@ -427,7 +427,7 @@ Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) { ...@@ -427,7 +427,7 @@ Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) {
} }
} }
Expr IRMutator::Mutate_(const Cast *op, const Expr& e) { Expr IRMutator::Mutate_(const Cast* op, const Expr& e) {
Expr value = this->Mutate(op->value); Expr value = this->Mutate(op->value);
if (value.same_as(op->value)) { if (value.same_as(op->value)) {
return e; return e;
...@@ -436,7 +436,7 @@ Expr IRMutator::Mutate_(const Cast *op, const Expr& e) { ...@@ -436,7 +436,7 @@ Expr IRMutator::Mutate_(const Cast *op, const Expr& e) {
} }
} }
Expr IRMutator::Mutate_(const Not *op, const Expr& e) { Expr IRMutator::Mutate_(const Not* op, const Expr& e) {
Expr a = this->Mutate(op->a); Expr a = this->Mutate(op->a);
if (a.same_as(op->a)) { if (a.same_as(op->a)) {
return e; return e;
...@@ -445,7 +445,7 @@ Expr IRMutator::Mutate_(const Not *op, const Expr& e) { ...@@ -445,7 +445,7 @@ Expr IRMutator::Mutate_(const Not *op, const Expr& e) {
} }
} }
Expr IRMutator::Mutate_(const Select *op, const Expr& e) { Expr IRMutator::Mutate_(const Select* op, const Expr& e) {
Expr cond = this->Mutate(op->condition); Expr cond = this->Mutate(op->condition);
Expr t = this->Mutate(op->true_value); Expr t = this->Mutate(op->true_value);
Expr f = this->Mutate(op->false_value); Expr f = this->Mutate(op->false_value);
...@@ -458,7 +458,7 @@ Expr IRMutator::Mutate_(const Select *op, const Expr& e) { ...@@ -458,7 +458,7 @@ Expr IRMutator::Mutate_(const Select *op, const Expr& e) {
} }
} }
Expr IRMutator::Mutate_(const Ramp *op, const Expr& e) { Expr IRMutator::Mutate_(const Ramp* op, const Expr& e) {
Expr base = this->Mutate(op->base); Expr base = this->Mutate(op->base);
Expr stride = this->Mutate(op->stride); Expr stride = this->Mutate(op->stride);
if (base.same_as(op->base) && if (base.same_as(op->base) &&
...@@ -469,7 +469,7 @@ Expr IRMutator::Mutate_(const Ramp *op, const Expr& e) { ...@@ -469,7 +469,7 @@ Expr IRMutator::Mutate_(const Ramp *op, const Expr& e) {
} }
} }
Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) { Expr IRMutator::Mutate_(const Broadcast* op, const Expr& e) {
Expr value = this->Mutate(op->value); Expr value = this->Mutate(op->value);
if (value.same_as(op->value)) { if (value.same_as(op->value)) {
return e; return e;
...@@ -478,7 +478,7 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) { ...@@ -478,7 +478,7 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) {
} }
} }
Expr IRMutator::Mutate_(const Shuffle *op, const Expr& e) { Expr IRMutator::Mutate_(const Shuffle* op, const Expr& e) {
auto new_vec = MutateArray(op->vectors, this); auto new_vec = MutateArray(op->vectors, this);
if (new_vec.same_as(op->vectors)) { if (new_vec.same_as(op->vectors)) {
return e; return e;
......
...@@ -43,7 +43,6 @@ class IRApplyVisit : public IRVisitor { ...@@ -43,7 +43,6 @@ class IRApplyVisit : public IRVisitor {
std::unordered_set<const Node*> visited_; std::unordered_set<const Node*> visited_;
}; };
void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit) { void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit) {
IRApplyVisit(fvisit).Visit(node); IRApplyVisit(fvisit).Visit(node);
} }
...@@ -68,7 +67,7 @@ inline void VisitRDom(const Array<IterVar>& rdom, IRVisitor* v) { ...@@ -68,7 +67,7 @@ inline void VisitRDom(const Array<IterVar>& rdom, IRVisitor* v) {
void IRVisitor::Visit_(const Variable* op) {} void IRVisitor::Visit_(const Variable* op) {}
void IRVisitor::Visit_(const LetStmt *op) { void IRVisitor::Visit_(const LetStmt* op) {
this->Visit(op->value); this->Visit(op->value);
this->Visit(op->body); this->Visit(op->body);
} }
...@@ -78,14 +77,14 @@ void IRVisitor::Visit_(const AttrStmt* op) { ...@@ -78,14 +77,14 @@ void IRVisitor::Visit_(const AttrStmt* op) {
this->Visit(op->body); this->Visit(op->body);
} }
void IRVisitor::Visit_(const For *op) { void IRVisitor::Visit_(const For* op) {
IRVisitor* v = this; IRVisitor* v = this;
v->Visit(op->min); v->Visit(op->min);
v->Visit(op->extent); v->Visit(op->extent);
v->Visit(op->body); v->Visit(op->body);
} }
void IRVisitor::Visit_(const Allocate *op) { void IRVisitor::Visit_(const Allocate* op) {
IRVisitor* v = this; IRVisitor* v = this;
for (size_t i = 0; i < op->extents.size(); i++) { for (size_t i = 0; i < op->extents.size(); i++) {
v->Visit(op->extents[i]); v->Visit(op->extents[i]);
...@@ -97,18 +96,18 @@ void IRVisitor::Visit_(const Allocate *op) { ...@@ -97,18 +96,18 @@ void IRVisitor::Visit_(const Allocate *op) {
} }
} }
void IRVisitor::Visit_(const Load *op) { void IRVisitor::Visit_(const Load* op) {
this->Visit(op->index); this->Visit(op->index);
this->Visit(op->predicate); this->Visit(op->predicate);
} }
void IRVisitor::Visit_(const Store *op) { void IRVisitor::Visit_(const Store* op) {
this->Visit(op->value); this->Visit(op->value);
this->Visit(op->index); this->Visit(op->index);
this->Visit(op->predicate); this->Visit(op->predicate);
} }
void IRVisitor::Visit_(const IfThenElse *op) { void IRVisitor::Visit_(const IfThenElse* op) {
this->Visit(op->condition); this->Visit(op->condition);
this->Visit(op->then_case); this->Visit(op->then_case);
if (op->else_case.defined()) { if (op->else_case.defined()) {
...@@ -116,14 +115,14 @@ void IRVisitor::Visit_(const IfThenElse *op) { ...@@ -116,14 +115,14 @@ void IRVisitor::Visit_(const IfThenElse *op) {
} }
} }
void IRVisitor::Visit_(const Let *op) { void IRVisitor::Visit_(const Let* op) {
this->Visit(op->value); this->Visit(op->value);
this->Visit(op->body); this->Visit(op->body);
} }
void IRVisitor::Visit_(const Free* op) {} void IRVisitor::Visit_(const Free* op) {}
void IRVisitor::Visit_(const Call *op) { void IRVisitor::Visit_(const Call* op) {
VisitArray(op->args, this); VisitArray(op->args, this);
} }
...@@ -171,38 +170,38 @@ void IRVisitor::Visit_(const Select* op) { ...@@ -171,38 +170,38 @@ void IRVisitor::Visit_(const Select* op) {
this->Visit(op->false_value); this->Visit(op->false_value);
} }
void IRVisitor::Visit_(const Ramp *op) { void IRVisitor::Visit_(const Ramp* op) {
this->Visit(op->base); this->Visit(op->base);
this->Visit(op->stride); this->Visit(op->stride);
} }
void IRVisitor::Visit_(const Shuffle *op) { void IRVisitor::Visit_(const Shuffle* op) {
for (const auto &elem : op->indices) for (const auto& elem : op->indices)
this->Visit(elem); this->Visit(elem);
for (const auto &elem : op->vectors) for (const auto& elem : op->vectors)
this->Visit(elem); this->Visit(elem);
} }
void IRVisitor::Visit_(const Broadcast *op) { void IRVisitor::Visit_(const Broadcast* op) {
this->Visit(op->value); this->Visit(op->value);
} }
void IRVisitor::Visit_(const AssertStmt *op) { void IRVisitor::Visit_(const AssertStmt* op) {
this->Visit(op->condition); this->Visit(op->condition);
this->Visit(op->message); this->Visit(op->message);
this->Visit(op->body); this->Visit(op->body);
} }
void IRVisitor::Visit_(const ProducerConsumer *op) { void IRVisitor::Visit_(const ProducerConsumer* op) {
this->Visit(op->body); this->Visit(op->body);
} }
void IRVisitor::Visit_(const Provide *op) { void IRVisitor::Visit_(const Provide* op) {
VisitArray(op->args, this); VisitArray(op->args, this);
this->Visit(op->value); this->Visit(op->value);
} }
void IRVisitor::Visit_(const Realize *op) { void IRVisitor::Visit_(const Realize* op) {
for (size_t i = 0; i < op->bounds.size(); i++) { for (size_t i = 0; i < op->bounds.size(); i++) {
this->Visit(op->bounds[i]->min); this->Visit(op->bounds[i]->min);
this->Visit(op->bounds[i]->extent); this->Visit(op->bounds[i]->extent);
...@@ -212,19 +211,19 @@ void IRVisitor::Visit_(const Realize *op) { ...@@ -212,19 +211,19 @@ void IRVisitor::Visit_(const Realize *op) {
this->Visit(op->condition); this->Visit(op->condition);
} }
void IRVisitor::Visit_(const Prefetch *op) { void IRVisitor::Visit_(const Prefetch* op) {
for (size_t i = 0; i < op->bounds.size(); i++) { for (size_t i = 0; i < op->bounds.size(); i++) {
this->Visit(op->bounds[i]->min); this->Visit(op->bounds[i]->min);
this->Visit(op->bounds[i]->extent); this->Visit(op->bounds[i]->extent);
} }
} }
void IRVisitor::Visit_(const Block *op) { void IRVisitor::Visit_(const Block* op) {
this->Visit(op->first); this->Visit(op->first);
this->Visit(op->rest); this->Visit(op->rest);
} }
void IRVisitor::Visit_(const Evaluate *op) { void IRVisitor::Visit_(const Evaluate* op) {
this->Visit(op->value); this->Visit(op->value);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment