Commit 5fc4bc57 by libing4752 Committed by Tianqi Chen

[PASS] enhance storage_rewrite to support different dtypes for unified buffer (#805)

* modified schedule_dataflow_rewrite.cc to fix losing tensor problem

* modified schedule_dataflow_rewrite.cc for lint scan

* modified schedule_dataflow_rewrite.cc for lint scan

* using tensor's value_index to index output of stage op

* repare address offset for different kinds of dtype

* bc

* aaa

* aaaaa

* repare address for different dtypes

* remove nonsense files

* add whitespace of line 581

* use base alloc elem_type

* enhance the testcast of basic buffer is 64bits,32bits,16bits,8bits

* use extends[0]->type() as dtype of offset

* clear program writes
parent adceea22
......@@ -576,33 +576,33 @@ class StoragePlanRewriter : public IRMutator {
// allocate with element type.
CHECK_NE(e->const_nbits, 0U);
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_bits = e->const_nbits;
size_t align = 1;
if (info.defined()) {
align = (info->max_simd_bits + e->elem_type.bits() - 1) / e->elem_type.bits();
align = info->max_simd_bits;
}
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
if (total_elem % align != 0) {
total_elem += align - (total_elem % align);
if (total_bits % align != 0) {
total_bits += align - (total_bits % align);
}
e->alloc_var = e->allocs[0]->buffer_var;
for (StorageEntry* child : e->merged_children) {
CHECK_NE(e->const_nbits, 0U);
CHECK_NE(total_elem, 0U);
size_t num_elem = child->const_nbits / child->elem_type.bits();
child->elem_offset = total_elem;
CHECK_NE(child->const_nbits, 0U);
CHECK_NE(total_bits, 0U);
child->elem_offset = total_bits / child->elem_type.bits();
child->alloc_var = e->alloc_var;
total_elem += num_elem;
if (total_elem % align != 0) {
total_elem += align - (total_elem % align);
total_bits += child->const_nbits;
if (total_bits % align != 0) {
total_bits += align - (total_bits % align);
}
}
uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
Expr alloc_size = make_const(e->allocs[0]->extents[0].type(),
total_elem);
(total_bits + type_bits - 1) / type_bits);
e->new_alloc = Allocate::make(
e->alloc_var, e->elem_type, {alloc_size}, const_true(),
Evaluate::make(0));
if (info.defined()) {
CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
CHECK_LE(total_bits, info->max_num_bits)
<< "Allocation exceed bound of memory tag " << e->scope.to_string();
}
}
......
......@@ -49,6 +49,52 @@ def test_alloc_seq():
tvm.ir_pass.PostOrderVisit(body, verify)
assert num_alloc[0] == 1
def test_alloc_different_dtypes():
def stmt_generater(dtype_list, length):
ib = tvm.ir_builder.create()
base_dtype = dtype_list[0]
global_a = tvm.placeholder((length,), name = "global_a", dtype = base_dtype)
for index, dtype in enumerate(dtype_list):
with ib.for_range(0, length, name="j") as j:
A = ib.allocate(dtype, length, name="A_" + str(index), scope="local.L0A")
A[j] = tvm.const(1, dtype = dtype)
return ib.get()
def dtype_bit_len(dtype):
index = 0
for i in dtype:
if i.isdigit():
break
index += 1
return int(dtype[index:])
def offset_generater(dtype_list, length):
dtype_len_list = [dtype_bit_len(i) for i in dtype_list]
base_len = dtype_len_list[0]
return sum([i * length / base_len for i in dtype_len_list])
def dtype_test(dtype_list, length):
def verify(n):
if isinstance(n, tvm.stmt.Allocate):
assert n.extents[0].value == offset
body = stmt_generater(dtype_list, length)
offset = offset_generater(dtype_list, length)
body = tvm.ir_pass.StorageRewrite(body)
tvm.ir_pass.PostOrderVisit(body, verify)
length = 1024
dtype_list = ["float16", "int32", "uint16", "int8"]
dtype_test(dtype_list, length)
dtype_list = ["float32", "int32", "uint16", "int8"]
dtype_test(dtype_list, length)
dtype_list = ["float64", "int32", "uint16", "int8"]
dtype_test(dtype_list, length)
dtype_list = ["int8", "int32", "uint16", "uint8"]
dtype_test(dtype_list, length)
def test_inplace_rule():
......@@ -91,7 +137,6 @@ def test_storage_combine():
s = tvm.create_schedule(B.op)
for S in stages[:-1]:
s[S].set_scope("global:tag")
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
......@@ -215,6 +260,7 @@ def test_inplace_rule2():
if __name__ == "__main__":
test_alloc_seq()
test_alloc_different_dtypes()
test_inplace_rule()
test_storage_share()
test_parallel_alloc()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment