Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
0c523787
Commit
0c523787
authored
Aug 29, 2018
by
Lianmin Zheng
Committed by
Tianqi Chen
Aug 29, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] Enhance gpu verify pass (#1660)
parent
9f99a4fa
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
1 deletions
+39
-1
src/pass/verify_gpu_code.cc
+15
-1
tests/python/unittest/test_pass_verify_gpu_code.py
+24
-0
No files found.
src/pass/verify_gpu_code.cc
View file @
0c523787
...
...
@@ -86,17 +86,29 @@ class GPUCodeVerifier : public IRVisitor {
// record the number of threads in a block
std
::
string
name
=
var
.
get
()
->
name_hint
;
if
(
name
==
"threadIdx.x"
||
name
==
"threadIdx.y"
||
name
==
"threadIdx.z"
)
{
size_t
length
=
static_cast
<
size_t
>
(
extent
->
value
);
if
(
!
visited_threads_
.
count
(
name
))
{
visited_threads_
.
insert
(
name
);
size_t
length
=
static_cast
<
size_t
>
(
extent
->
value
);
thread_per_block_
*=
length
;
if
(
name
==
"threadIdx.x"
)
{
valid_
&=
length
<=
max_thread_x_
;
thread_x_extent_
=
length
;
}
else
if
(
name
==
"threadIdx.y"
)
{
valid_
&=
length
<=
max_thread_y_
;
thread_y_extent_
=
length
;
}
else
if
(
name
==
"threadIdx.z"
)
{
valid_
&=
length
<=
max_thread_z_
;
thread_z_extent_
=
length
;
}
}
else
{
// the thread should be bound to axes with the same length
if
(
name
==
"threadIdx.x"
)
{
valid_
&=
length
==
thread_x_extent_
;
}
else
if
(
name
==
"threadIdx.y"
)
{
valid_
&=
length
==
thread_y_extent_
;
}
else
if
(
name
==
"threadIdx.z"
)
{
valid_
&=
length
==
thread_z_extent_
;
}
}
}
...
...
@@ -111,6 +123,8 @@ class GPUCodeVerifier : public IRVisitor {
std
::
unordered_set
<
const
tvm
::
Variable
*>
visited_shared_buffers_
;
std
::
unordered_set
<
std
::
string
>
visited_threads_
;
size_t
thread_x_extent_
,
thread_y_extent_
,
thread_z_extent_
;
size_t
local_memory_per_block_
;
size_t
shared_memory_per_block_
;
size_t
thread_per_block_
;
...
...
tests/python/unittest/test_pass_verify_gpu_code.py
View file @
0c523787
...
...
@@ -162,8 +162,32 @@ def test_multiple_kernels():
tvm
.
build
(
s
,
[
A
,
C
],
target
)
assert
valid
[
0
]
def
test_wrong_bind
():
N
=
1024
A
=
tvm
.
placeholder
((
N
,
N
-
1
),
name
=
'A'
)
B
=
tvm
.
compute
((
N
,
N
-
1
),
lambda
i
,
j
:
A
[
i
,
j
])
s
=
tvm
.
create_schedule
([
B
.
op
])
# bind a thread axis to two loop axes with different lengths
s
[
B
]
.
bind
(
s
[
B
]
.
op
.
axis
[
0
],
tvm
.
thread_axis
(
"threadIdx.x"
))
s
[
B
]
.
bind
(
s
[
B
]
.
op
.
axis
[
1
],
tvm
.
thread_axis
(
"threadIdx.x"
))
for
target
in
[
'opencl'
,
'cuda'
]:
if
not
tvm
.
context
(
target
)
.
exist
:
continue
valid
=
[
None
]
with
tvm
.
build_config
(
**
{
"add_lower_pass"
:
[
(
2
,
get_verify_pass
(
valid
,
max_threads_per_block
=
N
*
N
))]}):
tvm
.
build
(
s
,
[
A
,
B
],
target
)
assert
not
valid
[
0
]
if
__name__
==
"__main__"
:
test_local_memory
()
test_shared_memory
()
test_num_thread
()
test_multiple_kernels
()
test_wrong_bind
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment