Commit 6d3a3adf by xqdan Committed by Tianqi Chen

[PASS] Improve storage rewrite(#846) (#847)

* fix #802, create cache based on sugar tensor

* [Pass] Improve storage rewrite

* fix ci

* fix comment

* fix comment
parent aedfaaec
......@@ -681,6 +681,8 @@ class StoragePlanRewriter : public IRMutator {
StorageEntry* dst_entry = nullptr;
// inplace detection
if (detect_inplace) {
// only one inplace var for s.stmt
bool inplace_found = false;
for (const Variable* src : it->second.kill) {
if (!inplace_flag.count(src) && alloc_map_.count(src)) {
InplaceOpVerifier visitor;
......@@ -693,10 +695,11 @@ class StoragePlanRewriter : public IRMutator {
ae.alloc->constant_allocation_size() *
ae.alloc->type.bits() *
ae.alloc->type.lanes());
if (src_entry->const_nbits == const_nbits) {
if (src_entry->const_nbits == const_nbits && !inplace_found) {
// successfully inplace
dst_entry = src_entry;
inplace_flag.insert(src);
inplace_found = true;
}
}
}
......
......@@ -218,7 +218,7 @@ def test_parallel_alloc():
def test_inplace_rule2():
#Test Buffer
scope_tb = "local_TB"
scope_tb = "local_TB2"
@tvm.register_func("tvm.info.mem.%s" % scope_tb)
def mem_info_inp_buffer():
return tvm.make.node("MemoryInfo",
......@@ -258,6 +258,90 @@ def test_inplace_rule2():
tvm.ir_pass.PostOrderVisit(stmt, verify)
assert num_alloc[0] == 2
def test_inplace_rule3():
#Test Buffer
scope_tb = "local_TB3"
@tvm.register_func("tvm.info.mem.%s" % scope_tb)
def mem_info_inp_buffer():
return tvm.make.node("MemoryInfo",
unit_bits= 16,
max_simd_bits=32,
max_num_bits=1024*1024*1024,
head_address=None)
m = 10
B0 = tvm.placeholder((m,), name='B0')
B1 = tvm.placeholder((m,), name='B1')
B2 = tvm.placeholder((m,), name='B2')
B3 = tvm.placeholder((m,), name='B3')
B4 = tvm.placeholder((m,), name='B4')
B5 = tvm.placeholder((m,), name='B5')
B6 = tvm.compute((m,), lambda i: B1[i] * B5[i], name='B6')
B7 = tvm.compute((m,), lambda i: B2[i] * B4[i], name='B7')
B8 = tvm.compute((m,), lambda i: B6[i] - B7[i], name='B8')
B9 = tvm.compute((m,), lambda i: B2[i] * B3[i], name='B9')
B10 = tvm.compute((m,), lambda i: B0[i] * B5[i], name='B10')
B11 = tvm.compute((m,), lambda i: B9[i] - B10[i], name='B11')
B12 = tvm.compute((m,), lambda i: B0[i] * B4[i], name='B12')
B13 = tvm.compute((m,), lambda i: B1[i] * B3[i], name='B13')
B14 = tvm.compute((m,), lambda i: B12[i] - B13[i], name='B14')
B = tvm.compute((m,), lambda i: B8[i] * B11[i] + B14[i], name='B')
s = tvm.create_schedule(B.op)
B1L = s.cache_read(B1, scope_tb, [B6, B13])
B5L = s.cache_read(B5, scope_tb, [B6, B10])
B2L = s.cache_read(B2, scope_tb, [B7, B9])
B4L = s.cache_read(B4, scope_tb, [B7, B12])
B3L = s.cache_read(B3, scope_tb, [B9, B13])
B0L = s.cache_read(B0, scope_tb, [B10, B12])
B8L = s.cache_write(B8, scope_tb)
B11L = s.cache_write(B11, scope_tb)
B14L = s.cache_write(B14, scope_tb)
B6L = s.cache_write(B6, scope_tb)
B7L = s.cache_write(B7, scope_tb)
B9L = s.cache_write(B9, scope_tb)
B10L = s.cache_write(B10, scope_tb)
B12L = s.cache_write(B12, scope_tb)
B13L = s.cache_write(B13, scope_tb)
s[B12].compute_inline()
s[B13].compute_inline()
s[B8].compute_inline()
s[B11].compute_inline()
s[B14].compute_inline()
s[B6].compute_inline()
s[B7].compute_inline()
s[B9].compute_inline()
s[B10].compute_inline()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
B0a = tvm.decl_buffer(B0.shape, B0.dtype, name='B0')
B1a = tvm.decl_buffer(B1.shape, B1.dtype, name='B1')
B2a = tvm.decl_buffer(B2.shape, B2.dtype, name='B2')
B3a = tvm.decl_buffer(B3.shape, B3.dtype, name='B3')
B4a = tvm.decl_buffer(B4.shape, B4.dtype, name='B4')
B5a = tvm.decl_buffer(B5.shape, B5.dtype, name='B5')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {B0: B0a, B1: B1a, B2: B2a, B3: B2a, B4: B4a, B5: B5a, B: Bb}, 64)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.StorageRewrite(stmt)
# verify only have one allocations.
# verify inplace folding works
def verify(n):
if isinstance(n, tvm.stmt.Allocate):
assert n.extents[0].value == 70
tvm.ir_pass.PostOrderVisit(stmt, verify)
if __name__ == "__main__":
test_alloc_seq()
test_alloc_different_dtypes()
......@@ -267,3 +351,4 @@ if __name__ == "__main__":
test_storage_combine()
test_storage_share_gpu()
test_inplace_rule2()
test_inplace_rule3()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment