Commit 1587038e by Tianqi Chen Committed by GitHub

[PASS] Fix reuse small buffer in storage rewrite (#1012)

parent 6b0950dd
...@@ -824,6 +824,7 @@ class StoragePlanRewriter : public IRMutator { ...@@ -824,6 +824,7 @@ class StoragePlanRewriter : public IRMutator {
if (e->attach_scope_ != attach_scope) continue; if (e->attach_scope_ != attach_scope) continue;
if (e->scope != scope) continue; if (e->scope != scope) continue;
if (e->elem_type != op->type.element_of()) continue; if (e->elem_type != op->type.element_of()) continue;
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it); const_free_map_.erase(it);
return e; return e;
} }
......
...@@ -76,7 +76,7 @@ def test_alloc_different_dtypes(): ...@@ -76,7 +76,7 @@ def test_alloc_different_dtypes():
E = ib.allocate(dtype, length, name="E", scope="local.L0A") E = ib.allocate(dtype, length, name="E", scope="local.L0A")
E[j] = A[j].astype(dtype) + B[j].astype(dtype) + C[j].astype(dtype) + D[j].astype(dtype) E[j] = A[j].astype(dtype) + B[j].astype(dtype) + C[j].astype(dtype) + D[j].astype(dtype)
return ib.get() return ib.get()
def dtype_bit_len(dtype): def dtype_bit_len(dtype):
index = 0 index = 0
for i in dtype: for i in dtype:
...@@ -109,7 +109,7 @@ def test_alloc_different_dtypes(): ...@@ -109,7 +109,7 @@ def test_alloc_different_dtypes():
dtype_list = ["float64", "int32", "uint16", "int8"] dtype_list = ["float64", "int32", "uint16", "int8"]
dtype_test(dtype_list, length) dtype_test(dtype_list, length)
dtype_list = ["int8", "int32", "uint16", "uint8"] dtype_list = ["int8", "int32", "uint16", "uint8"]
dtype_test(dtype_list, length) dtype_test(dtype_list, length)
...@@ -313,7 +313,7 @@ def test_inplace_rule3(): ...@@ -313,7 +313,7 @@ def test_inplace_rule3():
B2L = s.cache_read(B2, scope_tb, [B7, B9]) B2L = s.cache_read(B2, scope_tb, [B7, B9])
B4L = s.cache_read(B4, scope_tb, [B7, B12]) B4L = s.cache_read(B4, scope_tb, [B7, B12])
B3L = s.cache_read(B3, scope_tb, [B9, B13]) B3L = s.cache_read(B3, scope_tb, [B9, B13])
B0L = s.cache_read(B0, scope_tb, [B10, B12]) B0L = s.cache_read(B0, scope_tb, [B10, B12])
B8L = s.cache_write(B8, scope_tb) B8L = s.cache_write(B8, scope_tb)
B11L = s.cache_write(B11, scope_tb) B11L = s.cache_write(B11, scope_tb)
...@@ -324,7 +324,7 @@ def test_inplace_rule3(): ...@@ -324,7 +324,7 @@ def test_inplace_rule3():
B10L = s.cache_write(B10, scope_tb) B10L = s.cache_write(B10, scope_tb)
B12L = s.cache_write(B12, scope_tb) B12L = s.cache_write(B12, scope_tb)
B13L = s.cache_write(B13, scope_tb) B13L = s.cache_write(B13, scope_tb)
s[B12].compute_inline() s[B12].compute_inline()
s[B13].compute_inline() s[B13].compute_inline()
s[B8].compute_inline() s[B8].compute_inline()
...@@ -334,12 +334,12 @@ def test_inplace_rule3(): ...@@ -334,12 +334,12 @@ def test_inplace_rule3():
s[B7].compute_inline() s[B7].compute_inline()
s[B9].compute_inline() s[B9].compute_inline()
s[B10].compute_inline() s[B10].compute_inline()
s = s.normalize() s = s.normalize()
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map) assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
B0a = tvm.decl_buffer(B0.shape, B0.dtype, name='B0') B0a = tvm.decl_buffer(B0.shape, B0.dtype, name='B0')
B1a = tvm.decl_buffer(B1.shape, B1.dtype, name='B1') B1a = tvm.decl_buffer(B1.shape, B1.dtype, name='B1')
B2a = tvm.decl_buffer(B2.shape, B2.dtype, name='B2') B2a = tvm.decl_buffer(B2.shape, B2.dtype, name='B2')
...@@ -411,6 +411,38 @@ def test_alloc_seq_type2(): ...@@ -411,6 +411,38 @@ def test_alloc_seq_type2():
tvm.ir_pass.PostOrderVisit(body, verify) tvm.ir_pass.PostOrderVisit(body, verify)
assert num_alloc[0] == 1 assert num_alloc[0] == 1
def test_reuse_small_buffer():
ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("int16", 200, name="A", scope="local.L0A")
A[j] = tvm.const(1, "int16")
B = ib.allocate("int16", 200, name="B", scope="local.L0A")
B[j] = tvm.const(1, "int16")
B1 = ib.allocate("int16", 200, name="B1", scope="local.L0A")
B1[j] = A[j] + B[j]
C = ib.allocate("int16", 400, name="C", scope="local.L0A")
C[j] = tvm.const(1, "int16")
D = ib.allocate("int16", 400, name="D", scope="local.L0A")
D[j] = tvm.const(1, "int16")
E = ib.allocate("int16", 400, name="E", scope="local.L0A")
E[j] = C[j]
body = ib.get()
body = tvm.ir_pass.StorageRewrite(body)
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.stmt.Allocate):
num_alloc[0] += 1
assert n.extents[0].value == 800
tvm.ir_pass.PostOrderVisit(body, verify)
assert num_alloc[0] == 1
if __name__ == "__main__": if __name__ == "__main__":
test_alloc_seq() test_alloc_seq()
test_alloc_different_dtypes() test_alloc_different_dtypes()
...@@ -423,4 +455,4 @@ if __name__ == "__main__": ...@@ -423,4 +455,4 @@ if __name__ == "__main__":
test_inplace_rule3() test_inplace_rule3()
test_alloc_seq_type() test_alloc_seq_type()
test_alloc_seq_type2() test_alloc_seq_type2()
test_reuse_small_buffer()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment