Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
267f0294
Commit
267f0294
authored
Apr 01, 2018
by
kun-zh
Committed by
Tianqi Chen
Mar 31, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix an issue in ReplaceDataFlow for issue 1043 (#1062)
parent
d39ac773
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
4 deletions
+32
-4
src/schedule/schedule_dataflow_rewrite.cc
+18
-4
tests/python/unittest/test_pass_storage_rewrite.py
+14
-0
No files found.
src/schedule/schedule_dataflow_rewrite.cc
View file @
267f0294
...
...
@@ -57,13 +57,21 @@ Expr InjectPredicate(const Array<Expr>& predicates,
// Replace data flow appears in all stages given the tensor change.
// Also update vmap if subsequent dataflow need to be replaced.
// Need to keep an update to the date transitive closure property on the vmap by a reverse map.
void
ReplaceDataFlow
(
const
Array
<
Stage
>&
stages
,
std
::
unordered_map
<
Tensor
,
Tensor
>*
vmap
)
{
std
::
unordered_map
<
Tensor
,
Tensor
>*
vmap
,
std
::
unordered_map
<
Tensor
,
Tensor
>*
rvmap
)
{
for
(
Stage
s
:
stages
)
{
Operation
op
=
s
->
op
->
ReplaceInputs
(
s
->
op
,
*
vmap
);
if
(
!
op
.
same_as
(
s
->
op
))
{
for
(
int
i
=
0
;
i
<
op
->
num_outputs
();
++
i
)
{
auto
it
=
rvmap
->
find
(
s
->
op
.
output
(
i
));
if
(
it
!=
rvmap
->
end
())
{
(
*
vmap
)[
it
->
second
]
=
op
.
output
(
i
);
}
else
{
(
*
vmap
)[
s
->
op
.
output
(
i
)]
=
op
.
output
(
i
);
(
*
rvmap
)[
op
.
output
(
i
)]
=
s
->
op
.
output
(
i
);
}
}
s
->
op
=
op
;
}
...
...
@@ -91,6 +99,7 @@ Tensor Schedule::cache_read(const Tensor& tensor,
vsub
[
sugar_tensor
]
=
cache
;
std
::
unordered_map
<
Tensor
,
Tensor
>
vmap
;
std
::
unordered_map
<
Tensor
,
Tensor
>
rvmap
;
for
(
Operation
op
:
readers
)
{
Stage
s
=
operator
[](
op
);
Operation
repl_op
=
s
->
op
->
ReplaceInputs
(
s
->
op
,
vsub
);
...
...
@@ -98,9 +107,10 @@ Tensor Schedule::cache_read(const Tensor& tensor,
<<
"Cannot find "
<<
tensor
<<
" in the inputs of "
<<
s
->
op
;
vmap
[
s
->
op
.
output
(
0
)]
=
repl_op
.
output
(
0
);
rvmap
[
repl_op
.
output
(
0
)]
=
s
->
op
.
output
(
0
);
s
->
op
=
repl_op
;
}
ReplaceDataFlow
((
*
this
)
->
stages
,
&
vmap
);
ReplaceDataFlow
((
*
this
)
->
stages
,
&
vmap
,
&
rvmap
);
ArrayNode
*
stages
=
(
*
this
)
->
stages
.
CopyOnWrite
();
Stage
op_stage
=
operator
[](
tensor
->
op
);
size_t
pos
=
FindNodeRef
(
stages
,
op_stage
);
...
...
@@ -197,8 +207,10 @@ Tensor CacheWriteWithReLayout(Schedule sch,
{
cache_tensor
(
args
)});
// The replace of the dataflow
std
::
unordered_map
<
Tensor
,
Tensor
>
vmap
;
std
::
unordered_map
<
Tensor
,
Tensor
>
rvmap
;
vmap
[
orig_stage
->
op
.
output
(
0
)]
=
orig_new_op
.
output
(
0
);
ReplaceDataFlow
(
sch
->
stages
,
&
vmap
);
rvmap
[
orig_new_op
.
output
(
0
)]
=
orig_stage
->
op
.
output
(
0
);
ReplaceDataFlow
(
sch
->
stages
,
&
vmap
,
&
rvmap
);
// mutate orig stage
orig_stage
->
op
=
orig_new_op
;
orig_stage
->
all_iter_vars
=
orig_stage
->
op
->
root_iter_vars
();
...
...
@@ -583,10 +595,12 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
},
reduce_stage
->
op
->
name
+
".repl"
);
std
::
unordered_map
<
Tensor
,
Tensor
>
vmap
;
std
::
unordered_map
<
Tensor
,
Tensor
>
rvmap
;
for
(
int
idx
=
0
;
idx
<
size
;
++
idx
)
{
vmap
[
old_tensors
[
idx
]]
=
repl_tensors
[
idx
];
rvmap
[
repl_tensors
[
idx
]]
=
old_tensors
[
idx
];
}
ReplaceDataFlow
((
*
this
)
->
stages
,
&
vmap
);
ReplaceDataFlow
((
*
this
)
->
stages
,
&
vmap
,
&
rvmap
);
// revamp the reduction stage.
reduce_stage
->
op
=
repl_tensors
[
0
]
->
op
;
reduce_stage
->
all_iter_vars
=
repl_tensors
[
0
]
->
op
->
root_iter_vars
();
...
...
tests/python/unittest/test_pass_storage_rewrite.py
View file @
267f0294
...
...
@@ -442,6 +442,19 @@ def test_reuse_small_buffer():
tvm
.
ir_pass
.
PostOrderVisit
(
body
,
verify
)
assert
num_alloc
[
0
]
==
1
def
test_replace_dataflow
():
shape
=
(
255
,)
A
=
tvm
.
placeholder
(
shape
,
name
=
"A"
)
B
=
tvm
.
compute
(
shape
,
lambda
i
:
A
[
i
]
+
A
[
i
],
name
=
"B"
)
C
=
tvm
.
compute
(
shape
,
lambda
i
:
A
[
i
]
+
B
[
i
],
name
=
"C"
)
D
=
tvm
.
compute
(
shape
,
lambda
i
:
A
[
i
]
+
C
[
i
],
name
=
"D"
)
E
=
tvm
.
compute
(
shape
,
lambda
i
:
A
[
i
]
+
D
[
i
],
name
=
"E"
)
s
=
tvm
.
create_schedule
(
E
.
op
)
s
.
cache_read
(
A
,
"local"
,
[
B
,
C
,
D
,
E
])
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
container
.
Map
)
if
__name__
==
"__main__"
:
test_alloc_seq
()
...
...
@@ -456,3 +469,4 @@ if __name__ == "__main__":
test_alloc_seq_type
()
test_alloc_seq_type2
()
test_reuse_small_buffer
()
test_replace_dataflow
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment