Commit 79e482bc by Tianqi Chen Committed by GitHub

[PASS] Memory barrier detection, storage access lower. (#317)

parent afa20869
......@@ -267,6 +267,15 @@ Stmt CoProcSync(Stmt stmt);
Stmt LiftAttrScope(Stmt stmt, std::string attr_key);
/*!
* \brief Lower attached storage access information.
* Do this pass after all storage access analysis finish.
*
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt LowerStorageAccessInfo(Stmt stmt);
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
......
......@@ -23,11 +23,17 @@ struct MemoryInfoNode : public Node {
int max_num_bits;
/*! \brief maximum number of bits to be used in simd op */
int max_simd_bits;
/*!
* \brief head address of the buffer, if visible to CPU
* This address can be None.
*/
Expr head_address;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("unit_bits", &unit_bits);
v->Visit("max_num_bits", &max_num_bits);
v->Visit("max_simd_bits", &max_simd_bits);
v->Visit("head_address", &head_address);
}
static constexpr const char* _type_key = "MemoryInfo";
......
......@@ -197,7 +197,6 @@ def lower(sch,
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.StorageRewrite(stmt)
stmt = ir_pass.CoProcSync(stmt)
cfg = BuildConfig.current
stmt = ir_pass.UnrollLoop(
stmt,
......@@ -210,6 +209,7 @@ def lower(sch,
stmt = ir_pass.Simplify(stmt)
if simple_mode:
return stmt
stmt = ir_pass.LowerStorageAccessInfo(stmt)
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
......
......@@ -95,6 +95,7 @@ REGISTER_PASS2(BindDeviceType);
REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(StorageRewrite);
REGISTER_PASS1(CoProcSync);
REGISTER_PASS1(LowerStorageAccessInfo);
REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(InjectPrefetch);
REGISTER_PASS1(LoopPartition);
......
......@@ -6,16 +6,13 @@
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/target_info.h>
#include <unordered_set>
#include "./ir_util.h"
#include "../arithmetic/compute_expr.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
using runtime::StorageScope;
inline Expr ConstInt32(size_t index) {
CHECK_LE(index, std::numeric_limits<int>::max());
return make_const(Int(32), static_cast<int>(index));
......@@ -69,14 +66,7 @@ class BuiltinLower : public IRMutator {
// Lower allocate to device allocate when needed.
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
// For special memory, remove allocate.
auto it = storage_info_.find(op->buffer_var.get());
if (it != storage_info_.end() && it->second.scope.tag.length() != 0) {
++it->second.alloc_count;
CHECK_LE(it->second.alloc_count, 1)
<< "Double allocation of " << it->second.scope.to_string();
return op->body;
}
if (op->new_expr.defined()) return stmt;
// Get constant allocation bound.
int64_t dev_type;
int64_t nbytes = GetVectorBytes(op->type);
......@@ -139,25 +129,12 @@ class BuiltinLower : public IRMutator {
CHECK(!device_type_.defined());
device_type_ = op->value;
return Mutate(op->body);
} else if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
StorageEntry e;
e.scope = scope;
if (scope.tag.length() != 0) {
e.info = GetMemoryInfo(op->value.as<StringImm>()->value);
CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
}
storage_info_[buf] = e;
return IRMutator::Mutate_(op, s);
} else {
return IRMutator::Mutate_(op, s);
}
}
Expr Mutate_(const Call* op, const Expr &e) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
return MakeAccessPtr(op, e);
} else if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
return MakeCallPacked(op, e);
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) {
return MakeShape(op, e);
......@@ -167,14 +144,6 @@ class BuiltinLower : public IRMutator {
return IRMutator::Mutate_(op, e);
}
}
Expr Convert(Type t, Expr e) {
if (e.type() != t) {
return Cast::make(t, e);
} else {
return e;
}
}
// call shape
Expr MakeShape(const Call* op, const Expr& e) {
size_t stack_begin = run_shape_stack_;
......@@ -183,7 +152,7 @@ class BuiltinLower : public IRMutator {
op = expr.as<Call>();
for (size_t i = 0; i < op->args.size(); ++i) {
prep_seq_.emplace_back(
Store::make(stack_shape_, Convert(Int(64), op->args[i]),
Store::make(stack_shape_, cast(Int(64), op->args[i]),
ConstInt32(stack_begin +i), const_true(1)));
}
return AddressOffset(stack_shape_, Int(64), stack_begin);
......@@ -224,15 +193,15 @@ class BuiltinLower : public IRMutator {
}
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset,
Convert(UInt(64), byte_offset)));
cast(UInt(64), byte_offset)));
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId,
Convert(Int(32), device_id_)));
cast(Int(32), device_id_)));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType,
Convert(Int(32), device_type_)));
cast(Int(32), device_type_)));
return TVMStructGet(Handle(), stack_array_, idx, intrinsic::kArrAddr);
}
// call packled.
......@@ -280,33 +249,6 @@ class BuiltinLower : public IRMutator {
Int(32), intrinsic::tvm_call_packed_lowered,
packed_args, Call::Intrinsic);
}
// tvm_access_ptr
Expr MakeAccessPtr(const Call* op, const Expr& e) {
// Specially handle the buffer packed intrinsic
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
CHECK_EQ(op->args.size(), 5U);
Type dtype = op->args[0].type();
const Variable* buffer = op->args[1].as<Variable>();
Expr offset = op->args[2];
auto it = storage_info_.find(buffer);
if (it != storage_info_.end() && it->second.scope.tag.length() != 0) {
return MakeTaggedAccessPtr(
op->type, dtype, offset,
it->second.info.defined() ? it->second.info->unit_bits : 8);
}
CHECK(op->type.is_handle());
// Change to address_of
return AddressOffset(Var(op->args[1].node_), dtype, offset);
}
Expr MakeTaggedAccessPtr(Type ptr_type, Type dtype,
Expr offset, int unit_bits) {
int dtype_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(unit_bits % dtype_bits, 0);
return Convert(ptr_type,
ir::Simplify(offset / make_const(offset.type(), unit_bits / dtype_bits)));
}
private:
bool IsArrayHandle(const Expr& arg) {
......@@ -337,17 +279,6 @@ class BuiltinLower : public IRMutator {
uint64_t max_shape_stack_{0};
uint64_t max_array_stack_{0};
uint64_t max_arg_stack_{0};
// The storage entry.
struct StorageEntry {
// Whether it is tagged memory.
StorageScope scope;
// The memory info if any.
MemoryInfo info;
// Allocation counter
int alloc_count{0};
};
// The storage scope of each buffer
std::unordered_map<const Variable*, StorageEntry> storage_info_;
};
LoweredFunc LowerTVMBuiltin(LoweredFunc f) {
......
......@@ -2,7 +2,12 @@
* Copyright (c) 2017 by Contributors
* \file storage_access.cc
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/target_info.h>
#include "./ir_util.h"
#include "./storage_access.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace ir {
......@@ -191,5 +196,110 @@ StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const {
if (it == storage_scope_.end()) return s;
return it->second;
}
class StorageAccessInfoLower : public IRMutator {
public:
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
// Lower allocate to device allocate when needed.
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
// For special memory, remove allocate, or use head expr
auto it = storage_info_.find(op->buffer_var.get());
if (it != storage_info_.end() && it->second.info.defined()) {
const MemoryInfo& info = it->second.info;
++it->second.alloc_count;
CHECK_LE(it->second.alloc_count, 1)
<< "Double allocation of " << it->second.scope.to_string();
if (info->head_address.defined()) {
return Allocate::make(
op->buffer_var, op->type, op->extents, op->condition,
op->body, info->head_address, "nop");
}
return op->body;
} else {
return stmt;
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
StorageEntry e;
e.scope = scope;
if (scope.tag.length() != 0) {
e.info = GetMemoryInfo(op->value.as<StringImm>()->value);
CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
}
storage_info_[buf] = e;
return IRMutator::Mutate_(op, s);
} else {
return IRMutator::Mutate_(op, s);
}
}
Expr Mutate_(const Call* op, const Expr &e) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
return MakeAccessPtr(op, e);
} else {
return IRMutator::Mutate_(op, e);
}
}
private:
// tvm_access_ptr
Expr MakeAccessPtr(const Call* op, const Expr& e) {
// Specially handle the buffer packed intrinsic
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
CHECK_EQ(op->args.size(), 5U);
Type dtype = op->args[0].type();
const Variable* buffer = op->args[1].as<Variable>();
Var buffer_var(op->args[1].node_);
Expr offset = op->args[2];
auto it = storage_info_.find(buffer);
if (it != storage_info_.end() && it->second.info.defined()) {
return MakeTaggedAccessPtr(
op->type, buffer_var, dtype, offset,
it->second.info);
}
CHECK(op->type.is_handle());
// Change to address_of
return AddressOffset(buffer_var, dtype, offset);
}
Expr MakeTaggedAccessPtr(Type ptr_type,
Var buffer_var,
Type dtype,
Expr offset,
const MemoryInfo& info) {
if (ptr_type.is_handle()) {
CHECK(info->head_address.defined())
<< buffer_var << " is not adddressable.";
return AddressOffset(buffer_var, dtype, offset);
}
int dtype_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(info->unit_bits % dtype_bits, 0);
return cast(ptr_type,
ir::Simplify(offset / make_const(
offset.type(), info->unit_bits / dtype_bits)));
}
// The storage entry.
struct StorageEntry {
// Whether it is tagged memory.
StorageScope scope;
// The memory info if any.
MemoryInfo info;
// Allocation counter
int alloc_count{0};
};
// The storage scope of each buffer
std::unordered_map<const Variable*, StorageEntry> storage_info_;
};
Stmt LowerStorageAccessInfo(Stmt stmt) {
return StorageAccessInfoLower().Mutate(stmt);
}
} // namespace ir
} // namespace tvm
......@@ -86,7 +86,6 @@ class StorageFlattener : public IRMutator {
return this->Mutate(op->body);
} else {
// create a buffer entry
// TODO(tqchen) allow permutation and inference of index dimension.
BufferEntry e;
e.bounds = op->bounds;
Array<Expr> shape;
......
......@@ -153,7 +153,6 @@ class ThreadSyncInserter : public IRMutator {
Stmt Mutate(Stmt stmt) final {
if (syncs_.size() == 0) return stmt;
stmt = IRMutator::Mutate(stmt);
if (syncs_.count(stmt.get())) {
Stmt barrier;
if (sync_scope_.rank == 0) {
......@@ -164,7 +163,11 @@ class ThreadSyncInserter : public IRMutator {
{StringImm::make(sync_scope_.to_string())},
Call::Intrinsic));
}
// Mutate after query, to avoid stmt change.
stmt = IRMutator::Mutate(stmt);
stmt = Block::make(barrier, stmt);
} else {
stmt = IRMutator::Mutate(stmt);
}
return stmt;
}
......@@ -296,201 +299,5 @@ LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
return LoweredFunc(n);
}
// Visitor to find touched set by co-processor scope.
class CoProcTouchedBuffer : public IRVisitor {
public:
void Visit_(const Load* op) final {
if (in_scope_) {
touched_.insert(op->buffer_var.get());
}
IRVisitor::Visit_(op);
}
void Visit_(const Store* op) final {
if (in_scope_) {
touched_.insert(op->buffer_var.get());
}
IRVisitor::Visit_(op);
}
void Visit_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr) && in_scope_) {
const Variable* buffer = op->args[1].as<Variable>();
touched_.insert(buffer);
}
IRVisitor::Visit_(op);
}
void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::coproc_scope && !in_scope_) {
in_scope_ = true;
IterVar iv(op->node.node_);
coproc_.insert(iv);
IRVisitor::Visit_(op);
in_scope_ = false;
} else {
IRVisitor::Visit_(op);
}
}
std::unordered_set<const Variable*> touched_;
std::unordered_set<IterVar> coproc_;
private:
bool in_scope_{false};
};
// Synchronization planning with co-processor.
class CoProcSyncPlanner : public StorageAccessVisitor {
public:
void Plan(const Stmt& stmt) {
CoProcTouchedBuffer visitor;
visitor.Visit(stmt);
touched_ = std::move(visitor.touched_);
if (!touched_.empty()) {
this->Visit(stmt);
PlanWriteSync(scope_.back(), nullptr, true);
CHECK_EQ(visitor.coproc_.size(), 1U);
if (write_sync_.size() == 0) {
write_sync_[stmt.get()] = GetWriteSync(
(*visitor.coproc_.begin())->var->name_hint + ".coproc_sync");
}
}
}
// Write synchronization to be inserted before or after stmt.
std::unordered_map<const Node*, std::vector<Stmt> > write_sync_;
protected:
bool Enabled(const Variable* buf,
const StorageScope& scope) const final {
return touched_.count(buf) && scope == global_scope_;
}
// Plan the sync
std::vector<AccessEntry> Summarize(
std::vector<StmtEntry> seq, const For* loop) final {
return PlanWriteSync(seq, loop, false);
}
private:
// Plan write synchronization if write is not coherent
std::vector<AccessEntry> PlanWriteSync(
std::vector<StmtEntry> seq, const For* loop,
bool force_sync_at_end) {
// detect write barriers
// access by the co-processor.
std::vector<AccessEntry> co_access;
bool contain_sync = false;
auto find_conflict = [&](const AccessEntry& acc) {
for (const AccessEntry& x : co_access) {
if (x.buffer.same_as(acc.buffer) &&
((acc.type == kRead && x.type == kWrite) ||
acc.type == kWrite)) {
return true;
}
}
return false;
};
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
bool sync_write = false;
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && find_conflict(acc)) {
sync_write = true; break;
}
if (acc.type == kSync) {
co_access.clear();
contain_sync = true;
}
}
if (sync_write) {
CHECK_NE(i, 0U);
write_sync_[seq[i - 1].stmt] = GetWriteSync(co_access);
co_access.clear();
contain_sync = true;
}
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() != 0) {
co_access.push_back(acc);
}
}
}
bool sync_at_end = force_sync_at_end;
if (loop != nullptr && !sync_at_end) {
// loop carray dependency
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
for (const AccessEntry& acc : s.access) {
if (acc.threads.size() == 0 && find_conflict(acc)) {
sync_at_end = true; break;
}
}
if (write_sync_.count(s.stmt) || sync_at_end) break;
}
}
if (sync_at_end && co_access.size() != 0) {
CHECK_NE(seq.size(), 0);
contain_sync = true;
write_sync_[seq.back().stmt] = GetWriteSync(co_access);
co_access.clear();
}
if (contain_sync) {
AccessEntry e;
e.type = kSync;
e.scope = global_scope_;
co_access.insert(co_access.begin(), e);
}
return co_access;
}
// Add write Synchronization
std::vector<Stmt> GetWriteSync(const std::vector<AccessEntry>& co_access) {
// Does not consider memory coherence, need runtime.
CHECK_NE(co_access.size(), 0U);
CHECK_EQ(co_access[0].threads.size(), 1U);
return GetWriteSync(co_access[0].threads[0]->var->name_hint + ".coproc_sync");
}
std::vector<Stmt> GetWriteSync(std::string sync_name) {
std::vector<Stmt> stmts;
stmts.emplace_back(
Evaluate::make(Call::make(
Int(32),
sync_name,
{}, Call::Intrinsic)));
return stmts;
}
std::unordered_set<const Variable*> touched_;
StorageScope global_scope_ = StorageScope::make("global");
};
class CoProcSyncInserter : public IRMutator {
public:
explicit CoProcSyncInserter(
const std::unordered_map<const Node*, std::vector<Stmt> >& write_sync)
: write_sync_(write_sync) {}
Stmt Mutate(Stmt stmt) final {
stmt = IRMutator::Mutate(stmt);
auto it = write_sync_.find(stmt.get());
if (it != write_sync_.end()) {
stmt = Block::make(stmt, MergeSeq(it->second));
}
return stmt;
}
private:
const std::unordered_map<const Node*, std::vector<Stmt> >& write_sync_;
};
Stmt CoProcSync(Stmt stmt) {
CoProcSyncPlanner planner;
planner.Plan(stmt);
if (planner.write_sync_.size() != 0) {
return CoProcSyncInserter(planner.write_sync_).Mutate(stmt);
} else {
return stmt;
}
}
} // namespace ir
} // namespace tvm
......@@ -32,16 +32,31 @@ def test_coproc_sync():
ib = tvm.ir_builder.create()
n = tvm.var("n")
cp = tvm.thread_axis((0, 1), "cop")
A = ib.allocate("float32", n, name="A", scope="global")
@tvm.register_func("tvm.info.mem.global.cache")
def meminfo_cache():
return tvm.make.node(
"MemoryInfo",
unit_bits=8,
max_simd_bits=32,
max_num_bits=128,
head_address=tvm.call_extern("handle", "global_cache"))
A = ib.allocate("float32", 128, name="A", scope="global.cache")
with ib.for_range(0, n, name="i") as i:
A[i] = A[i] + 1
with ib.for_range(0, 10, name="j") as j:
ib.scope_attr(cp, "coproc_scope", 1)
A[j] = A[j] + 2
body = ib.get()
body = tvm.ir_pass.CoProcSync(body)
body = body.body.body.body
assert(tvm.make.stmt_list(body)[-1].value.name == "cop.coproc_sync")
with ib.for_range(0, 8, name="k") as k:
with ib.for_range(0, 10, name="j") as j:
ib.scope_attr(cp, "coproc_scope", 1)
A[j] = A[j + k * 10] + 2
stmt = ib.get()
stmt = tvm.ir_pass.CoProcSync(stmt)
body = stmt.body.body.body
blist = tvm.make.stmt_list(body)
assert(blist[1].value.name == "cop.coproc_read_barrier")
assert(blist[1].value.args[3].value == 80)
assert(blist[-2].value.name == "cop.coproc_sync")
assert(blist[-1].value.name == "cop.coproc_write_barrier")
assert(blist[-1].value.args[3].value == 10)
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment