Commit 160e4107 by Yizhi Liu Committed by Tianqi Chen

fix buffer elem_offset calculation (#1762)

parent 5ed52a5f
......@@ -226,16 +226,12 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
Expr base = n->elem_offset;
if (n->strides.size() == 0) {
CHECK_EQ(n->shape.size(), index.size());
if (n->shape.size() != 0) {
if (is_zero(base)) {
base = index[0];
} else {
base = base + index[0];
if (index.size() > 0) {
Expr offset = index[0];
for (size_t i = 1; i < index.size(); ++i) {
offset = MergeMulMod(offset * n->shape[i] + index[i]);
}
}
base = MergeMulMod(base);
for (size_t i = 1; i < index.size(); ++i) {
base = MergeMulMod(base * n->shape[i] + index[i]);
base = base + offset;
}
} else {
CHECK_EQ(n->strides.size(), index.size());
......
......@@ -41,6 +41,14 @@ def test_buffer_access_ptr_offset():
assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v))
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
def test_buffer_vload():
m = tvm.var('m')
n = tvm.var('n')
Ab = tvm.decl_buffer((m, n), tvm.float32, elem_offset=100)
load = Ab.vload([2, 3])
offset = tvm.ir_pass.Simplify(load.index)
assert tvm.ir_pass.Equal(offset, n * 2 + 103)
def test_buffer_index_merge_mult_mod():
m = tvm.var('m')
n = tvm.var('n')
......@@ -76,4 +84,5 @@ if __name__ == "__main__":
test_buffer()
test_buffer_access_ptr()
test_buffer_access_ptr_offset()
test_buffer_vload()
test_buffer_index_merge_mult_mod()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment