Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
e3695cad
Commit
e3695cad
authored
May 09, 2017
by
Tianqi Chen
Committed by
GitHub
May 09, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[BUGFIX/PASS] Fix Vectorize with If condition (#135)
parent
e9debc9b
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
44 additions
and
15 deletions
+44
-15
python/tvm/ir_builder.py
+15
-2
src/pass/vectorize_loop.cc
+1
-1
tests/python/unittest/test_pass_vectorize.py
+28
-12
No files found.
python/tvm/ir_builder.py
View file @
e3695cad
...
@@ -144,7 +144,7 @@ class IRBuilder(object):
...
@@ -144,7 +144,7 @@ class IRBuilder(object):
value
=
_make
.
StringImm
(
value
)
value
=
_make
.
StringImm
(
value
)
self
.
emit
(
lambda
x
:
_make
.
AttrStmt
(
node
,
attr_key
,
value
,
x
))
self
.
emit
(
lambda
x
:
_make
.
AttrStmt
(
node
,
attr_key
,
value
,
x
))
def
for_range
(
self
,
begin
,
end
,
name
=
"i"
,
dtype
=
"int32"
):
def
for_range
(
self
,
begin
,
end
,
name
=
"i"
,
dtype
=
"int32"
,
for_type
=
"serial"
):
"""Create a for iteration scope.
"""Create a for iteration scope.
Parameters
Parameters
...
@@ -161,6 +161,9 @@ class IRBuilder(object):
...
@@ -161,6 +161,9 @@ class IRBuilder(object):
dtype : str, optional
dtype : str, optional
The data type of iteration variable.
The data type of iteration variable.
for_type : str, optional
The special tag on the for loop.
Returns
Returns
-------
-------
loop_scope : With.Scope of Var
loop_scope : With.Scope of Var
...
@@ -179,8 +182,18 @@ class IRBuilder(object):
...
@@ -179,8 +182,18 @@ class IRBuilder(object):
loop_var
=
_api
.
var
(
name
,
dtype
=
dtype
)
loop_var
=
_api
.
var
(
name
,
dtype
=
dtype
)
extent
=
end
if
begin
==
0
else
_pass
.
Simplify
(
end
-
begin
)
extent
=
end
if
begin
==
0
else
_pass
.
Simplify
(
end
-
begin
)
def
_exit_cb
():
def
_exit_cb
():
if
for_type
==
"serial"
:
for_type_id
=
0
elif
for_type
==
"parallel"
:
for_type_id
=
1
elif
for_type
==
"vectorize"
:
for_type_id
=
2
elif
for_type
==
"unroll"
:
for_type_id
=
3
else
:
raise
ValueError
(
"Unknown for_type"
)
self
.
emit
(
_make
.
For
(
self
.
emit
(
_make
.
For
(
loop_var
,
begin
,
extent
,
0
,
0
,
self
.
_pop_seq
()))
loop_var
,
begin
,
extent
,
for_type_id
,
0
,
self
.
_pop_seq
()))
return
WithScope
(
loop_var
,
_exit_cb
)
return
WithScope
(
loop_var
,
_exit_cb
)
def
if_scope
(
self
,
cond
):
def
if_scope
(
self
,
cond
):
...
...
src/pass/vectorize_loop.cc
View file @
e3695cad
...
@@ -252,7 +252,7 @@ class Vectorizer : public IRMutator {
...
@@ -252,7 +252,7 @@ class Vectorizer : public IRMutator {
}
}
Stmt
then_case
=
this
->
Mutate
(
op
->
then_case
);
Stmt
then_case
=
this
->
Mutate
(
op
->
then_case
);
Stmt
else_case
;
Stmt
else_case
;
if
(
else_case
.
defined
())
{
if
(
op
->
else_case
.
defined
())
{
else_case
=
this
->
Mutate
(
op
->
else_case
);
else_case
=
this
->
Mutate
(
op
->
else_case
);
}
}
if
(
condition
.
same_as
(
op
->
condition
)
&&
if
(
condition
.
same_as
(
op
->
condition
)
&&
...
...
tests/python/unittest/test_pass_vectorize.py
View file @
e3695cad
...
@@ -3,22 +3,38 @@ import tvm
...
@@ -3,22 +3,38 @@ import tvm
def
test_vectorize_loop
():
def
test_vectorize_loop
():
dtype
=
'int64'
dtype
=
'int64'
n
=
tvm
.
var
(
'n'
)
n
=
tvm
.
var
(
'n'
)
Ab
=
tvm
.
decl_buffer
((
n
,
),
dtype
)
ib
=
tvm
.
ir_builder
.
create
()
i
=
tvm
.
var
(
'i'
)
A
=
ib
.
pointer
(
"float32"
,
name
=
"A"
)
j
=
tvm
.
var
(
'j'
)
with
ib
.
for_range
(
0
,
n
)
as
i
:
VECTORIZE
=
2
with
ib
.
for_range
(
0
,
4
,
for_type
=
"vectorize"
)
as
j
:
# for i in 0 to n-1:
A
[
j
+
1
]
=
A
[
i
]
+
1
stmt
=
tvm
.
make
.
For
(
stmt
=
ib
.
get
()
i
,
n
,
2
,
0
,
0
,
tvm
.
make
.
For
(
j
,
0
,
4
,
VECTORIZE
,
0
,
tvm
.
make
.
Store
(
Ab
.
data
,
tvm
.
make
.
Load
(
dtype
,
Ab
.
data
,
i
)
+
1
,
j
+
1
)))
assert
isinstance
(
stmt
.
body
,
tvm
.
stmt
.
For
)
assert
isinstance
(
stmt
.
body
,
tvm
.
stmt
.
For
)
stmt
=
tvm
.
ir_pass
.
VectorizeLoop
(
stmt
)
stmt
=
tvm
.
ir_pass
.
VectorizeLoop
(
stmt
)
assert
isinstance
(
stmt
,
tvm
.
stmt
.
For
)
assert
isinstance
(
stmt
,
tvm
.
stmt
.
For
)
assert
not
isinstance
(
stmt
.
body
,
tvm
.
stmt
.
For
)
assert
not
isinstance
(
stmt
.
body
,
tvm
.
stmt
.
For
)
print
(
stmt
)
assert
isinstance
(
stmt
.
body
.
index
,
tvm
.
expr
.
Ramp
)
assert
isinstance
(
stmt
.
body
.
value
,
tvm
.
expr
.
Broadcast
)
def
test_vectorize_with_if
():
n
=
tvm
.
var
(
'n'
)
x
=
tvm
.
var
(
'x'
)
ib
=
tvm
.
ir_builder
.
create
()
A
=
ib
.
pointer
(
"float32"
,
name
=
"A"
)
with
ib
.
for_range
(
0
,
4
,
for_type
=
"vectorize"
)
as
i
:
with
ib
.
if_scope
(
x
<
n
):
A
[
i
]
=
A
[
i
]
+
1
with
ib
.
else_scope
():
with
ib
.
if_scope
(
i
<
n
):
A
[
i
]
=
2.0
stmt
=
ib
.
get
()
stmt
=
tvm
.
ir_pass
.
VectorizeLoop
(
stmt
)
assert
isinstance
(
stmt
,
tvm
.
stmt
.
IfThenElse
)
assert
isinstance
(
stmt
.
then_case
.
index
,
tvm
.
expr
.
Ramp
)
assert
isinstance
(
stmt
.
then_case
.
value
,
tvm
.
expr
.
Add
)
assert
stmt
.
then_case
.
value
.
dtype
==
"float32x4"
assert
isinstance
(
stmt
.
else_case
,
tvm
.
stmt
.
For
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_vectorize_with_if
()
test_vectorize_loop
()
test_vectorize_loop
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment