Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
35af4c8b
Commit
35af4c8b
authored
Dec 30, 2019
by
Animesh Jain
Committed by
Yizhi Liu
Dec 30, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Convert Layout] Handling batch norm layout change. (#4600)
parent
55bd786f
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
79 additions
and
1 deletions
+79
-1
src/relay/op/nn/nn.cc
+29
-0
src/relay/pass/convert_layout.cc
+1
-1
tests/python/relay/test_pass_convert_op_layout.py
+49
-0
No files found.
src/relay/op/nn/nn.cc
View file @
35af4c8b
...
...
@@ -617,6 +617,34 @@ The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input
// batch_norm
TVM_REGISTER_NODE_TYPE
(
BatchNormAttrs
);
Array
<
Array
<
Layout
>>
BatchNormInferCorrectLayout
(
const
Attrs
&
attrs
,
const
Array
<
Layout
>&
new_in_layouts
,
const
Array
<
Layout
>&
old_in_layouts
,
const
Array
<
Array
<
IndexExpr
>>&
old_in_shapes
)
{
BatchNormAttrs
*
param
=
const_cast
<
BatchNormAttrs
*>
(
attrs
.
as
<
BatchNormAttrs
>
());
size_t
axis
=
param
->
axis
<
0
?
param
->
axis
+
old_in_shapes
[
0
].
size
()
:
static_cast
<
size_t
>
(
param
->
axis
);
Layout
ret
=
Layout
::
Undef
();
// If new_in_layouts are defined, this code tries to modify the layout.
if
(
new_in_layouts
.
defined
()
&&
old_in_layouts
.
defined
())
{
// Get the new C axis. Extract the dim in old layout. Find the index of that dim in next layout.
const
auto
&
bn_dim
=
old_in_layouts
[
0
][
axis
];
auto
new_index
=
new_in_layouts
[
0
].
IndexOf
(
bn_dim
);
param
->
axis
=
new_index
;
ret
=
new_in_layouts
[
0
];
}
else
if
(
old_in_layouts
.
defined
())
{
ret
=
old_in_layouts
[
0
];
}
// BN has 5 inputs, 3 outputs. The last 4 inputs and last 2 outputs have "C" layout.
Layout
c_layout
=
Layout
(
"C"
);
return
Array
<
Array
<
Layout
>>
{{
ret
,
c_layout
,
c_layout
,
c_layout
,
c_layout
},
{
ret
,
c_layout
,
c_layout
}};
}
bool
BatchNormRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
...
...
@@ -708,6 +736,7 @@ axis to be the last item in the input shape.
.
add_argument
(
"beta"
,
"Tensor"
,
"The beta offset factor."
)
.
add_argument
(
"moving_mean"
,
"Tensor"
,
"Running mean of input."
)
.
add_argument
(
"moving_var"
,
"Tensor"
,
"Running variance of input."
)
.
set_attr
<
FInferCorrectLayout
>
(
"FInferCorrectLayout"
,
BatchNormInferCorrectLayout
)
.
set_support_level
(
1
)
.
add_type_rel
(
"BatchNorm"
,
BatchNormRel
);
...
...
src/relay/pass/convert_layout.cc
View file @
35af4c8b
...
...
@@ -134,7 +134,7 @@ Pass ConvertLayout(const std::string& desired_layout) {
};
return
CreateFunctionPass
(
pass_func
,
3
,
"ConvertLayout"
,
{
ir
::
StringImm
::
make
(
"InferType"
),
ir
::
StringImm
::
make
(
"SimplifyInference"
),
{
ir
::
StringImm
::
make
(
"InferType"
),
ir
::
StringImm
::
make
(
"CanonicalizeOps"
)});
}
...
...
tests/python/relay/test_pass_convert_op_layout.py
View file @
35af4c8b
...
...
@@ -349,6 +349,54 @@ def test_scalar_convert_layout():
assert
analysis
.
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_conv_bn_convert_layout
():
""" Check that layout transforms are propagated through bn. """
def
before
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
56
,
56
,
64
))
weight
=
relay
.
var
(
"weight"
,
shape
=
(
3
,
3
,
64
,
64
))
y
=
relay
.
nn
.
conv2d
(
x
,
weight
,
channels
=
64
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
data_layout
=
'NHWC'
,
kernel_layout
=
'HWIO'
)
dtype
=
"float32"
beta
=
relay
.
var
(
"beta"
,
relay
.
TensorType
((
64
,),
dtype
))
gamma
=
relay
.
var
(
"gamma"
,
relay
.
TensorType
((
64
,),
dtype
))
moving_mean
=
relay
.
var
(
"moving_mean"
,
relay
.
TensorType
((
64
,),
dtype
))
moving_var
=
relay
.
var
(
"moving_var"
,
relay
.
TensorType
((
64
,),
dtype
))
y
=
relay
.
nn
.
batch_norm
(
y
,
gamma
,
beta
,
moving_mean
,
moving_var
,
axis
=
3
)
y
=
relay
.
nn
.
relu
(
y
[
0
])
y
=
relay
.
Function
(
analysis
.
free_vars
(
y
),
y
)
return
y
def
expected
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
56
,
56
,
64
))
w
=
relay
.
var
(
"weight"
,
shape
=
(
3
,
3
,
64
,
64
))
x
=
relay
.
layout_transform
(
x
,
'NHWC'
,
'NCHW'
)
w
=
relay
.
layout_transform
(
w
,
'HWIO'
,
'OIHW'
)
y
=
relay
.
nn
.
conv2d
(
x
,
w
,
channels
=
64
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
dtype
=
"float32"
beta
=
relay
.
var
(
"beta"
,
relay
.
TensorType
((
64
,),
dtype
))
gamma
=
relay
.
var
(
"gamma"
,
relay
.
TensorType
((
64
,),
dtype
))
moving_mean
=
relay
.
var
(
"moving_mean"
,
relay
.
TensorType
((
64
,),
dtype
))
moving_var
=
relay
.
var
(
"moving_var"
,
relay
.
TensorType
((
64
,),
dtype
))
y
=
relay
.
nn
.
batch_norm
(
y
,
gamma
,
beta
,
moving_mean
,
moving_var
,
axis
=
1
)
y
=
relay
.
nn
.
relu
(
y
[
0
])
y
=
relay
.
layout_transform
(
y
,
"NCHW"
,
"NHWC"
)
y
=
relay
.
Function
(
analysis
.
free_vars
(
y
),
y
)
return
y
a
=
before
()
a
=
run_opt_pass
(
a
,
transform
.
ConvertLayout
(
'NCHW'
))
b
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
assert
analysis
.
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
if
__name__
==
"__main__"
:
test_no_convert_layout
()
test_conv_convert_layout
()
...
...
@@ -358,3 +406,4 @@ if __name__ == "__main__":
test_bn_convert_layout
()
test_resnet_convert_layout
()
test_scalar_convert_layout
()
test_conv_bn_convert_layout
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment