Commit fd96d285 by Tianqi Chen Committed by GitHub

[PASS] More storage sync. (#297)

parent 581be165
......@@ -143,6 +143,8 @@ namespace attr {
constexpr const char* thread_extent = "thread_extent";
/*! \brief Mark launching of a virtual thread. */
constexpr const char* virtual_thread = "virtual_thread";
/*! \brief Mark region is processed by a co-proccesor */
constexpr const char* coproc_scope = "coproc_scope";
/*! \brief Mark the scope as volatile access for certain handle. */
constexpr const char* volatile_scope = "volatile_scope";
/*!
......
......@@ -250,6 +250,14 @@ Stmt StorageRewrite(Stmt stmt);
Stmt LoopPartition(Stmt stmt);
/*!
* \brief Detect and insert sync points to co-processor.
*
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt CoProcSync(Stmt stmt);
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
......
......@@ -193,6 +193,7 @@ def lower(sch,
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.StorageRewrite(stmt)
stmt = ir_pass.CoProcSync(stmt)
cfg = BuildConfig.current
stmt = ir_pass.UnrollLoop(
stmt,
......
......@@ -77,4 +77,24 @@ def stmt_seq(*args):
ret = value if ret is None else Block(ret, value)
return ret if ret else Evaluate(0)
def stmt_list(stmt):
"""Make list of stmt from blocks.
Parameters
----------
stmt : A block statement
Returns
-------
stmt_list : list of Stmt
The unpacked list of statements
"""
if isinstance(stmt, _stmt.Block):
return stmt_list(stmt.first) + stmt_list(stmt.rest)
elif isinstance(stmt, _stmt.ProducerConsumer):
return stmt_list(stmt.body)
return [stmt]
_init_api("tvm.make")
......@@ -94,6 +94,7 @@ REGISTER_PASS5(MakeAPI);
REGISTER_PASS2(BindDeviceType);
REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(StorageRewrite);
REGISTER_PASS1(CoProcSync);
REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(InjectPrefetch);
REGISTER_PASS1(LoopPartition);
......
......@@ -14,7 +14,7 @@ void StorageAccessVisitor::Visit_(const Load* op) {
CHECK(allow_append_);
AccessEntry e;
e.threads = env_threads();
e.buffer = buf;
e.buffer = op->buffer_var;
e.dtype = op->type.element_of();
e.touched = arith::IntSet::vector(op->index);
e.type = kRead;
......@@ -34,7 +34,7 @@ void StorageAccessVisitor::Visit_(const Store* op) {
if (Enabled(buf, scope)) {
AccessEntry e;
e.threads = env_threads();
e.buffer = buf;
e.buffer = op->buffer_var;
e.dtype = op->value.type().element_of();
e.touched = arith::IntSet::vector(op->index);
e.type = kWrite;
......@@ -69,6 +69,11 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) {
storage_scope_[buf] =
StorageScope::make(op->value.as<StringImm>()->value);
IRVisitor::Visit_(op);
} else if (op->attr_key == attr::coproc_scope) {
IterVar iv(op->node.node_);
env_threads_.push_back(iv);
IRVisitor::Visit_(op);
env_threads_.CopyOnWrite()->data.pop_back();
} else if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_);
env_threads_.push_back(iv);
......@@ -102,11 +107,13 @@ void StorageAccessVisitor::Visit_(const For* op) {
relax_map[op->loop_var.get()] = arith::IntSet::range(
Range::make_by_min_extent(op->min, op->extent));
for (AccessEntry& e : s.access) {
if (e.buffer != nullptr) {
if (e.buffer.defined()) {
CHECK(e.touched.defined());
e.touched = arith::EvalSet(e.touched, relax_map);
}
}
}
if (!s.access.empty()) {
scope_.back().emplace_back(std::move(s));
}
}
......@@ -148,7 +155,7 @@ void StorageAccessVisitor::Visit_(const Call* op) {
AccessEntry e;
e.threads = env_threads();
e.dtype = dtype;
e.buffer = buffer;
e.buffer = VarExpr(op->args[1].node_);
e.touched = arith::IntSet::range(
Range::make_by_min_extent(offset, extent));
e.scope = scope;
......
......@@ -27,14 +27,16 @@ class StorageAccessVisitor : public IRVisitor {
kRead,
kWrite,
kSync,
kAlloc
kAlloc,
// acquired version of read, only need to handle WAR dep.
kReadAcquire
};
/*! \brief An access entry */
struct AccessEntry {
/*! \brief The thread index that access this entry */
Array<IterVar> threads;
/*! \brief The buffer variable, if any */
const Variable* buffer{nullptr};
VarExpr buffer;
/*! \brief The access data type */
Type dtype;
/*! \brief The touched access range */
......@@ -104,6 +106,8 @@ class StorageAccessVisitor : public IRVisitor {
* \return The scope of the final buffer array.
*/
StorageScope GetScope(const Variable* buf) const;
// access scope
std::vector<std::vector<StmtEntry> > scope_;
private:
// whether access appending is enabled.
......@@ -116,8 +120,6 @@ class StorageAccessVisitor : public IRVisitor {
StmtEntry curr_stmt_;
// The involving threads
Array<IterVar> env_threads_;
// access scope
std::vector<std::vector<StmtEntry> > scope_;
// The storage scope of each buffer
std::unordered_map<const Variable*, StorageScope> storage_scope_;
};
......
......@@ -37,7 +37,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
// if it is a loop, rotate two times to consider effect of loop.
size_t max_seq = seq.size();
if (loop != 0) max_seq *= 2;
if (loop != nullptr) max_seq *= 2;
// simulation based approach to find dependenceies
for (size_t i = 0; i < max_seq; ++i) {
const StmtEntry& s = seq[i % seq.size()];
......@@ -125,7 +125,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
bool FindConflict(const std::vector<AccessEntry>& vec,
const AccessEntry& e) {
for (const AccessEntry& x : vec) {
if (x.buffer == e.buffer) {
if (x.buffer.same_as(e.buffer)) {
// Assumes no race between threads
// Same index value means no conflicts
// TODO(tqchen) more standard set based testing.
......@@ -296,5 +296,189 @@ LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
return LoweredFunc(n);
}
// Visitor to find touched set by co-processor scope.
class CoProcTouchedBuffer : public IRVisitor {
public:
void Visit_(const Load* op) final {
if (in_scope_) {
touched_.insert(op->buffer_var.get());
}
IRVisitor::Visit_(op);
}
void Visit_(const Store* op) final {
if (in_scope_) {
touched_.insert(op->buffer_var.get());
}
IRVisitor::Visit_(op);
}
void Visit_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
const Variable* buffer = op->args[1].as<Variable>();
touched_.insert(buffer);
}
IRVisitor::Visit_(op);
}
void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::coproc_scope && !in_scope_) {
in_scope_ = true;
IRVisitor::Visit_(op);
in_scope_ = false;
} else {
IRVisitor::Visit_(op);
}
}
std::unordered_set<const Variable*> touched_;
private:
bool in_scope_{false};
};
// Synchronization planning with co-processor.
class CoProcSyncPlanner : public StorageAccessVisitor {
public:
void Plan(const Stmt& stmt) {
CoProcTouchedBuffer visitor;
visitor.Visit(stmt);
touched_ = std::move(visitor.touched_);
if (!touched_.empty()) {
this->Visit(stmt);
PlanWriteSync(scope_.back(), nullptr, true);
}
}
// Write synchronization to be inserted before or after stmt.
std::unordered_map<const Node*, std::vector<Stmt> > write_sync_;
protected:
bool Enabled(const Variable* buf,
const StorageScope& scope) const final {
return touched_.count(buf) && scope == global_scope_;
}
// Plan the sync
std::vector<AccessEntry> Summarize(
std::vector<StmtEntry> seq, const For* loop) final {
return PlanWriteSync(seq, loop, false);
}
private:
// Plan write synchronization if write is not coherent
std::vector<AccessEntry> PlanWriteSync(
std::vector<StmtEntry> seq, const For* loop,
bool force_sync_at_end) {
// detect write barriers
// access by the co-processor.
std::vector<AccessEntry> co_access;
bool contain_sync = false;
auto find_conflict = [&](const AccessEntry& acc) {
for (const AccessEntry& x : co_access) {
if (x.buffer.same_as(acc.buffer) &&
((acc.type == kRead && x.type == kWrite) ||
acc.type == kWrite)) {
return true;
}
}
return false;
};
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
bool sync_write = false;
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && find_conflict(acc)) {
sync_write = true; break;
}
if (acc.type == kSync) {
co_access.clear();
contain_sync = true;
}
}
if (sync_write) {
CHECK_NE(i, 0U);
write_sync_[seq[i - 1].stmt] = GetWriteSync(co_access);
co_access.clear();
contain_sync = true;
}
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() != 0) {
co_access.push_back(acc);
}
}
}
bool sync_at_end = force_sync_at_end;
if (loop != nullptr && !sync_at_end) {
// loop carray dependency
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && find_conflict(acc)) {
sync_at_end = true; break;
}
}
if (write_sync_.count(s.stmt) || sync_at_end) break;
}
}
if (sync_at_end && co_access.size() != 0) {
CHECK_NE(seq.size(), 0);
contain_sync = true;
write_sync_[seq.back().stmt] = GetWriteSync(co_access);
co_access.clear();
}
if (contain_sync) {
AccessEntry e;
e.type = kSync;
e.scope = global_scope_;
co_access.insert(co_access.begin(), e);
}
return co_access;
}
// Add write Synchronization
std::vector<Stmt> GetWriteSync(const std::vector<AccessEntry>& co_access) {
// Does not consider memory coherence, need runtime.
CHECK_NE(co_access.size(), 0U);
CHECK_EQ(co_access[0].threads.size(), 1U);
std::string sync_name = co_access[0].threads[0]->var->name_hint + ".coproc_sync";
std::vector<Stmt> stmts;
stmts.emplace_back(
Evaluate::make(Call::make(
Int(32),
sync_name,
{}, Call::Intrinsic)));
return stmts;
}
std::unordered_set<const Variable*> touched_;
StorageScope global_scope_ = StorageScope::make("global");
};
class CoProcSyncInserter : public IRMutator {
public:
explicit CoProcSyncInserter(
const std::unordered_map<const Node*, std::vector<Stmt> >& write_sync)
: write_sync_(write_sync) {}
Stmt Mutate(Stmt stmt) final {
stmt = IRMutator::Mutate(stmt);
auto it = write_sync_.find(stmt.get());
if (it != write_sync_.end()) {
stmt = Block::make(stmt, MergeSeq(it->second));
}
return stmt;
}
private:
const std::unordered_map<const Node*, std::vector<Stmt> >& write_sync_;
};
Stmt CoProcSync(Stmt stmt) {
CoProcSyncPlanner planner;
planner.Plan(stmt);
if (planner.write_sync_.size() != 0) {
return CoProcSyncInserter(planner.write_sync_).Mutate(stmt);
} else {
return stmt;
}
}
} // namespace ir
} // namespace tvm
......@@ -24,7 +24,26 @@ def test_storage_sync():
flist = tvm.ir_pass.SplitHostDevice(f)
f = flist[1]
f = tvm.ir_pass.ThreadSync(f, "shared")
print(f.body)
body_list = tvm.make.stmt_list(f.body.body.body.body)
assert(body_list[1].value.name == "tvm_storage_sync")
def test_coproc_sync():
ib = tvm.ir_builder.create()
n = tvm.var("n")
cp = tvm.thread_axis((0, 1), "cop")
A = ib.allocate("float32", n, name="A", scope="global")
with ib.for_range(0, n, name="i") as i:
A[i] = A[i] + 1
with ib.for_range(0, 10, name="j") as j:
ib.scope_attr(cp, "coproc_scope", 1)
A[j] = A[j] + 2
body = ib.get()
body = tvm.ir_pass.CoProcSync(body)
body = body.body.body.body
assert(tvm.make.stmt_list(body)[-1].value.name == "cop.coproc_sync")
if __name__ == "__main__":
test_coproc_sync()
test_storage_sync()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment