Commit 326fff5c by Zhi Committed by Yizhi Liu

[TVM][Bugfix] fix storage_rewrite bug when input is big (#2580)

* fix storage_rewrite bug when input is big

* cast when necessary

* simplification

* simplification

* int64->uint32

* revert uint32->int64
parent d20646c7
......@@ -550,8 +550,10 @@ class StoragePlanRewriter : public IRMutator {
}
if (e->allocs.size() == 1) {
// simply use the original allocation.
Expr sz = arith::ComputeReduce<Mul>(e->allocs[0]->extents,
make_const(Int(32), 1));
e->new_alloc = Allocate::make(
e->alloc_var, alloc_type, e->allocs[0]->extents,
e->alloc_var, alloc_type, {sz},
e->allocs[0]->condition, Evaluate::make(0));
if (e->scope.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
......@@ -564,8 +566,19 @@ class StoragePlanRewriter : public IRMutator {
Expr combo_size;
for (const Allocate* op : e->allocs) {
Expr sz = arith::ComputeReduce<Mul>(op->extents, make_const(Int(32), 1));
auto nbits = op->type.bits() * op->type.lanes();
if (const auto* imm = sz.as<IntImm>()) {
if (imm->value > std::numeric_limits<int>::max() / nbits) {
LOG(WARNING) << "The allocation requires : " << imm->value
<< " * " << nbits
<< " bits, which is greater than the maximum of"
" int32. The size is cast to int64."
<< "\n";
sz = make_const(Int(64), imm->value);
}
}
// transform to bits
auto sz_nbits = sz * (op->type.bits() * op->type.lanes());
auto sz_nbits = sz * nbits;
if (combo_size.defined()) {
combo_size = max(combo_size, sz_nbits);
} else {
......
......@@ -477,6 +477,30 @@ def test_replace_dataflow():
assert isinstance(bounds, tvm.container.Map)
def test_large_input():
@tvm.hybrid.script
def compute(a, b):
n = 16384
c = output_tensor((n, n), 'int32')
for i in range(n):
for j in range(n):
c[i, j] = a[i, j] - b[i, j]
return c
n = 16384
shape = (n, n)
a = tvm.placeholder(shape, name='a', dtype='int32')
b = tvm.placeholder(shape, name='b', dtype='int32')
c = tvm.compute(shape, lambda i, j: compute(a, b)[i, j])
c = tvm.compute(shape, lambda i, j: 1 + c[i, j])
s = tvm.create_schedule(c.op)
stmt = tvm.lower(s, [a, b, c], simple_mode=True)
def verify(n):
if isinstance(n, tvm.stmt.Allocate):
assert n.extents[0].value == 268435456
tvm.ir_pass.PostOrderVisit(stmt, verify)
if __name__ == "__main__":
test_alloc_seq()
test_alloc_different_dtypes()
......@@ -492,3 +516,4 @@ if __name__ == "__main__":
test_alloc_seq_type2()
test_reuse_small_buffer()
test_replace_dataflow()
test_large_input()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment