Commit 79e482bc by Tianqi Chen Committed by GitHub

[PASS] Memory barrier detection, storage access lower. (#317)

parent afa20869
......@@ -267,6 +267,15 @@ Stmt CoProcSync(Stmt stmt);
Stmt LiftAttrScope(Stmt stmt, std::string attr_key);
/*!
* \brief Lower attached storage access information.
* Do this pass after all storage access analysis finish.
*
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt LowerStorageAccessInfo(Stmt stmt);
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
......
......@@ -23,11 +23,17 @@ struct MemoryInfoNode : public Node {
int max_num_bits;
/*! \brief maximum number of bits to be used in simd op */
int max_simd_bits;
/*!
* \brief head address of the buffer, if visible to CPU
* This address can be None.
*/
Expr head_address;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("unit_bits", &unit_bits);
v->Visit("max_num_bits", &max_num_bits);
v->Visit("max_simd_bits", &max_simd_bits);
v->Visit("head_address", &head_address);
}
static constexpr const char* _type_key = "MemoryInfo";
......
......@@ -197,7 +197,6 @@ def lower(sch,
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.StorageRewrite(stmt)
stmt = ir_pass.CoProcSync(stmt)
cfg = BuildConfig.current
stmt = ir_pass.UnrollLoop(
stmt,
......@@ -210,6 +209,7 @@ def lower(sch,
stmt = ir_pass.Simplify(stmt)
if simple_mode:
return stmt
stmt = ir_pass.LowerStorageAccessInfo(stmt)
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
......
......@@ -95,6 +95,7 @@ REGISTER_PASS2(BindDeviceType);
REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(StorageRewrite);
REGISTER_PASS1(CoProcSync);
REGISTER_PASS1(LowerStorageAccessInfo);
REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(InjectPrefetch);
REGISTER_PASS1(LoopPartition);
......
/*!
* Copyright (c) 2017 by Contributors
* \file coproc_sync.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <unordered_map>
#include <unordered_set>
#include "./ir_util.h"
#include "./storage_access.h"
namespace tvm {
namespace ir {
// Visitor to find touched set by co-processor scope.
class CoProcTouchedBuffer : public IRVisitor {
public:
void Visit_(const Load* op) final {
if (in_scope_) {
touched_[op->buffer_var.get()].coproc = true;
} else {
touched_[op->buffer_var.get()].normal = true;
}
IRVisitor::Visit_(op);
}
void Visit_(const Store* op) final {
if (in_scope_) {
touched_[op->buffer_var.get()].coproc = true;
} else {
touched_[op->buffer_var.get()].normal = true;
}
IRVisitor::Visit_(op);
}
void Visit_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
const Variable* buffer = op->args[1].as<Variable>();
if (in_scope_) {
touched_[buffer].coproc = true;
} else {
touched_[buffer].normal = true;
}
}
IRVisitor::Visit_(op);
}
void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::coproc_scope && !in_scope_) {
in_scope_ = true;
IterVar iv(op->node.node_);
coproc_.insert(iv);
IRVisitor::Visit_(op);
in_scope_ = false;
} else {
IRVisitor::Visit_(op);
}
}
// Touch Entry
struct TouchEntry {
bool normal{false};
bool coproc{false};
};
std::unordered_map<const Variable*, TouchEntry> touched_;
std::unordered_set<IterVar> coproc_;
private:
bool in_scope_{false};
};
// Synchronization planning with co-processor.
class CoProcSyncPlanner : public StorageAccessVisitor {
public:
explicit CoProcSyncPlanner(
const std::unordered_set<const Variable*>& touched,
const std::string& coproc_name)
: touched_(touched), coproc_name_(coproc_name) {
}
void Plan(const Stmt& stmt) {
this->Visit(stmt);
PlanSync(scope_.back(), nullptr, true);
if (sync_.size() == 0) {
sync_[stmt.get()] = GetSync(coproc_name_ + ".coproc_sync");
}
}
// Write synchronization to be inserted before or after stmt.
std::unordered_map<const Node*, std::vector<Stmt> > sync_;
protected:
bool Enabled(const Variable* buf,
const StorageScope& scope) const final {
return touched_.count(buf);
}
// Plan the sync
std::vector<AccessEntry> Summarize(
std::vector<StmtEntry> seq, const For* loop) final {
return PlanSync(seq, loop, false);
}
private:
// Plan write synchronization if write is not coherent
std::vector<AccessEntry> PlanSync(
std::vector<StmtEntry> seq, const For* loop,
bool force_sync_at_end) {
// detect write barriers
// access by the co-processor.
std::vector<AccessEntry> co_access;
bool contain_sync = false;
auto find_conflict = [&](const AccessEntry& acc) {
for (const AccessEntry& x : co_access) {
if (x.buffer.same_as(acc.buffer) &&
((acc.type == kRead && x.type == kWrite) ||
acc.type == kWrite)) {
return true;
}
}
return false;
};
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
bool sync_write = false;
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && find_conflict(acc)) {
sync_write = true; break;
}
if (acc.type == kSync) {
co_access.clear();
contain_sync = true;
}
}
if (sync_write) {
CHECK_NE(i, 0U);
sync_[seq[i - 1].stmt] = GetSync(co_access);
co_access.clear();
contain_sync = true;
}
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() != 0) {
co_access.push_back(acc);
}
}
}
bool sync_at_end = force_sync_at_end;
if (loop != nullptr && !sync_at_end) {
// loop carray dependency
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && find_conflict(acc)) {
sync_at_end = true; break;
}
}
if (sync_.count(s.stmt) || sync_at_end) break;
}
}
if (sync_at_end && co_access.size() != 0) {
CHECK_NE(seq.size(), 0);
contain_sync = true;
sync_[seq.back().stmt] = GetSync(co_access);
co_access.clear();
}
if (contain_sync) {
AccessEntry e;
e.type = kSync;
co_access.insert(co_access.begin(), e);
}
return co_access;
}
// Add write Synchronization
std::vector<Stmt> GetSync(const std::vector<AccessEntry>& co_access) {
// Does not consider memory coherence, need runtime.
CHECK_NE(co_access.size(), 0U);
CHECK_EQ(co_access[0].threads.size(), 1U);
return GetSync(coproc_name_ + ".coproc_sync");
}
std::vector<Stmt> GetSync(std::string sync_name) {
return {Evaluate::make(Call::make(
Int(32),
sync_name,
{}, Call::Intrinsic))};
}
const std::unordered_set<const Variable*>& touched_;
std::string coproc_name_;
};
// Detect memory barriers when coproc read/write memory
class CoProcBarrierDetector : public StorageAccessVisitor {
public:
explicit CoProcBarrierDetector(
const std::unordered_set<const Variable*>& touched,
const std::string& coproc_name)
: touched_(touched) {
read_barrier_name_ = coproc_name + ".coproc_read_barrier";
write_barrier_name_ = coproc_name + ".coproc_write_barrier";
}
void PlanReadBarrier(Stmt stmt) {
read_barrier_ = true;
this->Visit(stmt);
}
void PlanWriteBarrier(Stmt stmt) {
read_barrier_ = false;
this->Visit(stmt);
}
std::unordered_map<const Node*, std::vector<Stmt> > barrier_before_;
std::unordered_map<const Node*, std::vector<Stmt> > barrier_after_;
protected:
bool Enabled(const Variable* buf,
const StorageScope& scope) const final {
return touched_.count(buf);
}
// Plan the sync
std::vector<AccessEntry> Summarize(
std::vector<StmtEntry> seq, const For* loop) final {
if (read_barrier_) {
return PlanReadBarrier(seq, loop);
} else {
return PlanWriteBarrier(seq, loop);
}
}
private:
// Plan write barrier at Read after write point.
std::vector<AccessEntry> PlanWriteBarrier(
std::vector<StmtEntry> seq, const For* loop) {
std::vector<AccessEntry> read_seq;
std::unordered_map<const Variable*, std::vector<AccessEntry> > write_set;
auto fupdate = [&](size_t i, const AccessEntry& acc) {
auto it = write_set.find(acc.buffer.get());
if (it != write_set.end()) {
CHECK_NE(i, 0U);
barrier_after_[seq[i - 1].stmt].push_back(
MakeBarrier(write_barrier_name_, it->second));
write_set.erase(it);
}
};
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && acc.type == kRead) {
fupdate(i, acc);
read_seq.push_back(acc);
}
}
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() != 0 && acc.type == kWrite) {
write_set[acc.buffer.get()].push_back(acc);
}
}
}
// loop carry
if (loop != nullptr) {
for (const AccessEntry& acc : read_seq) {
fupdate(seq.size(), acc);
}
}
for (const auto &kv : write_set) {
read_seq.insert(read_seq.end(), kv.second.begin(), kv.second.end());
}
return read_seq;
}
std::vector<AccessEntry> PlanReadBarrier(
std::vector<StmtEntry> seq, const For* loop) {
std::vector<AccessEntry> write_seq;
std::unordered_map<const Variable*, std::vector<AccessEntry> > read_set;
auto fupdate = [&](size_t i, const AccessEntry& acc) {
auto it = read_set.find(acc.buffer.get());
if (it != read_set.end()) {
CHECK_NE(i, seq.size());
barrier_before_[seq[i].stmt].push_back(
MakeBarrier(read_barrier_name_, it->second));
read_set.erase(it);
}
};
for (size_t i = seq.size(); i != 0; --i) {
const StmtEntry& s = seq[i - 1];
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && acc.type == kWrite) {
CHECK_NE(i, seq.size());
fupdate(i, acc);
write_seq.push_back(acc);
}
}
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() != 0 && acc.type == kRead) {
read_set[acc.buffer.get()].push_back(acc);
}
}
}
// loop carry
if (loop != nullptr) {
for (const AccessEntry& acc : write_seq) {
fupdate(0, acc);
}
}
for (const auto &kv : read_set) {
write_seq.insert(write_seq.end(), kv.second.begin(), kv.second.end());
}
return write_seq;
}
Stmt MakeBarrier(const std::string& func, const std::vector<AccessEntry>& wvec) {
// insert write point
Array<arith::IntSet> wset;
for (const AccessEntry& acc : wvec) {
CHECK(acc.dtype == wvec[0].dtype);
wset.push_back(acc.touched);
}
Range none;
Range r = arith::Union(wset).cover_range(none);
CHECK(r.defined())
<< "Cannot deduce write range of " << wvec[0].buffer;
Expr min = r->min;
Expr extent = r->extent;
return Evaluate::make(Call::make(
Int(32), func,
{wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, Call::Intrinsic));
}
// Write barrier name
bool read_barrier_{false};
std::string read_barrier_name_;
std::string write_barrier_name_;
const std::unordered_set<const Variable*>& touched_;
};
class CoProcSyncInserter : public IRMutator {
public:
Stmt Insert(Stmt stmt) {
CoProcTouchedBuffer visitor;
visitor.Visit(stmt);
if (visitor.coproc_.size() == 0) return stmt;
std::unordered_set<const Variable*> touched;
for (const auto &kv : visitor.touched_) {
if (kv.second.normal && kv.second.coproc) {
touched.insert(kv.first);
}
}
CHECK_EQ(visitor.coproc_.size(), 1U);
std::string coproc_name = (*visitor.coproc_.begin())->var->name_hint;
// plan sync.
CoProcSyncPlanner sync_planner(touched, coproc_name);
sync_planner.Plan(stmt);
for (const auto& kv : sync_planner.sync_) {
auto& vec = insert_after_[kv.first];
vec.insert(vec.end(), kv.second.begin(), kv.second.end());
}
// Detect barrier
CoProcBarrierDetector barrier_detector(touched, coproc_name);
barrier_detector.PlanReadBarrier(stmt);
barrier_detector.PlanWriteBarrier(stmt);
for (const auto& kv : barrier_detector.barrier_before_) {
auto& vec = insert_before_[kv.first];
vec.insert(vec.end(), kv.second.begin(), kv.second.end());
}
for (const auto& kv : barrier_detector.barrier_after_) {
auto& vec = insert_after_[kv.first];
vec.insert(vec.end(), kv.second.begin(), kv.second.end());
}
return Mutate(stmt);
}
Stmt Mutate(Stmt stmt) final {
Stmt before, after;
auto it = insert_before_.find(stmt.get());
if (it != insert_before_.end()) {
before = MergeSeq(it->second);
}
it = insert_after_.find(stmt.get());
if (it != insert_after_.end()) {
after = MergeSeq(it->second);
}
stmt = IRMutator::Mutate(stmt);
if (before.defined()) {
stmt = Block::make(before, stmt);
}
if (after.defined()) {
stmt = Block::make(stmt, after);
}
return stmt;
}
private:
std::unordered_map<const Node*, std::vector<Stmt> > insert_before_;
std::unordered_map<const Node*, std::vector<Stmt> > insert_after_;
};
Stmt CoProcSync(Stmt stmt) {
return CoProcSyncInserter().Insert(stmt);
}
} // namespace ir
} // namespace tvm
......@@ -6,16 +6,13 @@
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/target_info.h>
#include <unordered_set>
#include "./ir_util.h"
#include "../arithmetic/compute_expr.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
using runtime::StorageScope;
inline Expr ConstInt32(size_t index) {
CHECK_LE(index, std::numeric_limits<int>::max());
return make_const(Int(32), static_cast<int>(index));
......@@ -69,14 +66,7 @@ class BuiltinLower : public IRMutator {
// Lower allocate to device allocate when needed.
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
// For special memory, remove allocate.
auto it = storage_info_.find(op->buffer_var.get());
if (it != storage_info_.end() && it->second.scope.tag.length() != 0) {
++it->second.alloc_count;
CHECK_LE(it->second.alloc_count, 1)
<< "Double allocation of " << it->second.scope.to_string();
return op->body;
}
if (op->new_expr.defined()) return stmt;
// Get constant allocation bound.
int64_t dev_type;
int64_t nbytes = GetVectorBytes(op->type);
......@@ -139,25 +129,12 @@ class BuiltinLower : public IRMutator {
CHECK(!device_type_.defined());
device_type_ = op->value;
return Mutate(op->body);
} else if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
StorageEntry e;
e.scope = scope;
if (scope.tag.length() != 0) {
e.info = GetMemoryInfo(op->value.as<StringImm>()->value);
CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
}
storage_info_[buf] = e;
return IRMutator::Mutate_(op, s);
} else {
return IRMutator::Mutate_(op, s);
}
}
Expr Mutate_(const Call* op, const Expr &e) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
return MakeAccessPtr(op, e);
} else if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
return MakeCallPacked(op, e);
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) {
return MakeShape(op, e);
......@@ -167,14 +144,6 @@ class BuiltinLower : public IRMutator {
return IRMutator::Mutate_(op, e);
}
}
Expr Convert(Type t, Expr e) {
if (e.type() != t) {
return Cast::make(t, e);
} else {
return e;
}
}
// call shape
Expr MakeShape(const Call* op, const Expr& e) {
size_t stack_begin = run_shape_stack_;
......@@ -183,7 +152,7 @@ class BuiltinLower : public IRMutator {
op = expr.as<Call>();
for (size_t i = 0; i < op->args.size(); ++i) {
prep_seq_.emplace_back(
Store::make(stack_shape_, Convert(Int(64), op->args[i]),
Store::make(stack_shape_, cast(Int(64), op->args[i]),
ConstInt32(stack_begin +i), const_true(1)));
}
return AddressOffset(stack_shape_, Int(64), stack_begin);
......@@ -224,15 +193,15 @@ class BuiltinLower : public IRMutator {
}
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset,
Convert(UInt(64), byte_offset)));
cast(UInt(64), byte_offset)));
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId,
Convert(Int(32), device_id_)));
cast(Int(32), device_id_)));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType,
Convert(Int(32), device_type_)));
cast(Int(32), device_type_)));
return TVMStructGet(Handle(), stack_array_, idx, intrinsic::kArrAddr);
}
// call packled.
......@@ -280,33 +249,6 @@ class BuiltinLower : public IRMutator {
Int(32), intrinsic::tvm_call_packed_lowered,
packed_args, Call::Intrinsic);
}
// tvm_access_ptr
Expr MakeAccessPtr(const Call* op, const Expr& e) {
// Specially handle the buffer packed intrinsic
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
CHECK_EQ(op->args.size(), 5U);
Type dtype = op->args[0].type();
const Variable* buffer = op->args[1].as<Variable>();
Expr offset = op->args[2];
auto it = storage_info_.find(buffer);
if (it != storage_info_.end() && it->second.scope.tag.length() != 0) {
return MakeTaggedAccessPtr(
op->type, dtype, offset,
it->second.info.defined() ? it->second.info->unit_bits : 8);
}
CHECK(op->type.is_handle());
// Change to address_of
return AddressOffset(Var(op->args[1].node_), dtype, offset);
}
Expr MakeTaggedAccessPtr(Type ptr_type, Type dtype,
Expr offset, int unit_bits) {
int dtype_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(unit_bits % dtype_bits, 0);
return Convert(ptr_type,
ir::Simplify(offset / make_const(offset.type(), unit_bits / dtype_bits)));
}
private:
bool IsArrayHandle(const Expr& arg) {
......@@ -337,17 +279,6 @@ class BuiltinLower : public IRMutator {
uint64_t max_shape_stack_{0};
uint64_t max_array_stack_{0};
uint64_t max_arg_stack_{0};
// The storage entry.
struct StorageEntry {
// Whether it is tagged memory.
StorageScope scope;
// The memory info if any.
MemoryInfo info;
// Allocation counter
int alloc_count{0};
};
// The storage scope of each buffer
std::unordered_map<const Variable*, StorageEntry> storage_info_;
};
LoweredFunc LowerTVMBuiltin(LoweredFunc f) {
......
......@@ -2,7 +2,12 @@
* Copyright (c) 2017 by Contributors
* \file storage_access.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/target_info.h>
#include "./ir_util.h"
#include "./storage_access.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace ir {
......@@ -191,5 +196,110 @@ StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const {
if (it == storage_scope_.end()) return s;
return it->second;
}
class StorageAccessInfoLower : public IRMutator {
public:
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
// Lower allocate to device allocate when needed.
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
// For special memory, remove allocate, or use head expr
auto it = storage_info_.find(op->buffer_var.get());
if (it != storage_info_.end() && it->second.info.defined()) {
const MemoryInfo& info = it->second.info;
++it->second.alloc_count;
CHECK_LE(it->second.alloc_count, 1)
<< "Double allocation of " << it->second.scope.to_string();
if (info->head_address.defined()) {
return Allocate::make(
op->buffer_var, op->type, op->extents, op->condition,
op->body, info->head_address, "nop");
}
return op->body;
} else {
return stmt;
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
StorageEntry e;
e.scope = scope;
if (scope.tag.length() != 0) {
e.info = GetMemoryInfo(op->value.as<StringImm>()->value);
CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
}
storage_info_[buf] = e;
return IRMutator::Mutate_(op, s);
} else {
return IRMutator::Mutate_(op, s);
}
}
Expr Mutate_(const Call* op, const Expr &e) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
return MakeAccessPtr(op, e);
} else {
return IRMutator::Mutate_(op, e);
}
}
private:
// tvm_access_ptr
Expr MakeAccessPtr(const Call* op, const Expr& e) {
// Specially handle the buffer packed intrinsic
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
CHECK_EQ(op->args.size(), 5U);
Type dtype = op->args[0].type();
const Variable* buffer = op->args[1].as<Variable>();
Var buffer_var(op->args[1].node_);
Expr offset = op->args[2];
auto it = storage_info_.find(buffer);
if (it != storage_info_.end() && it->second.info.defined()) {
return MakeTaggedAccessPtr(
op->type, buffer_var, dtype, offset,
it->second.info);
}
CHECK(op->type.is_handle());
// Change to address_of
return AddressOffset(buffer_var, dtype, offset);
}
Expr MakeTaggedAccessPtr(Type ptr_type,
Var buffer_var,
Type dtype,
Expr offset,
const MemoryInfo& info) {
if (ptr_type.is_handle()) {
CHECK(info->head_address.defined())
<< buffer_var << " is not adddressable.";
return AddressOffset(buffer_var, dtype, offset);
}
int dtype_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(info->unit_bits % dtype_bits, 0);
return cast(ptr_type,
ir::Simplify(offset / make_const(
offset.type(), info->unit_bits / dtype_bits)));
}
// The storage entry.
struct StorageEntry {
// Whether it is tagged memory.
StorageScope scope;
// The memory info if any.
MemoryInfo info;
// Allocation counter
int alloc_count{0};
};
// The storage scope of each buffer
std::unordered_map<const Variable*, StorageEntry> storage_info_;
};
Stmt LowerStorageAccessInfo(Stmt stmt) {
return StorageAccessInfoLower().Mutate(stmt);
}
} // namespace ir
} // namespace tvm
......@@ -86,7 +86,6 @@ class StorageFlattener : public IRMutator {
return this->Mutate(op->body);
} else {
// create a buffer entry
// TODO(tqchen) allow permutation and inference of index dimension.
BufferEntry e;
e.bounds = op->bounds;
Array<Expr> shape;
......
......@@ -153,7 +153,6 @@ class ThreadSyncInserter : public IRMutator {
Stmt Mutate(Stmt stmt) final {
if (syncs_.size() == 0) return stmt;
stmt = IRMutator::Mutate(stmt);
if (syncs_.count(stmt.get())) {
Stmt barrier;
if (sync_scope_.rank == 0) {
......@@ -164,7 +163,11 @@ class ThreadSyncInserter : public IRMutator {
{StringImm::make(sync_scope_.to_string())},
Call::Intrinsic));
}
// Mutate after query, to avoid stmt change.
stmt = IRMutator::Mutate(stmt);
stmt = Block::make(barrier, stmt);
} else {
stmt = IRMutator::Mutate(stmt);
}
return stmt;
}
......@@ -296,201 +299,5 @@ LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
return LoweredFunc(n);
}
// Visitor to find touched set by co-processor scope.
class CoProcTouchedBuffer : public IRVisitor {
public:
void Visit_(const Load* op) final {
if (in_scope_) {
touched_.insert(op->buffer_var.get());
}
IRVisitor::Visit_(op);
}
void Visit_(const Store* op) final {
if (in_scope_) {
touched_.insert(op->buffer_var.get());
}
IRVisitor::Visit_(op);
}
void Visit_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr) && in_scope_) {
const Variable* buffer = op->args[1].as<Variable>();
touched_.insert(buffer);
}
IRVisitor::Visit_(op);
}
void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::coproc_scope && !in_scope_) {
in_scope_ = true;
IterVar iv(op->node.node_);
coproc_.insert(iv);
IRVisitor::Visit_(op);
in_scope_ = false;
} else {
IRVisitor::Visit_(op);
}
}
std::unordered_set<const Variable*> touched_;
std::unordered_set<IterVar> coproc_;
private:
bool in_scope_{false};
};
// Synchronization planning with co-processor.
class CoProcSyncPlanner : public StorageAccessVisitor {
public:
void Plan(const Stmt& stmt) {
CoProcTouchedBuffer visitor;
visitor.Visit(stmt);
touched_ = std::move(visitor.touched_);
if (!touched_.empty()) {
this->Visit(stmt);
PlanWriteSync(scope_.back(), nullptr, true);
CHECK_EQ(visitor.coproc_.size(), 1U);
if (write_sync_.size() == 0) {
write_sync_[stmt.get()] = GetWriteSync(
(*visitor.coproc_.begin())->var->name_hint + ".coproc_sync");
}
}
}
// Write synchronization to be inserted before or after stmt.
std::unordered_map<const Node*, std::vector<Stmt> > write_sync_;
protected:
bool Enabled(const Variable* buf,
const StorageScope& scope) const final {
return touched_.count(buf) && scope == global_scope_;
}
// Plan the sync
std::vector<AccessEntry> Summarize(
std::vector<StmtEntry> seq, const For* loop) final {
return PlanWriteSync(seq, loop, false);
}
private:
// Plan write synchronization if write is not coherent
std::vector<AccessEntry> PlanWriteSync(
std::vector<StmtEntry> seq, const For* loop,
bool force_sync_at_end) {
// detect write barriers
// access by the co-processor.
std::vector<AccessEntry> co_access;
bool contain_sync = false;
auto find_conflict = [&](const AccessEntry& acc) {
for (const AccessEntry& x : co_access) {
if (x.buffer.same_as(acc.buffer) &&
((acc.type == kRead && x.type == kWrite) ||
acc.type == kWrite)) {
return true;
}
}
return false;
};
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
bool sync_write = false;
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && find_conflict(acc)) {
sync_write = true; break;
}
if (acc.type == kSync) {
co_access.clear();
contain_sync = true;
}
}
if (sync_write) {
CHECK_NE(i, 0U);
write_sync_[seq[i - 1].stmt] = GetWriteSync(co_access);
co_access.clear();
contain_sync = true;
}
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() != 0) {
co_access.push_back(acc);
}
}
}
bool sync_at_end = force_sync_at_end;
if (loop != nullptr && !sync_at_end) {
// loop carray dependency
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && find_conflict(acc)) {
sync_at_end = true; break;
}
}
if (write_sync_.count(s.stmt) || sync_at_end) break;
}
}
if (sync_at_end && co_access.size() != 0) {
CHECK_NE(seq.size(), 0);
contain_sync = true;
write_sync_[seq.back().stmt] = GetWriteSync(co_access);
co_access.clear();
}
if (contain_sync) {
AccessEntry e;
e.type = kSync;
e.scope = global_scope_;
co_access.insert(co_access.begin(), e);
}
return co_access;
}
// Add write Synchronization
std::vector<Stmt> GetWriteSync(const std::vector<AccessEntry>& co_access) {
// Does not consider memory coherence, need runtime.
CHECK_NE(co_access.size(), 0U);
CHECK_EQ(co_access[0].threads.size(), 1U);
return GetWriteSync(co_access[0].threads[0]->var->name_hint + ".coproc_sync");
}
std::vector<Stmt> GetWriteSync(std::string sync_name) {
std::vector<Stmt> stmts;
stmts.emplace_back(
Evaluate::make(Call::make(
Int(32),
sync_name,
{}, Call::Intrinsic)));
return stmts;
}
std::unordered_set<const Variable*> touched_;
StorageScope global_scope_ = StorageScope::make("global");
};
class CoProcSyncInserter : public IRMutator {
public:
explicit CoProcSyncInserter(
const std::unordered_map<const Node*, std::vector<Stmt> >& write_sync)
: write_sync_(write_sync) {}
Stmt Mutate(Stmt stmt) final {
stmt = IRMutator::Mutate(stmt);
auto it = write_sync_.find(stmt.get());
if (it != write_sync_.end()) {
stmt = Block::make(stmt, MergeSeq(it->second));
}
return stmt;
}
private:
const std::unordered_map<const Node*, std::vector<Stmt> >& write_sync_;
};
Stmt CoProcSync(Stmt stmt) {
CoProcSyncPlanner planner;
planner.Plan(stmt);
if (planner.write_sync_.size() != 0) {
return CoProcSyncInserter(planner.write_sync_).Mutate(stmt);
} else {
return stmt;
}
}
} // namespace ir
} // namespace tvm
......@@ -32,16 +32,31 @@ def test_coproc_sync():
ib = tvm.ir_builder.create()
n = tvm.var("n")
cp = tvm.thread_axis((0, 1), "cop")
A = ib.allocate("float32", n, name="A", scope="global")
@tvm.register_func("tvm.info.mem.global.cache")
def meminfo_cache():
return tvm.make.node(
"MemoryInfo",
unit_bits=8,
max_simd_bits=32,
max_num_bits=128,
head_address=tvm.call_extern("handle", "global_cache"))
A = ib.allocate("float32", 128, name="A", scope="global.cache")
with ib.for_range(0, n, name="i") as i:
A[i] = A[i] + 1
with ib.for_range(0, 10, name="j") as j:
ib.scope_attr(cp, "coproc_scope", 1)
A[j] = A[j] + 2
body = ib.get()
body = tvm.ir_pass.CoProcSync(body)
body = body.body.body.body
assert(tvm.make.stmt_list(body)[-1].value.name == "cop.coproc_sync")
with ib.for_range(0, 8, name="k") as k:
with ib.for_range(0, 10, name="j") as j:
ib.scope_attr(cp, "coproc_scope", 1)
A[j] = A[j + k * 10] + 2
stmt = ib.get()
stmt = tvm.ir_pass.CoProcSync(stmt)
body = stmt.body.body.body
blist = tvm.make.stmt_list(body)
assert(blist[1].value.name == "cop.coproc_read_barrier")
assert(blist[1].value.args[3].value == 80)
assert(blist[-2].value.name == "cop.coproc_sync")
assert(blist[-1].value.name == "cop.coproc_write_barrier")
assert(blist[-1].value.args[3].value == 10)
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment