Commit 293dac39 by kun-zh Committed by Tianqi Chen

support using pointer with an original offset (#826)

* when there is no intrin func, using body for initialization. For issue 714.

* Refine code per review comments, and add a test case.

* Fix lint issues.

* Re-organize the tensorize test cases, and add a new case for none-reset
mode.

* Fix a typo.

* Delete the unit case because merged it into test_schedule_tensorize.py already.

* always use new tensor in its stage when rewrite for cache read

* revert previous changes to sync up with master

* support using the ptr with an original offset

* update test case and fix CI error
parent 0b54952b
...@@ -52,9 +52,10 @@ class Buffer : public NodeRef { ...@@ -52,9 +52,10 @@ class Buffer : public NodeRef {
* \param access_mask The access mask * \param access_mask The access mask
* \param ptr_type The type of the pointer. * \param ptr_type The type of the pointer.
* \param content_lanes The number of lanes for the (data) type. * \param content_lanes The number of lanes for the (data) type.
* \param offset The offset of ptr.
*/ */
TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(), TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(),
int content_lanes = 1) const; int content_lanes = 1, int offset = 0) const;
/*! /*!
* \brief Create an Expr that does a vector load at begin index. * \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index * \param begin The beginning index
......
...@@ -25,7 +25,7 @@ class Buffer(NodeBase): ...@@ -25,7 +25,7 @@ class Buffer(NodeBase):
READ = 1 READ = 1
WRITE = 2 WRITE = 2
def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1): def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
"""Get an access pointer to the head of buffer. """Get an access pointer to the head of buffer.
This is the recommended method to get buffer data This is the recommended method to get buffer data
...@@ -45,6 +45,10 @@ class Buffer(NodeBase): ...@@ -45,6 +45,10 @@ class Buffer(NodeBase):
The number of lanes for the data type. This value The number of lanes for the data type. This value
is greater than one for vector types. is greater than one for vector types.
offset: int, optional
The offset of pointer. We can use it to offset by
the number of elements from the address of ptr.
Examples Examples
-------- --------
.. code-block:: python .. code-block:: python
...@@ -68,7 +72,7 @@ class Buffer(NodeBase): ...@@ -68,7 +72,7 @@ class Buffer(NodeBase):
raise ValueError("Unknown access_mask %s" % access_mask) raise ValueError("Unknown access_mask %s" % access_mask)
access_mask = mask access_mask = mask
return _api_internal._BufferAccessPtr(self, access_mask, ptr_type, return _api_internal._BufferAccessPtr(self, access_mask, ptr_type,
content_lanes) content_lanes, offset)
def vload(self, begin, dtype=None): def vload(self, begin, dtype=None):
"""Generate an Expr that loads dtype from begin index. """Generate an Expr that loads dtype from begin index.
......
...@@ -159,7 +159,7 @@ TVM_REGISTER_API("_Buffer") ...@@ -159,7 +159,7 @@ TVM_REGISTER_API("_Buffer")
TVM_REGISTER_API("_BufferAccessPtr") TVM_REGISTER_API("_BufferAccessPtr")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Buffer() *ret = args[0].operator Buffer()
.access_ptr(args[1], args[2], args[3]); .access_ptr(args[1], args[2], args[3], args[4]);
}); });
TVM_REGISTER_API("_BufferVLoad") TVM_REGISTER_API("_BufferVLoad")
......
...@@ -335,7 +335,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const { ...@@ -335,7 +335,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
0); 0);
} }
Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes) const { Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, int offset) const {
const BufferNode* self = operator->(); const BufferNode* self = operator->();
Expr e_dtype; Expr e_dtype;
Expr extent; Expr extent;
...@@ -348,7 +348,7 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes) const ...@@ -348,7 +348,7 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes) const
} else { } else {
extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr()); extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr());
} }
Expr elem_offset = self->elem_offset; Expr elem_offset = self->elem_offset + offset;
if (content_lanes > 1) { if (content_lanes > 1) {
e_dtype = make_zero(self->dtype.with_lanes(content_lanes)); e_dtype = make_zero(self->dtype.with_lanes(content_lanes));
extent = extent / make_const(self->elem_offset.type(), content_lanes); extent = extent / make_const(self->elem_offset.type(), content_lanes);
......
...@@ -23,6 +23,15 @@ def test_buffer_access_ptr(): ...@@ -23,6 +23,15 @@ def test_buffer_access_ptr():
aptr = Ab.access_ptr("w") aptr = Ab.access_ptr("w")
assert aptr.args[4].value == Buffer.WRITE assert aptr.args[4].value == Buffer.WRITE
def test_buffer_access_ptr_offset():
m = tvm.var('m')
n = tvm.var('n')
Ab = tvm.decl_buffer((m, n), tvm.float32)
aptr = Ab.access_ptr("rw", offset=100)
offset = tvm.ir_pass.Simplify(aptr.args[2])
assert tvm.ir_pass.Equal(offset, 100)
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
def test_buffer_index_merge_mult_mod(): def test_buffer_index_merge_mult_mod():
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
...@@ -57,4 +66,5 @@ def test_buffer_index_merge_mult_mod(): ...@@ -57,4 +66,5 @@ def test_buffer_index_merge_mult_mod():
if __name__ == "__main__": if __name__ == "__main__":
test_buffer() test_buffer()
test_buffer_access_ptr() test_buffer_access_ptr()
test_buffer_access_ptr_offset()
test_buffer_index_merge_mult_mod() test_buffer_index_merge_mult_mod()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment