Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
a3968975
Commit
a3968975
authored
Nov 14, 2018
by
Yizhi Liu
Committed by
Tianqi Chen
Nov 14, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Bugfix] Recover original layout when alter_layout function return None (#2101)
parent
629a293a
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
69 additions
and
12 deletions
+69
-12
nnvm/src/compiler/alter_op_layout.cc
+15
-10
nnvm/tests/python/compiler/test_alter_op_layout.py
+54
-2
No files found.
nnvm/src/compiler/alter_op_layout.cc
View file @
a3968975
...
...
@@ -46,7 +46,7 @@ Graph AlterOpLayout(const Graph& src) {
std
::
vector
<
std
::
vector
<
Layout
>
>
in_layouts_of_node
(
idx_graph
.
num_nodes
());
std
::
vector
<
std
::
vector
<
Layout
>
>
out_layouts_of_node
(
idx_graph
.
num_nodes
());
std
::
unordered_map
<
const
Node
*
,
uint32_t
>
new
_nodes
;
std
::
unordered_map
<
const
Node
*
,
uint32_t
>
unchanged
_nodes
;
if
(
src
.
HasAttr
(
"layout"
))
{
// record layouts so that LayoutTransform pass can fix layouts correctly,
...
...
@@ -56,10 +56,8 @@ Graph AlterOpLayout(const Graph& src) {
const
auto
&
layouts
=
src
.
GetAttr
<
std
::
vector
<
Layout
>
>
(
"layout"
);
for
(
uint32_t
nid
=
0
;
nid
<
idx_graph
.
num_nodes
();
++
nid
)
{
const
auto
&
inode
=
idx_graph
[
nid
];
if
(
falter_op_layout
.
count
(
inode
.
source
->
op
()))
{
// do not record input layouts of nodes that will be replaced.
continue
;
}
// record input layouts for all nodes,
// while replaced nodes will ignore the records here and have undefined input layouts.
std
::
vector
<
Layout
>
in_layout
;
for
(
const
auto
&
e
:
inode
.
inputs
)
{
in_layout
.
emplace_back
(
layouts
[
idx_graph
.
entry_id
(
e
)]);
...
...
@@ -80,7 +78,8 @@ Graph AlterOpLayout(const Graph& src) {
nnvm
::
compiler
::
FTVMAlterOpLayout
fn_alter_op_layout
=
falter_op_layout
.
get
(
n
->
op
(),
nullptr
);
if
(
fn_alter_op_layout
==
nullptr
)
{
new_nodes
[
n
.
get
()]
=
nid
;
// will restore the original input layouts later.
unchanged_nodes
[
n
.
get
()]
=
nid
;
return
false
;
}
...
...
@@ -106,7 +105,13 @@ Graph AlterOpLayout(const Graph& src) {
Symbol
op
;
bool
do_alter
=
fn_alter_op_layout
(
n
->
attrs
,
Symbol
::
CreateGroup
(
op_inputs
),
tensor_infos
,
&
op
);
if
(
do_alter
)
*
ret
=
op
.
outputs
;
if
(
do_alter
)
{
*
ret
=
op
.
outputs
;
}
else
{
// will restore the original input layouts later.
unchanged_nodes
[
n
.
get
()]
=
nid
;
}
return
do_alter
;
};
...
...
@@ -118,15 +123,15 @@ Graph AlterOpLayout(const Graph& src) {
std
::
vector
<
Layout
>
ret_layouts
(
ret_idx
.
num_node_entries
(),
Layout
::
Undef
());
for
(
uint32_t
nid
=
0
;
nid
<
ret_idx
.
num_nodes
();
++
nid
)
{
const
auto
&
inode
=
ret_idx
[
nid
];
if
(
new
_nodes
.
count
(
inode
.
source
))
{
if
(
unchanged
_nodes
.
count
(
inode
.
source
))
{
const
std
::
vector
<
Layout
>&
in_layouts
=
in_layouts_of_node
[
new
_nodes
[
inode
.
source
]];
in_layouts_of_node
[
unchanged
_nodes
[
inode
.
source
]];
for
(
uint32_t
i
=
0
;
i
<
inode
.
inputs
.
size
();
++
i
)
{
const
auto
&
e
=
inode
.
inputs
[
i
];
ret_layouts
[
ret_idx
.
entry_id
(
e
)]
=
in_layouts
[
i
];
}
const
std
::
vector
<
Layout
>&
out_layouts
=
out_layouts_of_node
[
new
_nodes
[
inode
.
source
]];
out_layouts_of_node
[
unchanged
_nodes
[
inode
.
source
]];
for
(
uint32_t
i
=
0
;
i
<
inode
.
source
->
num_outputs
();
++
i
)
{
ret_layouts
[
ret_idx
.
entry_id
(
nid
,
i
)]
=
out_layouts
[
i
];
}
...
...
nnvm/tests/python/compiler/test_alter_op_layout.py
View file @
a3968975
...
...
@@ -45,9 +45,61 @@ def test_alter_conv2d_layout():
# check copy layouts
for
node
in
[
"data"
,
"relu"
,
"flatten"
,
"softmax"
,
"conv_weight"
]:
assert
(
layouts
[
node
]
==
layouts_origin
[
node
])
assert
(
layouts
[
"conv_alter"
]
==
layouts_origin
[
"conv"
])
assert
layouts
[
node
]
==
layouts_origin
[
node
]
assert
layouts
[
"conv_alter"
]
==
layouts_origin
[
"conv"
]
def
test_consecutive_alter_layout
():
data
=
sym
.
Variable
(
"data"
,
shape
=
(
1
,
32
,
512
,
512
))
pool1
=
sym
.
global_avg_pool2d
(
data
,
name
=
"global_avg_pool2d_1"
,
layout
=
"NCHW"
)
pool2
=
sym
.
global_avg_pool2d
(
pool1
,
name
=
"global_avg_pool2d_2"
,
layout
=
"NCHW"
)
relu
=
sym
.
relu
(
pool2
,
name
=
"relu"
)
g
=
graph
.
create
(
relu
)
g
=
g
.
apply
(
"CorrectLayout"
)
g
=
graph_attr
.
set_dtype_inputs
(
g
,
"float32"
)
g
=
g
.
apply
([
"InferShape"
,
"InferType"
])
assert
g
.
json_attr
(
"layout"
)
==
[
'NCHW'
,
'NCHW'
,
'NCHW'
,
'NCHW'
]
@reg.register_alter_op_layout
(
"global_avg_pool2d"
,
level
=
100
)
def
alter_global_avg_pool2d_layout
(
attrs
,
inputs
,
tinfos
):
new_attrs
=
{
k
:
attrs
[
k
]
for
k
in
attrs
.
keys
()}
new_attrs
[
"layout"
]
=
"NCHW16c"
return
sym
.
global_avg_pool2d
(
inputs
[
0
],
**
new_attrs
)
g
=
g
.
apply
(
"AlterOpLayout"
)
# pool1 get replaced - output layout of pool1 is not recorded
# pool2 get replaced - input layout of pool2 is not recorded
# thus the second entry must be undefined - it can neither recover from pool1's output,
# nor from pool2's input.
assert
g
.
json_attr
(
"layout"
)
==
[
'NCHW'
,
'__undef__'
,
'NCHW'
,
'NCHW'
]
def
test_alter_func_return_none
():
data
=
sym
.
Variable
(
"data"
,
shape
=
(
1
,
32
,
512
,
512
))
pool1
=
sym
.
global_max_pool2d
(
data
,
name
=
"pool1"
,
layout
=
"NCHW"
)
pool2
=
sym
.
global_max_pool2d
(
pool1
,
name
=
"pool2"
,
layout
=
"NCHW"
)
relu
=
sym
.
relu
(
pool2
,
name
=
"relu"
)
g
=
graph
.
create
(
relu
)
g
=
g
.
apply
(
"CorrectLayout"
)
g
=
graph_attr
.
set_dtype_inputs
(
g
,
"float32"
)
g
=
g
.
apply
([
"InferShape"
,
"InferType"
])
assert
g
.
json_attr
(
"layout"
)
==
[
'NCHW'
,
'NCHW'
,
'NCHW'
,
'NCHW'
]
@reg.register_alter_op_layout
(
"global_max_pool2d"
,
level
=
100
)
def
alter_global_max_pool2d_layout
(
attrs
,
inputs
,
tinfos
):
return
None
g
=
g
.
apply
(
"AlterOpLayout"
)
# alter func return none, nothing get replaced,
# the layouts should remain the same
assert
g
.
json_attr
(
"layout"
)
==
[
'NCHW'
,
'NCHW'
,
'NCHW'
,
'NCHW'
]
if
__name__
==
"__main__"
:
test_alter_conv2d_layout
()
test_consecutive_alter_layout
()
test_alter_func_return_none
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment