Commit 1587038e by Tianqi Chen Committed by GitHub

[PASS] Fix reuse small buffer in storage rewrite (#1012)

parent 6b0950dd
......@@ -824,6 +824,7 @@ class StoragePlanRewriter : public IRMutator {
if (e->attach_scope_ != attach_scope) continue;
if (e->scope != scope) continue;
if (e->elem_type != op->type.element_of()) continue;
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it);
return e;
}
......
......@@ -411,6 +411,38 @@ def test_alloc_seq_type2():
tvm.ir_pass.PostOrderVisit(body, verify)
assert num_alloc[0] == 1
def test_reuse_small_buffer():
ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("int16", 200, name="A", scope="local.L0A")
A[j] = tvm.const(1, "int16")
B = ib.allocate("int16", 200, name="B", scope="local.L0A")
B[j] = tvm.const(1, "int16")
B1 = ib.allocate("int16", 200, name="B1", scope="local.L0A")
B1[j] = A[j] + B[j]
C = ib.allocate("int16", 400, name="C", scope="local.L0A")
C[j] = tvm.const(1, "int16")
D = ib.allocate("int16", 400, name="D", scope="local.L0A")
D[j] = tvm.const(1, "int16")
E = ib.allocate("int16", 400, name="E", scope="local.L0A")
E[j] = C[j]
body = ib.get()
body = tvm.ir_pass.StorageRewrite(body)
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.stmt.Allocate):
num_alloc[0] += 1
assert n.extents[0].value == 800
tvm.ir_pass.PostOrderVisit(body, verify)
assert num_alloc[0] == 1
if __name__ == "__main__":
test_alloc_seq()
test_alloc_different_dtypes()
......@@ -423,4 +455,4 @@ if __name__ == "__main__":
test_inplace_rule3()
test_alloc_seq_type()
test_alloc_seq_type2()
test_reuse_small_buffer()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment