Commit 38274115 by libing4752 Committed by Tianqi Chen

enhance access_ptr that args can support Expr (#970)

parent 078c767c
...@@ -55,7 +55,7 @@ class Buffer : public NodeRef { ...@@ -55,7 +55,7 @@ class Buffer : public NodeRef {
* \param offset The offset of ptr. * \param offset The offset of ptr.
*/ */
TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(), TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(),
int content_lanes = 1, int offset = 0) const; int content_lanes = 1, Expr offset = make_const(Int(32), 0)) const;
/*! /*!
* \brief Create an Expr that does a vector load at begin index. * \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index * \param begin The beginning index
......
...@@ -2,12 +2,34 @@ ...@@ -2,12 +2,34 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.base import string_types from ._ffi.base import string_types
from ._ffi.node import NodeBase, register_node from ._ffi.node import NodeBase, register_node
from ._ffi.function import _init_api from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.function import _init_api, Function
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
from . import _api_internal from . import _api_internal
from . import tensor as _tensor from . import tensor as _tensor
from . import expr as _expr from . import expr as _expr
from . import container as _container from . import container as _container
def convert(value):
"""Convert value to TVM node or function.
Parameters
----------
value : python value
Returns
-------
tvm_val : Node or Function
Converted value in TVM
"""
if isinstance(value, (Function, NodeBase)):
return value
if callable(value):
return _convert_tvm_func(value)
return _convert_to_node(value)
@register_node @register_node
class Buffer(NodeBase): class Buffer(NodeBase):
"""Symbolic data buffer in TVM. """Symbolic data buffer in TVM.
...@@ -45,7 +67,7 @@ class Buffer(NodeBase): ...@@ -45,7 +67,7 @@ class Buffer(NodeBase):
The number of lanes for the data type. This value The number of lanes for the data type. This value
is greater than one for vector types. is greater than one for vector types.
offset: int, optional offset: Expr, optional
The offset of pointer. We can use it to offset by The offset of pointer. We can use it to offset by
the number of elements from the address of ptr. the number of elements from the address of ptr.
...@@ -60,6 +82,8 @@ class Buffer(NodeBase): ...@@ -60,6 +82,8 @@ class Buffer(NodeBase):
buffer.access_ptr(Buffer.READ | Buffer.WRITE) buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# Get access ptr for read/write with str flag # Get access ptr for read/write with str flag
buffer.access_ptr("rw") buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
""" """
if isinstance(access_mask, string_types): if isinstance(access_mask, string_types):
mask = 0 mask = 0
...@@ -71,6 +95,7 @@ class Buffer(NodeBase): ...@@ -71,6 +95,7 @@ class Buffer(NodeBase):
else: else:
raise ValueError("Unknown access_mask %s" % access_mask) raise ValueError("Unknown access_mask %s" % access_mask)
access_mask = mask access_mask = mask
offset = convert(offset)
return _api_internal._BufferAccessPtr(self, access_mask, ptr_type, return _api_internal._BufferAccessPtr(self, access_mask, ptr_type,
content_lanes, offset) content_lanes, offset)
......
...@@ -335,7 +335,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const { ...@@ -335,7 +335,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
0); 0);
} }
Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, int offset) const { Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr offset) const {
const BufferNode* self = operator->(); const BufferNode* self = operator->();
Expr e_dtype; Expr e_dtype;
Expr extent; Expr extent;
......
...@@ -31,6 +31,15 @@ def test_buffer_access_ptr_offset(): ...@@ -31,6 +31,15 @@ def test_buffer_access_ptr_offset():
offset = tvm.ir_pass.Simplify(aptr.args[2]) offset = tvm.ir_pass.Simplify(aptr.args[2])
assert tvm.ir_pass.Equal(offset, 100) assert tvm.ir_pass.Equal(offset, 100)
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
v = tvm.var('int32')
aptr = Ab.access_ptr("rw", offset=100 + 100 + v)
offset = tvm.ir_pass.Simplify(aptr.args[2])
assert tvm.ir_pass.Equal(offset, 200 + v)
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
aptr = Ab.access_ptr("rw", offset=tvm.call_extern('int32', "test_call", 100 + 100 + v))
offset = tvm.ir_pass.Simplify(aptr.args[2])
assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v))
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
def test_buffer_index_merge_mult_mod(): def test_buffer_index_merge_mult_mod():
m = tvm.var('m') m = tvm.var('m')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment