Commit e4b40b53 by Tianqi Chen Committed by GitHub

[PASS] Enhance LiftAttrScope (#632)

* [PASS] Enhance LiftAttrScope

* update vt
parent 182a7852
...@@ -229,7 +229,8 @@ class VTInjector : public IRMutator { ...@@ -229,7 +229,8 @@ class VTInjector : public IRMutator {
if (visit_touched_var_ && !vt_loop_injected_) { if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(s, true); return InjectVTLoop(s, true);
} else if (!allow_share_ && !vt_loop_injected_ && } else if (!allow_share_ && !vt_loop_injected_ &&
op->attr_key == attr::coproc_uop_scope) { (op->attr_key == attr::coproc_uop_scope ||
op->attr_key == attr::coproc_scope)) {
return InjectVTLoop(s, true); return InjectVTLoop(s, true);
} else { } else {
Stmt body = Mutate(op->body); Stmt body = Mutate(op->body);
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
*/ */
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include "./ir_util.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -57,41 +58,16 @@ class AttrScopeLifter : public IRMutator { ...@@ -57,41 +58,16 @@ class AttrScopeLifter : public IRMutator {
} }
Stmt Mutate_(const Block* op, const Stmt& s) final { Stmt Mutate_(const Block* op, const Stmt& s) final {
Stmt first = this->Mutate(op->first); std::vector<Stmt> seq;
NodeRef first_node_; FlattenSeq(op->first, &seq);
Expr first_value_; FlattenSeq(op->rest, &seq);
std::swap(first_node_, attr_node_); seq = MutateSeq(seq);
std::swap(first_value_, attr_value_); if (seq.size() == 2 &&
Stmt rest = this->Mutate(op->rest); seq[0].same_as(op->first) &&
if (attr_node_.defined() && seq[1].same_as(op->rest)) {
attr_value_.defined() && return s;
first_node_.defined() &&
first_value_.defined() &&
attr_node_.same_as(first_node_) &&
attr_value_.same_as(first_value_)) {
if (first.same_as(op->first) && rest.same_as(op->rest)) {
return s;
} else {
return Block::make(first, rest);
}
} else {
if (first_node_.defined()) {
first = AttrStmt::make(
first_node_, attr_key_, first_value_, first);
}
if (attr_node_.defined()) {
rest = AttrStmt::make(
attr_node_, attr_key_, attr_value_, rest);
// undefine them
attr_node_ = NodeRef();
attr_value_ = Expr();
}
if (first.same_as(op->first) && rest.same_as(op->rest)) {
return s;
} else {
return Block::make(first, rest);
}
} }
return MergeSeq(seq);
} }
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final { Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
...@@ -99,17 +75,17 @@ class AttrScopeLifter : public IRMutator { ...@@ -99,17 +75,17 @@ class AttrScopeLifter : public IRMutator {
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
} }
Stmt then_case = this->Mutate(op->then_case); Stmt then_case = this->Mutate(op->then_case);
NodeRef first_node_; NodeRef first_node;
Expr first_value_; Expr first_value;
std::swap(first_node_, attr_node_); std::swap(first_node, attr_node_);
std::swap(first_value_, attr_value_); std::swap(first_value, attr_value_);
Stmt else_case = this->Mutate(op->else_case); Stmt else_case = this->Mutate(op->else_case);
if (attr_node_.defined() && if (attr_node_.defined() &&
attr_value_.defined() && attr_value_.defined() &&
first_node_.defined() && first_node.defined() &&
first_value_.defined() && first_value.defined() &&
attr_node_.same_as(first_node_) && attr_node_.same_as(first_node) &&
attr_value_.same_as(first_value_)) { ValueSame(attr_value_, first_value)) {
if (then_case.same_as(op->then_case) && if (then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) { else_case.same_as(op->else_case)) {
return s; return s;
...@@ -117,9 +93,9 @@ class AttrScopeLifter : public IRMutator { ...@@ -117,9 +93,9 @@ class AttrScopeLifter : public IRMutator {
return IfThenElse::make(op->condition, then_case, else_case); return IfThenElse::make(op->condition, then_case, else_case);
} }
} else { } else {
if (first_node_.defined()) { if (first_node.defined()) {
then_case = AttrStmt::make( then_case = AttrStmt::make(
first_node_, attr_key_, first_value_, then_case); first_node, attr_key_, first_value, then_case);
} }
if (attr_node_.defined()) { if (attr_node_.defined()) {
else_case = AttrStmt::make( else_case = AttrStmt::make(
...@@ -138,6 +114,82 @@ class AttrScopeLifter : public IRMutator { ...@@ -138,6 +114,82 @@ class AttrScopeLifter : public IRMutator {
} }
private: private:
void FlattenSeq(Stmt s, std::vector<Stmt>* res) {
if (const Block* op = s.as<Block>()) {
FlattenSeq(op->first, res);
FlattenSeq(op->rest, res);
} else if (const ProducerConsumer* op = s.as<ProducerConsumer>()) {
if (!op->is_producer) {
FlattenSeq(op->body, res);
} else {
res->emplace_back(s);
}
} else {
res->emplace_back(s);
}
}
std::vector<Stmt> MutateSeq(const std::vector<Stmt>& seq) {
std::vector<Stmt> res_seq;
NodeRef curr_node;
Expr curr_value;
Stmt curr_stmt;
for (const Stmt & stmt : seq) {
attr_node_ = NodeRef();
attr_value_ = Expr();
Stmt rest = this->Mutate(stmt);
if (attr_node_.defined() &&
attr_value_.defined() &&
curr_node.defined() &&
curr_value.defined() &&
attr_node_.same_as(curr_node) &&
ValueSame(attr_value_, curr_value)) {
curr_stmt = Block::make(curr_stmt, rest);
} else {
if (curr_stmt.defined()) {
if (curr_node.defined()) {
curr_stmt = AttrStmt::make(
curr_node, attr_key_, curr_value, curr_stmt);
}
res_seq.push_back(curr_stmt);
}
curr_stmt = rest;
curr_node = attr_node_;
curr_value = attr_value_;
}
}
if (curr_stmt.defined()) {
// keep attr_node_, attr_node_
if (res_seq.size() == 0) {
return {curr_stmt};
}
if (curr_node.defined()) {
curr_stmt = AttrStmt::make(
curr_node, attr_key_, curr_value, curr_stmt);
}
res_seq.push_back(curr_stmt);
// reset
attr_node_ = NodeRef();
attr_value_ = Expr();
}
return res_seq;
}
// value comparison that also compares content of int constant
static bool ValueSame(const Expr& a, const Expr& b) {
if (a.same_as(b)) return true;
if (a->type_key() != b->type_key()) return false;
if (a.type() != b.type()) return false;
if (const IntImm* op = a.as<IntImm>()) {
return op->value == b.as<IntImm>()->value;
}
if (const UIntImm* op = a.as<UIntImm>()) {
return op->value == b.as<UIntImm>()->value;
}
return false;
}
std::string attr_key_; std::string attr_key_;
NodeRef attr_node_; NodeRef attr_node_;
Expr attr_value_; Expr attr_value_;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment