Commit e4b40b53 by Tianqi Chen Committed by GitHub

[PASS] Enhance LiftAttrScope (#632)

* [PASS] Enhance LiftAttrScope

* update vt
parent 182a7852
......@@ -229,7 +229,8 @@ class VTInjector : public IRMutator {
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(s, true);
} else if (!allow_share_ && !vt_loop_injected_ &&
op->attr_key == attr::coproc_uop_scope) {
(op->attr_key == attr::coproc_uop_scope ||
op->attr_key == attr::coproc_scope)) {
return InjectVTLoop(s, true);
} else {
Stmt body = Mutate(op->body);
......
......@@ -7,6 +7,7 @@
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include "./ir_util.h"
namespace tvm {
namespace ir {
......@@ -57,41 +58,16 @@ class AttrScopeLifter : public IRMutator {
}
Stmt Mutate_(const Block* op, const Stmt& s) final {
Stmt first = this->Mutate(op->first);
NodeRef first_node_;
Expr first_value_;
std::swap(first_node_, attr_node_);
std::swap(first_value_, attr_value_);
Stmt rest = this->Mutate(op->rest);
if (attr_node_.defined() &&
attr_value_.defined() &&
first_node_.defined() &&
first_value_.defined() &&
attr_node_.same_as(first_node_) &&
attr_value_.same_as(first_value_)) {
if (first.same_as(op->first) && rest.same_as(op->rest)) {
std::vector<Stmt> seq;
FlattenSeq(op->first, &seq);
FlattenSeq(op->rest, &seq);
seq = MutateSeq(seq);
if (seq.size() == 2 &&
seq[0].same_as(op->first) &&
seq[1].same_as(op->rest)) {
return s;
} else {
return Block::make(first, rest);
}
} else {
if (first_node_.defined()) {
first = AttrStmt::make(
first_node_, attr_key_, first_value_, first);
}
if (attr_node_.defined()) {
rest = AttrStmt::make(
attr_node_, attr_key_, attr_value_, rest);
// undefine them
attr_node_ = NodeRef();
attr_value_ = Expr();
}
if (first.same_as(op->first) && rest.same_as(op->rest)) {
return s;
} else {
return Block::make(first, rest);
}
}
return MergeSeq(seq);
}
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
......@@ -99,17 +75,17 @@ class AttrScopeLifter : public IRMutator {
return IRMutator::Mutate_(op, s);
}
Stmt then_case = this->Mutate(op->then_case);
NodeRef first_node_;
Expr first_value_;
std::swap(first_node_, attr_node_);
std::swap(first_value_, attr_value_);
NodeRef first_node;
Expr first_value;
std::swap(first_node, attr_node_);
std::swap(first_value, attr_value_);
Stmt else_case = this->Mutate(op->else_case);
if (attr_node_.defined() &&
attr_value_.defined() &&
first_node_.defined() &&
first_value_.defined() &&
attr_node_.same_as(first_node_) &&
attr_value_.same_as(first_value_)) {
first_node.defined() &&
first_value.defined() &&
attr_node_.same_as(first_node) &&
ValueSame(attr_value_, first_value)) {
if (then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
......@@ -117,9 +93,9 @@ class AttrScopeLifter : public IRMutator {
return IfThenElse::make(op->condition, then_case, else_case);
}
} else {
if (first_node_.defined()) {
if (first_node.defined()) {
then_case = AttrStmt::make(
first_node_, attr_key_, first_value_, then_case);
first_node, attr_key_, first_value, then_case);
}
if (attr_node_.defined()) {
else_case = AttrStmt::make(
......@@ -138,6 +114,82 @@ class AttrScopeLifter : public IRMutator {
}
private:
void FlattenSeq(Stmt s, std::vector<Stmt>* res) {
if (const Block* op = s.as<Block>()) {
FlattenSeq(op->first, res);
FlattenSeq(op->rest, res);
} else if (const ProducerConsumer* op = s.as<ProducerConsumer>()) {
if (!op->is_producer) {
FlattenSeq(op->body, res);
} else {
res->emplace_back(s);
}
} else {
res->emplace_back(s);
}
}
std::vector<Stmt> MutateSeq(const std::vector<Stmt>& seq) {
std::vector<Stmt> res_seq;
NodeRef curr_node;
Expr curr_value;
Stmt curr_stmt;
for (const Stmt & stmt : seq) {
attr_node_ = NodeRef();
attr_value_ = Expr();
Stmt rest = this->Mutate(stmt);
if (attr_node_.defined() &&
attr_value_.defined() &&
curr_node.defined() &&
curr_value.defined() &&
attr_node_.same_as(curr_node) &&
ValueSame(attr_value_, curr_value)) {
curr_stmt = Block::make(curr_stmt, rest);
} else {
if (curr_stmt.defined()) {
if (curr_node.defined()) {
curr_stmt = AttrStmt::make(
curr_node, attr_key_, curr_value, curr_stmt);
}
res_seq.push_back(curr_stmt);
}
curr_stmt = rest;
curr_node = attr_node_;
curr_value = attr_value_;
}
}
if (curr_stmt.defined()) {
// keep attr_node_, attr_node_
if (res_seq.size() == 0) {
return {curr_stmt};
}
if (curr_node.defined()) {
curr_stmt = AttrStmt::make(
curr_node, attr_key_, curr_value, curr_stmt);
}
res_seq.push_back(curr_stmt);
// reset
attr_node_ = NodeRef();
attr_value_ = Expr();
}
return res_seq;
}
// value comparison that also compares content of int constant
static bool ValueSame(const Expr& a, const Expr& b) {
if (a.same_as(b)) return true;
if (a->type_key() != b->type_key()) return false;
if (a.type() != b.type()) return false;
if (const IntImm* op = a.as<IntImm>()) {
return op->value == b.as<IntImm>()->value;
}
if (const UIntImm* op = a.as<UIntImm>()) {
return op->value == b.as<UIntImm>()->value;
}
return false;
}
std::string attr_key_;
NodeRef attr_node_;
Expr attr_value_;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment