Commit 6588662f by Tianqi Chen Committed by GitHub

[RUNTIME] More reliable thread enumeration (#1017)

parent fcd32e9b
...@@ -44,7 +44,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { ...@@ -44,7 +44,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
if (info.alignment > 16) { if (info.alignment > 16) {
info.alignment = 16; info.alignment = 16;
} }
if (info.scope.rank == 2) { if (info.scope.rank == runtime::StorageRank::kLocal) {
// const int local_address_space = 5; // const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set. // TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = builder_->CreateAlloca( llvm::AllocaInst* alloca = builder_->CreateAlloca(
...@@ -54,7 +54,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { ...@@ -54,7 +54,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
} }
buf = alloca; buf = alloca;
} else { } else {
CHECK_EQ(info.scope.rank, 1) CHECK(info.scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel"; << "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3 // Shared memory: address space == 3
const unsigned shared_address_space = 3; const unsigned shared_address_space = 3;
......
...@@ -47,7 +47,7 @@ class CodeGenNVPTX : public CodeGenLLVM { ...@@ -47,7 +47,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
if (info.alignment > 16) { if (info.alignment > 16) {
info.alignment = 16; info.alignment = 16;
} }
if (info.scope.rank == 2) { if (info.scope.rank == runtime::StorageRank::kLocal) {
// const int local_address_space = 5; // const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set. // TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = builder_->CreateAlloca( llvm::AllocaInst* alloca = builder_->CreateAlloca(
...@@ -57,7 +57,7 @@ class CodeGenNVPTX : public CodeGenLLVM { ...@@ -57,7 +57,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
} }
buf = alloca; buf = alloca;
} else { } else {
CHECK_EQ(info.scope.rank, 1) CHECK(info.scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel"; << "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3 // Shared memory: address space == 3
const unsigned shared_address_space = 3; const unsigned shared_address_space = 3;
......
...@@ -561,13 +561,13 @@ void CodeGenSPIRV::VisitStmt_(const Allocate* op) { ...@@ -561,13 +561,13 @@ void CodeGenSPIRV::VisitStmt_(const Allocate* op) {
spirv::Value buf; spirv::Value buf;
StorageInfo& info = storage_info_[op->buffer_var.get()]; StorageInfo& info = storage_info_[op->buffer_var.get()];
spirv::SType etype = builder_->GetSType(op->type); spirv::SType etype = builder_->GetSType(op->type);
if (info.scope.rank == 2) { if (info.scope.rank == runtime::StorageRank::kLocal) {
buf = builder_->Allocate( buf = builder_->Allocate(
etype, static_cast<uint32_t>(constant_size), etype, static_cast<uint32_t>(constant_size),
spv::StorageClassFunction); spv::StorageClassFunction);
} else { } else {
// shared memory // shared memory
CHECK_EQ(info.scope.rank, 1) CHECK(info.scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel"; << "Can only allocate shared or local memory inside kernel";
// Shared memory // Shared memory
buf = builder_->Allocate( buf = builder_->Allocate(
......
...@@ -210,7 +210,8 @@ void StorageAccessVisitor::Visit_(const Call* op) { ...@@ -210,7 +210,8 @@ void StorageAccessVisitor::Visit_(const Call* op) {
StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const { StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const {
auto it = storage_scope_.find(buf); auto it = storage_scope_.find(buf);
StorageScope s; s.rank = 0; StorageScope s;
s.rank = StorageRank::kGlobal;
if (it == storage_scope_.end()) return s; if (it == storage_scope_.end()) return s;
return it->second; return it->second;
} }
......
...@@ -17,6 +17,7 @@ namespace tvm { ...@@ -17,6 +17,7 @@ namespace tvm {
namespace ir { namespace ir {
using runtime::StorageScope; using runtime::StorageScope;
using runtime::StorageRank;
/*! /*!
* \brief Base class of storage access analysis * \brief Base class of storage access analysis
*/ */
......
...@@ -23,6 +23,7 @@ namespace tvm { ...@@ -23,6 +23,7 @@ namespace tvm {
namespace ir { namespace ir {
using HalideIR::Internal::Region; using HalideIR::Internal::Region;
using runtime::StorageRank;
using runtime::StorageScope; using runtime::StorageScope;
using runtime::ThreadScope; using runtime::ThreadScope;
using intrinsic::tvm_address_of; using intrinsic::tvm_address_of;
...@@ -141,7 +142,8 @@ class StorageFlattener : public IRMutator { ...@@ -141,7 +142,8 @@ class StorageFlattener : public IRMutator {
const std::string& strkey = it->second; const std::string& strkey = it->second;
if (strkey.length() == 0) { if (strkey.length() == 0) {
if (curr_thread_scope_.size() != 0) { if (curr_thread_scope_.size() != 0) {
skey.rank = curr_thread_scope_.back().rank + 1; skey.rank = runtime::DefaultStorageRank(
curr_thread_scope_.back().rank);
} }
} else { } else {
skey = StorageScope::make(strkey); skey = StorageScope::make(strkey);
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
using runtime::StorageRank;
using runtime::StorageScope; using runtime::StorageScope;
// Find a linear pattern of storage acess // Find a linear pattern of storage acess
...@@ -794,7 +795,7 @@ class StoragePlanRewriter : public IRMutator { ...@@ -794,7 +795,7 @@ class StoragePlanRewriter : public IRMutator {
// disable reuse of small arrays, they will be lowered to registers in LLVM // disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory // This rules only apply if we are using non special memory
if (scope.tag.length() == 0) { if (scope.tag.length() == 0) {
if (scope.rank > 1 || op->type.is_handle()) { if (scope.rank >= StorageRank::kWarp || op->type.is_handle()) {
return NewAlloc(op, attach_scope, scope, const_nbits); return NewAlloc(op, attach_scope, scope, const_nbits);
} }
if (const_nbits > 0 && const_nbits <= 32) { if (const_nbits > 0 && const_nbits <= 32) {
...@@ -853,7 +854,8 @@ class StoragePlanRewriter : public IRMutator { ...@@ -853,7 +854,8 @@ class StoragePlanRewriter : public IRMutator {
// This rules only apply if we are using non special memory // This rules only apply if we are using non special memory
if (e->scope.tag.length() == 0) { if (e->scope.tag.length() == 0) {
// Disable sharing of local memory. // Disable sharing of local memory.
if (e->scope.rank > 1 || e->allocs[0]->type.is_handle()) return; if (e->scope.rank >= StorageRank::kWarp ||
e->allocs[0]->type.is_handle()) return;
// disable reuse of small arrays // disable reuse of small arrays
if (e->const_nbits > 0 && e->const_nbits <= 32) return; if (e->const_nbits > 0 && e->const_nbits <= 32) return;
} }
......
...@@ -189,7 +189,7 @@ class ThreadSyncInserter : public IRMutator { ...@@ -189,7 +189,7 @@ class ThreadSyncInserter : public IRMutator {
if (syncs_.size() == 0) return stmt; if (syncs_.size() == 0) return stmt;
if (syncs_.count(stmt.get())) { if (syncs_.count(stmt.get())) {
Stmt barrier; Stmt barrier;
if (sync_scope_.rank == 0) { if (sync_scope_.rank == StorageRank::kGlobal) {
barrier = MakeGlobalBarrier(); barrier = MakeGlobalBarrier();
} else { } else {
barrier = Evaluate::make( barrier = Evaluate::make(
...@@ -206,15 +206,15 @@ class ThreadSyncInserter : public IRMutator { ...@@ -206,15 +206,15 @@ class ThreadSyncInserter : public IRMutator {
return stmt; return stmt;
} }
Expr Mutate_(const Load* op, const Expr& e) final { Expr Mutate_(const Load* op, const Expr& e) final {
if (sync_scope_.rank == 0 && if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(op->buffer_var.get()).rank == 0) { GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
++rw_stats_[op->buffer_var].read_count; ++rw_stats_[op->buffer_var].read_count;
} }
return IRMutator::Mutate_(op, e); return IRMutator::Mutate_(op, e);
} }
Stmt Mutate_(const Store* op, const Stmt& s) final { Stmt Mutate_(const Store* op, const Stmt& s) final {
if (sync_scope_.rank == 0 && if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(op->buffer_var.get()).rank == 0) { GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
++rw_stats_[op->buffer_var].write_count; ++rw_stats_[op->buffer_var].write_count;
} }
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
...@@ -228,7 +228,7 @@ class ThreadSyncInserter : public IRMutator { ...@@ -228,7 +228,7 @@ class ThreadSyncInserter : public IRMutator {
thread_extents_.pop_back(); thread_extents_.pop_back();
std::swap(temp, in_thread_env_); std::swap(temp, in_thread_env_);
// first thread scope. // first thread scope.
if (!in_thread_env_ && sync_scope_.rank == 0) { if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) {
ret = InitGlobalBarrier(ret.as<AttrStmt>()); ret = InitGlobalBarrier(ret.as<AttrStmt>());
num_blocks_ = Expr(); num_blocks_ = Expr();
is_lead_ = Expr(); is_lead_ = Expr();
...@@ -253,7 +253,8 @@ class ThreadSyncInserter : public IRMutator { ...@@ -253,7 +253,8 @@ class ThreadSyncInserter : public IRMutator {
// Get current storage scope. // Get current storage scope.
StorageScope GetScope(const Variable* buf) const { StorageScope GetScope(const Variable* buf) const {
auto it = storage_scope_.find(buf); auto it = storage_scope_.find(buf);
StorageScope s; s.rank = 0; StorageScope s;
s.rank = StorageRank::kGlobal;
if (it == storage_scope_.end()) return s; if (it == storage_scope_.end()) return s;
return it->second; return it->second;
} }
...@@ -279,7 +280,7 @@ class ThreadSyncInserter : public IRMutator { ...@@ -279,7 +280,7 @@ class ThreadSyncInserter : public IRMutator {
return Block::make(prep, body); return Block::make(prep, body);
} }
Stmt MakeGlobalBarrier() { Stmt MakeGlobalBarrier() {
CHECK_EQ(sync_scope_.rank, 0); CHECK(sync_scope_.rank == StorageRank::kGlobal);
if (!num_blocks_.defined()) { if (!num_blocks_.defined()) {
CHECK(!is_lead_.defined()); CHECK(!is_lead_.defined());
num_work_dim_ = thread_extents_.size(); num_work_dim_ = thread_extents_.size();
......
...@@ -13,10 +13,47 @@ ...@@ -13,10 +13,47 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
/*!
* \brief Memory hierachy rank in the storage system
* \note The global rank and shared rank have one to one
* correspondence to the thread rank.
*/
enum class StorageRank {
/*! \brief global memory */
kGlobal = 0,
/*! \brief shared memory among thread group */
kShared = 1,
/*!
* \brief reserved for warp memory.
* This is only used by programming model.
* There is no such memory usually in GPU.
* Instead, we can simulate it by registers and shuffle.
*/
kWarp = 2,
/*! \brief thread local memory */
kLocal = 3
};
/*!
* \param thread_scope_rank The thread scope rank
* \return default storage rank given the thread scope
*/
inline StorageRank DefaultStorageRank(int thread_scope_rank) {
switch (thread_scope_rank) {
case -1: return StorageRank::kGlobal;
case 0: return StorageRank::kShared;
case 1: return StorageRank::kLocal;
default: {
LOG(FATAL) << "unknown rank";
return StorageRank::kGlobal;
}
}
}
/*! \brief class to represent storage scope */ /*! \brief class to represent storage scope */
struct StorageScope { struct StorageScope {
/*! \brief The rank of the storage */ /*! \brief The rank of the storage */
int rank{0}; StorageRank rank{StorageRank::kGlobal};
/*! \brief tag for special purpose memory. */ /*! \brief tag for special purpose memory. */
std::string tag; std::string tag;
// comparator // comparator
...@@ -29,9 +66,10 @@ struct StorageScope { ...@@ -29,9 +66,10 @@ struct StorageScope {
inline std::string to_string() const { inline std::string to_string() const {
std::string ret; std::string ret;
switch (rank) { switch (rank) {
case 0: return "global" + tag; case StorageRank::kGlobal: return "global" + tag;
case 1: return "shared" + tag; case StorageRank::kShared: return "shared" + tag;
case 2: return "local" + tag; case StorageRank::kWarp: return "warp" + tag;
case StorageRank::kLocal: return "local" + tag;
default: LOG(FATAL) << "unknown storage scope"; return ""; default: LOG(FATAL) << "unknown storage scope"; return "";
} }
} }
...@@ -43,13 +81,16 @@ struct StorageScope { ...@@ -43,13 +81,16 @@ struct StorageScope {
static StorageScope make(const std::string& s) { static StorageScope make(const std::string& s) {
StorageScope r; StorageScope r;
if (s.compare(0, 6, "global") == 0) { if (s.compare(0, 6, "global") == 0) {
r.rank = 0; r.rank = StorageRank::kGlobal;
r.tag = s.substr(6, std::string::npos); r.tag = s.substr(6, std::string::npos);
} else if (s.compare(0, 6, "shared") == 0) { } else if (s.compare(0, 6, "shared") == 0) {
r.rank = 1; r.rank = StorageRank::kShared;
r.tag = s.substr(6, std::string::npos); r.tag = s.substr(6, std::string::npos);
} else if (s.compare(0, 4, "warp") == 0) {
r.rank = StorageRank::kWarp;
r.tag = s.substr(4, std::string::npos);
} else if (s.compare(0, 5, "local") == 0) { } else if (s.compare(0, 5, "local") == 0) {
r.rank = 2; r.rank = StorageRank::kLocal;
r.tag = s.substr(5, std::string::npos); r.tag = s.substr(5, std::string::npos);
} else { } else {
LOG(FATAL) << "unknown storage scope " << s; LOG(FATAL) << "unknown storage scope " << s;
......
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
namespace tvm { namespace tvm {
namespace schedule { namespace schedule {
using runtime::ThreadScope; using runtime::StorageRank;
using runtime::StorageScope; using runtime::StorageScope;
using runtime::ThreadScope;
/*! \brief The graph context used during bound inference. */ /*! \brief The graph context used during bound inference. */
struct GraphContext { struct GraphContext {
...@@ -41,7 +42,7 @@ bool NeedRelax(const IterVar& iv, ...@@ -41,7 +42,7 @@ bool NeedRelax(const IterVar& iv,
if (tag.length() == 0 || tag == "pipeline") { if (tag.length() == 0 || tag == "pipeline") {
return !found_attach; return !found_attach;
} }
return scope.rank <= ThreadScope::make(tag).rank; return static_cast<int>(scope.rank) <= ThreadScope::make(tag).rank;
} }
// infer storage scope, if not given // infer storage scope, if not given
...@@ -50,16 +51,17 @@ StorageScope InferStorageScope( ...@@ -50,16 +51,17 @@ StorageScope InferStorageScope(
if (stage->scope.length() != 0) { if (stage->scope.length() != 0) {
return StorageScope::make(stage->scope); return StorageScope::make(stage->scope);
} }
int max_rank = 0; int max_rank = -1;
for (IterVar iv : ctx.attach_path.at(stage->op)) { for (IterVar iv : ctx.attach_path.at(stage->op)) {
auto it = ctx.bind_map.find(iv); auto it = ctx.bind_map.find(iv);
const std::string& tag = ( const std::string& tag = (
it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag); it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag);
if (tag != "pipeline" && tag.length() != 0) { if (tag != "pipeline" && tag.length() != 0) {
max_rank = std::max(max_rank, ThreadScope::make(tag).rank + 1); max_rank = std::max(max_rank, ThreadScope::make(tag).rank);
} }
} }
StorageScope s; s.rank = max_rank; StorageScope s;
s.rank = runtime::DefaultStorageRank(max_rank);
return s; return s;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment