Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
38274115
Commit
38274115
authored
Mar 08, 2018
by
libing4752
Committed by
Tianqi Chen
Mar 07, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
enhance access_ptr that args can support Expr (#970)
parent
078c767c
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
4 deletions
+38
-4
include/tvm/buffer.h
+1
-1
python/tvm/schedule.py
+27
-2
src/lang/buffer.cc
+1
-1
tests/python/unittest/test_lang_buffer.py
+9
-0
No files found.
include/tvm/buffer.h
View file @
38274115
...
@@ -55,7 +55,7 @@ class Buffer : public NodeRef {
...
@@ -55,7 +55,7 @@ class Buffer : public NodeRef {
* \param offset The offset of ptr.
* \param offset The offset of ptr.
*/
*/
TVM_DLL
Expr
access_ptr
(
int
access_mask
,
Type
ptr_type
=
Handle
(),
TVM_DLL
Expr
access_ptr
(
int
access_mask
,
Type
ptr_type
=
Handle
(),
int
content_lanes
=
1
,
int
offset
=
0
)
const
;
int
content_lanes
=
1
,
Expr
offset
=
make_const
(
Int
(
32
),
0
)
)
const
;
/*!
/*!
* \brief Create an Expr that does a vector load at begin index.
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
* \param begin The beginning index
...
...
python/tvm/schedule.py
View file @
38274115
...
@@ -2,12 +2,34 @@
...
@@ -2,12 +2,34 @@
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
._ffi.base
import
string_types
from
._ffi.base
import
string_types
from
._ffi.node
import
NodeBase
,
register_node
from
._ffi.node
import
NodeBase
,
register_node
from
._ffi.function
import
_init_api
from
._ffi.node
import
convert_to_node
as
_convert_to_node
from
._ffi.function
import
_init_api
,
Function
from
._ffi.function
import
convert_to_tvm_func
as
_convert_tvm_func
from
.
import
_api_internal
from
.
import
_api_internal
from
.
import
tensor
as
_tensor
from
.
import
tensor
as
_tensor
from
.
import
expr
as
_expr
from
.
import
expr
as
_expr
from
.
import
container
as
_container
from
.
import
container
as
_container
def
convert
(
value
):
"""Convert value to TVM node or function.
Parameters
----------
value : python value
Returns
-------
tvm_val : Node or Function
Converted value in TVM
"""
if
isinstance
(
value
,
(
Function
,
NodeBase
)):
return
value
if
callable
(
value
):
return
_convert_tvm_func
(
value
)
return
_convert_to_node
(
value
)
@register_node
@register_node
class
Buffer
(
NodeBase
):
class
Buffer
(
NodeBase
):
"""Symbolic data buffer in TVM.
"""Symbolic data buffer in TVM.
...
@@ -45,7 +67,7 @@ class Buffer(NodeBase):
...
@@ -45,7 +67,7 @@ class Buffer(NodeBase):
The number of lanes for the data type. This value
The number of lanes for the data type. This value
is greater than one for vector types.
is greater than one for vector types.
offset:
int
, optional
offset:
Expr
, optional
The offset of pointer. We can use it to offset by
The offset of pointer. We can use it to offset by
the number of elements from the address of ptr.
the number of elements from the address of ptr.
...
@@ -60,6 +82,8 @@ class Buffer(NodeBase):
...
@@ -60,6 +82,8 @@ class Buffer(NodeBase):
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# Get access ptr for read/write with str flag
# Get access ptr for read/write with str flag
buffer.access_ptr("rw")
buffer.access_ptr("rw")
# Get access ptr for read with offset
buffer.access_ptr("r", offset = 100)
"""
"""
if
isinstance
(
access_mask
,
string_types
):
if
isinstance
(
access_mask
,
string_types
):
mask
=
0
mask
=
0
...
@@ -71,6 +95,7 @@ class Buffer(NodeBase):
...
@@ -71,6 +95,7 @@ class Buffer(NodeBase):
else
:
else
:
raise
ValueError
(
"Unknown access_mask
%
s"
%
access_mask
)
raise
ValueError
(
"Unknown access_mask
%
s"
%
access_mask
)
access_mask
=
mask
access_mask
=
mask
offset
=
convert
(
offset
)
return
_api_internal
.
_BufferAccessPtr
(
self
,
access_mask
,
ptr_type
,
return
_api_internal
.
_BufferAccessPtr
(
self
,
access_mask
,
ptr_type
,
content_lanes
,
offset
)
content_lanes
,
offset
)
...
...
src/lang/buffer.cc
View file @
38274115
...
@@ -335,7 +335,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
...
@@ -335,7 +335,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
0
);
0
);
}
}
Expr
Buffer
::
access_ptr
(
int
access_mask
,
Type
ptr_type
,
int
content_lanes
,
int
offset
)
const
{
Expr
Buffer
::
access_ptr
(
int
access_mask
,
Type
ptr_type
,
int
content_lanes
,
Expr
offset
)
const
{
const
BufferNode
*
self
=
operator
->
();
const
BufferNode
*
self
=
operator
->
();
Expr
e_dtype
;
Expr
e_dtype
;
Expr
extent
;
Expr
extent
;
...
...
tests/python/unittest/test_lang_buffer.py
View file @
38274115
...
@@ -31,6 +31,15 @@ def test_buffer_access_ptr_offset():
...
@@ -31,6 +31,15 @@ def test_buffer_access_ptr_offset():
offset
=
tvm
.
ir_pass
.
Simplify
(
aptr
.
args
[
2
])
offset
=
tvm
.
ir_pass
.
Simplify
(
aptr
.
args
[
2
])
assert
tvm
.
ir_pass
.
Equal
(
offset
,
100
)
assert
tvm
.
ir_pass
.
Equal
(
offset
,
100
)
assert
aptr
.
args
[
4
]
.
value
==
Buffer
.
READ
|
Buffer
.
WRITE
assert
aptr
.
args
[
4
]
.
value
==
Buffer
.
READ
|
Buffer
.
WRITE
v
=
tvm
.
var
(
'int32'
)
aptr
=
Ab
.
access_ptr
(
"rw"
,
offset
=
100
+
100
+
v
)
offset
=
tvm
.
ir_pass
.
Simplify
(
aptr
.
args
[
2
])
assert
tvm
.
ir_pass
.
Equal
(
offset
,
200
+
v
)
assert
aptr
.
args
[
4
]
.
value
==
Buffer
.
READ
|
Buffer
.
WRITE
aptr
=
Ab
.
access_ptr
(
"rw"
,
offset
=
tvm
.
call_extern
(
'int32'
,
"test_call"
,
100
+
100
+
v
))
offset
=
tvm
.
ir_pass
.
Simplify
(
aptr
.
args
[
2
])
assert
tvm
.
ir_pass
.
Equal
(
offset
,
tvm
.
call_extern
(
'int32'
,
"test_call"
,
200
+
v
))
assert
aptr
.
args
[
4
]
.
value
==
Buffer
.
READ
|
Buffer
.
WRITE
def
test_buffer_index_merge_mult_mod
():
def
test_buffer_index_merge_mult_mod
():
m
=
tvm
.
var
(
'm'
)
m
=
tvm
.
var
(
'm'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment