Commit 7cd7dbff by Gaoxiong Committed by Tianqi Chen

Fix non-zero extent of access_ptr out of range (#1937) (#1939)

parent 150f7a8b
......@@ -357,9 +357,9 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr
} else if (self->strides.size() == self->shape.size()) {
int highest_dim = 0;
extent = arith::ComputeExpr<ir::Mul>(
self->strides[highest_dim], self->shape[highest_dim]);
self->strides[highest_dim], self->shape[highest_dim]) - offset;
} else {
extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr());
extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr()) - offset;
}
Expr elem_offset = self->elem_offset + offset;
if (content_lanes > 1) {
......
......@@ -41,6 +41,18 @@ def test_buffer_access_ptr_offset():
assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v))
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
def test_buffer_access_ptr_extent():
m = tvm.var('m')
n = tvm.var('n')
Ab = tvm.decl_buffer((m, n), tvm.float32)
aptr = Ab.access_ptr("rw")
assert tvm.ir_pass.Equal(aptr.args[3], m * n)
aptr = Ab.access_ptr("rw", offset=100)
assert tvm.ir_pass.Equal(aptr.args[3], m * n - 100)
Ab = tvm.decl_buffer((m, n), tvm.float32, strides=[n + 1 , 1])
aptr = Ab.access_ptr("rw", offset=100)
assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m - 100)
def test_buffer_vload():
m = tvm.var('m')
n = tvm.var('n')
......@@ -84,5 +96,6 @@ if __name__ == "__main__":
test_buffer()
test_buffer_access_ptr()
test_buffer_access_ptr_offset()
test_buffer_access_ptr_extent()
test_buffer_vload()
test_buffer_index_merge_mult_mod()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment