Commit 160e4107 by Yizhi Liu Committed by Tianqi Chen

fix buffer elem_offset calculation (#1762)

parent 5ed52a5f
...@@ -226,16 +226,12 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) { ...@@ -226,16 +226,12 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
Expr base = n->elem_offset; Expr base = n->elem_offset;
if (n->strides.size() == 0) { if (n->strides.size() == 0) {
CHECK_EQ(n->shape.size(), index.size()); CHECK_EQ(n->shape.size(), index.size());
if (n->shape.size() != 0) { if (index.size() > 0) {
if (is_zero(base)) { Expr offset = index[0];
base = index[0]; for (size_t i = 1; i < index.size(); ++i) {
} else { offset = MergeMulMod(offset * n->shape[i] + index[i]);
base = base + index[0];
} }
} base = base + offset;
base = MergeMulMod(base);
for (size_t i = 1; i < index.size(); ++i) {
base = MergeMulMod(base * n->shape[i] + index[i]);
} }
} else { } else {
CHECK_EQ(n->strides.size(), index.size()); CHECK_EQ(n->strides.size(), index.size());
......
...@@ -41,6 +41,14 @@ def test_buffer_access_ptr_offset(): ...@@ -41,6 +41,14 @@ def test_buffer_access_ptr_offset():
assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v)) assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v))
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
def test_buffer_vload():
m = tvm.var('m')
n = tvm.var('n')
Ab = tvm.decl_buffer((m, n), tvm.float32, elem_offset=100)
load = Ab.vload([2, 3])
offset = tvm.ir_pass.Simplify(load.index)
assert tvm.ir_pass.Equal(offset, n * 2 + 103)
def test_buffer_index_merge_mult_mod(): def test_buffer_index_merge_mult_mod():
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
...@@ -76,4 +84,5 @@ if __name__ == "__main__": ...@@ -76,4 +84,5 @@ if __name__ == "__main__":
test_buffer() test_buffer()
test_buffer_access_ptr() test_buffer_access_ptr()
test_buffer_access_ptr_offset() test_buffer_access_ptr_offset()
test_buffer_vload()
test_buffer_index_merge_mult_mod() test_buffer_index_merge_mult_mod()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment