Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
cfdc5119
Commit
cfdc5119
authored
Mar 28, 2018
by
Lianmin Zheng
Committed by
Tianqi Chen
Mar 28, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
delete init part when keeping trivial loop (#1031)
parent
ca9ec009
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
66 additions
and
56 deletions
+66
-56
include/tvm/operation.h
+6
-6
include/tvm/schedule_pass.h
+5
-2
src/api/api_schedule.cc
+1
-1
src/codegen/build_module.cc
+1
-1
src/op/compute_op.cc
+15
-10
src/op/compute_op.h
+6
-6
src/op/cross_thread_reduction.cc
+2
-2
src/op/extern_op.cc
+1
-1
src/op/op_util.cc
+3
-3
src/op/op_util.h
+2
-2
src/op/placeholder_op.cc
+1
-1
src/op/scan_op.cc
+2
-2
src/op/tensorize.cc
+2
-2
src/schedule/schedule_ops.cc
+19
-17
No files found.
include/tvm/operation.h
View file @
cfdc5119
...
...
@@ -117,13 +117,13 @@ class OperationNode : public FunctionBaseNode {
* \brief Build the statement that provide the output tensors.
* \param stage The schedule stage of the op.
* \param dom_map The domain map of all iteration domains.
* \param de
l_trivial_loop Whether eliminate trivial loop
with extent of 1
* \param de
bug_keep_trivial_loop Whether keep trivial loops
with extent of 1
* \return A statement that add production and wraps consumer.
*/
virtual
Stmt
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
=
0
;
bool
de
bug_keep
_trivial_loop
)
const
=
0
;
static
constexpr
const
char
*
_type_key
=
"Operation"
;
...
...
@@ -163,7 +163,7 @@ class PlaceholderOpNode : public OperationNode {
Stmt
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
final
;
bool
de
bug_keep
_trivial_loop
)
const
final
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
...
...
@@ -215,7 +215,7 @@ class ComputeOpNode : public OperationNode {
Stmt
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
final
;
bool
de
bug_keep
_trivial_loop
)
const
final
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
...
...
@@ -287,7 +287,7 @@ class ScanOpNode : public OperationNode {
Stmt
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
final
;
bool
de
bug_keep
_trivial_loop
)
const
final
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
...
...
@@ -351,7 +351,7 @@ class ExternOpNode : public OperationNode {
Stmt
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
final
;
bool
de
bug_keep
_trivial_loop
)
const
final
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
...
...
include/tvm/schedule_pass.h
View file @
cfdc5119
...
...
@@ -29,10 +29,13 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \param del_trivial_loop Whether delete trivial loops with extent of 1
* \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 during lowering.
* This is a debug feature for dataflow/axis analysis.
* Note: If this is true, The lowered IR may be incorrect,
* because we will also delete the init part of reduction
* \return the result Stmt
*/
Stmt
ScheduleOps
(
Schedule
s
,
Map
<
IterVar
,
Range
>
dom_map
,
bool
de
l
_trivial_loop
);
Stmt
ScheduleOps
(
Schedule
s
,
Map
<
IterVar
,
Range
>
dom_map
,
bool
de
bug_keep
_trivial_loop
);
/*!
* \brief To automatically inline the element-wise operations.
...
...
src/api/api_schedule.cc
View file @
cfdc5119
...
...
@@ -27,7 +27,7 @@ TVM_REGISTER_API("schedule.AutoInlineInjective")
TVM_REGISTER_API
(
"schedule.ScheduleOps"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
.
size
()
==
2
)
*
ret
=
ScheduleOps
(
args
[
0
],
args
[
1
],
tru
e
);
*
ret
=
ScheduleOps
(
args
[
0
],
args
[
1
],
fals
e
);
else
*
ret
=
ScheduleOps
(
args
[
0
],
args
[
1
],
args
[
2
]);
});
...
...
src/codegen/build_module.cc
View file @
cfdc5119
...
...
@@ -349,7 +349,7 @@ Stmt BuildStmt(Schedule sch,
// Phase 0
auto
bounds
=
schedule
::
InferBound
(
sch
);
auto
stmt
=
schedule
::
ScheduleOps
(
sch
,
bounds
,
tru
e
);
auto
stmt
=
schedule
::
ScheduleOps
(
sch
,
bounds
,
fals
e
);
stmt
=
ir
::
InjectPrefetch
(
stmt
);
// Phase 1
...
...
src/op/compute_op.cc
View file @
cfdc5119
...
...
@@ -296,9 +296,9 @@ Stmt MakeProvide(const ComputeOpNode* op,
Stmt
MakeComputeStmt
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
{
bool
de
bug_keep
_trivial_loop
)
{
// grab the nest structure
ComputeLoopNest
n
=
ComputeLoopNest
::
make
(
self
,
stage
,
dom_map
,
de
l
_trivial_loop
);
ComputeLoopNest
n
=
ComputeLoopNest
::
make
(
self
,
stage
,
dom_map
,
de
bug_keep
_trivial_loop
);
// Normal loop structure
n
.
init_nest
.
emplace_back
(
op
::
MakeIfNest
(
n
.
init_predicates
));
n
.
main_nest
.
emplace_back
(
op
::
MakeIfNest
(
n
.
main_predicates
));
...
...
@@ -319,7 +319,11 @@ Stmt MakeComputeStmt(const ComputeOpNode* self,
n
.
main_nest
.
begin
()
+
n
.
num_common_loop
+
1
,
n
.
main_nest
.
end
());
provide
=
op
::
Substitute
(
provide
,
n
.
main_vmap
);
provide
=
MergeNest
(
reduce
,
provide
);
return
MergeNest
(
common
,
Block
::
make
(
init
,
provide
));
if
(
debug_keep_trivial_loop
)
{
return
MergeNest
(
common
,
provide
);
}
else
{
return
MergeNest
(
common
,
Block
::
make
(
init
,
provide
));
}
}
else
{
std
::
vector
<
Stmt
>
provides
;
for
(
size_t
i
=
0
;
i
<
self
->
body
.
size
();
++
i
)
{
...
...
@@ -379,16 +383,16 @@ ComputeType DetectComputeType(const ComputeOpNode* self,
Stmt
ComputeOpNode
::
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
{
bool
de
bug_keep
_trivial_loop
)
const
{
CHECK_EQ
(
stage
->
op
.
operator
->
(),
this
);
ComputeType
ctype
=
DetectComputeType
(
this
,
stage
);
if
(
ctype
==
ComputeType
::
kCrossThreadReduction
)
{
// specially handle cross thread reduction.
return
MakeCrossThreadReduction
(
this
,
stage
,
dom_map
,
de
l
_trivial_loop
);
return
MakeCrossThreadReduction
(
this
,
stage
,
dom_map
,
de
bug_keep
_trivial_loop
);
}
else
if
(
ctype
==
ComputeType
::
kTensorize
)
{
return
MakeTensorize
(
this
,
stage
,
dom_map
,
de
l
_trivial_loop
);
return
MakeTensorize
(
this
,
stage
,
dom_map
,
de
bug_keep
_trivial_loop
);
}
else
{
return
MakeComputeStmt
(
this
,
stage
,
dom_map
,
de
l
_trivial_loop
);
return
MakeComputeStmt
(
this
,
stage
,
dom_map
,
de
bug_keep
_trivial_loop
);
}
}
...
...
@@ -396,12 +400,13 @@ ComputeLoopNest ComputeLoopNest::make(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
{
bool
de
bug_keep
_trivial_loop
)
{
CHECK_EQ
(
stage
->
op
.
operator
->
(),
self
);
ComputeLoopNest
ret
;
// make main loop nest
ret
.
main_nest
=
op
::
MakeLoopNest
(
stage
,
dom_map
,
0
,
false
,
std
::
unordered_set
<
IterVar
>
(),
&
ret
.
main_vmap
,
del_trivial_loop
);
stage
,
dom_map
,
0
,
false
,
std
::
unordered_set
<
IterVar
>
(),
&
ret
.
main_vmap
,
debug_keep_trivial_loop
);
ret
.
main_predicates
=
schedule
::
MakeBoundCheck
(
stage
,
dom_map
,
ret
.
main_vmap
,
false
,
std
::
unordered_set
<
IterVar
>
());
...
...
@@ -443,7 +448,7 @@ ComputeLoopNest ComputeLoopNest::make(
}
ret
.
init_nest
=
op
::
MakeLoopNest
(
stage
,
dom_map
,
begin_loop
,
true
,
skip_iter
,
&
(
ret
.
init_vmap
),
de
l
_trivial_loop
);
skip_iter
,
&
(
ret
.
init_vmap
),
de
bug_keep
_trivial_loop
);
ret
.
init_predicates
=
schedule
::
MakeBoundCheck
(
stage
,
dom_map
,
ret
.
init_vmap
,
true
,
skip_iter
);
for
(
auto
&
e
:
ret
.
init_predicates
)
{
...
...
src/op/compute_op.h
View file @
cfdc5119
...
...
@@ -37,14 +37,14 @@ struct ComputeLoopNest {
* \param self The pointer to compute op.
* \param stage The scxhedule stage.
* \param dom_map The domain map.
* \param de
l_trivial_loop Whether eliminate
trivial loops with extent of 1
* \param de
bug_keep_trivial_loop Whether keep
trivial loops with extent of 1
* \return The constructed loop nest
*/
static
ComputeLoopNest
make
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
);
bool
de
bug_keep
_trivial_loop
);
};
/*!
...
...
@@ -52,27 +52,27 @@ struct ComputeLoopNest {
* \param self The pointer to ComputeOpNode
* \param stage The schedule stage.
* \param dom_map The domain map.
* \param de
l_trivial_loop Wheter eliminate
trivial loops with extent of 1
* \param de
bug_keep_trivial_loop Whether keep
trivial loops with extent of 1
* \return The created statement.
*/
Stmt
MakeCrossThreadReduction
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
);
bool
de
bug_keep
_trivial_loop
);
/*!
* \brief Build body of compute for tensorization.
* \param self The pointer to ComputeOpNode
* \param stage The schedule stage.
* \param dom_map The domain map.
* \param de
l_trivial_loop Wheter eliminate
trivial loops with extent of 1
* \param de
bug_keep_trivial_loop Whether keep
trivial loops with extent of 1
* \return The created statement.
*/
Stmt
MakeTensorize
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
);
bool
de
bug_keep
_trivial_loop
);
}
// namespace tvm
#endif // TVM_OP_COMPUTE_OP_H_
src/op/cross_thread_reduction.cc
View file @
cfdc5119
...
...
@@ -14,14 +14,14 @@ Stmt MakeCrossThreadReduction(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
{
bool
de
bug_keep
_trivial_loop
)
{
Array
<
Expr
>
args
;
for
(
IterVar
iv
:
self
->
axis
)
{
args
.
push_back
(
iv
->
var
);
}
std
::
unordered_map
<
IterVar
,
Expr
>
value_map
;
auto
nest
=
op
::
MakeLoopNest
(
stage
,
dom_map
,
0
,
false
,
std
::
unordered_set
<
IterVar
>
(),
&
value_map
,
de
l
_trivial_loop
);
stage
,
dom_map
,
0
,
false
,
std
::
unordered_set
<
IterVar
>
(),
&
value_map
,
de
bug_keep
_trivial_loop
);
auto
conds
=
schedule
::
MakeBoundCheck
(
stage
,
dom_map
,
value_map
,
false
,
std
::
unordered_set
<
IterVar
>
());
...
...
src/op/extern_op.cc
View file @
cfdc5119
...
...
@@ -129,7 +129,7 @@ Stmt ExternOpNode::BuildRealize(
Stmt
ExternOpNode
::
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
{
bool
de
bug_keep
_trivial_loop
)
const
{
CHECK_EQ
(
stage
->
op
.
operator
->
(),
this
);
Stmt
ret
=
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
attr
::
extern_scope
,
0
,
this
->
body
);
auto
f_push_bind
=
[
&
ret
](
Buffer
buffer
,
Tensor
tensor
)
{
...
...
src/op/op_util.cc
View file @
cfdc5119
...
...
@@ -24,7 +24,7 @@ MakeLoopNest(const Stage& stage,
bool
new_loop_var
,
const
std
::
unordered_set
<
IterVar
>&
skip_iter
,
std
::
unordered_map
<
IterVar
,
Expr
>*
p_value_map
,
bool
de
l
_trivial_loop
)
{
bool
de
bug_keep
_trivial_loop
)
{
auto
leaf_iter_vars
=
stage
->
leaf_iter_vars
;
Stmt
no_op
=
Evaluate
::
make
(
0
);
// create the loop nest
...
...
@@ -76,7 +76,7 @@ MakeLoopNest(const Stage& stage,
AttrStmt
::
make
(
iv
,
ir
::
attr
::
pragma_scope
,
p
,
no_op
));
}
}
if
(
del
_trivial_loop
&&
is_one
(
dom
->
extent
))
{
if
(
!
debug_keep
_trivial_loop
&&
is_one
(
dom
->
extent
))
{
nest
[
i
+
1
].
emplace_back
(
LetStmt
::
make
(
var
,
dom
->
min
,
no_op
));
value_map
[
iv
]
=
dom
->
min
;
...
...
@@ -131,7 +131,7 @@ MakeLoopNest(const Stage& stage,
// annotate the extent of the IterVar
nest
[
i
+
1
].
emplace_back
(
AttrStmt
::
make
(
bind_iv
,
ir
::
attr
::
thread_extent
,
dom
->
extent
,
no_op
));
if
(
del
_trivial_loop
&&
is_one
(
dom
->
extent
))
{
if
(
!
debug_keep
_trivial_loop
&&
is_one
(
dom
->
extent
))
{
value_map
[
iv
]
=
dom
->
min
;
}
else
{
value_map
[
iv
]
=
var
;
...
...
src/op/op_util.h
View file @
cfdc5119
...
...
@@ -29,7 +29,7 @@ using ir::MergeNest;
* \param new_loop_var Whether create new loop variable.
* \param skip_iter Whether skip certain iteration.
* \param p_value_map The result value of each IterVar.
* \param de
l_trivial_loop Whether eliminate
trivial loops with extent of 1
* \param de
bug_keep_trivial_loop Whether keep
trivial loops with extent of 1
*/
std
::
vector
<
std
::
vector
<
Stmt
>
>
MakeLoopNest
(
const
Stage
&
stage
,
...
...
@@ -38,7 +38,7 @@ MakeLoopNest(const Stage& stage,
bool
new_loop_var
,
const
std
::
unordered_set
<
IterVar
>&
skip_iter
,
std
::
unordered_map
<
IterVar
,
Expr
>*
p_value_map
,
bool
de
l
_trivial_loop
);
bool
de
bug_keep
_trivial_loop
);
/*!
* \brief Create a nest of if checking the predicates.
...
...
src/op/placeholder_op.cc
View file @
cfdc5119
...
...
@@ -79,7 +79,7 @@ Stmt PlaceholderOpNode::BuildRealize(
Stmt
PlaceholderOpNode
::
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
{
bool
de
bug_keep
_trivial_loop
)
const
{
return
Stmt
();
}
}
// namespace tvm
src/op/scan_op.cc
View file @
cfdc5119
...
...
@@ -253,7 +253,7 @@ Stmt ScanOpNode::BuildRealize(
Stmt
ScanOpNode
::
BuildProvide
(
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
{
bool
de
bug_keep
_trivial_loop
)
const
{
CHECK_EQ
(
stage
->
op
.
operator
->
(),
this
);
Stmt
provide
=
AttrStmt
::
make
(
stage
->
op
,
attr
::
scan_update_scope
,
this
->
scan_axis
->
var
,
...
...
@@ -271,7 +271,7 @@ Stmt ScanOpNode::BuildProvide(
std
::
unordered_map
<
IterVar
,
Expr
>
vmap
;
std
::
unordered_set
<
IterVar
>
empty
;
auto
nest
=
op
::
MakeLoopNest
(
stage
,
dom_map
,
0
,
false
,
empty
,
&
vmap
,
de
l
_trivial_loop
);
stage
,
dom_map
,
0
,
false
,
empty
,
&
vmap
,
de
bug_keep
_trivial_loop
);
nest
[
begin_scan
].
push_back
(
init
);
nest
.
push_back
(
op
::
MakeIfNest
(
...
...
src/op/tensorize.cc
View file @
cfdc5119
...
...
@@ -370,14 +370,14 @@ Stmt TransformUpdate(const Stage& stage,
Stmt
MakeTensorize
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
{
bool
de
bug_keep
_trivial_loop
)
{
std
::
unordered_map
<
IterVar
,
Range
>
out_dom
;
std
::
unordered_map
<
Tensor
,
Array
<
Range
>
>
in_region
;
size_t
tloc
=
InferTensorizeRegion
(
self
,
stage
,
dom_map
,
&
out_dom
,
&
in_region
);
TensorIntrin
intrin
=
stage
->
iter_var_attrs
.
at
(
stage
->
leaf_iter_vars
[
tloc
])
->
tensor_intrin
;
CHECK
(
intrin
.
defined
());
ComputeLoopNest
n
=
ComputeLoopNest
::
make
(
self
,
stage
,
dom_map
,
de
l
_trivial_loop
);
ComputeLoopNest
n
=
ComputeLoopNest
::
make
(
self
,
stage
,
dom_map
,
de
bug_keep
_trivial_loop
);
VerifyTensorizeLoopNest
(
self
,
stage
,
n
,
tloc
);
VerifyTensorizeBody
(
self
,
stage
,
out_dom
,
in_region
,
intrin
);
// Start bind data.
...
...
src/schedule/schedule_ops.cc
View file @
cfdc5119
...
...
@@ -23,8 +23,8 @@ using namespace ir;
Stmt
MakePipeline
(
const
Stage
&
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
Stmt
consumer
,
bool
de
l
_trivial_loop
)
{
Stmt
producer
=
s
->
op
->
BuildProvide
(
s
,
dom_map
,
de
l
_trivial_loop
);
bool
de
bug_keep
_trivial_loop
)
{
Stmt
producer
=
s
->
op
->
BuildProvide
(
s
,
dom_map
,
de
bug_keep
_trivial_loop
);
if
(
producer
.
defined
())
{
producer
=
ProducerConsumer
::
make
(
s
->
op
,
true
,
producer
);
}
...
...
@@ -58,9 +58,9 @@ class InjectAttach : public IRMutator {
InjectAttach
(
const
Stage
&
stage
,
const
Stage
&
attach_spec
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
bool
de
bug_keep
_trivial_loop
)
:
stage_
(
stage
),
attach_spec_
(
attach_spec
),
dom_map_
(
dom_map
),
de
l_trivial_loop_
(
del
_trivial_loop
)
{}
de
bug_keep_trivial_loop_
(
debug_keep
_trivial_loop
)
{}
Stmt
Mutate
(
Stmt
stmt
)
final
{
CHECK
(
stmt
.
defined
());
...
...
@@ -76,7 +76,7 @@ class InjectAttach : public IRMutator {
found_attach
=
true
;
stmt
=
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
op
->
value
,
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
,
de
l
_trivial_loop_
));
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
,
de
bug_keep
_trivial_loop_
));
}
}
return
stmt
;
...
...
@@ -91,8 +91,9 @@ class InjectAttach : public IRMutator {
const
Stage
&
attach_spec_
;
// domain map
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map_
;
// whether delete trivial loops with extent of 1
bool
del_trivial_loop_
;
// Whether keep trivial loops with extent of 1 during lowering.
// This is a debug feature for dataflow/axis analysis
bool
debug_keep_trivial_loop_
;
};
// inject the operator's realization on the stmt.
...
...
@@ -102,9 +103,9 @@ class InjectScanStep : public IRMutator {
const
Operation
&
scan_op
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
is_init
,
bool
de
l
_trivial_loop
)
bool
de
bug_keep
_trivial_loop
)
:
stage_
(
stage
),
scan_op_
(
scan_op
),
dom_map_
(
dom_map
),
is_init_
(
is_init
),
de
l_trivial_loop_
(
del
_trivial_loop
)
{}
dom_map_
(
dom_map
),
is_init_
(
is_init
),
de
bug_keep_trivial_loop_
(
debug_keep
_trivial_loop
)
{}
Stmt
Mutate
(
Stmt
stmt
)
final
{
CHECK
(
stmt
.
defined
());
...
...
@@ -118,7 +119,7 @@ class InjectScanStep : public IRMutator {
found_attach
=
true
;
stmt
=
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
op
->
value
,
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
,
de
l
_trivial_loop_
));
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
,
de
bug_keep
_trivial_loop_
));
}
}
return
stmt
;
...
...
@@ -135,8 +136,9 @@ class InjectScanStep : public IRMutator {
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map_
;
// whether it is init.
bool
is_init_
;
// whether delete trivial loops with extent of 1
bool
del_trivial_loop_
;
// Whether keep trivial loops with extent of 1 during lowering.
// This is a debug feature for dataflow/axis analysis
bool
debug_keep_trivial_loop_
;
};
// Postprocessing of schedule op
...
...
@@ -337,7 +339,7 @@ class SchedulePostProc : public IRMutator {
};
Stmt
ScheduleOps
(
Schedule
sch
,
Map
<
IterVar
,
Range
>
dom_map_
,
bool
de
l
_trivial_loop
)
{
Schedule
sch
,
Map
<
IterVar
,
Range
>
dom_map_
,
bool
de
bug_keep
_trivial_loop
)
{
Stmt
body
=
Stmt
();
std
::
unordered_map
<
IterVar
,
Range
>
dom_map
=
as_unordered_map
(
dom_map_
);
// scan init and scan updates
...
...
@@ -372,14 +374,14 @@ Stmt ScheduleOps(
if
(
scan_init
.
count
(
s
->
op
))
{
CHECK
(
body
.
defined
());
InjectScanStep
mu
(
s
,
scan_init
.
at
(
s
->
op
),
dom_map
,
true
,
de
l
_trivial_loop
);
InjectScanStep
mu
(
s
,
scan_init
.
at
(
s
->
op
),
dom_map
,
true
,
de
bug_keep
_trivial_loop
);
body
=
mu
.
Mutate
(
body
);
CHECK
(
mu
.
found_attach
)
<<
"did not find attachment point for scan.init"
;
}
else
if
(
attach_spec
->
attach_type
==
kScanUpdate
)
{
// Handle scan update
CHECK
(
body
.
defined
());
InjectScanStep
mu
(
s
,
attach_spec
->
attach_stage
->
op
,
dom_map
,
false
,
de
l
_trivial_loop
);
InjectScanStep
mu
(
s
,
attach_spec
->
attach_stage
->
op
,
dom_map
,
false
,
de
bug_keep
_trivial_loop
);
body
=
mu
.
Mutate
(
body
);
CHECK
(
mu
.
found_attach
)
<<
"did not find attachment point for scan.update"
;
...
...
@@ -387,11 +389,11 @@ Stmt ScheduleOps(
// do nothing
}
else
if
(
attach_spec
->
attach_type
==
kGroupRoot
)
{
CHECK
(
!
s
->
group
.
defined
());
body
=
MakePipeline
(
s
,
dom_map
,
body
,
de
l
_trivial_loop
);
body
=
MakePipeline
(
s
,
dom_map
,
body
,
de
bug_keep
_trivial_loop
);
}
else
{
CHECK_EQ
(
attach_spec
->
attach_type
,
kScope
);
CHECK
(
body
.
defined
());
InjectAttach
mutator
(
s
,
attach_spec
,
dom_map
,
de
l
_trivial_loop
);
InjectAttach
mutator
(
s
,
attach_spec
,
dom_map
,
de
bug_keep
_trivial_loop
);
body
=
mutator
.
Mutate
(
body
);
CHECK
(
mutator
.
found_attach
)
<<
"did not find attachment point for "
<<
s
<<
" in "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment