Commit 6588662f by Tianqi Chen Committed by GitHub

[RUNTIME] More reliable thread enumeration (#1017)

parent fcd32e9b
......@@ -44,7 +44,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
if (info.alignment > 16) {
info.alignment = 16;
}
if (info.scope.rank == 2) {
if (info.scope.rank == runtime::StorageRank::kLocal) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = builder_->CreateAlloca(
......@@ -54,7 +54,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
}
buf = alloca;
} else {
CHECK_EQ(info.scope.rank, 1)
CHECK(info.scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
......
......@@ -47,7 +47,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
if (info.alignment > 16) {
info.alignment = 16;
}
if (info.scope.rank == 2) {
if (info.scope.rank == runtime::StorageRank::kLocal) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = builder_->CreateAlloca(
......@@ -57,7 +57,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
}
buf = alloca;
} else {
CHECK_EQ(info.scope.rank, 1)
CHECK(info.scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
......
......@@ -561,13 +561,13 @@ void CodeGenSPIRV::VisitStmt_(const Allocate* op) {
spirv::Value buf;
StorageInfo& info = storage_info_[op->buffer_var.get()];
spirv::SType etype = builder_->GetSType(op->type);
if (info.scope.rank == 2) {
if (info.scope.rank == runtime::StorageRank::kLocal) {
buf = builder_->Allocate(
etype, static_cast<uint32_t>(constant_size),
spv::StorageClassFunction);
} else {
// shared memory
CHECK_EQ(info.scope.rank, 1)
CHECK(info.scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory
buf = builder_->Allocate(
......
......@@ -210,7 +210,8 @@ void StorageAccessVisitor::Visit_(const Call* op) {
StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const {
auto it = storage_scope_.find(buf);
StorageScope s; s.rank = 0;
StorageScope s;
s.rank = StorageRank::kGlobal;
if (it == storage_scope_.end()) return s;
return it->second;
}
......
......@@ -17,6 +17,7 @@ namespace tvm {
namespace ir {
using runtime::StorageScope;
using runtime::StorageRank;
/*!
* \brief Base class of storage access analysis
*/
......
......@@ -23,6 +23,7 @@ namespace tvm {
namespace ir {
using HalideIR::Internal::Region;
using runtime::StorageRank;
using runtime::StorageScope;
using runtime::ThreadScope;
using intrinsic::tvm_address_of;
......@@ -141,7 +142,8 @@ class StorageFlattener : public IRMutator {
const std::string& strkey = it->second;
if (strkey.length() == 0) {
if (curr_thread_scope_.size() != 0) {
skey.rank = curr_thread_scope_.back().rank + 1;
skey.rank = runtime::DefaultStorageRank(
curr_thread_scope_.back().rank);
}
} else {
skey = StorageScope::make(strkey);
......
......@@ -19,6 +19,7 @@
namespace tvm {
namespace ir {
using runtime::StorageRank;
using runtime::StorageScope;
// Find a linear pattern of storage acess
......@@ -794,7 +795,7 @@ class StoragePlanRewriter : public IRMutator {
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
if (scope.tag.length() == 0) {
if (scope.rank > 1 || op->type.is_handle()) {
if (scope.rank >= StorageRank::kWarp || op->type.is_handle()) {
return NewAlloc(op, attach_scope, scope, const_nbits);
}
if (const_nbits > 0 && const_nbits <= 32) {
......@@ -853,7 +854,8 @@ class StoragePlanRewriter : public IRMutator {
// This rules only apply if we are using non special memory
if (e->scope.tag.length() == 0) {
// Disable sharing of local memory.
if (e->scope.rank > 1 || e->allocs[0]->type.is_handle()) return;
if (e->scope.rank >= StorageRank::kWarp ||
e->allocs[0]->type.is_handle()) return;
// disable reuse of small arrays
if (e->const_nbits > 0 && e->const_nbits <= 32) return;
}
......
......@@ -189,7 +189,7 @@ class ThreadSyncInserter : public IRMutator {
if (syncs_.size() == 0) return stmt;
if (syncs_.count(stmt.get())) {
Stmt barrier;
if (sync_scope_.rank == 0) {
if (sync_scope_.rank == StorageRank::kGlobal) {
barrier = MakeGlobalBarrier();
} else {
barrier = Evaluate::make(
......@@ -206,15 +206,15 @@ class ThreadSyncInserter : public IRMutator {
return stmt;
}
Expr Mutate_(const Load* op, const Expr& e) final {
if (sync_scope_.rank == 0 &&
GetScope(op->buffer_var.get()).rank == 0) {
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
++rw_stats_[op->buffer_var].read_count;
}
return IRMutator::Mutate_(op, e);
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
if (sync_scope_.rank == 0 &&
GetScope(op->buffer_var.get()).rank == 0) {
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
++rw_stats_[op->buffer_var].write_count;
}
return IRMutator::Mutate_(op, s);
......@@ -228,7 +228,7 @@ class ThreadSyncInserter : public IRMutator {
thread_extents_.pop_back();
std::swap(temp, in_thread_env_);
// first thread scope.
if (!in_thread_env_ && sync_scope_.rank == 0) {
if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) {
ret = InitGlobalBarrier(ret.as<AttrStmt>());
num_blocks_ = Expr();
is_lead_ = Expr();
......@@ -253,7 +253,8 @@ class ThreadSyncInserter : public IRMutator {
// Get current storage scope.
StorageScope GetScope(const Variable* buf) const {
auto it = storage_scope_.find(buf);
StorageScope s; s.rank = 0;
StorageScope s;
s.rank = StorageRank::kGlobal;
if (it == storage_scope_.end()) return s;
return it->second;
}
......@@ -279,7 +280,7 @@ class ThreadSyncInserter : public IRMutator {
return Block::make(prep, body);
}
Stmt MakeGlobalBarrier() {
CHECK_EQ(sync_scope_.rank, 0);
CHECK(sync_scope_.rank == StorageRank::kGlobal);
if (!num_blocks_.defined()) {
CHECK(!is_lead_.defined());
num_work_dim_ = thread_extents_.size();
......
......@@ -13,10 +13,47 @@
namespace tvm {
namespace runtime {
/*!
* \brief Memory hierachy rank in the storage system
* \note The global rank and shared rank have one to one
* correspondence to the thread rank.
*/
enum class StorageRank {
/*! \brief global memory */
kGlobal = 0,
/*! \brief shared memory among thread group */
kShared = 1,
/*!
* \brief reserved for warp memory.
* This is only used by programming model.
* There is no such memory usually in GPU.
* Instead, we can simulate it by registers and shuffle.
*/
kWarp = 2,
/*! \brief thread local memory */
kLocal = 3
};
/*!
* \param thread_scope_rank The thread scope rank
* \return default storage rank given the thread scope
*/
inline StorageRank DefaultStorageRank(int thread_scope_rank) {
switch (thread_scope_rank) {
case -1: return StorageRank::kGlobal;
case 0: return StorageRank::kShared;
case 1: return StorageRank::kLocal;
default: {
LOG(FATAL) << "unknown rank";
return StorageRank::kGlobal;
}
}
}
/*! \brief class to represent storage scope */
struct StorageScope {
/*! \brief The rank of the storage */
int rank{0};
StorageRank rank{StorageRank::kGlobal};
/*! \brief tag for special purpose memory. */
std::string tag;
// comparator
......@@ -29,9 +66,10 @@ struct StorageScope {
inline std::string to_string() const {
std::string ret;
switch (rank) {
case 0: return "global" + tag;
case 1: return "shared" + tag;
case 2: return "local" + tag;
case StorageRank::kGlobal: return "global" + tag;
case StorageRank::kShared: return "shared" + tag;
case StorageRank::kWarp: return "warp" + tag;
case StorageRank::kLocal: return "local" + tag;
default: LOG(FATAL) << "unknown storage scope"; return "";
}
}
......@@ -43,13 +81,16 @@ struct StorageScope {
static StorageScope make(const std::string& s) {
StorageScope r;
if (s.compare(0, 6, "global") == 0) {
r.rank = 0;
r.rank = StorageRank::kGlobal;
r.tag = s.substr(6, std::string::npos);
} else if (s.compare(0, 6, "shared") == 0) {
r.rank = 1;
r.rank = StorageRank::kShared;
r.tag = s.substr(6, std::string::npos);
} else if (s.compare(0, 4, "warp") == 0) {
r.rank = StorageRank::kWarp;
r.tag = s.substr(4, std::string::npos);
} else if (s.compare(0, 5, "local") == 0) {
r.rank = 2;
r.rank = StorageRank::kLocal;
r.tag = s.substr(5, std::string::npos);
} else {
LOG(FATAL) << "unknown storage scope " << s;
......
......@@ -16,8 +16,9 @@
namespace tvm {
namespace schedule {
using runtime::ThreadScope;
using runtime::StorageRank;
using runtime::StorageScope;
using runtime::ThreadScope;
/*! \brief The graph context used during bound inference. */
struct GraphContext {
......@@ -41,7 +42,7 @@ bool NeedRelax(const IterVar& iv,
if (tag.length() == 0 || tag == "pipeline") {
return !found_attach;
}
return scope.rank <= ThreadScope::make(tag).rank;
return static_cast<int>(scope.rank) <= ThreadScope::make(tag).rank;
}
// infer storage scope, if not given
......@@ -50,16 +51,17 @@ StorageScope InferStorageScope(
if (stage->scope.length() != 0) {
return StorageScope::make(stage->scope);
}
int max_rank = 0;
int max_rank = -1;
for (IterVar iv : ctx.attach_path.at(stage->op)) {
auto it = ctx.bind_map.find(iv);
const std::string& tag = (
it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag);
if (tag != "pipeline" && tag.length() != 0) {
max_rank = std::max(max_rank, ThreadScope::make(tag).rank + 1);
max_rank = std::max(max_rank, ThreadScope::make(tag).rank);
}
}
StorageScope s; s.rank = max_rank;
StorageScope s;
s.rank = runtime::DefaultStorageRank(max_rank);
return s;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment