Commit e8afa1b4 by xqdan Committed by Tianqi Chen

[PASS] Support buffer reuse for different types (#891)

[PASS] Support buffer reuse for different types
parent 61cdf903
...@@ -502,7 +502,6 @@ class StoragePlanRewriter : public IRMutator { ...@@ -502,7 +502,6 @@ class StoragePlanRewriter : public IRMutator {
} }
// Remap the index // Remap the index
Expr RemapIndex(Type dtype, Expr index, StorageEntry* e) { Expr RemapIndex(Type dtype, Expr index, StorageEntry* e) {
CHECK_EQ(dtype.element_of(), e->elem_type);
if (e->bits_offset == 0) return index; if (e->bits_offset == 0) return index;
uint64_t elem_bits = dtype.bits() * dtype.lanes(); uint64_t elem_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(e->bits_offset % elem_bits, 0U); CHECK_EQ(e->bits_offset % elem_bits, 0U);
...@@ -564,16 +563,21 @@ class StoragePlanRewriter : public IRMutator { ...@@ -564,16 +563,21 @@ class StoragePlanRewriter : public IRMutator {
Expr combo_size; Expr combo_size;
for (const Allocate* op : e->allocs) { for (const Allocate* op : e->allocs) {
Expr sz = arith::ComputeReduce<Mul>(op->extents, make_const(Int(32), 1)); Expr sz = arith::ComputeReduce<Mul>(op->extents, make_const(Int(32), 1));
if (alloc_type.lanes() != op->type.lanes()) { // transform to bits
sz = (sz * make_const(sz.type(), op->type.lanes()) + auto sz_nbits = sz * (op->type.bits() * op->type.lanes());
make_const(sz.type(), alloc_type.lanes() - 1)) /
make_const(sz.type(), alloc_type.lanes());
}
if (combo_size.defined()) { if (combo_size.defined()) {
combo_size = max(combo_size, sz); combo_size = max(combo_size, sz_nbits);
} else { } else {
combo_size = sz; combo_size = sz_nbits;
}
} }
// transform to alloc bytes
auto type_bits = alloc_type.bits() * alloc_type.lanes();
bool divided = can_prove(combo_size % type_bits == 0);
combo_size = combo_size / type_bits;
// round up for can not divided
if (!divided) {
combo_size += make_const(Int(32), 1);
} }
combo_size = ir::Simplify(combo_size); combo_size = ir::Simplify(combo_size);
e->new_alloc = Allocate::make( e->new_alloc = Allocate::make(
...@@ -784,8 +788,9 @@ class StoragePlanRewriter : public IRMutator { ...@@ -784,8 +788,9 @@ class StoragePlanRewriter : public IRMutator {
// skip plan for local variable, // skip plan for local variable,
// compiler can do a better job with register allocation. // compiler can do a better job with register allocation.
const uint64_t match_range = 16; const uint64_t match_range = 16;
uint64_t op_elem_bits = op->type.bits() * op->type.lanes();
uint64_t const_nbits = static_cast<uint64_t>( uint64_t const_nbits = static_cast<uint64_t>(
op->constant_allocation_size() * op->type.bits() * op->type.lanes()); op->constant_allocation_size() * op_elem_bits);
// disable reuse of small arrays, they will be lowered to registers in LLVM // disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory // This rules only apply if we are using non special memory
if (scope.tag.length() == 0) { if (scope.tag.length() == 0) {
...@@ -801,15 +806,18 @@ class StoragePlanRewriter : public IRMutator { ...@@ -801,15 +806,18 @@ class StoragePlanRewriter : public IRMutator {
auto begin = const_free_map_.lower_bound(const_nbits / match_range); auto begin = const_free_map_.lower_bound(const_nbits / match_range);
auto mid = const_free_map_.lower_bound(const_nbits); auto mid = const_free_map_.lower_bound(const_nbits);
auto end = const_free_map_.upper_bound(const_nbits * match_range); auto end = const_free_map_.upper_bound(const_nbits * match_range);
// start looking at the buffer that is bigger than the required size first
for (auto it = mid; it != end; ++it) { for (auto it = mid; it != end; ++it) {
StorageEntry *e = it->second; StorageEntry *e = it->second;
if (e->attach_scope_ != attach_scope) continue; if (e->attach_scope_ != attach_scope) continue;
if (e->scope != scope) continue; if (e->scope != scope) continue;
if (e->elem_type != op->type.element_of()) continue; // when not divided, no reuse, eg, float4 vs float3
if (e->bits_offset % op_elem_bits != 0) continue;
e->const_nbits = std::max(const_nbits, e->const_nbits); e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it); const_free_map_.erase(it);
return e; return e;
} }
// then start looking at smaller buffers.
for (auto it = mid; it != begin;) { for (auto it = mid; it != begin;) {
--it; --it;
StorageEntry *e = it->second; StorageEntry *e = it->second;
......
...@@ -54,10 +54,27 @@ def test_alloc_different_dtypes(): ...@@ -54,10 +54,27 @@ def test_alloc_different_dtypes():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
base_dtype = dtype_list[0] base_dtype = dtype_list[0]
global_a = tvm.placeholder((length,), name = "global_a", dtype = base_dtype) global_a = tvm.placeholder((length,), name = "global_a", dtype = base_dtype)
for index, dtype in enumerate(dtype_list): assert len(dtype_list) == 4
with ib.for_range(0, length, name="j") as j: with ib.for_range(0, length, name="j") as j:
A = ib.allocate(dtype, length, name="A_" + str(index), scope="local.L0A") dtype = dtype_list[0]
A = ib.allocate(dtype, length, name="A", scope="local.L0A")
A[j] = tvm.const(1, dtype = dtype) A[j] = tvm.const(1, dtype = dtype)
with ib.for_range(0, length, name="j") as j:
dtype = dtype_list[1]
B = ib.allocate(dtype, length, name="B", scope="local.L0A")
B[j] = tvm.const(1, dtype = dtype)
with ib.for_range(0, length, name="j") as j:
dtype = dtype_list[2]
C = ib.allocate(dtype, length, name="C", scope="local.L0A")
C[j] = tvm.const(1, dtype = dtype)
with ib.for_range(0, length, name="j") as j:
dtype = dtype_list[3]
D = ib.allocate(dtype, length, name="D", scope="local.L0A")
D[j] = tvm.const(1, dtype = dtype)
with ib.for_range(0, length, name="j") as j:
dtype = "int8"
E = ib.allocate(dtype, length, name="E", scope="local.L0A")
E[j] = A[j].astype(dtype) + B[j].astype(dtype) + C[j].astype(dtype) + D[j].astype(dtype)
return ib.get() return ib.get()
def dtype_bit_len(dtype): def dtype_bit_len(dtype):
...@@ -342,6 +359,58 @@ def test_inplace_rule3(): ...@@ -342,6 +359,58 @@ def test_inplace_rule3():
assert n.extents[0].value == 70 assert n.extents[0].value == 70
tvm.ir_pass.PostOrderVisit(stmt, verify) tvm.ir_pass.PostOrderVisit(stmt, verify)
def test_alloc_seq_type():
ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", 200, name="A", scope="local.L0A")
A1 = ib.allocate("float32", 200, name="A1", scope="local.L0A")
A[j] = 1.2
A1[j] = 1.3
B = ib.allocate("int16", 200, name="B", scope="local.L0A")
B[j] = tvm.const(1, "int16")
C = ib.allocate("int16", 200, name="C", scope="local.L0A")
C[j] = tvm.const(1, "int16")
D = ib.allocate("int16", 200, name="D", scope="local.L0A")
D[j] = B[j] + C[j]
A2 = ib.allocate("float32", 200, name="A2", scope="local.L0A")
A2[j] = A[j]
body = ib.get()
body = tvm.ir_pass.StorageRewrite(body)
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.stmt.Allocate):
num_alloc[0] += 1
assert n.extents[0].value == 500
tvm.ir_pass.PostOrderVisit(body, verify)
assert num_alloc[0] == 1
def test_alloc_seq_type2():
ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", 200, name="A", scope="local.L0A")
A[j] = 1.2
with ib.for_range(0, 20, name="j") as j:
B = ib.allocate("int16", 400, name="B", scope="local.L0A")
B[j] = tvm.const(1, "int16")
with ib.for_range(0, 10, name="j") as j:
C = ib.allocate("float32", 200, name="C", scope="local.L0A")
C[j] = 1.2
body = ib.get()
body = tvm.ir_pass.StorageRewrite(body)
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.stmt.Allocate):
num_alloc[0] += 1
assert n.extents[0].value == 200
tvm.ir_pass.PostOrderVisit(body, verify)
assert num_alloc[0] == 1
if __name__ == "__main__": if __name__ == "__main__":
test_alloc_seq() test_alloc_seq()
test_alloc_different_dtypes() test_alloc_different_dtypes()
...@@ -352,3 +421,6 @@ if __name__ == "__main__": ...@@ -352,3 +421,6 @@ if __name__ == "__main__":
test_storage_share_gpu() test_storage_share_gpu()
test_inplace_rule2() test_inplace_rule2()
test_inplace_rule3() test_inplace_rule3()
test_alloc_seq_type()
test_alloc_seq_type2()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment