Commit 28bb0f68 by Tianqi Chen Committed by GitHub

[PASS] Prepare storage rewrite for unified buffer (#885)

* [PASS] Prepare storage rewrite for unified buffer

* more comments
parent 04fb5509
...@@ -363,7 +363,7 @@ class StoragePlanRewriter : public IRMutator { ...@@ -363,7 +363,7 @@ class StoragePlanRewriter : public IRMutator {
Expr Mutate_(const Variable* op, const Expr& e) final { Expr Mutate_(const Variable* op, const Expr& e) final {
auto it = alloc_map_.find(op); auto it = alloc_map_.find(op);
if (it != alloc_map_.end()) { if (it != alloc_map_.end()) {
if (it->second->elem_offset != 0) { if (it->second->bits_offset != 0) {
LOG(WARNING) << "Use a merged buffer variable address, could cause error"; LOG(WARNING) << "Use a merged buffer variable address, could cause error";
} }
return it->second->alloc_var; return it->second->alloc_var;
...@@ -381,11 +381,10 @@ class StoragePlanRewriter : public IRMutator { ...@@ -381,11 +381,10 @@ class StoragePlanRewriter : public IRMutator {
const StorageEntry* se = it->second; const StorageEntry* se = it->second;
Expr offset = Mutate(op->args[2]); Expr offset = Mutate(op->args[2]);
Expr extent = Mutate(op->args[3]); Expr extent = Mutate(op->args[3]);
CHECK_EQ(se->elem_type, dtype.element_of()) uint64_t elem_bits = dtype.bits() * dtype.lanes();
<< " buffer=" << buffer->name_hint; CHECK_EQ(se->bits_offset % elem_bits, 0U);
CHECK_EQ(se->elem_offset % dtype.lanes(), 0); if (se->bits_offset != 0) {
if (se->elem_offset != 0) { offset = make_const(offset.type(), se->bits_offset / elem_bits) + offset;
offset = make_const(offset.type(), se->elem_offset / dtype.lanes()) + offset;
} }
return Call::make( return Call::make(
op->type, op->name, op->type, op->name,
...@@ -465,8 +464,17 @@ class StoragePlanRewriter : public IRMutator { ...@@ -465,8 +464,17 @@ class StoragePlanRewriter : public IRMutator {
// The allocation element type. // The allocation element type.
Type elem_type; Type elem_type;
// This is non-zero if this allocate is folded into another one // This is non-zero if this allocate is folded into another one
// the address becomes alloc_var + sizeof(elem_type) * elem_offset; // the address(in bits) becomes alloc_var + bits_offset;
uint64_t elem_offset{0}; // can be effectively converted to the element type.
// We need to convert bit_offset to offset of specific element type later.
//
// We use bits(instead of bytes) to support non-conventional indexing in hardware.
// When we are merging buffer together, the bits_offset are set to be aligned
// to certain value given by the max_simd_bits property of the special memory.
//
// This allows effective sharing among different types as long as their alignment
// requirement fits into the max_simd_bits.
uint64_t bits_offset{0};
}; };
// Alllocate entry of node. // Alllocate entry of node.
...@@ -495,8 +503,10 @@ class StoragePlanRewriter : public IRMutator { ...@@ -495,8 +503,10 @@ class StoragePlanRewriter : public IRMutator {
// Remap the index // Remap the index
Expr RemapIndex(Type dtype, Expr index, StorageEntry* e) { Expr RemapIndex(Type dtype, Expr index, StorageEntry* e) {
CHECK_EQ(dtype.element_of(), e->elem_type); CHECK_EQ(dtype.element_of(), e->elem_type);
if (e->elem_offset == 0) return index; if (e->bits_offset == 0) return index;
return make_const(index.type(), e->elem_offset) + index; uint64_t elem_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(e->bits_offset % elem_bits, 0U);
return make_const(index.type(), e->bits_offset / elem_bits) + index;
} }
// Prepare the new allocations // Prepare the new allocations
void PrepareNewAlloc() { void PrepareNewAlloc() {
...@@ -526,7 +536,7 @@ class StoragePlanRewriter : public IRMutator { ...@@ -526,7 +536,7 @@ class StoragePlanRewriter : public IRMutator {
for (size_t i = 0; i < vec.size(); ++i) { for (size_t i = 0; i < vec.size(); ++i) {
StorageEntry* e = vec[i]; StorageEntry* e = vec[i];
// already merged // already merged
if (e->elem_offset != 0) continue; if (e->bits_offset != 0) continue;
if (e->merged_children.size() != 0) { if (e->merged_children.size() != 0) {
NewAllocTagMerged(e); continue; NewAllocTagMerged(e); continue;
} }
...@@ -580,10 +590,13 @@ class StoragePlanRewriter : public IRMutator { ...@@ -580,10 +590,13 @@ class StoragePlanRewriter : public IRMutator {
CHECK_NE(e->const_nbits, 0U); CHECK_NE(e->const_nbits, 0U);
MemoryInfo info = GetMemoryInfo(e->scope.to_string()); MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_bits = e->const_nbits; uint64_t total_bits = e->const_nbits;
size_t align = 1; // By default, align to 32 bits.
size_t align = 32;
if (info.defined()) { if (info.defined()) {
align = info->max_simd_bits; align = info->max_simd_bits;
} }
// Always align to max_simd_bits
// so we can remap types by keeping this property
if (total_bits % align != 0) { if (total_bits % align != 0) {
total_bits += align - (total_bits % align); total_bits += align - (total_bits % align);
} }
...@@ -591,7 +604,7 @@ class StoragePlanRewriter : public IRMutator { ...@@ -591,7 +604,7 @@ class StoragePlanRewriter : public IRMutator {
for (StorageEntry* child : e->merged_children) { for (StorageEntry* child : e->merged_children) {
CHECK_NE(child->const_nbits, 0U); CHECK_NE(child->const_nbits, 0U);
CHECK_NE(total_bits, 0U); CHECK_NE(total_bits, 0U);
child->elem_offset = total_bits / child->elem_type.bits(); child->bits_offset = total_bits;
child->alloc_var = e->alloc_var; child->alloc_var = e->alloc_var;
total_bits += child->const_nbits; total_bits += child->const_nbits;
if (total_bits % align != 0) { if (total_bits % align != 0) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment