Commit 38d08357 by xqdan Committed by Tianqi Chen

#1592 [PASS] Fix missing mem CHECK in storage_rewrite (#1616)

parent 9b0e4990
...@@ -584,6 +584,12 @@ class StoragePlanRewriter : public IRMutator { ...@@ -584,6 +584,12 @@ class StoragePlanRewriter : public IRMutator {
e->new_alloc = Allocate::make( e->new_alloc = Allocate::make(
e->alloc_var, alloc_type, {combo_size}, const_true(), e->alloc_var, alloc_type, {combo_size}, const_true(),
Evaluate::make(0)); Evaluate::make(0));
if (e->scope.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
<< "Allocation exceed bound of memory tag " << e->scope.to_string();
}
} }
} }
} }
......
...@@ -28,15 +28,30 @@ def test_storage_share(): ...@@ -28,15 +28,30 @@ def test_storage_share():
tvm.ir_pass.PostOrderVisit(stmt, verify) tvm.ir_pass.PostOrderVisit(stmt, verify)
assert num_alloc[0] == 1 assert num_alloc[0] == 1
def register_mem(scope_tb, max_bits):
#Register mem
@tvm.register_func("tvm.info.mem.%s" % scope_tb)
def mem_info_inp_buffer():
return tvm.make.node("MemoryInfo",
unit_bits= 16,
max_simd_bits=32,
max_num_bits=max_bits,
head_address=None)
def test_alloc_seq(): def test_alloc_seq():
scope_tb = "local.L0A"
max_bits = 1024 * 1024 * 1024
register_mem(scope_tb, max_bits)
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
n = tvm.var("n") n = tvm.var("n")
with ib.for_range(0, n, name="i") as i: with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j: with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", 200, name="A", scope="local.L0A") A = ib.allocate("float32", 200, name="A", scope=scope_tb)
A[j] = 1.2 A[j] = 1.2
with ib.for_range(0, 10, name="j") as j: with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", 200, name="B", scope="local.L0A") A = ib.allocate("float32", 200, name="B", scope=scope_tb)
A[j] = 1.3 A[j] = 1.3
body = ib.get() body = ib.get()
...@@ -233,16 +248,9 @@ def test_parallel_alloc(): ...@@ -233,16 +248,9 @@ def test_parallel_alloc():
assert(isinstance(body.body.body.body.body, tvm.stmt.Allocate)) assert(isinstance(body.body.body.body.body, tvm.stmt.Allocate))
def test_inplace_rule2(): def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024):
#Test Buffer #Test Buffer
scope_tb = "local_TB2" register_mem(scope_tb, max_bits)
@tvm.register_func("tvm.info.mem.%s" % scope_tb)
def mem_info_inp_buffer():
return tvm.make.node("MemoryInfo",
unit_bits= 16,
max_simd_bits=32,
max_num_bits=1024*1024*1024,
head_address=None)
m = 10 m = 10
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
C = tvm.placeholder((m,), name='C') C = tvm.placeholder((m,), name='C')
...@@ -275,16 +283,23 @@ def test_inplace_rule2(): ...@@ -275,16 +283,23 @@ def test_inplace_rule2():
tvm.ir_pass.PostOrderVisit(stmt, verify) tvm.ir_pass.PostOrderVisit(stmt, verify)
assert num_alloc[0] == 2 assert num_alloc[0] == 2
def test_exceed_mem():
max_bits = 639
# The critical max_num_bits is between 639 and 640
loc = -1
try:
test_inplace_rule2("local_TEM", max_bits)
except Exception as e:
estr = str(e)
loc = estr.find('Allocation exceed bound of memory')
assert loc != -1
def test_inplace_rule3(): def test_inplace_rule3():
#Test Buffer #Test Buffer
scope_tb = "local_TB3" scope_tb = "local_TB3"
@tvm.register_func("tvm.info.mem.%s" % scope_tb) max_bits=1024 * 1024 * 1024
def mem_info_inp_buffer():
return tvm.make.node("MemoryInfo", register_mem(scope_tb, max_bits)
unit_bits= 16,
max_simd_bits=32,
max_num_bits=1024*1024*1024,
head_address=None)
m = 10 m = 10
B0 = tvm.placeholder((m,), name='B0') B0 = tvm.placeholder((m,), name='B0')
B1 = tvm.placeholder((m,), name='B1') B1 = tvm.placeholder((m,), name='B1')
...@@ -388,17 +403,22 @@ def test_alloc_seq_type(): ...@@ -388,17 +403,22 @@ def test_alloc_seq_type():
assert num_alloc[0] == 1 assert num_alloc[0] == 1
def test_alloc_seq_type2(): def test_alloc_seq_type2():
scope_tb = "local.L0A2"
max_bits=1024 * 1024 * 1024
register_mem(scope_tb, max_bits)
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
n = tvm.var("n") n = tvm.var("n")
with ib.for_range(0, n, name="i") as i: with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, 10, name="j") as j: with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", 200, name="A", scope="local.L0A") A = ib.allocate("float32", 200, name="A", scope=scope_tb)
A[j] = 1.2 A[j] = 1.2
with ib.for_range(0, 20, name="j") as j: with ib.for_range(0, 20, name="j") as j:
B = ib.allocate("int16", 400, name="B", scope="local.L0A") B = ib.allocate("int16", 400, name="B", scope=scope_tb)
B[j] = tvm.const(1, "int16") B[j] = tvm.const(1, "int16")
with ib.for_range(0, 10, name="j") as j: with ib.for_range(0, 10, name="j") as j:
C = ib.allocate("float32", 200, name="C", scope="local.L0A") C = ib.allocate("float32", 200, name="C", scope=scope_tb)
C[j] = 1.2 C[j] = 1.2
body = ib.get() body = ib.get()
...@@ -465,6 +485,7 @@ if __name__ == "__main__": ...@@ -465,6 +485,7 @@ if __name__ == "__main__":
test_storage_combine() test_storage_combine()
test_storage_share_gpu() test_storage_share_gpu()
test_inplace_rule2() test_inplace_rule2()
test_exceed_mem()
test_inplace_rule3() test_inplace_rule3()
test_alloc_seq_type() test_alloc_seq_type()
test_alloc_seq_type2() test_alloc_seq_type2()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment