Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
be07fac5
Unverified
Commit
be07fac5
authored
Jan 08, 2019
by
Tianqi Chen
Committed by
GitHub
Jan 08, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] not vectorize if_then_else (#2389)
parent
a12c556a
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
73 additions
and
0 deletions
+73
-0
src/codegen/llvm/codegen_llvm.cc
+2
-0
src/pass/vectorize_loop.cc
+42
-0
tests/python/unittest/test_pass_vectorize.py
+29
-0
No files found.
src/codegen/llvm/codegen_llvm.cc
View file @
be07fac5
...
...
@@ -654,6 +654,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_handle_is_null
))
{
return
builder_
->
CreateIsNull
(
MakeValue
(
op
->
args
[
0
]));
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_if_then_else
))
{
CHECK_EQ
(
op
->
args
[
0
].
type
().
lanes
(),
1
)
<<
"if_then_else can only take scalar condition"
;
using
llvm
::
BasicBlock
;
BasicBlock
*
then_block
=
BasicBlock
::
Create
(
*
ctx_
,
"if_then"
,
function_
);
...
...
src/pass/vectorize_loop.cc
View file @
be07fac5
...
...
@@ -83,6 +83,19 @@ class Vectorizer : public IRMutator {
// user mutate from parent.
using
IRMutator
::
Mutate
;
Stmt
Mutate
(
Stmt
stmt
)
final
{
CHECK
(
!
need_scalarize_
);
Stmt
ret
=
IRMutator
::
Mutate
(
stmt
);
if
(
need_scalarize_
)
{
need_scalarize_
=
false
;
return
Scalarize
(
stmt
);
}
else
{
return
ret
;
}
}
Expr
Mutate_
(
const
Add
*
op
,
const
Expr
&
e
)
final
{
return
AddSubVec
(
op
,
e
);
}
...
...
@@ -200,10 +213,37 @@ class Vectorizer : public IRMutator {
return
e
;
}
}
// IfThenElse expr
Expr
MutateIfThenElseExpr_
(
const
Call
*
op
,
const
Expr
&
e
)
{
Expr
cond
=
this
->
Mutate
(
op
->
args
[
0
]);
if
(
cond
.
type
().
is_vector
())
{
need_scalarize_
=
true
;
return
e
;
}
Expr
t
=
this
->
Mutate
(
op
->
args
[
1
]);
Expr
f
=
this
->
Mutate
(
op
->
args
[
2
]);
if
(
cond
.
same_as
(
op
->
args
[
0
])
&&
t
.
same_as
(
op
->
args
[
1
])
&&
f
.
same_as
(
op
->
args
[
2
]))
{
return
e
;
}
else
{
int
lanes
=
std
::
max
(
t
.
type
().
lanes
(),
f
.
type
().
lanes
());
t
=
BroadcastTo
(
t
,
lanes
);
f
=
BroadcastTo
(
f
,
lanes
);
return
Call
::
make
(
op
->
type
.
with_lanes
(
lanes
),
op
->
name
,
{
cond
,
t
,
f
},
op
->
call_type
,
op
->
func
,
op
->
value_index
);
}
}
// Call
Expr
Mutate_
(
const
Call
*
op
,
const
Expr
&
e
)
final
{
if
(
op
->
name
==
intrinsic
::
tvm_if_then_else
)
{
return
MutateIfThenElseExpr_
(
op
,
e
);
}
int
lane
=
0
;
Array
<
Expr
>
new_args
=
MutateArray
(
op
->
args
,
&
lane
);
// normal code path.
if
(
op
->
args
.
same_as
(
new_args
))
{
return
e
;
}
else
{
...
...
@@ -367,6 +407,8 @@ class Vectorizer : public IRMutator {
int
var_lanes_
;
// ramp representing the var.
Expr
ramp_
;
// flag to mark requirment of scalarization.
bool
need_scalarize_
{
false
};
// The lets
std
::
unordered_map
<
const
Variable
*
,
Expr
>
lets_
;
// mutate array, with given lane requirement
...
...
tests/python/unittest/test_pass_vectorize.py
View file @
be07fac5
...
...
@@ -53,7 +53,36 @@ def test_vectorize_with_if():
assert
stmt
.
then_case
.
value
.
dtype
==
"float32x4"
assert
isinstance
(
stmt
.
else_case
,
tvm
.
stmt
.
For
)
def
test_vectorize_if_then_else
():
n
=
tvm
.
var
(
'n'
)
x
=
tvm
.
var
(
'x'
)
ib
=
tvm
.
ir_builder
.
create
()
A
=
ib
.
pointer
(
"float32"
,
name
=
"A"
)
with
ib
.
for_range
(
0
,
4
,
for_type
=
"vectorize"
)
as
i
:
A
[
i
]
=
tvm
.
call_intrin
(
"float32"
,
"tvm_if_then_else"
,
i
>
0
,
A
[
i
]
+
1
,
A
[
i
])
stmt
=
ib
.
get
()
stmt
=
tvm
.
ir_pass
.
VectorizeLoop
(
stmt
)
assert
isinstance
(
stmt
,
tvm
.
stmt
.
For
)
ib
=
tvm
.
ir_builder
.
create
()
A
=
ib
.
pointer
(
"float32"
,
name
=
"A"
)
with
ib
.
for_range
(
0
,
n
)
as
k
:
with
ib
.
for_range
(
0
,
4
,
for_type
=
"vectorize"
)
as
i
:
A
[
k
*
4
+
i
]
=
tvm
.
call_intrin
(
"float32"
,
"tvm_if_then_else"
,
k
>
0
,
A
[
k
*
4
+
i
],
0
)
stmt
=
ib
.
get
()
assert
isinstance
(
stmt
.
body
,
tvm
.
stmt
.
For
)
stmt
=
tvm
.
ir_pass
.
VectorizeLoop
(
stmt
)
assert
not
isinstance
(
stmt
.
body
,
tvm
.
stmt
.
For
)
assert
isinstance
(
stmt
.
body
.
value
.
args
[
2
],
tvm
.
expr
.
Broadcast
)
if
__name__
==
"__main__"
:
test_vectorize_vector
()
test_vectorize_with_if
()
test_vectorize_loop
()
test_vectorize_if_then_else
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment