Commit 293dac39 by kun-zh Committed by Tianqi Chen

support using pointer with an original offset (#826)

* when there is no intrin func, using body for initialization. For issue 714.

* Refine code per review comments, and add a test case.

* Fix lint issues.

* Re-organize the tensorize test cases, and add a new case for none-reset
mode.

* Fix a typo.

* Delete the unit case because merged it into test_schedule_tensorize.py already.

* always use new tensor in its stage when rewrite for cache read

* revert previous changes to sync up with master

* support using the ptr with an original offset

* update test case and fix CI error
parent 0b54952b
......@@ -52,9 +52,10 @@ class Buffer : public NodeRef {
* \param access_mask The access mask
* \param ptr_type The type of the pointer.
* \param content_lanes The number of lanes for the (data) type.
* \param offset The offset of ptr.
*/
TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(),
int content_lanes = 1) const;
int content_lanes = 1, int offset = 0) const;
/*!
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
......
......@@ -25,7 +25,7 @@ class Buffer(NodeBase):
READ = 1
WRITE = 2
def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1):
def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
"""Get an access pointer to the head of buffer.
This is the recommended method to get buffer data
......@@ -45,6 +45,10 @@ class Buffer(NodeBase):
The number of lanes for the data type. This value
is greater than one for vector types.
offset: int, optional
The offset of pointer. We can use it to offset by
the number of elements from the address of ptr.
Examples
--------
.. code-block:: python
......@@ -68,7 +72,7 @@ class Buffer(NodeBase):
raise ValueError("Unknown access_mask %s" % access_mask)
access_mask = mask
return _api_internal._BufferAccessPtr(self, access_mask, ptr_type,
content_lanes)
content_lanes, offset)
def vload(self, begin, dtype=None):
"""Generate an Expr that loads dtype from begin index.
......
......@@ -159,7 +159,7 @@ TVM_REGISTER_API("_Buffer")
TVM_REGISTER_API("_BufferAccessPtr")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Buffer()
.access_ptr(args[1], args[2], args[3]);
.access_ptr(args[1], args[2], args[3], args[4]);
});
TVM_REGISTER_API("_BufferVLoad")
......
......@@ -335,7 +335,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
0);
}
Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes) const {
Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, int offset) const {
const BufferNode* self = operator->();
Expr e_dtype;
Expr extent;
......@@ -348,7 +348,7 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes) const
} else {
extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr());
}
Expr elem_offset = self->elem_offset;
Expr elem_offset = self->elem_offset + offset;
if (content_lanes > 1) {
e_dtype = make_zero(self->dtype.with_lanes(content_lanes));
extent = extent / make_const(self->elem_offset.type(), content_lanes);
......
......@@ -23,6 +23,15 @@ def test_buffer_access_ptr():
aptr = Ab.access_ptr("w")
assert aptr.args[4].value == Buffer.WRITE
def test_buffer_access_ptr_offset():
m = tvm.var('m')
n = tvm.var('n')
Ab = tvm.decl_buffer((m, n), tvm.float32)
aptr = Ab.access_ptr("rw", offset=100)
offset = tvm.ir_pass.Simplify(aptr.args[2])
assert tvm.ir_pass.Equal(offset, 100)
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
def test_buffer_index_merge_mult_mod():
m = tvm.var('m')
n = tvm.var('n')
......@@ -57,4 +66,5 @@ def test_buffer_index_merge_mult_mod():
if __name__ == "__main__":
test_buffer()
test_buffer_access_ptr()
test_buffer_access_ptr_offset()
test_buffer_index_merge_mult_mod()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment