Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
ea9c1c59
Commit
ea9c1c59
authored
Apr 09, 2017
by
Tianqi Chen
Committed by
GitHub
Apr 09, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[SCHEDULE] More reliable bound inference on threading. (#84)
parent
3ac94439
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
126 additions
and
130 deletions
+126
-130
src/schedule/bound.cc
+94
-101
src/schedule/schedule_dataflow_rewrite.cc
+2
-28
tests/python/unittest/test_schedule_bound_inference.py
+30
-1
No files found.
src/schedule/bound.cc
View file @
ea9c1c59
...
...
@@ -15,33 +15,56 @@
namespace
tvm
{
namespace
schedule
{
using
runtime
::
ThreadScope
;
using
runtime
::
StorageScope
;
/*! \brief The graph context used during bound inference. */
struct
GraphContext
{
/*! \brief The feed graph */
FeedGraph
feed_graph
;
/*! \brief Attachment path */
AttachPath
attach_path
;
/*! \brief The bind map */
std
::
unordered_map
<
IterVar
,
IterVar
>
bind_map
;
/*! \brief map from op to stage */
std
::
unordered_map
<
const
Node
*
,
Stage
>
op2stage_
;
};
// check if scope
inline
bool
ScopeRelax
(
const
IterVar
&
ivar
,
const
std
::
unordered_map
<
IterVar
,
IterVar
>&
bind_map
,
const
std
::
string
&
scope
)
{
using
runtime
::
ThreadScope
;
using
runtime
::
StorageScope
;
auto
it
=
bind_map
.
find
(
ivar
);
IterVar
iv
=
ivar
;
if
(
it
!=
bind_map
.
end
())
{
iv
=
it
->
second
;
bool
NeedRelax
(
const
IterVar
&
iv
,
bool
found_attach
,
const
std
::
unordered_map
<
IterVar
,
IterVar
>&
bind_map
,
const
runtime
::
StorageScope
&
scope
)
{
auto
it
=
bind_map
.
find
(
iv
);
const
std
::
string
&
tag
=
(
it
!=
bind_map
.
end
()
?
it
->
second
->
thread_tag
:
iv
->
thread_tag
);
if
(
tag
.
length
()
==
0
||
tag
==
"pipeline"
)
{
return
!
found_attach
;
}
if
(
iv
->
thread_tag
.
length
()
==
0
)
return
false
;
if
(
scope
.
length
()
==
0
)
return
false
;
return
scope
.
rank
<=
ThreadScope
::
make
(
tag
).
rank
;
}
return
StorageScope
::
make
(
scope
).
rank
<=
ThreadScope
::
make
(
iv
->
thread_tag
).
rank
;
// infer storage scope, if not given
StorageScope
InferStorageScope
(
const
Stage
&
stage
,
const
GraphContext
&
ctx
)
{
if
(
stage
->
scope
.
length
()
!=
0
)
{
return
StorageScope
::
make
(
stage
->
scope
);
}
int
max_rank
=
0
;
for
(
IterVar
iv
:
ctx
.
attach_path
.
at
(
stage
->
op
))
{
auto
it
=
ctx
.
bind_map
.
find
(
iv
);
const
std
::
string
&
tag
=
(
it
!=
ctx
.
bind_map
.
end
()
?
it
->
second
->
thread_tag
:
iv
->
thread_tag
);
if
(
tag
!=
"pipeline"
&&
tag
.
length
()
!=
0
)
{
max_rank
=
std
::
max
(
max_rank
,
ThreadScope
::
make
(
tag
).
rank
+
1
);
}
}
StorageScope
s
;
s
.
rank
=
max_rank
;
return
s
;
}
void
InferRootBound
(
const
Stage
&
stage
,
const
GraphContext
&
ctx
,
const
AttachPath
&
attach_path
,
const
std
::
unordered_map
<
IterVar
,
IterVar
>&
bind_map
,
std
::
unordered_map
<
IterVar
,
Range
>*
rmap
)
{
CHECK_NE
(
stage
->
attach_type
,
kInline
)
<<
"call schedule.normalize before scheduleops"
;
...
...
@@ -59,73 +82,78 @@ void InferRootBound(const Stage& stage,
}
return
;
}
// parent stage, if any
Stage
parent
;
Stage
attach_spec
=
stage
.
GetAttachSpec
();
if
(
attach_spec
->
attach_type
==
kScope
||
attach_spec
->
attach_type
==
kScanUpdate
)
{
parent
=
attach_spec
->
attach_stage
;
}
// The tensor domain.
std
::
unordered_map
<
Tensor
,
TensorDom
>
tmap
;
//
consumers other than parent
//
The consumers of the op.
std
::
unordered_set
<
Operation
>
consumers
;
// initialize the result
bool
direct_consume_by_parent
=
false
;
for
(
int
i
=
0
;
i
<
stage
->
op
->
num_outputs
();
++
i
)
{
Tensor
t
=
stage
->
op
.
output
(
i
);
tmap
.
emplace
(
t
,
TensorDom
(
static_cast
<
int
>
(
t
.
ndim
())));
auto
it
=
ctx
.
feed_graph
.
find
(
t
);
if
(
it
!=
ctx
.
feed_graph
.
end
())
{
for
(
const
Operation
&
op
:
it
->
second
)
{
if
(
!
parent
.
defined
()
||
op
!=
parent
->
op
)
{
consumers
.
insert
(
op
);
}
else
{
direct_consume_by_parent
=
true
;
}
consumers
.
insert
(
op
);
}
}
else
{
LOG
(
INFO
)
<<
"not in feed graph consumer = "
<<
stage
->
op
;
}
}
//
The relax set
// Thie specifieds the iteration variables that need to be relaxed
//
from the already inferred bound
s.
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
relax_set
;
for
(
IterVar
iv
:
attach_path
.
at
(
stage
->
op
))
{
if
(
ScopeRelax
(
iv
,
bind_map
,
stage
->
scope
))
{
relax_set
[
iv
->
var
.
get
()]
=
IntSet
::
range
(
rmap
->
at
(
iv
));
}
}
if
(
direct_consume_by_parent
)
{
// Bound inference logics in parent.
//
storage scope.
runtime
::
StorageScope
scope
=
InferStorageScope
(
stage
,
ctx
);
//
Bound prop by other consumer
s.
// - Compute bound by relaxation rules: NeedRelax
// - For normal index, use relative location of loop nest./
// - For thread index, use the thread scope.
//
Array
<
IterVar
>
stage_attach
=
ctx
.
attach_path
.
at
(
stage
->
op
);
// The parent set.
for
(
const
Operation
&
op
:
consumers
)
{
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
relax_set
;
std
::
unordered_map
<
IterVar
,
IntSet
>
up_state
;
bool
fix_value
=
true
;
for
(
auto
iv
:
parent
->
leaf_iter_vars
)
{
bool
found_attach
=
false
;
CHECK
(
ctx
.
op2stage_
.
count
(
op
.
get
()));
const
Stage
&
op_stage
=
ctx
.
op2stage_
.
at
(
op
.
get
());
// Consumer nest
for
(
size_t
i
=
op_stage
->
leaf_iter_vars
.
size
();
i
!=
0
;
--
i
)
{
IterVar
iv
=
op_stage
->
leaf_iter_vars
[
i
-
1
];
if
(
stage_attach
.
size
()
!=
0
&&
iv
==
stage_attach
[
0
])
{
found_attach
=
true
;
}
auto
it
=
rmap
->
find
(
iv
);
CHECK
(
it
!=
rmap
->
end
());
Range
vrange
=
it
->
second
;
CHECK
(
is_zero
(
vrange
->
min
))
<<
"InferBound requires every leaf iter var's min equals 0, "
<<
" call schedule.normalize to achieve this. "
<<
" stage="
<<
parent
<<
", vrange="
<<
vrange
->
min
;
// special optimization to remove trivial loop
const
Range
&
vrange
=
it
->
second
;
if
(
is_one
(
vrange
->
extent
))
{
up_state
[
iv
]
=
IntSet
::
single_point
(
vrange
->
min
);
}
else
if
(
fix_value
&&
!
ScopeRelax
(
iv
,
bind_map
,
stage
->
scope
))
{
}
else
if
(
!
NeedRelax
(
iv
,
found_attach
,
ctx
.
bind_map
,
scope
))
{
CHECK
(
is_zero
(
vrange
->
min
))
<<
"InferBound requires every leaf iter var's min equals 0, "
<<
" call schedule.normalize to achieve this. "
;
up_state
[
iv
]
=
IntSet
::
single_point
(
iv
->
var
);
}
else
{
up_state
[
iv
]
=
IntSet
::
range
(
vrange
);
}
if
(
attach_spec
->
attach_ivar
==
iv
)
{
fix_value
=
false
;
}
// Consumer's attach nest
for
(
IterVar
iv
:
ctx
.
attach_path
.
at
(
op
))
{
if
(
stage_attach
.
size
()
!=
0
&&
iv
==
stage_attach
[
0
])
{
found_attach
=
true
;
}
Range
vrange
=
rmap
->
at
(
iv
);
CHECK
(
is_zero
(
vrange
->
min
))
<<
"InferBound requires every leaf iter var's min equals 0, "
<<
"call schedule.normalize to achieve this."
;
if
(
NeedRelax
(
iv
,
found_attach
,
ctx
.
bind_map
,
scope
))
{
relax_set
[
iv
->
var
.
get
()]
=
IntSet
::
range
(
vrange
);
}
}
// get the bound of the root IterVars given current location.
PassUpDomain
(
parent
,
*
rmap
,
&
up_state
);
CHECK
(
found_attach
||
stage_attach
.
size
()
==
0
)
<<
"Invalid Schedule, cannot find the producer "
<<
stage
->
op
<<
" along the loop nest specified by compute_at of consumer "
<<
op
;
// Get the domain of the consumer
PassUpDomain
(
op_stage
,
*
rmap
,
&
up_state
);
// Relax if needed.
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
dom_map
;
for
(
auto
iv
:
parent
->
op
->
root_iter_vars
())
{
for
(
auto
iv
:
op
->
root_iter_vars
())
{
Range
r
;
if
(
up_state
.
count
(
iv
))
{
r
=
up_state
.
at
(
iv
).
cover_range
(
iv
->
dom
);
...
...
@@ -138,70 +166,35 @@ void InferRootBound(const Stage& stage,
dom_map
[
iv
->
var
.
get
()]
=
IntSet
::
range
(
r
);
}
}
// prop from parent.
parent
->
op
->
PropBoundToInputs
(
parent
->
op
,
dom_map
,
&
tmap
);
}
// Bound prop by other consumers.
// To explain the the general logic, consider the example:
//
// for (i_outer, 0, 10) {
// producer
//
// for (i_inner, 0, 4) {
// consumer op
// }
// }
// - Get domain of each of consumer op, say [i_inner + i_outer*8, extent=4)
// - We need to relax it since the producer is attached at i_outer
// - Consumer's path is [i_inner, i_outer], then [i_inner] need to be relaxed
// - Traverse attach_path, relax until reaching the producer's attachment point.
for
(
const
Operation
&
op
:
consumers
)
{
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
dom_map
;
bool
found
=
false
;
Array
<
IterVar
>
attach
=
attach_path
.
at
(
stage
->
op
);
for
(
IterVar
iv
:
attach_path
.
at
(
op
))
{
if
(
attach
.
size
()
!=
0
&&
iv
==
attach
[
0
])
{
found
=
true
;
break
;
}
Range
vrange
=
rmap
->
at
(
iv
);
CHECK
(
is_zero
(
vrange
->
min
))
<<
"InferBound requires every leaf iter var's min equals 0, "
<<
"call schedule.normalize to achieve this."
;
relax_set
[
iv
->
var
.
get
()]
=
IntSet
::
range
(
vrange
);
}
CHECK
(
found
||
attach
.
size
()
==
0
)
<<
"Invalid Schedule, cannot find the producer "
<<
stage
->
op
<<
" along the loop nest specified by compute_at of consumer "
<<
op
;
for
(
auto
iv
:
op
->
root_iter_vars
())
{
Range
r
=
rmap
->
at
(
iv
);
dom_map
[
iv
->
var
.
get
()]
=
EvalSet
(
r
,
relax_set
);
}
op
->
PropBoundToInputs
(
op
,
dom_map
,
&
tmap
);
}
stage
->
op
->
GatherBound
(
stage
->
op
,
tmap
,
rmap
);
}
Map
<
IterVar
,
Range
>
InferBound
(
const
Schedule
&
sch
)
{
// Prepare context
GraphContext
ctx
;
Array
<
Operation
>
roots
;
for
(
Operation
op
:
sch
->
outputs
)
{
roots
.
push_back
(
sch
->
stage_map
[
op
]
->
op
);
}
std
::
unordered_map
<
IterVar
,
IterVar
>
bind_map
;
ctx
.
feed_graph
=
CreateFeedGraph
(
CreateReadGraph
(
roots
));
for
(
Stage
stage
:
sch
->
stages
)
{
for
(
auto
kv
:
stage
->
iter_var_attrs
)
{
if
(
kv
.
second
->
bind_thread
.
defined
())
{
CHECK
(
!
bind_map
.
count
(
kv
.
first
));
bind_map
[
kv
.
first
]
=
kv
.
second
->
bind_thread
;
CHECK
(
!
ctx
.
bind_map
.
count
(
kv
.
first
));
ctx
.
bind_map
[
kv
.
first
]
=
kv
.
second
->
bind_thread
;
}
}
ctx
.
op2stage_
[
stage
->
op
.
get
()]
=
stage
;
}
GraphContext
ctx
;
ctx
.
feed_graph
=
CreateFeedGraph
(
CreateReadGraph
(
roots
));
AttachPath
attach_path
=
CreateAttachPath
(
sch
);
ctx
.
attach_path
=
CreateAttachPath
(
sch
);
// Run inference.
std
::
unordered_map
<
IterVar
,
Range
>
ret
;
for
(
size_t
i
=
sch
->
stages
.
size
();
i
!=
0
;
--
i
)
{
const
Stage
&
stage
=
sch
->
stages
[
i
-
1
];
InferRootBound
(
stage
,
ctx
,
attach_path
,
bind_map
,
&
ret
);
InferRootBound
(
stage
,
ctx
,
&
ret
);
// pass down to get bound of all iter vars.
PassDownDomain
(
stage
,
&
ret
);
for
(
IterVar
iv
:
stage
->
env_threads
)
{
...
...
src/schedule/schedule_dataflow_rewrite.cc
View file @
ea9c1c59
...
...
@@ -154,24 +154,9 @@ Tensor Schedule::cache_write(const Tensor& tensor,
void
RebaseNonZeroMinLoop
(
const
Schedule
&
sch
)
{
std
::
unordered_map
<
IterVar
,
IterVar
>
rebase_map
;
std
::
unordered_map
<
const
Node
*
,
int
>
attach_mark
;
for
(
Stage
s
:
sch
->
stages
)
{
if
(
s
->
attach_type
==
kScope
)
{
attach_mark
[
s
->
attach_stage
.
get
()]
=
1
;
}
if
(
s
->
op
.
as
<
ScanOpNode
>
())
{
attach_mark
[
s
.
get
()]
=
1
;
}
}
for
(
Stage
s
:
sch
->
groups
)
{
if
(
s
->
attach_type
==
kScope
)
{
attach_mark
[
s
->
attach_stage
.
get
()]
=
1
;
}
}
if
(
s
->
attach_type
==
kInlinedAlready
)
continue
;
for
(
Stage
s
:
sch
->
stages
)
{
if
(
!
attach_mark
.
count
(
s
.
get
()))
continue
;
auto
root_iter_vars
=
s
->
op
->
root_iter_vars
();
ArrayNode
*
leaf_vars
=
s
->
leaf_iter_vars
.
CopyOnWrite
();
for
(
IterVar
iv
:
root_iter_vars
)
{
...
...
@@ -201,16 +186,6 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
}
}
void
SetScanAttach
(
const
Schedule
&
sch
)
{
// NOLINT(*)
for
(
Stage
stage
:
sch
->
stages
)
{
if
(
stage
->
attach_type
==
kScanUpdate
)
{
const
Stage
&
parent
=
stage
->
attach_stage
;
stage
->
attach_ivar
=
parent
->
leaf_iter_vars
[
parent
->
leaf_iter_vars
.
size
()
-
1
];
}
}
}
void
InjectInline
(
ScheduleNode
*
sch
)
{
sch
->
InvalidateCache
();
std
::
vector
<
Expr
>
new_body
(
sch
->
stages
.
size
());
...
...
@@ -262,9 +237,8 @@ void InjectInline(ScheduleNode* sch) {
}
void
Schedule
::
normalize
()
{
RebaseNonZeroMinLoop
(
*
this
);
SetScanAttach
(
*
this
);
InjectInline
(
operator
->
());
RebaseNonZeroMinLoop
(
*
this
);
}
// Handle reduction factor.
...
...
tests/python/unittest/test_schedule_bound_inference.py
View file @
ea9c1c59
...
...
@@ -148,7 +148,37 @@ def test_bound_nest_group():
assert
bounds
[
x1
.
op
.
axis
[
0
]]
.
extent
.
value
==
1
assert
bounds
[
x1
.
op
.
axis
[
1
]]
.
extent
==
n
def
test_bound_nest_thread
():
m
=
tvm
.
Var
(
'm'
)
A
=
tvm
.
placeholder
((
m
),
name
=
'A'
)
A1
=
tvm
.
compute
((
m
,),
lambda
i
:
A
[
i
],
name
=
'A1'
)
A2
=
tvm
.
compute
((
m
,),
lambda
i
:
A1
[
i
]
+
2
,
name
=
'A2'
)
A3
=
tvm
.
compute
((
m
,),
lambda
i
:
A2
[
i
]
+
3
,
name
=
'A3'
)
s
=
tvm
.
Schedule
(
A3
.
op
)
s
[
A2
]
.
set_scope
(
"shared"
)
s
[
A1
]
.
set_scope
(
"local"
)
block_x
=
tvm
.
thread_axis
(
"blockIdx.x"
)
thread_x
=
tvm
.
thread_axis
(
"threadIdx.x"
)
bx
,
tx
=
s
[
A3
]
.
split
(
A3
.
op
.
axis
[
0
],
factor
=
32
)
s
[
A3
]
.
bind
(
bx
,
block_x
)
s
[
A3
]
.
bind
(
tx
,
thread_x
)
s
[
A2
]
.
compute_at
(
s
[
A3
],
tx
)
_
,
xi
=
s
[
A2
]
.
split
(
A2
.
op
.
axis
[
0
],
nparts
=
1
)
s
[
A2
]
.
bind
(
xi
,
thread_x
)
s
[
A1
]
.
compute_at
(
s
[
A3
],
tx
)
s
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
(
bounds
[
A1
.
op
.
axis
[
0
]]
.
extent
.
value
==
1
)
assert
(
bounds
[
A2
.
op
.
axis
[
0
]]
.
extent
.
value
==
32
)
assert
(
bounds
[
A3
.
op
.
axis
[
0
]]
.
extent
==
m
)
if
__name__
==
"__main__"
:
test_bound_nest_thread
()
test_bound1
()
test_bound_nest_group
()
test_bound_group_schedule
()
test_bound_scan
()
...
...
@@ -156,5 +186,4 @@ if __name__ == "__main__":
test_bound_rfactor
()
test_bound_blur
()
test_bound_conv1d
()
test_bound1
()
test_bound2
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment