Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
bc48811f
Commit
bc48811f
authored
Oct 31, 2018
by
kun-zh
Committed by
Tianqi Chen
Oct 30, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix a bug in inject-virtual-thread (#2039)
parent
b840e960
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
1 deletions
+20
-1
src/pass/inject_virtual_thread.cc
+1
-1
tests/python/unittest/test_pass_inject_vthread.py
+19
-0
No files found.
src/pass/inject_virtual_thread.cc
View file @
bc48811f
...
...
@@ -321,7 +321,7 @@ class VTInjector : public IRMutator {
CHECK_EQ
(
max_loop_depth_
,
0
);
Stmt
then_case
=
this
->
Mutate
(
op
->
then_case
);
Stmt
else_case
;
if
(
else_case
.
defined
())
{
if
(
op
->
else_case
.
defined
())
{
int
temp
=
max_loop_depth_
;
max_loop_depth_
=
0
;
else_case
=
this
->
Mutate
(
op
->
else_case
);
...
...
tests/python/unittest/test_pass_inject_vthread.py
View file @
bc48811f
...
...
@@ -60,7 +60,26 @@ def test_vthread_extern():
assert
stmt
.
body
.
body
.
body
.
body
.
body
.
body
.
extents
[
0
]
.
value
==
2
assert
len
(
stmt
.
body
.
body
.
body
.
body
.
body
.
body
.
extents
)
==
3
def
test_vthread_if_then_else
():
nthread
=
2
tx
=
tvm
.
thread_axis
(
"vthread"
)
ib
=
tvm
.
ir_builder
.
create
()
A
=
ib
.
pointer
(
"float32"
,
name
=
"A"
)
with
ib
.
for_range
(
0
,
100
)
as
i
:
ib
.
scope_attr
(
tx
,
"virtual_thread"
,
nthread
)
B
=
ib
.
allocate
(
"float32"
,
128
,
name
=
"B"
,
scope
=
"shared"
)
with
ib
.
if_scope
(
i
==
0
):
B
[
i
]
=
A
[
i
*
nthread
+
tx
]
with
ib
.
else_scope
():
B
[
i
]
=
A
[
i
*
nthread
+
tx
]
+
1
with
ib
.
if_scope
(
i
==
0
):
B
[
i
]
=
A
[
i
*
nthread
+
tx
]
+
2
stmt
=
ib
.
get
()
stmt
=
tvm
.
ir_pass
.
InjectVirtualThread
(
stmt
)
assert
stmt
.
body
.
body
.
body
.
first
.
else_case
!=
None
assert
stmt
.
body
.
body
.
body
.
rest
.
else_case
==
None
if
__name__
==
"__main__"
:
test_vthread_extern
()
test_vthread
()
test_vthread_if_then_else
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment