Commit e4b500b6 by Tianqi Chen Committed by GitHub

[PASS][FIX] Fix LiftAttrScope with if (#309)

* [PASS][FIX] Fix LiftAttrScope with if

* [PASS] Fix on proc sync

* fix
parent 19381b51
...@@ -95,7 +95,7 @@ class AttrScopeLifter : public IRMutator { ...@@ -95,7 +95,7 @@ class AttrScopeLifter : public IRMutator {
} }
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final { Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
if (!op->then_case.defined()) { if (!op->else_case.defined()) {
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
} }
Stmt then_case = this->Mutate(op->then_case); Stmt then_case = this->Mutate(op->then_case);
......
...@@ -312,7 +312,7 @@ class CoProcTouchedBuffer : public IRVisitor { ...@@ -312,7 +312,7 @@ class CoProcTouchedBuffer : public IRVisitor {
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
} }
void Visit_(const Call* op) final { void Visit_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { if (op->is_intrinsic(intrinsic::tvm_access_ptr) && in_scope_) {
const Variable* buffer = op->args[1].as<Variable>(); const Variable* buffer = op->args[1].as<Variable>();
touched_.insert(buffer); touched_.insert(buffer);
} }
...@@ -321,6 +321,8 @@ class CoProcTouchedBuffer : public IRVisitor { ...@@ -321,6 +321,8 @@ class CoProcTouchedBuffer : public IRVisitor {
void Visit_(const AttrStmt* op) final { void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::coproc_scope && !in_scope_) { if (op->attr_key == attr::coproc_scope && !in_scope_) {
in_scope_ = true; in_scope_ = true;
IterVar iv(op->node.node_);
coproc_.insert(iv);
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
in_scope_ = false; in_scope_ = false;
} else { } else {
...@@ -329,6 +331,7 @@ class CoProcTouchedBuffer : public IRVisitor { ...@@ -329,6 +331,7 @@ class CoProcTouchedBuffer : public IRVisitor {
} }
std::unordered_set<const Variable*> touched_; std::unordered_set<const Variable*> touched_;
std::unordered_set<IterVar> coproc_;
private: private:
bool in_scope_{false}; bool in_scope_{false};
...@@ -344,6 +347,11 @@ class CoProcSyncPlanner : public StorageAccessVisitor { ...@@ -344,6 +347,11 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
if (!touched_.empty()) { if (!touched_.empty()) {
this->Visit(stmt); this->Visit(stmt);
PlanWriteSync(scope_.back(), nullptr, true); PlanWriteSync(scope_.back(), nullptr, true);
CHECK_EQ(visitor.coproc_.size(), 1U);
if (write_sync_.size() == 0) {
write_sync_[stmt.get()] = GetWriteSync(
(*visitor.coproc_.begin())->var->name_hint + ".coproc_sync");
}
} }
} }
...@@ -438,7 +446,10 @@ class CoProcSyncPlanner : public StorageAccessVisitor { ...@@ -438,7 +446,10 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
// Does not consider memory coherence, need runtime. // Does not consider memory coherence, need runtime.
CHECK_NE(co_access.size(), 0U); CHECK_NE(co_access.size(), 0U);
CHECK_EQ(co_access[0].threads.size(), 1U); CHECK_EQ(co_access[0].threads.size(), 1U);
std::string sync_name = co_access[0].threads[0]->var->name_hint + ".coproc_sync"; return GetWriteSync(co_access[0].threads[0]->var->name_hint + ".coproc_sync");
}
std::vector<Stmt> GetWriteSync(std::string sync_name) {
std::vector<Stmt> stmts; std::vector<Stmt> stmts;
stmts.emplace_back( stmts.emplace_back(
Evaluate::make(Call::make( Evaluate::make(Call::make(
...@@ -447,6 +458,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor { ...@@ -447,6 +458,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
{}, Call::Intrinsic))); {}, Call::Intrinsic)));
return stmts; return stmts;
} }
std::unordered_set<const Variable*> touched_; std::unordered_set<const Variable*> touched_;
StorageScope global_scope_ = StorageScope::make("global"); StorageScope global_scope_ = StorageScope::make("global");
}; };
......
...@@ -11,9 +11,10 @@ def test_coproc_lift(): ...@@ -11,9 +11,10 @@ def test_coproc_lift():
with ib.for_range(0, 10, name="j") as j: with ib.for_range(0, 10, name="j") as j:
ib.scope_attr(cp, "coproc_uop_scope", value) ib.scope_attr(cp, "coproc_uop_scope", value)
A[i] = A[i] + 1 A[i] = A[i] + 1
with ib.for_range(0, 10, name="j") as j: with ib.if_scope(i.equal(0)):
ib.scope_attr(cp, "coproc_uop_scope", value) with ib.for_range(0, 10, name="j") as j:
A[j] = A[j] + 2 ib.scope_attr(cp, "coproc_uop_scope", value)
A[j] = A[j] + 2
body = ib.get() body = ib.get()
body = tvm.ir_pass.LiftAttrScope(body, "coproc_uop_scope") body = tvm.ir_pass.LiftAttrScope(body, "coproc_uop_scope")
assert body.body.body.node == cp assert body.body.body.node == cp
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment