Commit 29338ea4 by Tianqi Chen Committed by GitHub

[PASS] Allow allocation in parallel scope (#305)

parent 11328f64
......@@ -53,8 +53,6 @@ class LinearAccessPatternFinder final : public IRVisitor {
return std::move(linear_seq_);
void Visit_(const Allocate* op) final {
<< "Allocation inside parallel is not yet handled.";
size_t level = scope_.size();
const Variable* buf = op->buffer_var.get();
......@@ -140,6 +138,9 @@ class LinearAccessPatternFinder final : public IRVisitor {
in_thread_env_ = true;
in_thread_env_ = false;
} else if (op->attr_key == attr::pragma_scope &&
op-><StringImm>()->value == "parallel_launch_point") {
} else if (op->attr_key == attr::storage_scope) {
const Variable* buf = op-><Variable>();
storage_scope_[buf] =
......@@ -149,20 +150,14 @@ class LinearAccessPatternFinder final : public IRVisitor {
void Visit_(const For* op) final {
if (op->for_type == ForType::Parallel) {
bool in_par = in_parallel_env_;
in_parallel_env_ = true;
in_parallel_env_ = in_par;
} else {
void Visit_(const IfThenElse* op) final {
void Visit_(const For* op) final {
// Get storage scope of buffer.
StorageScope GetScope(const Variable* buf) const {
......@@ -172,8 +167,6 @@ class LinearAccessPatternFinder final : public IRVisitor {
// Whether already in thread env.
bool in_thread_env_{false};
// Whether already in parallel env.
bool in_parallel_env_{false};
// linearized access sequence.
std::vector<StmtEntry> linear_seq_;
// The scope stack.
......@@ -267,27 +260,22 @@ class StoragePlanRewriter : public IRMutator {
return IRMutator::Mutate_(op, e);
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
CHECK(op->attr_key != attr::virtual_thread)
<< "InjectVirtualThread before StoragePlan";
if (op->attr_key == attr::storage_scope) {
return this->Mutate(op->body);
} else if (op->attr_key == attr::thread_extent) {
// remake all the allocation at the thread extent.
} else if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pragma_scope) {
// remake all the allocation at the attach scope.
if (attach_map_.count(op)) {
std::vector<Stmt> nest;
for (StorageEntry* e : {
e->alloc_var, attr::storage_scope,
auto& svec = attach_map_[op];
Stmt stmt = IRMutator::Mutate_(op, s);
op =<AttrStmt>();
Stmt body = MergeNest(nest, op->body);
return AttrStmt::make(
op->node, op->attr_key, op->value, body);
op->node, op->attr_key, op->value,
MakeAttach(svec, op->body));
} else {
return IRMutator::Mutate_(op, s);
......@@ -305,8 +293,19 @@ class StoragePlanRewriter : public IRMutator {
Stmt Mutate_(const For* op, const Stmt& s) final {
CHECK(op->for_type != ForType::Vectorized)
<< "VectorizeLoop before LiftStorageAlloc";
return IRMutator::Mutate_(op, s);
// remake all the allocation at the attach scope.
if (attach_map_.count(op)) {
auto& svec = attach_map_[op];
Stmt stmt = IRMutator::Mutate_(op, s);
op =<For>();
return For::make(
op->loop_var, op->min, op->extent, op->for_type, op->device_api,
MakeAttach(svec, op->body));
} else {
return IRMutator::Mutate_(op, s);
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
return this->Mutate(op->body);
......@@ -336,6 +335,18 @@ class StoragePlanRewriter : public IRMutator {
// the address becomes alloc_var + sizeof(elem_type) * elem_offset;
uint64_t elem_offset{0};
Stmt MakeAttach(const std::vector<StorageEntry*>& svec,
Stmt body) {
std::vector<Stmt> nest;
for (StorageEntry* e : svec) {
e->alloc_var, attr::storage_scope,
return MergeNest(nest, body);
// Remap the index
Expr RemapIndex(Type dtype, Expr index, StorageEntry* e) {
CHECK_EQ(dtype.element_of(), e->elem_type);
......@@ -461,31 +472,49 @@ class StoragePlanRewriter : public IRMutator {
void PlanNewScope(const Node* op) {
if (thread_scope_ != nullptr) {
CHECK(thread_scope_ == op);
// erase all memory atatched to this scope.
for (auto it = const_free_map_.begin(); it != const_free_map_.end();) {
if (it->second->attach_scope_ == op) {
it = const_free_map_.erase(it);
} else {
for (auto it = sym_free_list_.begin(); it != sym_free_list_.end();) {
if ((*it)->attach_scope_ == op) {
it = sym_free_list_.erase(it);
} else {
thread_scope_ = nullptr;
} else {
thread_scope_ = op;
// Memory plan algorithm
void PlanMemory(const std::vector<StmtEntry>& seq) {
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
if (s.stmt->is_type<AttrStmt>()) {
const auto* op = static_cast<const AttrStmt*>(s.stmt);
CHECK_EQ(op->attr_key, attr::thread_extent);
if (thread_scope_ != nullptr) {
CHECK(thread_scope_ == op);
// erase all non-global memory from constant free map.
for (auto it = const_free_map_.begin();
it != const_free_map_.end();) {
if (it->second->scope.rank != 0) {
it = const_free_map_.erase(it);
} else {
CHECK(op->attr_key == attr::thread_extent ||
op->attr_key == attr::pragma_scope);
} else if (s.stmt->is_type<For>()) {
const auto* op = static_cast<const For*>(s.stmt);
if (op->for_type == ForType::Parallel) {
if (thread_scope_ == nullptr || thread_scope_ == op) {
thread_scope_ = nullptr;
} else {
thread_scope_ = op;
} else if (s.stmt->is_type<Allocate>()) {
const auto* op = static_cast<const Allocate*>(s.stmt);
StorageEntry* e = this->FindAlloc(op, s.alloc_scope);
StorageEntry* e = this->FindAlloc(op, thread_scope_, s.alloc_scope);
alloc_map_[op->buffer_var.get()] = e;
......@@ -499,11 +528,12 @@ class StoragePlanRewriter : public IRMutator {
// Allocate new storage entry.
StorageEntry* NewAlloc(const Allocate* op,
const Node* attach_scope,
const StorageScope& scope,
size_t const_nbits) {
// Re-use not successful, allocate a new buffer.
std::unique_ptr<StorageEntry> entry(new StorageEntry());
entry->attach_scope_ = thread_scope_;
entry->attach_scope_ = attach_scope;
entry->scope = scope;
entry->elem_type = op->type.element_of();
entry->const_nbits = const_nbits;
......@@ -512,6 +542,7 @@ class StoragePlanRewriter : public IRMutator {
return e;
StorageEntry* FindAlloc(const Allocate* op,
const Node* attach_scope,
const StorageScope& scope) {
// skip plan for local variable,
// compiler can do a better job with register allocation.
......@@ -519,13 +550,13 @@ class StoragePlanRewriter : public IRMutator {
uint64_t const_nbits = static_cast<uint64_t>(
op->constant_allocation_size() * op->type.bits() * op->type.lanes());
if (scope.rank > 1 || op->type.is_handle()) {
return NewAlloc(op, scope, const_nbits);
return NewAlloc(op, attach_scope, scope, const_nbits);
// disable reuse of small arrays, they will be lowered to registers in LLVM
if (const_nbits > 0 &&
const_nbits <= 32 &&
scope.tag.length() == 0) {
return NewAlloc(op, scope, const_nbits);
return NewAlloc(op, attach_scope, scope, const_nbits);
if (const_nbits != 0) {
// constant allocation.
......@@ -534,6 +565,7 @@ class StoragePlanRewriter : public IRMutator {
auto end = const_free_map_.upper_bound(const_nbits * match_range);
for (auto it = mid; it != end; ++it) {
StorageEntry *e = it->second;
if (e->attach_scope_ != attach_scope) continue;
if (e->scope != scope) continue;
if (e->elem_type != op->type.element_of()) continue;
e->const_nbits = std::max(const_nbits, e->const_nbits);
......@@ -543,6 +575,7 @@ class StoragePlanRewriter : public IRMutator {
for (auto it = mid; it != begin;) {
StorageEntry *e = it->second;
if (e->attach_scope_ != attach_scope) continue;
if (e->scope != scope) continue;
if (e->elem_type != op->type.element_of()) continue;
......@@ -553,13 +586,14 @@ class StoragePlanRewriter : public IRMutator {
for (auto it = sym_free_list_.begin();
it != sym_free_list_.end(); ++it) {
StorageEntry* e = *it;
if (e->attach_scope_ != attach_scope) continue;
if (e->scope != scope) continue;
if (e->elem_type != op->type.element_of()) continue;
return e;
return NewAlloc(op, scope, const_nbits);
return NewAlloc(op, attach_scope, scope, const_nbits);
// simulated free.
void Free(const Variable* var) {
......@@ -96,9 +96,11 @@ def test_llvm_vadd_pipeline():
B = tvm.compute((n,), lambda i: A[i], name='B')
C = tvm.compute((n,), lambda i: B[i] + tvm.const(1, A.dtype), name='C')
s = tvm.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=2)
xo, xi = s[C].split(C.op.axis[0], nparts=2)
_, xi = s[C].split(xi, factor=2)
s[B].compute_at(s[C], xo)
xo, xi = s[B].split(B.op.axis[0], factor=2)
# build and invoke the kernel.
......@@ -112,6 +114,7 @@ def test_llvm_vadd_pipeline():
c.asnumpy(), a.asnumpy() + 1)
check_llvm(64, 2)
check_llvm(512, 2)
def test_llvm_madd_pipeline():
......@@ -98,8 +98,35 @@ def test_storage_share_gpu():
assert alloc_stats["global"] == 2
assert alloc_stats["shared"] == num_stage
def test_parallel_alloc():
ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="i", for_type="parallel") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", n, name="A", scope="global")
A[j] = A[j] + 2
body = ib.get()
body = tvm.ir_pass.StorageRewrite(body)
assert (isinstance(body.body.body, tvm.stmt.Allocate))
ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="t") as i:
tvm.const(1) , "pragma_scope", tvm.make.StringImm("parallel_launch_point"))
with ib.for_range(0, n, name="i", for_type="parallel") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", n, name="A", scope="global")
A[j] = A[j] + 2
body = ib.get()
body = tvm.ir_pass.StorageRewrite(body)
assert(isinstance(body.body.body.body, tvm.stmt.Allocate))
if __name__ == "__main__":
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment