Commit f1aabedc by Tianqi Chen Committed by GitHub

[PASS] Update coproc sync (#634)

parent 32b0fff2
...@@ -201,7 +201,8 @@ def lower(sch, ...@@ -201,7 +201,8 @@ def lower(sch,
add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else [] add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
lower_phase2 = [x[1] for x in add_lower_pass if x[0] > 1] lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]
# normalize schedule first # normalize schedule first
sch = sch.normalize() sch = sch.normalize()
# Phase 0 # Phase 0
...@@ -213,6 +214,9 @@ def lower(sch, ...@@ -213,6 +214,9 @@ def lower(sch,
# Phase 1 # Phase 1
stmt = ir_pass.StorageFlatten(stmt, binds, 64) stmt = ir_pass.StorageFlatten(stmt, binds, 64)
stmt = ir_pass.CanonicalSimplify(stmt) stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
stmt = f(stmt)
# Phase 2
if not simple_mode: if not simple_mode:
stmt = ir_pass.LoopPartition(stmt) stmt = ir_pass.LoopPartition(stmt)
stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.VectorizeLoop(stmt)
...@@ -224,14 +228,14 @@ def lower(sch, ...@@ -224,14 +228,14 @@ def lower(sch,
cfg.auto_unroll_max_step, cfg.auto_unroll_max_step,
cfg.auto_unroll_max_depth, cfg.auto_unroll_max_depth,
cfg.unroll_explicit) cfg.unroll_explicit)
for f in lower_phase1: for f in lower_phase2:
stmt = f(stmt) stmt = f(stmt)
# Phase 2 # Phase 2
stmt = ir_pass.Simplify(stmt) stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.LowerStorageAccessInfo(stmt) stmt = ir_pass.LowerStorageAccessInfo(stmt)
stmt = ir_pass.RemoveNoOp(stmt) stmt = ir_pass.RemoveNoOp(stmt)
stmt = ir_pass.RewriteUnsafeSelect(stmt) stmt = ir_pass.RewriteUnsafeSelect(stmt)
for f in lower_phase2: for f in lower_phase3:
stmt = f(stmt) stmt = f(stmt)
if simple_mode: if simple_mode:
return stmt return stmt
......
...@@ -338,6 +338,256 @@ class CoProcBarrierDetector : public StorageAccessVisitor { ...@@ -338,6 +338,256 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
}; };
class CoProcInstDepDetector : public IRVisitor {
public:
explicit CoProcInstDepDetector(
const IterVar& coproc_axis,
const std::string& coproc_name)
: coproc_axis_(coproc_axis) {
sync_push_name_ = coproc_name + ".coproc_dep_push";
sync_pop_name_ = coproc_name + ".coproc_dep_pop";
}
void Plan(Stmt stmt) {
this->Visit(stmt);
if (last_state_.node != nullptr) {
MatchFixEnterPop(first_state_);
MatchFixExitPush(last_state_);
}
}
void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::coproc_scope &&
op->node.same_as(coproc_axis_)) {
const IntImm* ctx_id = op->value.as<IntImm>();
CHECK(ctx_id != nullptr);
curr_state_.clear();
curr_state_.node = op->body.get();
curr_state_.enter_ctx.insert(ctx_id->value);
curr_state_.exit_ctx.insert(ctx_id->value);
UpdateState();
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const For* op) final {
SyncState temp_first, temp_last;
std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last);
this->Visit(op->body);
curr_state_.clear();
if (last_state_.node != nullptr) {
curr_state_.node = op;
CHECK(first_state_.node != nullptr);
// loop carry dependency
InjectSync(last_state_, first_state_,
&(curr_state_.exit_push),
&(curr_state_.enter_pop));
curr_state_.enter_ctx = first_state_.enter_ctx;
curr_state_.exit_ctx = last_state_.enter_ctx;
}
std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last);
if (curr_state_.node != nullptr) {
UpdateState();
}
}
void Visit_(const IfThenElse* op) final {
SyncState temp_first, temp_last, curr_state;
std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last);
{
// then stmt
this->Visit(op->then_case);
if (last_state_.node != nullptr) {
curr_state.node = op;
MatchFixEnterPop(first_state_);
MatchFixExitPush(last_state_);
curr_state.enter_ctx.insert(
first_state_.enter_ctx.begin(),
first_state_.enter_ctx.end());
curr_state.exit_ctx.insert(
last_state_.exit_ctx.begin(),
last_state_.exit_ctx.end());
}
first_state_.clear();
last_state_.clear();
}
if (op->else_case.defined()) {
this->Visit(op->else_case);
if (last_state_.node != nullptr) {
curr_state.node = op;
MatchFixEnterPop(first_state_);
MatchFixExitPush(last_state_);
curr_state.enter_ctx.insert(
first_state_.enter_ctx.begin(),
first_state_.enter_ctx.end());
curr_state.exit_ctx.insert(
last_state_.exit_ctx.begin(),
last_state_.exit_ctx.end());
}
}
// update in the trace.
std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last);
std::swap(curr_state_, curr_state);
if (curr_state_.node != nullptr) {
UpdateState();
}
}
// insert before is stored in reverse order
// the first element is closest to the node.
std::unordered_map<const Node*, std::vector<Stmt> > insert_before_;
std::unordered_map<const Node*, std::vector<Stmt> > insert_after_;
private:
// state in the sync entry
struct SyncState {
// The statement of the state.
const Node* node{nullptr};
// Set of all possible contexts in the entering moment.
std::unordered_set<int> enter_ctx;
// Set of all possible contexts in the exit moment.
std::unordered_set<int> exit_ctx;
// existing pop performed at enter
std::vector<std::pair<int, int> > enter_pop;
// existing push peformed at exit
std::vector<std::pair<int, int> > exit_push;
// clear the state
void clear() {
node = nullptr;
enter_ctx.clear();
exit_ctx.clear();
enter_pop.clear();
exit_push.clear();
}
};
// inject proper sync into the pair
// record the push/pop sequence that could be possibly un-matched.
// return the push/pop message at enter/exit of the Block
// after considering the existing unmatcheded events and added events
void InjectSync(const SyncState& prev,
const SyncState& next,
std::vector<std::pair<int, int> >* prev_exit_push,
std::vector<std::pair<int, int> >* next_enter_pop) {
prev_exit_push->clear();
next_enter_pop->clear();
// quick path
if (prev.exit_push.size() == 0 && next.enter_pop.size() == 0 &&
prev.exit_ctx.size() == 1 && next.enter_ctx.size() == 1) {
int from = *prev.exit_ctx.begin();
int to = *next.enter_ctx.begin();
if (from != to) {
insert_after_[prev.node].emplace_back(MakePush(from, to));
insert_before_[next.node].emplace_back(MakePop(from, to));
prev_exit_push->emplace_back(std::make_pair(from, to));
next_enter_pop->emplace_back(std::make_pair(from, to));
}
return;
}
// complicate path.
std::vector<std::pair<int, int> > vpush = prev.exit_push;
std::vector<std::pair<int, int> > vpop = next.enter_pop;
std::vector<std::pair<int, int> > pending;
for (int from : prev.exit_ctx) {
for (int to : next.enter_ctx) {
if (from != to) {
pending.emplace_back(std::make_pair(from, to));
}
}
}
// policy 1
std::vector<Stmt> prev_after, next_before;
for (const std::pair<int, int>& p : pending) {
if (std::find(prev.exit_push.begin(),
prev.exit_push.end(), p) ==
prev.exit_push.end()) {
vpush.push_back(p);
prev_after.emplace_back(MakePush(p.first, p.second));
}
if (std::find(next.enter_pop.begin(),
next.enter_pop.end(), p) ==
next.enter_pop.end()) {
vpop.push_back(p);
next_before.emplace_back(MakePop(p.first, p.second));
}
}
// fix pending
for (const std::pair<int, int>& p : vpush) {
if (std::find(vpop.begin(), vpop.end(), p) == vpop.end()) {
prev_after.emplace_back(MakePop(p.first, p.second));
} else {
prev_exit_push->push_back(p);
}
}
for (const std::pair<int, int>& p : vpop) {
if (std::find(vpush.begin(), vpush.end(), p) == vpush.end()) {
next_before.emplace_back(MakePush(p.first, p.second));
} else {
next_enter_pop->push_back(p);
}
}
if (prev_after.size() != 0) {
auto &v1 = insert_after_[prev.node];
v1.insert(v1.end(), prev_after.begin(), prev_after.end());
}
if (next_before.size() != 0) {
auto &v2 = insert_before_[next.node];
v2.insert(v2.end(), next_before.begin(), next_before.end());
}
}
void MatchFixEnterPop(const SyncState& state) {
if (state.enter_pop.size() == 0) return;
auto &vec = insert_before_[state.node];
for (const std::pair<int, int>& p : state.enter_pop) {
vec.push_back(MakePush(p.first, p.second));
}
}
void MatchFixExitPush(const SyncState& state) {
if (state.exit_push.size() == 0) return;
auto &vec = insert_after_[state.node];
for (const std::pair<int, int>& p : state.exit_push) {
vec.push_back(MakePop(p.first, p.second));
}
}
void UpdateState() {
if (last_state_.node != nullptr) {
std::vector<std::pair<int, int> > t1, t2;
InjectSync(last_state_, curr_state_, &t1, &t2);
std::swap(last_state_, curr_state_);
} else {
CHECK(first_state_.node == nullptr);
first_state_ = curr_state_;
last_state_ = curr_state_;
}
}
Stmt MakePush(int from, int to) {
return Evaluate::make(Call::make(
Int(32), sync_push_name_,
{make_const(Int(32), from), make_const(Int(32), to)},
Call::Intrinsic));
}
Stmt MakePop(int from, int to) {
return Evaluate::make(Call::make(
Int(32), sync_pop_name_,
{make_const(Int(32), from), make_const(Int(32), to)},
Call::Intrinsic));
}
// sync states.
SyncState first_state_, last_state_, curr_state_;
// Variables
IterVar coproc_axis_;
std::string sync_push_name_, sync_pop_name_;
};
class CoProcSyncInserter : public IRMutator { class CoProcSyncInserter : public IRMutator {
public: public:
Stmt Insert(Stmt stmt) { Stmt Insert(Stmt stmt) {
...@@ -372,6 +622,18 @@ class CoProcSyncInserter : public IRMutator { ...@@ -372,6 +622,18 @@ class CoProcSyncInserter : public IRMutator {
auto& vec = insert_after_[kv.first]; auto& vec = insert_after_[kv.first];
vec.insert(vec.end(), kv.second.begin(), kv.second.end()); vec.insert(vec.end(), kv.second.begin(), kv.second.end());
} }
// Detect barrier
CoProcInstDepDetector sync_detector(
*visitor.coproc_.begin(), coproc_name);
sync_detector.Plan(stmt);
for (const auto& kv : sync_detector.insert_before_) {
auto& vec = insert_before_[kv.first];
vec.insert(vec.end(), kv.second.begin(), kv.second.end());
}
for (const auto& kv : sync_detector.insert_after_) {
auto& vec = insert_after_[kv.first];
vec.insert(vec.end(), kv.second.begin(), kv.second.end());
}
return Mutate(stmt); return Mutate(stmt);
} }
...@@ -379,7 +641,8 @@ class CoProcSyncInserter : public IRMutator { ...@@ -379,7 +641,8 @@ class CoProcSyncInserter : public IRMutator {
Stmt before, after; Stmt before, after;
auto it = insert_before_.find(stmt.get()); auto it = insert_before_.find(stmt.get());
if (it != insert_before_.end()) { if (it != insert_before_.end()) {
before = MergeSeq(it->second); before = MergeSeq(std::vector<Stmt>(
it->second.rbegin(), it->second.rend()));
} }
it = insert_after_.find(stmt.get()); it = insert_after_.find(stmt.get());
if (it != insert_after_.end()) { if (it != insert_after_.end()) {
...@@ -396,10 +659,13 @@ class CoProcSyncInserter : public IRMutator { ...@@ -396,10 +659,13 @@ class CoProcSyncInserter : public IRMutator {
} }
private: private:
// insert before is stored in reverse order
// the first element is closest to the node.
std::unordered_map<const Node*, std::vector<Stmt> > insert_before_; std::unordered_map<const Node*, std::vector<Stmt> > insert_before_;
std::unordered_map<const Node*, std::vector<Stmt> > insert_after_; std::unordered_map<const Node*, std::vector<Stmt> > insert_after_;
}; };
Stmt CoProcSync(Stmt stmt) { Stmt CoProcSync(Stmt stmt) {
return CoProcSyncInserter().Insert(stmt); return CoProcSyncInserter().Insert(stmt);
} }
......
...@@ -189,7 +189,7 @@ class StoragePlanRewriter : public IRMutator { ...@@ -189,7 +189,7 @@ class StoragePlanRewriter : public IRMutator {
if (attach_map_.count(nullptr)) { if (attach_map_.count(nullptr)) {
std::vector<Stmt> nest; std::vector<Stmt> nest;
for (StorageEntry* e : attach_map_.at(nullptr)) { for (StorageEntry* e : attach_map_.at(nullptr)) {
CHECK_EQ(e->scope.rank, 0); // CHECK_EQ(e->scope.rank, 0);
if (e->new_alloc.defined()) { if (e->new_alloc.defined()) {
nest.emplace_back(AttrStmt::make( nest.emplace_back(AttrStmt::make(
e->alloc_var, attr::storage_scope, e->alloc_var, attr::storage_scope,
...@@ -395,6 +395,12 @@ class StoragePlanRewriter : public IRMutator { ...@@ -395,6 +395,12 @@ class StoragePlanRewriter : public IRMutator {
e->new_alloc = Allocate::make( e->new_alloc = Allocate::make(
e->alloc_var, alloc_type, e->allocs[0]->extents, e->alloc_var, alloc_type, e->allocs[0]->extents,
e->allocs[0]->condition, Evaluate::make(0)); e->allocs[0]->condition, Evaluate::make(0));
if (e->scope.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
<< "Allocation exceed bound of memory tag " << e->scope.to_string();
}
} else { } else {
// Build a merged allocation // Build a merged allocation
Expr combo_size; Expr combo_size;
......
...@@ -71,7 +71,7 @@ struct ThreadScope { ...@@ -71,7 +71,7 @@ struct ThreadScope {
*/ */
static ThreadScope make(const std::string& s) { static ThreadScope make(const std::string& s) {
ThreadScope r; ThreadScope r;
if (s == "vthread") { if (s == "vthread" || s == "cthread") {
// virtual thread at the same level as local // virtual thread at the same level as local
r.rank = 1; r.rank = 1;
r.dim_index = -1; r.dim_index = -1;
......
...@@ -58,6 +58,27 @@ def test_coproc_sync(): ...@@ -58,6 +58,27 @@ def test_coproc_sync():
assert(blist[-1].value.args[3].value == 10) assert(blist[-1].value.args[3].value == 10)
def test_coproc_sync2():
ib = tvm.ir_builder.create()
n = tvm.var("n")
cp = tvm.thread_axis((0, 1), "cop")
ty = tvm.thread_axis("cthread")
A = ib.allocate("float32", 128, name="A")
ib.scope_attr(ty, "virtual_thread", 2)
with ib.new_scope():
ib.scope_attr(cp, "coproc_scope", 2)
A[ty] = 0.0
with ib.for_range(0, n, name="i") as i:
with ib.new_scope():
ib.scope_attr(cp, "coproc_scope", 1)
A[ty] = 1.0
with ib.new_scope():
ib.scope_attr(cp, "coproc_scope", 2)
A[ty] = 1.0
stmt = ib.get()
stmt = tvm.ir_pass.CoProcSync(stmt)
if __name__ == "__main__": if __name__ == "__main__":
test_coproc_sync() test_coproc_sync()
test_storage_sync() test_storage_sync()
test_coproc_sync2()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment