Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
cfdc5119
Commit
cfdc5119
authored
Mar 28, 2018
by
Lianmin Zheng
Committed by
Tianqi Chen
Mar 28, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
delete init part when keeping trivial loop (#1031)
parent
ca9ec009
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
66 additions
and
56 deletions
+66
-56
include/tvm/operation.h
+6
-6
include/tvm/schedule_pass.h
+5
-2
src/api/api_schedule.cc
+1
-1
src/codegen/build_module.cc
+1
-1
src/op/compute_op.cc
+15
-10
src/op/compute_op.h
+6
-6
src/op/cross_thread_reduction.cc
+2
-2
src/op/extern_op.cc
+1
-1
src/op/op_util.cc
+3
-3
src/op/op_util.h
+2
-2
src/op/placeholder_op.cc
+1
-1
src/op/scan_op.cc
+2
-2
src/op/tensorize.cc
+2
-2
src/schedule/schedule_ops.cc
+19
-17
No files found.
include/tvm/operation.h
View file @
cfdc5119
...
@@ -117,13 +117,13 @@ class OperationNode : public FunctionBaseNode {
...
@@ -117,13 +117,13 @@ class OperationNode : public FunctionBaseNode {
* \brief Build the statement that provide the output tensors.
* \brief Build the statement that provide the output tensors.
* \param stage The schedule stage of the op.
* \param stage The schedule stage of the op.
* \param dom_map The domain map of all iteration domains.
* \param dom_map The domain map of all iteration domains.
* \param de
l_trivial_loop Whether eliminate trivial loop
with extent of 1
* \param de
bug_keep_trivial_loop Whether keep trivial loops
with extent of 1
* \return A statement that add production and wraps consumer.
* \return A statement that add production and wraps consumer.
*/
*/
virtual
Stmt
BuildProvide
(
virtual
Stmt
BuildProvide
(
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
=
0
;
bool
de
bug_keep
_trivial_loop
)
const
=
0
;
static
constexpr
const
char
*
_type_key
=
"Operation"
;
static
constexpr
const
char
*
_type_key
=
"Operation"
;
...
@@ -163,7 +163,7 @@ class PlaceholderOpNode : public OperationNode {
...
@@ -163,7 +163,7 @@ class PlaceholderOpNode : public OperationNode {
Stmt
BuildProvide
(
Stmt
BuildProvide
(
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
final
;
bool
de
bug_keep
_trivial_loop
)
const
final
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"name"
,
&
name
);
...
@@ -215,7 +215,7 @@ class ComputeOpNode : public OperationNode {
...
@@ -215,7 +215,7 @@ class ComputeOpNode : public OperationNode {
Stmt
BuildProvide
(
Stmt
BuildProvide
(
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
final
;
bool
de
bug_keep
_trivial_loop
)
const
final
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"name"
,
&
name
);
...
@@ -287,7 +287,7 @@ class ScanOpNode : public OperationNode {
...
@@ -287,7 +287,7 @@ class ScanOpNode : public OperationNode {
Stmt
BuildProvide
(
Stmt
BuildProvide
(
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
final
;
bool
de
bug_keep
_trivial_loop
)
const
final
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"name"
,
&
name
);
...
@@ -351,7 +351,7 @@ class ExternOpNode : public OperationNode {
...
@@ -351,7 +351,7 @@ class ExternOpNode : public OperationNode {
Stmt
BuildProvide
(
Stmt
BuildProvide
(
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
final
;
bool
de
bug_keep
_trivial_loop
)
const
final
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
v
->
Visit
(
"name"
,
&
name
);
...
...
include/tvm/schedule_pass.h
View file @
cfdc5119
...
@@ -29,10 +29,13 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
...
@@ -29,10 +29,13 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
*
*
* \param s The schedule to be realized
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \param dom_map The domain of each iter vars.
* \param del_trivial_loop Whether delete trivial loops with extent of 1
* \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 during lowering.
* This is a debug feature for dataflow/axis analysis.
* Note: If this is true, The lowered IR may be incorrect,
* because we will also delete the init part of reduction
* \return the result Stmt
* \return the result Stmt
*/
*/
Stmt
ScheduleOps
(
Schedule
s
,
Map
<
IterVar
,
Range
>
dom_map
,
bool
de
l
_trivial_loop
);
Stmt
ScheduleOps
(
Schedule
s
,
Map
<
IterVar
,
Range
>
dom_map
,
bool
de
bug_keep
_trivial_loop
);
/*!
/*!
* \brief To automatically inline the element-wise operations.
* \brief To automatically inline the element-wise operations.
...
...
src/api/api_schedule.cc
View file @
cfdc5119
...
@@ -27,7 +27,7 @@ TVM_REGISTER_API("schedule.AutoInlineInjective")
...
@@ -27,7 +27,7 @@ TVM_REGISTER_API("schedule.AutoInlineInjective")
TVM_REGISTER_API
(
"schedule.ScheduleOps"
)
TVM_REGISTER_API
(
"schedule.ScheduleOps"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
.
size
()
==
2
)
if
(
args
.
size
()
==
2
)
*
ret
=
ScheduleOps
(
args
[
0
],
args
[
1
],
tru
e
);
*
ret
=
ScheduleOps
(
args
[
0
],
args
[
1
],
fals
e
);
else
else
*
ret
=
ScheduleOps
(
args
[
0
],
args
[
1
],
args
[
2
]);
*
ret
=
ScheduleOps
(
args
[
0
],
args
[
1
],
args
[
2
]);
});
});
...
...
src/codegen/build_module.cc
View file @
cfdc5119
...
@@ -349,7 +349,7 @@ Stmt BuildStmt(Schedule sch,
...
@@ -349,7 +349,7 @@ Stmt BuildStmt(Schedule sch,
// Phase 0
// Phase 0
auto
bounds
=
schedule
::
InferBound
(
sch
);
auto
bounds
=
schedule
::
InferBound
(
sch
);
auto
stmt
=
schedule
::
ScheduleOps
(
sch
,
bounds
,
tru
e
);
auto
stmt
=
schedule
::
ScheduleOps
(
sch
,
bounds
,
fals
e
);
stmt
=
ir
::
InjectPrefetch
(
stmt
);
stmt
=
ir
::
InjectPrefetch
(
stmt
);
// Phase 1
// Phase 1
...
...
src/op/compute_op.cc
View file @
cfdc5119
...
@@ -296,9 +296,9 @@ Stmt MakeProvide(const ComputeOpNode* op,
...
@@ -296,9 +296,9 @@ Stmt MakeProvide(const ComputeOpNode* op,
Stmt
MakeComputeStmt
(
const
ComputeOpNode
*
self
,
Stmt
MakeComputeStmt
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
{
bool
de
bug_keep
_trivial_loop
)
{
// grab the nest structure
// grab the nest structure
ComputeLoopNest
n
=
ComputeLoopNest
::
make
(
self
,
stage
,
dom_map
,
de
l
_trivial_loop
);
ComputeLoopNest
n
=
ComputeLoopNest
::
make
(
self
,
stage
,
dom_map
,
de
bug_keep
_trivial_loop
);
// Normal loop structure
// Normal loop structure
n
.
init_nest
.
emplace_back
(
op
::
MakeIfNest
(
n
.
init_predicates
));
n
.
init_nest
.
emplace_back
(
op
::
MakeIfNest
(
n
.
init_predicates
));
n
.
main_nest
.
emplace_back
(
op
::
MakeIfNest
(
n
.
main_predicates
));
n
.
main_nest
.
emplace_back
(
op
::
MakeIfNest
(
n
.
main_predicates
));
...
@@ -319,7 +319,11 @@ Stmt MakeComputeStmt(const ComputeOpNode* self,
...
@@ -319,7 +319,11 @@ Stmt MakeComputeStmt(const ComputeOpNode* self,
n
.
main_nest
.
begin
()
+
n
.
num_common_loop
+
1
,
n
.
main_nest
.
end
());
n
.
main_nest
.
begin
()
+
n
.
num_common_loop
+
1
,
n
.
main_nest
.
end
());
provide
=
op
::
Substitute
(
provide
,
n
.
main_vmap
);
provide
=
op
::
Substitute
(
provide
,
n
.
main_vmap
);
provide
=
MergeNest
(
reduce
,
provide
);
provide
=
MergeNest
(
reduce
,
provide
);
return
MergeNest
(
common
,
Block
::
make
(
init
,
provide
));
if
(
debug_keep_trivial_loop
)
{
return
MergeNest
(
common
,
provide
);
}
else
{
return
MergeNest
(
common
,
Block
::
make
(
init
,
provide
));
}
}
else
{
}
else
{
std
::
vector
<
Stmt
>
provides
;
std
::
vector
<
Stmt
>
provides
;
for
(
size_t
i
=
0
;
i
<
self
->
body
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
self
->
body
.
size
();
++
i
)
{
...
@@ -379,16 +383,16 @@ ComputeType DetectComputeType(const ComputeOpNode* self,
...
@@ -379,16 +383,16 @@ ComputeType DetectComputeType(const ComputeOpNode* self,
Stmt
ComputeOpNode
::
BuildProvide
(
Stmt
ComputeOpNode
::
BuildProvide
(
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
{
bool
de
bug_keep
_trivial_loop
)
const
{
CHECK_EQ
(
stage
->
op
.
operator
->
(),
this
);
CHECK_EQ
(
stage
->
op
.
operator
->
(),
this
);
ComputeType
ctype
=
DetectComputeType
(
this
,
stage
);
ComputeType
ctype
=
DetectComputeType
(
this
,
stage
);
if
(
ctype
==
ComputeType
::
kCrossThreadReduction
)
{
if
(
ctype
==
ComputeType
::
kCrossThreadReduction
)
{
// specially handle cross thread reduction.
// specially handle cross thread reduction.
return
MakeCrossThreadReduction
(
this
,
stage
,
dom_map
,
de
l
_trivial_loop
);
return
MakeCrossThreadReduction
(
this
,
stage
,
dom_map
,
de
bug_keep
_trivial_loop
);
}
else
if
(
ctype
==
ComputeType
::
kTensorize
)
{
}
else
if
(
ctype
==
ComputeType
::
kTensorize
)
{
return
MakeTensorize
(
this
,
stage
,
dom_map
,
de
l
_trivial_loop
);
return
MakeTensorize
(
this
,
stage
,
dom_map
,
de
bug_keep
_trivial_loop
);
}
else
{
}
else
{
return
MakeComputeStmt
(
this
,
stage
,
dom_map
,
de
l
_trivial_loop
);
return
MakeComputeStmt
(
this
,
stage
,
dom_map
,
de
bug_keep
_trivial_loop
);
}
}
}
}
...
@@ -396,12 +400,13 @@ ComputeLoopNest ComputeLoopNest::make(
...
@@ -396,12 +400,13 @@ ComputeLoopNest ComputeLoopNest::make(
const
ComputeOpNode
*
self
,
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
{
bool
de
bug_keep
_trivial_loop
)
{
CHECK_EQ
(
stage
->
op
.
operator
->
(),
self
);
CHECK_EQ
(
stage
->
op
.
operator
->
(),
self
);
ComputeLoopNest
ret
;
ComputeLoopNest
ret
;
// make main loop nest
// make main loop nest
ret
.
main_nest
=
op
::
MakeLoopNest
(
ret
.
main_nest
=
op
::
MakeLoopNest
(
stage
,
dom_map
,
0
,
false
,
std
::
unordered_set
<
IterVar
>
(),
&
ret
.
main_vmap
,
del_trivial_loop
);
stage
,
dom_map
,
0
,
false
,
std
::
unordered_set
<
IterVar
>
(),
&
ret
.
main_vmap
,
debug_keep_trivial_loop
);
ret
.
main_predicates
=
schedule
::
MakeBoundCheck
(
ret
.
main_predicates
=
schedule
::
MakeBoundCheck
(
stage
,
dom_map
,
ret
.
main_vmap
,
false
,
stage
,
dom_map
,
ret
.
main_vmap
,
false
,
std
::
unordered_set
<
IterVar
>
());
std
::
unordered_set
<
IterVar
>
());
...
@@ -443,7 +448,7 @@ ComputeLoopNest ComputeLoopNest::make(
...
@@ -443,7 +448,7 @@ ComputeLoopNest ComputeLoopNest::make(
}
}
ret
.
init_nest
=
op
::
MakeLoopNest
(
ret
.
init_nest
=
op
::
MakeLoopNest
(
stage
,
dom_map
,
begin_loop
,
true
,
stage
,
dom_map
,
begin_loop
,
true
,
skip_iter
,
&
(
ret
.
init_vmap
),
de
l
_trivial_loop
);
skip_iter
,
&
(
ret
.
init_vmap
),
de
bug_keep
_trivial_loop
);
ret
.
init_predicates
=
schedule
::
MakeBoundCheck
(
ret
.
init_predicates
=
schedule
::
MakeBoundCheck
(
stage
,
dom_map
,
ret
.
init_vmap
,
true
,
skip_iter
);
stage
,
dom_map
,
ret
.
init_vmap
,
true
,
skip_iter
);
for
(
auto
&
e
:
ret
.
init_predicates
)
{
for
(
auto
&
e
:
ret
.
init_predicates
)
{
...
...
src/op/compute_op.h
View file @
cfdc5119
...
@@ -37,14 +37,14 @@ struct ComputeLoopNest {
...
@@ -37,14 +37,14 @@ struct ComputeLoopNest {
* \param self The pointer to compute op.
* \param self The pointer to compute op.
* \param stage The scxhedule stage.
* \param stage The scxhedule stage.
* \param dom_map The domain map.
* \param dom_map The domain map.
* \param de
l_trivial_loop Whether eliminate
trivial loops with extent of 1
* \param de
bug_keep_trivial_loop Whether keep
trivial loops with extent of 1
* \return The constructed loop nest
* \return The constructed loop nest
*/
*/
static
ComputeLoopNest
make
(
static
ComputeLoopNest
make
(
const
ComputeOpNode
*
self
,
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
);
bool
de
bug_keep
_trivial_loop
);
};
};
/*!
/*!
...
@@ -52,27 +52,27 @@ struct ComputeLoopNest {
...
@@ -52,27 +52,27 @@ struct ComputeLoopNest {
* \param self The pointer to ComputeOpNode
* \param self The pointer to ComputeOpNode
* \param stage The schedule stage.
* \param stage The schedule stage.
* \param dom_map The domain map.
* \param dom_map The domain map.
* \param de
l_trivial_loop Wheter eliminate
trivial loops with extent of 1
* \param de
bug_keep_trivial_loop Whether keep
trivial loops with extent of 1
* \return The created statement.
* \return The created statement.
*/
*/
Stmt
MakeCrossThreadReduction
(
Stmt
MakeCrossThreadReduction
(
const
ComputeOpNode
*
self
,
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
);
bool
de
bug_keep
_trivial_loop
);
/*!
/*!
* \brief Build body of compute for tensorization.
* \brief Build body of compute for tensorization.
* \param self The pointer to ComputeOpNode
* \param self The pointer to ComputeOpNode
* \param stage The schedule stage.
* \param stage The schedule stage.
* \param dom_map The domain map.
* \param dom_map The domain map.
* \param de
l_trivial_loop Wheter eliminate
trivial loops with extent of 1
* \param de
bug_keep_trivial_loop Whether keep
trivial loops with extent of 1
* \return The created statement.
* \return The created statement.
*/
*/
Stmt
MakeTensorize
(
const
ComputeOpNode
*
self
,
Stmt
MakeTensorize
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
);
bool
de
bug_keep
_trivial_loop
);
}
// namespace tvm
}
// namespace tvm
#endif // TVM_OP_COMPUTE_OP_H_
#endif // TVM_OP_COMPUTE_OP_H_
src/op/cross_thread_reduction.cc
View file @
cfdc5119
...
@@ -14,14 +14,14 @@ Stmt MakeCrossThreadReduction(
...
@@ -14,14 +14,14 @@ Stmt MakeCrossThreadReduction(
const
ComputeOpNode
*
self
,
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
{
bool
de
bug_keep
_trivial_loop
)
{
Array
<
Expr
>
args
;
Array
<
Expr
>
args
;
for
(
IterVar
iv
:
self
->
axis
)
{
for
(
IterVar
iv
:
self
->
axis
)
{
args
.
push_back
(
iv
->
var
);
args
.
push_back
(
iv
->
var
);
}
}
std
::
unordered_map
<
IterVar
,
Expr
>
value_map
;
std
::
unordered_map
<
IterVar
,
Expr
>
value_map
;
auto
nest
=
op
::
MakeLoopNest
(
auto
nest
=
op
::
MakeLoopNest
(
stage
,
dom_map
,
0
,
false
,
std
::
unordered_set
<
IterVar
>
(),
&
value_map
,
de
l
_trivial_loop
);
stage
,
dom_map
,
0
,
false
,
std
::
unordered_set
<
IterVar
>
(),
&
value_map
,
de
bug_keep
_trivial_loop
);
auto
conds
=
schedule
::
MakeBoundCheck
(
auto
conds
=
schedule
::
MakeBoundCheck
(
stage
,
dom_map
,
value_map
,
false
,
stage
,
dom_map
,
value_map
,
false
,
std
::
unordered_set
<
IterVar
>
());
std
::
unordered_set
<
IterVar
>
());
...
...
src/op/extern_op.cc
View file @
cfdc5119
...
@@ -129,7 +129,7 @@ Stmt ExternOpNode::BuildRealize(
...
@@ -129,7 +129,7 @@ Stmt ExternOpNode::BuildRealize(
Stmt
ExternOpNode
::
BuildProvide
(
Stmt
ExternOpNode
::
BuildProvide
(
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
{
bool
de
bug_keep
_trivial_loop
)
const
{
CHECK_EQ
(
stage
->
op
.
operator
->
(),
this
);
CHECK_EQ
(
stage
->
op
.
operator
->
(),
this
);
Stmt
ret
=
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
attr
::
extern_scope
,
0
,
this
->
body
);
Stmt
ret
=
AttrStmt
::
make
(
make_zero
(
Int
(
32
)),
attr
::
extern_scope
,
0
,
this
->
body
);
auto
f_push_bind
=
[
&
ret
](
Buffer
buffer
,
Tensor
tensor
)
{
auto
f_push_bind
=
[
&
ret
](
Buffer
buffer
,
Tensor
tensor
)
{
...
...
src/op/op_util.cc
View file @
cfdc5119
...
@@ -24,7 +24,7 @@ MakeLoopNest(const Stage& stage,
...
@@ -24,7 +24,7 @@ MakeLoopNest(const Stage& stage,
bool
new_loop_var
,
bool
new_loop_var
,
const
std
::
unordered_set
<
IterVar
>&
skip_iter
,
const
std
::
unordered_set
<
IterVar
>&
skip_iter
,
std
::
unordered_map
<
IterVar
,
Expr
>*
p_value_map
,
std
::
unordered_map
<
IterVar
,
Expr
>*
p_value_map
,
bool
de
l
_trivial_loop
)
{
bool
de
bug_keep
_trivial_loop
)
{
auto
leaf_iter_vars
=
stage
->
leaf_iter_vars
;
auto
leaf_iter_vars
=
stage
->
leaf_iter_vars
;
Stmt
no_op
=
Evaluate
::
make
(
0
);
Stmt
no_op
=
Evaluate
::
make
(
0
);
// create the loop nest
// create the loop nest
...
@@ -76,7 +76,7 @@ MakeLoopNest(const Stage& stage,
...
@@ -76,7 +76,7 @@ MakeLoopNest(const Stage& stage,
AttrStmt
::
make
(
iv
,
ir
::
attr
::
pragma_scope
,
p
,
no_op
));
AttrStmt
::
make
(
iv
,
ir
::
attr
::
pragma_scope
,
p
,
no_op
));
}
}
}
}
if
(
del
_trivial_loop
&&
is_one
(
dom
->
extent
))
{
if
(
!
debug_keep
_trivial_loop
&&
is_one
(
dom
->
extent
))
{
nest
[
i
+
1
].
emplace_back
(
nest
[
i
+
1
].
emplace_back
(
LetStmt
::
make
(
var
,
dom
->
min
,
no_op
));
LetStmt
::
make
(
var
,
dom
->
min
,
no_op
));
value_map
[
iv
]
=
dom
->
min
;
value_map
[
iv
]
=
dom
->
min
;
...
@@ -131,7 +131,7 @@ MakeLoopNest(const Stage& stage,
...
@@ -131,7 +131,7 @@ MakeLoopNest(const Stage& stage,
// annotate the extent of the IterVar
// annotate the extent of the IterVar
nest
[
i
+
1
].
emplace_back
(
nest
[
i
+
1
].
emplace_back
(
AttrStmt
::
make
(
bind_iv
,
ir
::
attr
::
thread_extent
,
dom
->
extent
,
no_op
));
AttrStmt
::
make
(
bind_iv
,
ir
::
attr
::
thread_extent
,
dom
->
extent
,
no_op
));
if
(
del
_trivial_loop
&&
is_one
(
dom
->
extent
))
{
if
(
!
debug_keep
_trivial_loop
&&
is_one
(
dom
->
extent
))
{
value_map
[
iv
]
=
dom
->
min
;
value_map
[
iv
]
=
dom
->
min
;
}
else
{
}
else
{
value_map
[
iv
]
=
var
;
value_map
[
iv
]
=
var
;
...
...
src/op/op_util.h
View file @
cfdc5119
...
@@ -29,7 +29,7 @@ using ir::MergeNest;
...
@@ -29,7 +29,7 @@ using ir::MergeNest;
* \param new_loop_var Whether create new loop variable.
* \param new_loop_var Whether create new loop variable.
* \param skip_iter Whether skip certain iteration.
* \param skip_iter Whether skip certain iteration.
* \param p_value_map The result value of each IterVar.
* \param p_value_map The result value of each IterVar.
* \param de
l_trivial_loop Whether eliminate
trivial loops with extent of 1
* \param de
bug_keep_trivial_loop Whether keep
trivial loops with extent of 1
*/
*/
std
::
vector
<
std
::
vector
<
Stmt
>
>
std
::
vector
<
std
::
vector
<
Stmt
>
>
MakeLoopNest
(
const
Stage
&
stage
,
MakeLoopNest
(
const
Stage
&
stage
,
...
@@ -38,7 +38,7 @@ MakeLoopNest(const Stage& stage,
...
@@ -38,7 +38,7 @@ MakeLoopNest(const Stage& stage,
bool
new_loop_var
,
bool
new_loop_var
,
const
std
::
unordered_set
<
IterVar
>&
skip_iter
,
const
std
::
unordered_set
<
IterVar
>&
skip_iter
,
std
::
unordered_map
<
IterVar
,
Expr
>*
p_value_map
,
std
::
unordered_map
<
IterVar
,
Expr
>*
p_value_map
,
bool
de
l
_trivial_loop
);
bool
de
bug_keep
_trivial_loop
);
/*!
/*!
* \brief Create a nest of if checking the predicates.
* \brief Create a nest of if checking the predicates.
...
...
src/op/placeholder_op.cc
View file @
cfdc5119
...
@@ -79,7 +79,7 @@ Stmt PlaceholderOpNode::BuildRealize(
...
@@ -79,7 +79,7 @@ Stmt PlaceholderOpNode::BuildRealize(
Stmt
PlaceholderOpNode
::
BuildProvide
(
Stmt
PlaceholderOpNode
::
BuildProvide
(
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
{
bool
de
bug_keep
_trivial_loop
)
const
{
return
Stmt
();
return
Stmt
();
}
}
}
// namespace tvm
}
// namespace tvm
src/op/scan_op.cc
View file @
cfdc5119
...
@@ -253,7 +253,7 @@ Stmt ScanOpNode::BuildRealize(
...
@@ -253,7 +253,7 @@ Stmt ScanOpNode::BuildRealize(
Stmt
ScanOpNode
::
BuildProvide
(
Stmt
ScanOpNode
::
BuildProvide
(
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
const
{
bool
de
bug_keep
_trivial_loop
)
const
{
CHECK_EQ
(
stage
->
op
.
operator
->
(),
this
);
CHECK_EQ
(
stage
->
op
.
operator
->
(),
this
);
Stmt
provide
=
AttrStmt
::
make
(
Stmt
provide
=
AttrStmt
::
make
(
stage
->
op
,
attr
::
scan_update_scope
,
this
->
scan_axis
->
var
,
stage
->
op
,
attr
::
scan_update_scope
,
this
->
scan_axis
->
var
,
...
@@ -271,7 +271,7 @@ Stmt ScanOpNode::BuildProvide(
...
@@ -271,7 +271,7 @@ Stmt ScanOpNode::BuildProvide(
std
::
unordered_map
<
IterVar
,
Expr
>
vmap
;
std
::
unordered_map
<
IterVar
,
Expr
>
vmap
;
std
::
unordered_set
<
IterVar
>
empty
;
std
::
unordered_set
<
IterVar
>
empty
;
auto
nest
=
op
::
MakeLoopNest
(
auto
nest
=
op
::
MakeLoopNest
(
stage
,
dom_map
,
0
,
false
,
empty
,
&
vmap
,
de
l
_trivial_loop
);
stage
,
dom_map
,
0
,
false
,
empty
,
&
vmap
,
de
bug_keep
_trivial_loop
);
nest
[
begin_scan
].
push_back
(
init
);
nest
[
begin_scan
].
push_back
(
init
);
nest
.
push_back
(
nest
.
push_back
(
op
::
MakeIfNest
(
op
::
MakeIfNest
(
...
...
src/op/tensorize.cc
View file @
cfdc5119
...
@@ -370,14 +370,14 @@ Stmt TransformUpdate(const Stage& stage,
...
@@ -370,14 +370,14 @@ Stmt TransformUpdate(const Stage& stage,
Stmt
MakeTensorize
(
const
ComputeOpNode
*
self
,
Stmt
MakeTensorize
(
const
ComputeOpNode
*
self
,
const
Stage
&
stage
,
const
Stage
&
stage
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
{
bool
de
bug_keep
_trivial_loop
)
{
std
::
unordered_map
<
IterVar
,
Range
>
out_dom
;
std
::
unordered_map
<
IterVar
,
Range
>
out_dom
;
std
::
unordered_map
<
Tensor
,
Array
<
Range
>
>
in_region
;
std
::
unordered_map
<
Tensor
,
Array
<
Range
>
>
in_region
;
size_t
tloc
=
InferTensorizeRegion
(
self
,
stage
,
dom_map
,
&
out_dom
,
&
in_region
);
size_t
tloc
=
InferTensorizeRegion
(
self
,
stage
,
dom_map
,
&
out_dom
,
&
in_region
);
TensorIntrin
intrin
=
stage
->
iter_var_attrs
.
at
(
TensorIntrin
intrin
=
stage
->
iter_var_attrs
.
at
(
stage
->
leaf_iter_vars
[
tloc
])
->
tensor_intrin
;
stage
->
leaf_iter_vars
[
tloc
])
->
tensor_intrin
;
CHECK
(
intrin
.
defined
());
CHECK
(
intrin
.
defined
());
ComputeLoopNest
n
=
ComputeLoopNest
::
make
(
self
,
stage
,
dom_map
,
de
l
_trivial_loop
);
ComputeLoopNest
n
=
ComputeLoopNest
::
make
(
self
,
stage
,
dom_map
,
de
bug_keep
_trivial_loop
);
VerifyTensorizeLoopNest
(
self
,
stage
,
n
,
tloc
);
VerifyTensorizeLoopNest
(
self
,
stage
,
n
,
tloc
);
VerifyTensorizeBody
(
self
,
stage
,
out_dom
,
in_region
,
intrin
);
VerifyTensorizeBody
(
self
,
stage
,
out_dom
,
in_region
,
intrin
);
// Start bind data.
// Start bind data.
...
...
src/schedule/schedule_ops.cc
View file @
cfdc5119
...
@@ -23,8 +23,8 @@ using namespace ir;
...
@@ -23,8 +23,8 @@ using namespace ir;
Stmt
MakePipeline
(
const
Stage
&
s
,
Stmt
MakePipeline
(
const
Stage
&
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
Stmt
consumer
,
Stmt
consumer
,
bool
de
l
_trivial_loop
)
{
bool
de
bug_keep
_trivial_loop
)
{
Stmt
producer
=
s
->
op
->
BuildProvide
(
s
,
dom_map
,
de
l
_trivial_loop
);
Stmt
producer
=
s
->
op
->
BuildProvide
(
s
,
dom_map
,
de
bug_keep
_trivial_loop
);
if
(
producer
.
defined
())
{
if
(
producer
.
defined
())
{
producer
=
ProducerConsumer
::
make
(
s
->
op
,
true
,
producer
);
producer
=
ProducerConsumer
::
make
(
s
->
op
,
true
,
producer
);
}
}
...
@@ -58,9 +58,9 @@ class InjectAttach : public IRMutator {
...
@@ -58,9 +58,9 @@ class InjectAttach : public IRMutator {
InjectAttach
(
const
Stage
&
stage
,
InjectAttach
(
const
Stage
&
stage
,
const
Stage
&
attach_spec
,
const
Stage
&
attach_spec
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
de
l
_trivial_loop
)
bool
de
bug_keep
_trivial_loop
)
:
stage_
(
stage
),
attach_spec_
(
attach_spec
),
dom_map_
(
dom_map
),
:
stage_
(
stage
),
attach_spec_
(
attach_spec
),
dom_map_
(
dom_map
),
de
l_trivial_loop_
(
del
_trivial_loop
)
{}
de
bug_keep_trivial_loop_
(
debug_keep
_trivial_loop
)
{}
Stmt
Mutate
(
Stmt
stmt
)
final
{
Stmt
Mutate
(
Stmt
stmt
)
final
{
CHECK
(
stmt
.
defined
());
CHECK
(
stmt
.
defined
());
...
@@ -76,7 +76,7 @@ class InjectAttach : public IRMutator {
...
@@ -76,7 +76,7 @@ class InjectAttach : public IRMutator {
found_attach
=
true
;
found_attach
=
true
;
stmt
=
AttrStmt
::
make
(
stmt
=
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
op
->
value
,
op
->
node
,
op
->
attr_key
,
op
->
value
,
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
,
de
l
_trivial_loop_
));
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
,
de
bug_keep
_trivial_loop_
));
}
}
}
}
return
stmt
;
return
stmt
;
...
@@ -91,8 +91,9 @@ class InjectAttach : public IRMutator {
...
@@ -91,8 +91,9 @@ class InjectAttach : public IRMutator {
const
Stage
&
attach_spec_
;
const
Stage
&
attach_spec_
;
// domain map
// domain map
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map_
;
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map_
;
// whether delete trivial loops with extent of 1
// Whether keep trivial loops with extent of 1 during lowering.
bool
del_trivial_loop_
;
// This is a debug feature for dataflow/axis analysis
bool
debug_keep_trivial_loop_
;
};
};
// inject the operator's realization on the stmt.
// inject the operator's realization on the stmt.
...
@@ -102,9 +103,9 @@ class InjectScanStep : public IRMutator {
...
@@ -102,9 +103,9 @@ class InjectScanStep : public IRMutator {
const
Operation
&
scan_op
,
const
Operation
&
scan_op
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
bool
is_init
,
bool
is_init
,
bool
de
l
_trivial_loop
)
bool
de
bug_keep
_trivial_loop
)
:
stage_
(
stage
),
scan_op_
(
scan_op
),
:
stage_
(
stage
),
scan_op_
(
scan_op
),
dom_map_
(
dom_map
),
is_init_
(
is_init
),
de
l_trivial_loop_
(
del
_trivial_loop
)
{}
dom_map_
(
dom_map
),
is_init_
(
is_init
),
de
bug_keep_trivial_loop_
(
debug_keep
_trivial_loop
)
{}
Stmt
Mutate
(
Stmt
stmt
)
final
{
Stmt
Mutate
(
Stmt
stmt
)
final
{
CHECK
(
stmt
.
defined
());
CHECK
(
stmt
.
defined
());
...
@@ -118,7 +119,7 @@ class InjectScanStep : public IRMutator {
...
@@ -118,7 +119,7 @@ class InjectScanStep : public IRMutator {
found_attach
=
true
;
found_attach
=
true
;
stmt
=
AttrStmt
::
make
(
stmt
=
AttrStmt
::
make
(
op
->
node
,
op
->
attr_key
,
op
->
value
,
op
->
node
,
op
->
attr_key
,
op
->
value
,
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
,
de
l
_trivial_loop_
));
MakePipeline
(
stage_
,
dom_map_
,
op
->
body
,
de
bug_keep
_trivial_loop_
));
}
}
}
}
return
stmt
;
return
stmt
;
...
@@ -135,8 +136,9 @@ class InjectScanStep : public IRMutator {
...
@@ -135,8 +136,9 @@ class InjectScanStep : public IRMutator {
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map_
;
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map_
;
// whether it is init.
// whether it is init.
bool
is_init_
;
bool
is_init_
;
// whether delete trivial loops with extent of 1
// Whether keep trivial loops with extent of 1 during lowering.
bool
del_trivial_loop_
;
// This is a debug feature for dataflow/axis analysis
bool
debug_keep_trivial_loop_
;
};
};
// Postprocessing of schedule op
// Postprocessing of schedule op
...
@@ -337,7 +339,7 @@ class SchedulePostProc : public IRMutator {
...
@@ -337,7 +339,7 @@ class SchedulePostProc : public IRMutator {
};
};
Stmt
ScheduleOps
(
Stmt
ScheduleOps
(
Schedule
sch
,
Map
<
IterVar
,
Range
>
dom_map_
,
bool
de
l
_trivial_loop
)
{
Schedule
sch
,
Map
<
IterVar
,
Range
>
dom_map_
,
bool
de
bug_keep
_trivial_loop
)
{
Stmt
body
=
Stmt
();
Stmt
body
=
Stmt
();
std
::
unordered_map
<
IterVar
,
Range
>
dom_map
=
as_unordered_map
(
dom_map_
);
std
::
unordered_map
<
IterVar
,
Range
>
dom_map
=
as_unordered_map
(
dom_map_
);
// scan init and scan updates
// scan init and scan updates
...
@@ -372,14 +374,14 @@ Stmt ScheduleOps(
...
@@ -372,14 +374,14 @@ Stmt ScheduleOps(
if
(
scan_init
.
count
(
s
->
op
))
{
if
(
scan_init
.
count
(
s
->
op
))
{
CHECK
(
body
.
defined
());
CHECK
(
body
.
defined
());
InjectScanStep
mu
(
s
,
scan_init
.
at
(
s
->
op
),
dom_map
,
true
,
de
l
_trivial_loop
);
InjectScanStep
mu
(
s
,
scan_init
.
at
(
s
->
op
),
dom_map
,
true
,
de
bug_keep
_trivial_loop
);
body
=
mu
.
Mutate
(
body
);
body
=
mu
.
Mutate
(
body
);
CHECK
(
mu
.
found_attach
)
CHECK
(
mu
.
found_attach
)
<<
"did not find attachment point for scan.init"
;
<<
"did not find attachment point for scan.init"
;
}
else
if
(
attach_spec
->
attach_type
==
kScanUpdate
)
{
}
else
if
(
attach_spec
->
attach_type
==
kScanUpdate
)
{
// Handle scan update
// Handle scan update
CHECK
(
body
.
defined
());
CHECK
(
body
.
defined
());
InjectScanStep
mu
(
s
,
attach_spec
->
attach_stage
->
op
,
dom_map
,
false
,
de
l
_trivial_loop
);
InjectScanStep
mu
(
s
,
attach_spec
->
attach_stage
->
op
,
dom_map
,
false
,
de
bug_keep
_trivial_loop
);
body
=
mu
.
Mutate
(
body
);
body
=
mu
.
Mutate
(
body
);
CHECK
(
mu
.
found_attach
)
CHECK
(
mu
.
found_attach
)
<<
"did not find attachment point for scan.update"
;
<<
"did not find attachment point for scan.update"
;
...
@@ -387,11 +389,11 @@ Stmt ScheduleOps(
...
@@ -387,11 +389,11 @@ Stmt ScheduleOps(
// do nothing
// do nothing
}
else
if
(
attach_spec
->
attach_type
==
kGroupRoot
)
{
}
else
if
(
attach_spec
->
attach_type
==
kGroupRoot
)
{
CHECK
(
!
s
->
group
.
defined
());
CHECK
(
!
s
->
group
.
defined
());
body
=
MakePipeline
(
s
,
dom_map
,
body
,
de
l
_trivial_loop
);
body
=
MakePipeline
(
s
,
dom_map
,
body
,
de
bug_keep
_trivial_loop
);
}
else
{
}
else
{
CHECK_EQ
(
attach_spec
->
attach_type
,
kScope
);
CHECK_EQ
(
attach_spec
->
attach_type
,
kScope
);
CHECK
(
body
.
defined
());
CHECK
(
body
.
defined
());
InjectAttach
mutator
(
s
,
attach_spec
,
dom_map
,
de
l
_trivial_loop
);
InjectAttach
mutator
(
s
,
attach_spec
,
dom_map
,
de
bug_keep
_trivial_loop
);
body
=
mutator
.
Mutate
(
body
);
body
=
mutator
.
Mutate
(
body
);
CHECK
(
mutator
.
found_attach
)
CHECK
(
mutator
.
found_attach
)
<<
"did not find attachment point for "
<<
s
<<
" in "
<<
"did not find attachment point for "
<<
s
<<
" in "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment