Commit e8afa1b4 by xqdan Committed by Tianqi Chen

[PASS] Support buffer reuse for different types (#891)

[PASS] Support buffer reuse for different types
parent 61cdf903
......@@ -502,7 +502,6 @@ class StoragePlanRewriter : public IRMutator {
}
// Remap the index
Expr RemapIndex(Type dtype, Expr index, StorageEntry* e) {
CHECK_EQ(dtype.element_of(), e->elem_type);
if (e->bits_offset == 0) return index;
uint64_t elem_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(e->bits_offset % elem_bits, 0U);
......@@ -564,17 +563,22 @@ class StoragePlanRewriter : public IRMutator {
Expr combo_size;
for (const Allocate* op : e->allocs) {
Expr sz = arith::ComputeReduce<Mul>(op->extents, make_const(Int(32), 1));
if (alloc_type.lanes() != op->type.lanes()) {
sz = (sz * make_const(sz.type(), op->type.lanes()) +
make_const(sz.type(), alloc_type.lanes() - 1)) /
make_const(sz.type(), alloc_type.lanes());
}
// transform to bits
auto sz_nbits = sz * (op->type.bits() * op->type.lanes());
if (combo_size.defined()) {
combo_size = max(combo_size, sz);
combo_size = max(combo_size, sz_nbits);
} else {
combo_size = sz;
combo_size = sz_nbits;
}
}
// transform to alloc bytes
auto type_bits = alloc_type.bits() * alloc_type.lanes();
bool divided = can_prove(combo_size % type_bits == 0);
combo_size = combo_size / type_bits;
// round up for can not divided
if (!divided) {
combo_size += make_const(Int(32), 1);
}
combo_size = ir::Simplify(combo_size);
e->new_alloc = Allocate::make(
e->alloc_var, alloc_type, {combo_size}, const_true(),
......@@ -784,8 +788,9 @@ class StoragePlanRewriter : public IRMutator {
// skip plan for local variable,
// compiler can do a better job with register allocation.
const uint64_t match_range = 16;
uint64_t op_elem_bits = op->type.bits() * op->type.lanes();
uint64_t const_nbits = static_cast<uint64_t>(
op->constant_allocation_size() * op->type.bits() * op->type.lanes());
op->constant_allocation_size() * op_elem_bits);
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
if (scope.tag.length() == 0) {
......@@ -801,15 +806,18 @@ class StoragePlanRewriter : public IRMutator {
auto begin = const_free_map_.lower_bound(const_nbits / match_range);
auto mid = const_free_map_.lower_bound(const_nbits);
auto end = const_free_map_.upper_bound(const_nbits * match_range);
// start looking at the buffer that is bigger than the required size first
for (auto it = mid; it != end; ++it) {
StorageEntry *e = it->second;
if (e->attach_scope_ != attach_scope) continue;
if (e->scope != scope) continue;
if (e->elem_type != op->type.element_of()) continue;
// when not divided, no reuse, eg, float4 vs float3
if (e->bits_offset % op_elem_bits != 0) continue;
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it);
return e;
}
// then start looking at smaller buffers.
for (auto it = mid; it != begin;) {
--it;
StorageEntry *e = it->second;
......
......@@ -54,10 +54,27 @@ def test_alloc_different_dtypes():
ib = tvm.ir_builder.create()
base_dtype = dtype_list[0]
global_a = tvm.placeholder((length,), name = "global_a", dtype = base_dtype)
for index, dtype in enumerate(dtype_list):
with ib.for_range(0, length, name="j") as j:
A = ib.allocate(dtype, length, name="A_" + str(index), scope="local.L0A")
A[j] = tvm.const(1, dtype = dtype)
assert len(dtype_list) == 4
with ib.for_range(0, length, name="j") as j:
dtype = dtype_list[0]
A = ib.allocate(dtype, length, name="A", scope="local.L0A")
A[j] = tvm.const(1, dtype = dtype)
with ib.for_range(0, length, name="j") as j:
dtype = dtype_list[1]
B = ib.allocate(dtype, length, name="B", scope="local.L0A")
B[j] = tvm.const(1, dtype = dtype)
with ib.for_range(0, length, name="j") as j:
dtype = dtype_list[2]
C = ib.allocate(dtype, length, name="C", scope="local.L0A")
C[j] = tvm.const(1, dtype = dtype)
with ib.for_range(0, length, name="j") as j:
dtype = dtype_list[3]
D = ib.allocate(dtype, length, name="D", scope="local.L0A")
D[j] = tvm.const(1, dtype = dtype)
with ib.for_range(0, length, name="j") as j:
dtype = "int8"
E = ib.allocate(dtype, length, name="E", scope="local.L0A")
E[j] = A[j].astype(dtype) + B[j].astype(dtype) + C[j].astype(dtype) + D[j].astype(dtype)
return ib.get()
def dtype_bit_len(dtype):
......@@ -342,6 +359,58 @@ def test_inplace_rule3():
assert n.extents[0].value == 70
tvm.ir_pass.PostOrderVisit(stmt, verify)
def test_alloc_seq_type():
ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", 200, name="A", scope="local.L0A")
A1 = ib.allocate("float32", 200, name="A1", scope="local.L0A")
A[j] = 1.2
A1[j] = 1.3
B = ib.allocate("int16", 200, name="B", scope="local.L0A")
B[j] = tvm.const(1, "int16")
C = ib.allocate("int16", 200, name="C", scope="local.L0A")
C[j] = tvm.const(1, "int16")
D = ib.allocate("int16", 200, name="D", scope="local.L0A")
D[j] = B[j] + C[j]
A2 = ib.allocate("float32", 200, name="A2", scope="local.L0A")
A2[j] = A[j]
body = ib.get()
body = tvm.ir_pass.StorageRewrite(body)
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.stmt.Allocate):
num_alloc[0] += 1
assert n.extents[0].value == 500
tvm.ir_pass.PostOrderVisit(body, verify)
assert num_alloc[0] == 1
def test_alloc_seq_type2():
ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", 200, name="A", scope="local.L0A")
A[j] = 1.2
with ib.for_range(0, 20, name="j") as j:
B = ib.allocate("int16", 400, name="B", scope="local.L0A")
B[j] = tvm.const(1, "int16")
with ib.for_range(0, 10, name="j") as j:
C = ib.allocate("float32", 200, name="C", scope="local.L0A")
C[j] = 1.2
body = ib.get()
body = tvm.ir_pass.StorageRewrite(body)
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.stmt.Allocate):
num_alloc[0] += 1
assert n.extents[0].value == 200
tvm.ir_pass.PostOrderVisit(body, verify)
assert num_alloc[0] == 1
if __name__ == "__main__":
test_alloc_seq()
test_alloc_different_dtypes()
......@@ -352,3 +421,6 @@ if __name__ == "__main__":
test_storage_share_gpu()
test_inplace_rule2()
test_inplace_rule3()
test_alloc_seq_type()
test_alloc_seq_type2()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment