Commit 38274115 by libing4752 Committed by Tianqi Chen

enhance access_ptr that args can support Expr (#970)

parent 078c767c
......@@ -55,7 +55,7 @@ class Buffer : public NodeRef {
* \param offset The offset of ptr.
*/
TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(),
int content_lanes = 1, int offset = 0) const;
int content_lanes = 1, Expr offset = make_const(Int(32), 0)) const;
/*!
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
......
......@@ -2,12 +2,34 @@
from __future__ import absolute_import as _abs
from ._ffi.base import string_types
from ._ffi.node import NodeBase, register_node
from ._ffi.function import _init_api
from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.function import _init_api, Function
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
from . import _api_internal
from . import tensor as _tensor
from . import expr as _expr
from . import container as _container
def convert(value):
"""Convert value to TVM node or function.
Parameters
----------
value : python value
Returns
-------
tvm_val : Node or Function
Converted value in TVM
"""
if isinstance(value, (Function, NodeBase)):
return value
if callable(value):
return _convert_tvm_func(value)
return _convert_to_node(value)
@register_node
class Buffer(NodeBase):
"""Symbolic data buffer in TVM.
......@@ -45,7 +67,7 @@ class Buffer(NodeBase):
The number of lanes for the data type. This value
is greater than one for vector types.
offset: int, optional
offset: Expr, optional
The offset of pointer. We can use it to offset by
the number of elements from the address of ptr.
......@@ -60,6 +82,8 @@ class Buffer(NodeBase):
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# Get access ptr for read/write with str flag
buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
"""
if isinstance(access_mask, string_types):
mask = 0
......@@ -71,6 +95,7 @@ class Buffer(NodeBase):
else:
raise ValueError("Unknown access_mask %s" % access_mask)
access_mask = mask
offset = convert(offset)
return _api_internal._BufferAccessPtr(self, access_mask, ptr_type,
content_lanes, offset)
......
......@@ -335,7 +335,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
0);
}
Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, int offset) const {
Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr offset) const {
const BufferNode* self = operator->();
Expr e_dtype;
Expr extent;
......
......@@ -31,6 +31,15 @@ def test_buffer_access_ptr_offset():
offset = tvm.ir_pass.Simplify(aptr.args[2])
assert tvm.ir_pass.Equal(offset, 100)
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
v = tvm.var('int32')
aptr = Ab.access_ptr("rw", offset=100 + 100 + v)
offset = tvm.ir_pass.Simplify(aptr.args[2])
assert tvm.ir_pass.Equal(offset, 200 + v)
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
aptr = Ab.access_ptr("rw", offset=tvm.call_extern('int32', "test_call", 100 + 100 + v))
offset = tvm.ir_pass.Simplify(aptr.args[2])
assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v))
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
def test_buffer_index_merge_mult_mod():
m = tvm.var('m')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment