Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
090468aa
Commit
090468aa
authored
Aug 15, 2017
by
Tianqi Chen
Committed by
GitHub
Aug 15, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] RewriteUnsafeSelect lowers unsafe select to condition expr (#335)
parent
25ded693
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
208 additions
and
0 deletions
+208
-0
include/tvm/ir.h
+8
-0
include/tvm/ir_pass.h
+7
-0
python/tvm/build_module.py
+1
-0
src/api/api_pass.cc
+1
-0
src/codegen/codegen_c.cc
+8
-0
src/codegen/llvm/codegen_llvm.cc
+25
-0
src/pass/rewrite_unsafe_select.cc
+115
-0
tests/python/unittest/test_codegen_llvm.py
+21
-0
tests/python/unittest/test_pass_rewrite_unsafe_select.py
+22
-0
No files found.
include/tvm/ir.h
View file @
090468aa
...
@@ -218,6 +218,14 @@ namespace intrinsic {
...
@@ -218,6 +218,14 @@ namespace intrinsic {
*/
*/
constexpr
const
char
*
tvm_address_of
=
"tvm_address_of"
;
constexpr
const
char
*
tvm_address_of
=
"tvm_address_of"
;
/*!
/*!
* \brief Same as select, used for unsafe memory access.
*
* Type tvm_if_then_else(cond, a, b) {
* return cond ? a : b;
* }
*/
constexpr
const
char
*
tvm_if_then_else
=
"tvm_if_then_else"
;
/*!
* \brief Get head access address with memory access pattern info.
* \brief Get head access address with memory access pattern info.
*
*
* This operator also marks range of the memory access
* This operator also marks range of the memory access
...
...
include/tvm/ir_pass.h
View file @
090468aa
...
@@ -267,6 +267,13 @@ Stmt CoProcSync(Stmt stmt);
...
@@ -267,6 +267,13 @@ Stmt CoProcSync(Stmt stmt);
Stmt
LiftAttrScope
(
Stmt
stmt
,
std
::
string
attr_key
);
Stmt
LiftAttrScope
(
Stmt
stmt
,
std
::
string
attr_key
);
/*!
/*!
* \brief Detect and rewrite unsafe select that contains memory access.
* \param stmt The statment to be rewritten.
* \return Transformed stmt.
*/
Stmt
RewriteUnsafeSelect
(
Stmt
stmt
);
/*!
* \brief Lower attached storage access information.
* \brief Lower attached storage access information.
* Do this pass after all storage access analysis finish.
* Do this pass after all storage access analysis finish.
*
*
...
...
python/tvm/build_module.py
View file @
090468aa
...
@@ -211,6 +211,7 @@ def lower(sch,
...
@@ -211,6 +211,7 @@ def lower(sch,
stmt
=
ir_pass
.
Simplify
(
stmt
)
stmt
=
ir_pass
.
Simplify
(
stmt
)
stmt
=
ir_pass
.
LowerStorageAccessInfo
(
stmt
)
stmt
=
ir_pass
.
LowerStorageAccessInfo
(
stmt
)
stmt
=
ir_pass
.
RemoveNoOp
(
stmt
)
stmt
=
ir_pass
.
RemoveNoOp
(
stmt
)
stmt
=
ir_pass
.
RewriteUnsafeSelect
(
stmt
)
if
simple_mode
:
if
simple_mode
:
return
stmt
return
stmt
return
ir_pass
.
MakeAPI
(
stmt
,
name
,
arg_list
,
0
,
cfg
.
restricted_func
)
return
ir_pass
.
MakeAPI
(
stmt
,
name
,
arg_list
,
0
,
cfg
.
restricted_func
)
...
...
src/api/api_pass.cc
View file @
090468aa
...
@@ -85,6 +85,7 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
...
@@ -85,6 +85,7 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
REGISTER_PASS1
(
ConvertSSA
);
REGISTER_PASS1
(
ConvertSSA
);
REGISTER_PASS1
(
VerifySSA
);
REGISTER_PASS1
(
VerifySSA
);
REGISTER_PASS1
(
RewriteUnsafeSelect
);
REGISTER_PASS4
(
Inline
);
REGISTER_PASS4
(
Inline
);
REGISTER_PASS3
(
StorageFlatten
);
REGISTER_PASS3
(
StorageFlatten
);
REGISTER_PASS1
(
VectorizeLoop
);
REGISTER_PASS1
(
VectorizeLoop
);
...
...
src/codegen/codegen_c.cc
View file @
090468aa
...
@@ -482,6 +482,14 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
...
@@ -482,6 +482,14 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
PrintBinaryIntrinsitc
(
op
,
" << "
,
os
,
this
);
PrintBinaryIntrinsitc
(
op
,
" << "
,
os
,
this
);
}
else
if
(
op
->
is_intrinsic
(
Call
::
shift_right
))
{
}
else
if
(
op
->
is_intrinsic
(
Call
::
shift_right
))
{
PrintBinaryIntrinsitc
(
op
,
" >> "
,
os
,
this
);
PrintBinaryIntrinsitc
(
op
,
" >> "
,
os
,
this
);
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_if_then_else
))
{
os
<<
"("
;
PrintExpr
(
op
->
args
[
0
],
os
);
os
<<
" ? "
;
PrintExpr
(
op
->
args
[
1
],
os
);
os
<<
" : "
;
PrintExpr
(
op
->
args
[
2
],
os
);
os
<<
")"
;
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_address_of
))
{
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_address_of
))
{
const
Load
*
l
=
op
->
args
[
0
].
as
<
Load
>
();
const
Load
*
l
=
op
->
args
[
0
].
as
<
Load
>
();
CHECK
(
op
->
args
.
size
()
==
1
&&
l
);
CHECK
(
op
->
args
.
size
()
==
1
&&
l
);
...
...
src/codegen/llvm/codegen_llvm.cc
View file @
090468aa
...
@@ -1028,6 +1028,31 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
...
@@ -1028,6 +1028,31 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
llvm
::
Value
*
ptr
=
MakeValue
(
op
->
args
[
0
]);
llvm
::
Value
*
ptr
=
MakeValue
(
op
->
args
[
0
]);
return
builder_
->
CreateICmpEQ
(
return
builder_
->
CreateICmpEQ
(
ptr
,
llvm
::
Constant
::
getNullValue
(
ptr
->
getType
()));
ptr
,
llvm
::
Constant
::
getNullValue
(
ptr
->
getType
()));
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_if_then_else
))
{
using
llvm
::
BasicBlock
;
CHECK_EQ
(
op
->
args
.
size
(),
3U
);
llvm
::
Value
*
cond
=
MakeValue
(
op
->
args
[
0
]);
BasicBlock
*
then_block
=
BasicBlock
::
Create
(
*
ctx_
,
"if_then"
,
function_
);
BasicBlock
*
else_block
=
BasicBlock
::
Create
(
*
ctx_
,
"if_else"
,
function_
);
BasicBlock
*
end_block
=
BasicBlock
::
Create
(
*
ctx_
,
"if_end"
,
function_
);
builder_
->
CreateCondBr
(
cond
,
then_block
,
else_block
);
// Then
builder_
->
SetInsertPoint
(
then_block
);
llvm
::
Value
*
then_value
=
MakeValue
(
op
->
args
[
1
]);
builder_
->
CreateBr
(
end_block
);
builder_
->
SetInsertPoint
(
else_block
);
// else
llvm
::
Value
*
else_value
=
MakeValue
(
op
->
args
[
2
]);
builder_
->
CreateBr
(
end_block
);
builder_
->
SetInsertPoint
(
end_block
);
// phi
llvm
::
PHINode
*
phi
=
builder_
->
CreatePHI
(
then_value
->
getType
(),
2
);
phi
->
addIncoming
(
then_value
,
then_block
);
phi
->
addIncoming
(
else_value
,
else_block
);
return
phi
;
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_struct_get
))
{
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_struct_get
))
{
CHECK_EQ
(
op
->
args
.
size
(),
3U
);
CHECK_EQ
(
op
->
args
.
size
(),
3U
);
int
kind
=
op
->
args
[
2
].
as
<
IntImm
>
()
->
value
;
int
kind
=
op
->
args
[
2
].
as
<
IntImm
>
()
->
value
;
...
...
src/pass/rewrite_unsafe_select.cc
0 → 100644
View file @
090468aa
/*!
* Copyright (c) 2017 by Contributors
* \file unsafe_select_rewrite.cc
* \brief Rewrite uinsafe select expression.
*/
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
namespace
tvm
{
namespace
ir
{
// For now, rewrite unsafe select expression to if_then_else
// TODO(tqchen) pattern matching to support masked load
class
UnsafeExprDetector
:
public
ExprFunctor
<
bool
(
const
Expr
&
n
)
>
{
public
:
// select itself is always considered safe if condition is safe
// Because we will issue guard to make sure it is.
bool
VisitExpr_
(
const
Select
*
op
)
{
return
VisitExpr
(
op
->
condition
);
}
bool
VisitExpr_
(
const
Call
*
op
)
{
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_if_then_else
))
{
return
VisitExpr
(
op
->
args
[
0
]);
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_address_of
))
{
const
Load
*
l
=
op
->
args
[
0
].
as
<
Load
>
();
return
this
->
VisitExpr
(
l
->
index
);
}
else
if
(
op
->
is_pure
())
{
for
(
Expr
e
:
op
->
args
)
{
if
(
VisitExpr
(
e
))
return
true
;
}
return
false
;
}
else
{
return
true
;
}
}
bool
VisitExpr_
(
const
Load
*
op
)
{
// Load is considered unsafe.
return
true
;
}
bool
VisitExpr_
(
const
Add
*
op
)
final
{
return
BinaryOp
(
op
);
}
bool
VisitExpr_
(
const
Sub
*
op
)
final
{
return
BinaryOp
(
op
);
}
bool
VisitExpr_
(
const
Mul
*
op
)
final
{
return
BinaryOp
(
op
);
}
bool
VisitExpr_
(
const
Div
*
op
)
final
{
return
BinaryOp
(
op
);
}
bool
VisitExpr_
(
const
Mod
*
op
)
final
{
return
BinaryOp
(
op
);
}
bool
VisitExpr_
(
const
Min
*
op
)
final
{
return
BinaryOp
(
op
);
}
bool
VisitExpr_
(
const
Max
*
op
)
final
{
return
BinaryOp
(
op
);
}
bool
VisitExpr_
(
const
EQ
*
op
)
final
{
return
BinaryOp
(
op
);
}
bool
VisitExpr_
(
const
NE
*
op
)
final
{
return
BinaryOp
(
op
);
}
bool
VisitExpr_
(
const
LT
*
op
)
final
{
return
BinaryOp
(
op
);
}
bool
VisitExpr_
(
const
LE
*
op
)
final
{
return
BinaryOp
(
op
);
}
bool
VisitExpr_
(
const
GT
*
op
)
final
{
return
BinaryOp
(
op
);
}
bool
VisitExpr_
(
const
GE
*
op
)
final
{
return
BinaryOp
(
op
);
}
bool
VisitExpr_
(
const
And
*
op
)
final
{
return
BinaryOp
(
op
);
}
bool
VisitExpr_
(
const
Or
*
op
)
final
{
return
BinaryOp
(
op
);
}
bool
VisitExpr_
(
const
Not
*
op
)
final
{
return
VisitExpr
(
op
->
a
);
}
bool
VisitExpr_
(
const
Let
*
op
)
final
{
return
VisitExpr
(
op
->
body
)
&&
VisitExpr
(
op
->
value
);
}
bool
VisitExpr_
(
const
Cast
*
op
)
final
{
return
VisitExpr
(
op
->
value
);
}
bool
VisitExpr_
(
const
Broadcast
*
op
)
final
{
return
VisitExpr
(
op
->
value
);
}
bool
VisitExpr_
(
const
Ramp
*
op
)
final
{
return
VisitExpr
(
op
->
base
)
&&
VisitExpr
(
op
->
stride
);
}
bool
VisitExpr_
(
const
Shuffle
*
op
)
final
{
for
(
Expr
e
:
op
->
vectors
)
{
if
(
VisitExpr
(
e
))
return
true
;
}
return
false
;
}
bool
VisitExpr_
(
const
Variable
*
op
)
final
{
return
false
;
}
bool
VisitExpr_
(
const
IntImm
*
op
)
final
{
return
false
;
}
bool
VisitExpr_
(
const
FloatImm
*
op
)
final
{
return
false
;
}
bool
VisitExpr_
(
const
StringImm
*
op
)
final
{
return
false
;
}
private
:
template
<
typename
T
>
bool
BinaryOp
(
const
T
*
op
)
{
return
VisitExpr
(
op
->
a
)
&&
VisitExpr
(
op
->
b
);
}
};
class
UnsafeSelectRewriter
:
public
IRMutator
{
public
:
Expr
Mutate_
(
const
Select
*
op
,
const
Expr
&
e
)
{
Expr
expr
=
IRMutator
::
Mutate_
(
op
,
e
);
op
=
expr
.
as
<
Select
>
();
UnsafeExprDetector
unsafe
;
if
(
unsafe
.
VisitExpr
(
op
->
true_value
)
||
unsafe
.
VisitExpr
(
op
->
false_value
))
{
return
Call
::
make
(
op
->
type
,
intrinsic
::
tvm_if_then_else
,
{
op
->
condition
,
op
->
true_value
,
op
->
false_value
},
Call
::
Intrinsic
);
}
else
{
return
expr
;
}
}
};
Stmt
RewriteUnsafeSelect
(
Stmt
stmt
)
{
return
UnsafeSelectRewriter
().
Mutate
(
stmt
);
}
}
// namespace ir
}
// namespace tvm
tests/python/unittest/test_codegen_llvm.py
View file @
090468aa
...
@@ -201,7 +201,28 @@ def test_multiple_func():
...
@@ -201,7 +201,28 @@ def test_multiple_func():
check_llvm
()
check_llvm
()
def
test_llvm_select
():
def
check_llvm
(
n
,
offset
):
if
not
tvm
.
module
.
enabled
(
"llvm"
):
return
A
=
tvm
.
placeholder
((
n
,
),
name
=
'A'
)
C
=
tvm
.
compute
((
n
,),
lambda
i
:
tvm
.
select
(
i
>=
offset
,
A
[
i
],
0.0
),
name
=
'C'
)
s
=
tvm
.
create_schedule
(
C
.
op
)
# build and invoke the kernel.
f
=
tvm
.
build
(
s
,
[
A
,
C
],
"llvm"
)
ctx
=
tvm
.
cpu
(
0
)
# launch the kernel.
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
n
,))
.
astype
(
A
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
empty
((
n
,),
A
.
dtype
,
ctx
)
f
(
a
,
c
)
c_np
=
a
.
asnumpy
()
c_np
[:
offset
]
=
0
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
c_np
)
check_llvm
(
64
,
8
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_llvm_select
()
test_llvm_vadd_pipeline
()
test_llvm_vadd_pipeline
()
test_llvm_add_pipeline
()
test_llvm_add_pipeline
()
test_llvm_intrin
()
test_llvm_intrin
()
...
...
tests/python/unittest/test_pass_rewrite_unsafe_select.py
0 → 100644
View file @
090468aa
import
tvm
def
test_rewrite_select
():
ib
=
tvm
.
ir_builder
.
create
()
A
=
ib
.
allocate
(
"float32"
,
100
,
name
=
"A"
,
scope
=
"global"
)
i
=
tvm
.
var
(
"i"
)
y
=
tvm
.
select
(
i
>
1
,
A
[
i
-
1
],
1.0
)
yy
=
tvm
.
ir_pass
.
RewriteUnsafeSelect
(
tvm
.
make
.
Evaluate
(
y
))
.
value
z
=
tvm
.
select
(
tvm
.
select
(
i
>
1
,
A
[
i
-
1
],
1.0
)
>
0.0
,
A
[
i
],
0.1
)
zz
=
tvm
.
ir_pass
.
RewriteUnsafeSelect
(
tvm
.
make
.
Evaluate
(
z
))
.
value
a
=
tvm
.
select
(
i
>
10
,
y
,
z
)
aa
=
tvm
.
ir_pass
.
RewriteUnsafeSelect
(
tvm
.
make
.
Evaluate
(
a
))
.
value
assert
yy
.
name
==
"tvm_if_then_else"
assert
zz
.
name
==
"tvm_if_then_else"
assert
isinstance
(
aa
,
tvm
.
expr
.
Select
)
if
__name__
==
"__main__"
:
test_rewrite_select
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment