Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d56c777a
Commit
d56c777a
authored
Feb 07, 2018
by
Lianmin Zheng
Committed by
Tianqi Chen
Feb 06, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
support to keep trivial loops with extent of 1 (#877)
parent
b21aee7d
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
70 additions
and
38 deletions
+70
-38
include/tvm/operation.h
+11
-5
include/tvm/schedule_pass.h
+2
-1
src/api/api_schedule.cc
+8
-1
src/codegen/build_module.cc
+1
-1
src/op/compute_op.cc
+12
-9
src/op/compute_op.h
+9
-3
src/op/cross_thread_reduction.cc
+3
-2
src/op/extern_op.cc
+2
-1
src/op/op_util.cc
+4
-3
src/op/op_util.h
+3
-1
src/op/placeholder_op.cc
+2
-1
src/op/scan_op.cc
+3
-2
src/op/tensorize.cc
+3
-2
src/schedule/schedule_ops.cc
+7
-6
No files found.
include/tvm/operation.h
View file @
d56c777a
...
...
@@ -117,11 +117,13 @@ class OperationNode : public FunctionBaseNode {
* \brief Build the statement that provide the output tensors.
* \param stage The schedule stage of the op.
* \param dom_map The domain map of all iteration domains.
* \param del_trivial_loop Whether eliminate trivial loop with extent of 1
* \return A statement that add production and wraps consumer.
*/
virtual
Stmt
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
const
=
0
;
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
del_trivial_loop
)
const
=
0
;
static
constexpr
const
char
*
_type_key
=
"Operation"
;
...
...
@@ -160,7 +162,8 @@ class PlaceholderOpNode : public OperationNode {
const
Stmt
&
body
)
const
final
;
Stmt
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
const
final
;
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
del_trivial_loop
)
const
final
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
...
...
@@ -211,7 +214,8 @@ class ComputeOpNode : public OperationNode {
const
Stmt
&
body
)
const
final
;
Stmt
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
const
final
;
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
del_trivial_loop
)
const
final
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
...
...
@@ -282,7 +286,8 @@ class ScanOpNode : public OperationNode {
const
Stmt
&
body
)
const
final
;
Stmt
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
const
final
;
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
del_trivial_loop
)
const
final
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
...
...
@@ -345,7 +350,8 @@ class ExternOpNode : public OperationNode {
const
Stmt
&
body
)
const
final
;
Stmt
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
const
final
;
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
del_trivial_loop
)
const
final
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
...
...
include/tvm/schedule_pass.h
View file @
d56c777a
...
...
@@ -29,9 +29,10 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \param del_trivial_loop Whether delete trivial loops with extent of 1
* \return the result Stmt
*/
Stmt
ScheduleOps
(
Schedule
s
,
Map
<
IterVar
,
Range
>
dom_map
);
Stmt
ScheduleOps
(
Schedule
s
,
Map
<
IterVar
,
Range
>
dom_map
,
bool
del_trivial_loop
);
/*!
* \brief To automatically inline the element-wise operations.
...
...
src/api/api_schedule.cc
View file @
d56c777a
...
...
@@ -24,6 +24,14 @@ TVM_REGISTER_API("schedule.AutoInlineInjective")
AutoInlineInjective
(
args
[
0
]);
});
TVM_REGISTER_API
(
"schedule.ScheduleOps"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
.
size
()
==
2
)
*
ret
=
ScheduleOps
(
args
[
0
],
args
[
1
],
true
);
else
*
ret
=
ScheduleOps
(
args
[
0
],
args
[
1
],
args
[
2
]);
});
#define REGISTER_SCHEDULE_PASS1(PassName) \
TVM_REGISTER_API("schedule."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
...
...
@@ -43,7 +51,6 @@ REGISTER_SCHEDULE_PASS2(PostDFSOrder);
REGISTER_SCHEDULE_PASS1
(
CreateAttachPath
);
REGISTER_SCHEDULE_PASS1
(
ScanGetBody
);
REGISTER_SCHEDULE_PASS1
(
ScanFixPointAnalysis
);
REGISTER_SCHEDULE_PASS2
(
ScheduleOps
);
}
// namespace schedule
}
// namespace tvm
src/codegen/build_module.cc
View file @
d56c777a
...
...
@@ -211,7 +211,7 @@ Stmt BuildStmt(Schedule sch,
// Phase 0
auto
bounds
=
schedule
::
InferBound
(
sch
);
auto
stmt
=
schedule
::
ScheduleOps
(
sch
,
bounds
);
auto
stmt
=
schedule
::
ScheduleOps
(
sch
,
bounds
,
true
);
stmt
=
ir
::
InjectPrefetch
(
stmt
);
// Phase 1
...
...
src/op/compute_op.cc
View file @
d56c777a
...
...
@@ -305,9 +305,10 @@ Stmt MakeProvide(const ComputeOpNode* op,
Stmt
MakeComputeStmt
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
{
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
del_trivial_loop
)
{
// grab the nest structure
ComputeLoopNest
n
=
ComputeLoopNest
::
make
(
self
,
stage
,
dom_map
);
ComputeLoopNest
n
=
ComputeLoopNest
::
make
(
self
,
stage
,
dom_map
,
del_trivial_loop
);
// Normal loop structure
n
.
init_nest
.
emplace_back
(
op
::
MakeIfNest
(
n
.
init_predicates
));
n
.
main_nest
.
emplace_back
(
op
::
MakeIfNest
(
n
.
main_predicates
));
...
...
@@ -387,28 +388,30 @@ ComputeType DetectComputeType(const ComputeOpNode* self,
// implement the provide utility.
Stmt
ComputeOpNode
::
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
const
{
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
del_trivial_loop
)
const
{
CHECK_EQ
(
stage
->
op
.
operator
->
(),
this
);
ComputeType
ctype
=
DetectComputeType
(
this
,
stage
);
if
(
ctype
==
ComputeType
::
kCrossThreadReduction
)
{
// specially handle cross thread reduction.
return
MakeCrossThreadReduction
(
this
,
stage
,
dom_map
);
return
MakeCrossThreadReduction
(
this
,
stage
,
dom_map
,
del_trivial_loop
);
}
else
if
(
ctype
==
ComputeType
::
kTensorize
)
{
return
MakeTensorize
(
this
,
stage
,
dom_map
);
return
MakeTensorize
(
this
,
stage
,
dom_map
,
del_trivial_loop
);
}
else
{
return
MakeComputeStmt
(
this
,
stage
,
dom_map
);
return
MakeComputeStmt
(
this
,
stage
,
dom_map
,
del_trivial_loop
);
}
}
ComputeLoopNest
ComputeLoopNest
::
make
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
{
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
del_trivial_loop
)
{
CHECK_EQ
(
stage
->
op
.
operator
->
(),
self
);
ComputeLoopNest
ret
;
// make main loop nest
ret
.
main_nest
=
op
::
MakeLoopNest
(
stage
,
dom_map
,
0
,
false
,
std
::
unordered_set
<
IterVar
>
(),
&
ret
.
main_vmap
);
stage
,
dom_map
,
0
,
false
,
std
::
unordered_set
<
IterVar
>
(),
&
ret
.
main_vmap
,
del_trivial_loop
);
ret
.
main_predicates
=
schedule
::
MakeBoundCheck
(
stage
,
dom_map
,
ret
.
main_vmap
,
false
,
std
::
unordered_set
<
IterVar
>
());
...
...
@@ -450,7 +453,7 @@ ComputeLoopNest ComputeLoopNest::make(
}
ret
.
init_nest
=
op
::
MakeLoopNest
(
stage
,
dom_map
,
begin_loop
,
true
,
skip_iter
,
&
(
ret
.
init_vmap
));
skip_iter
,
&
(
ret
.
init_vmap
)
,
del_trivial_loop
);
ret
.
init_predicates
=
schedule
::
MakeBoundCheck
(
stage
,
dom_map
,
ret
.
init_vmap
,
true
,
skip_iter
);
for
(
auto
&
e
:
ret
.
init_predicates
)
{
...
...
src/op/compute_op.h
View file @
d56c777a
...
...
@@ -37,12 +37,14 @@ struct ComputeLoopNest {
* \param self The pointer to compute op.
* \param stage The scxhedule stage.
* \param dom_map The domain map.
* \param del_trivial_loop Whether eliminate trivial loops with extent of 1
* \return The constructed loop nest
*/
static
ComputeLoopNest
make
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
);
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
del_trivial_loop
);
};
/*!
...
...
@@ -50,23 +52,27 @@ struct ComputeLoopNest {
* \param self The pointer to ComputeOpNode
* \param stage The schedule stage.
* \param dom_map The domain map.
* \param del_trivial_loop Wheter eliminate trivial loops with extent of 1
* \return The created statement.
*/
Stmt
MakeCrossThreadReduction
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
);
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
del_trivial_loop
);
/*!
* \brief Build body of compute for tensorization.
* \param self The pointer to ComputeOpNode
* \param stage The schedule stage.
* \param dom_map The domain map.
* \param del_trivial_loop Wheter eliminate trivial loops with extent of 1
* \return The created statement.
*/
Stmt
MakeTensorize
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
);
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
del_trivial_loop
);
}
// namespace tvm
#endif // TVM_OP_COMPUTE_OP_H_
src/op/cross_thread_reduction.cc
View file @
d56c777a
...
...
@@ -13,14 +13,15 @@ using namespace ir;
Stmt
MakeCrossThreadReduction
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
{
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
del_trivial_loop
)
{
Array
<
Expr
>
args
;
for
(
IterVar
iv
:
self
->
axis
)
{
args
.
push_back
(
iv
->
var
);
}
std
::
unordered_map
<
IterVar
,
Expr
>
value_map
;
auto
nest
=
op
::
MakeLoopNest
(
stage
,
dom_map
,
0
,
false
,
std
::
unordered_set
<
IterVar
>
(),
&
value_map
);
stage
,
dom_map
,
0
,
false
,
std
::
unordered_set
<
IterVar
>
(),
&
value_map
,
del_trivial_loop
);
auto
conds
=
schedule
::
MakeBoundCheck
(
stage
,
dom_map
,
value_map
,
false
,
std
::
unordered_set
<
IterVar
>
());
...
...
src/op/extern_op.cc
View file @
d56c777a
...
...
@@ -128,7 +128,8 @@ Stmt ExternOpNode::BuildRealize(
Stmt
ExternOpNode
::
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
const
{
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
del_trivial_loop
)
const
{
CHECK_EQ
(
stage
->
op
.
operator
->
(),
this
);
Stmt
ret
=
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
attr
::
extern_scope
,
0
,
this
->
body
);
auto
f_push_bind
=
[
&
ret
](
Buffer
buffer
,
Tensor
tensor
)
{
...
...
src/op/op_util.cc
View file @
d56c777a
...
...
@@ -23,7 +23,8 @@ MakeLoopNest(const Stage& stage,
size_t
begin_iter_pos
,
bool
new_loop_var
,
const
std
::
unordered_set
<
IterVar
>&
skip_iter
,
std
::
unordered_map
<
IterVar
,
Expr
>*
p_value_map
)
{
std
::
unordered_map
<
IterVar
,
Expr
>*
p_value_map
,
bool
del_trivial_loop
)
{
auto
leaf_iter_vars
=
stage
->
leaf_iter_vars
;
Stmt
no_op
=
Evaluate
::
make
(
0
);
// create the loop nest
...
...
@@ -75,7 +76,7 @@ MakeLoopNest(const Stage& stage,
AttrStmt
::
make
(
iv
,
ir
::
attr
::
pragma_scope
,
p
,
no_op
));
}
}
if
(
is_one
(
dom
->
extent
))
{
if
(
del_trivial_loop
&&
is_one
(
dom
->
extent
))
{
nest
[
i
+
1
].
emplace_back
(
LetStmt
::
make
(
var
,
dom
->
min
,
no_op
));
value_map
[
iv
]
=
dom
->
min
;
...
...
@@ -130,7 +131,7 @@ MakeLoopNest(const Stage& stage,
// annotate the extent of the IterVar
nest
[
i
+
1
].
emplace_back
(
AttrStmt
::
make
(
bind_iv
,
ir
::
attr
::
thread_extent
,
dom
->
extent
,
no_op
));
if
(
is_one
(
dom
->
extent
))
{
if
(
del_trivial_loop
&&
is_one
(
dom
->
extent
))
{
value_map
[
iv
]
=
dom
->
min
;
}
else
{
value_map
[
iv
]
=
var
;
...
...
src/op/op_util.h
View file @
d56c777a
...
...
@@ -29,6 +29,7 @@ using ir::MergeNest;
* \param new_loop_var Whether create new loop variable.
* \param skip_iter Whether skip certain iteration.
* \param p_value_map The result value of each IterVar.
* \param del_trivial_loop Whether eliminate trivial loops with extent of 1
*/
std
::
vector
<
std
::
vector
<
Stmt
>
>
MakeLoopNest
(
const
Stage
&
stage
,
...
...
@@ -36,7 +37,8 @@ MakeLoopNest(const Stage& stage,
size_t
begin_iter_pos
,
bool
new_loop_var
,
const
std
::
unordered_set
<
IterVar
>&
skip_iter
,
std
::
unordered_map
<
IterVar
,
Expr
>*
p_value_map
);
std
::
unordered_map
<
IterVar
,
Expr
>*
p_value_map
,
bool
del_trivial_loop
);
/*!
* \brief Create a nest of if checking the predicates.
...
...
src/op/placeholder_op.cc
View file @
d56c777a
...
...
@@ -78,7 +78,8 @@ Stmt PlaceholderOpNode::BuildRealize(
Stmt
PlaceholderOpNode
::
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
const
{
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
del_trivial_loop
)
const
{
return
Stmt
();
}
}
// namespace tvm
src/op/scan_op.cc
View file @
d56c777a
...
...
@@ -252,7 +252,8 @@ Stmt ScanOpNode::BuildRealize(
Stmt
ScanOpNode
::
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
const
{
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
del_trivial_loop
)
const
{
CHECK_EQ
(
stage
->
op
.
operator
->
(),
this
);
Stmt
provide
=
AttrStmt
::
make
(
stage
->
op
,
attr
::
scan_update_scope
,
this
->
scan_axis
->
var
,
...
...
@@ -270,7 +271,7 @@ Stmt ScanOpNode::BuildProvide(
std
::
unordered_map
<
IterVar
,
Expr
>
vmap
;
std
::
unordered_set
<
IterVar
>
empty
;
auto
nest
=
op
::
MakeLoopNest
(
stage
,
dom_map
,
0
,
false
,
empty
,
&
vmap
);
stage
,
dom_map
,
0
,
false
,
empty
,
&
vmap
,
del_trivial_loop
);
nest
[
begin_scan
].
push_back
(
init
);
nest
.
push_back
(
op
::
MakeIfNest
(
...
...
src/op/tensorize.cc
View file @
d56c777a
...
...
@@ -369,14 +369,15 @@ Stmt TransformUpdate(const Stage& stage,
Stmt
MakeTensorize
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
)
{
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
del_trivial_loop
)
{
std
::
unordered_map
<
IterVar
,
Range
>
out_dom
;
std
::
unordered_map
<
Tensor
,
Array
<
Range
>
>
in_region
;
size_t
tloc
=
InferTensorizeRegion
(
self
,
stage
,
dom_map
,
&
out_dom
,
&
in_region
);
TensorIntrin
intrin
=
stage
->
iter_var_attrs
.
at
(
stage
->
leaf_iter_vars
[
tloc
])
->
tensor_intrin
;
CHECK
(
intrin
.
defined
());
ComputeLoopNest
n
=
ComputeLoopNest
::
make
(
self
,
stage
,
dom_map
);
ComputeLoopNest
n
=
ComputeLoopNest
::
make
(
self
,
stage
,
dom_map
,
del_trivial_loop
);
VerifyTensorizeLoopNest
(
self
,
stage
,
n
,
tloc
);
VerifyTensorizeBody
(
self
,
stage
,
out_dom
,
in_region
,
intrin
);
// Start bind data.
...
...
src/schedule/schedule_ops.cc
View file @
d56c777a
...
...
@@ -22,8 +22,9 @@ using namespace ir;
Stmt
MakePipeline
(
const
Stage
&
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
Stmt
consumer
)
{
Stmt
producer
=
s
->
op
->
BuildProvide
(
s
,
dom_map
);
Stmt
consumer
,
bool
del_trivial_loop
)
{
Stmt
producer
=
s
->
op
->
BuildProvide
(
s
,
dom_map
,
del_trivial_loop
);
if
(
producer
.
defined
())
{
producer
=
ProducerConsumer
::
make
(
s
->
op
,
true
,
producer
);
}
...
...
@@ -68,7 +69,7 @@ class InjectAttach : public IRMutator {
found_attach
=
true
;
stmt
=
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
op
->
value
,
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
));
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
,
true
));
}
}
return
stmt
;
...
...
@@ -107,7 +108,7 @@ class InjectScanStep : public IRMutator {
found_attach
=
true
;
stmt
=
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
op
->
value
,
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
));
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
,
true
));
}
}
return
stmt
;
...
...
@@ -324,7 +325,7 @@ class SchedulePostProc : public IRMutator {
};
Stmt
ScheduleOps
(
Schedule
sch
,
Map
<
IterVar
,
Range
>
dom_map_
)
{
Schedule
sch
,
Map
<
IterVar
,
Range
>
dom_map_
,
bool
del_trivial_loop
)
{
Stmt
body
=
Stmt
();
std
::
unordered_map
<
IterVar
,
Range
>
dom_map
=
as_unordered_map
(
dom_map_
);
// scan init and scan updates
...
...
@@ -374,7 +375,7 @@ Stmt ScheduleOps(
// do nothing
}
else
if
(
attach_spec
->
attach_type
==
kGroupRoot
)
{
CHECK
(
!
s
->
group
.
defined
());
body
=
MakePipeline
(
s
,
dom_map
,
body
);
body
=
MakePipeline
(
s
,
dom_map
,
body
,
del_trivial_loop
);
}
else
{
CHECK_EQ
(
attach_spec
->
attach_type
,
kScope
);
CHECK
(
body
.
defined
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment