Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
40bc10f3
Commit
40bc10f3
authored
Sep 21, 2017
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] SimplifyBatchNorm->SimplifyInference, remove dropout (#24)
parent
215693df
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
11 deletions
+20
-11
nnvm/python/nnvm/compiler/build_module.py
+4
-7
nnvm/python/nnvm/top/tensor.py
+5
-0
nnvm/src/compiler/simplify_inference.cc
+9
-3
nnvm/tests/python/compiler/test_simplify_inference.py
+2
-1
No files found.
nnvm/python/nnvm/compiler/build_module.py
View file @
40bc10f3
...
...
@@ -8,7 +8,7 @@ from .. import graph as _graph
from
..
import
runtime
OPT_PASS_LEVEL
=
{
"Simplify
BatchNorm
Inference"
:
2
,
"SimplifyInference"
:
2
,
"PrecomputePrune"
:
2
,
"OpFusion"
:
1
}
...
...
@@ -115,12 +115,9 @@ def optimize(graph, shape, dtype="float32"):
"""
# pylint: disable=unused-argument
cfg
=
BuildConfig
.
current
graph
=
graph_attr
.
set_shape_inputs
(
graph
,
shape
)
graph
=
graph
.
apply
(
"InferShape"
)
if
graph
.
json_attr
(
"shape_num_unknown_nodes"
):
raise
ValueError
(
"InferShape fails.."
)
if
cfg
.
opt_level
>=
OPT_PASS_LEVEL
[
"SimplifyBatchNormInference"
]:
graph
=
graph
.
apply
(
"SimplifyBatchNormInference"
)
if
cfg
.
opt_level
>=
OPT_PASS_LEVEL
[
"SimplifyInference"
]:
graph
=
graph_attr
.
set_shape_inputs
(
graph
,
shape
)
graph
=
graph
.
apply
([
"InferShape"
,
"SimplifyInference"
])
return
graph
...
...
nnvm/python/nnvm/top/tensor.py
View file @
40bc10f3
...
...
@@ -44,6 +44,11 @@ def _compute_binary(f):
_fschedule_broadcast
=
tvm
.
convert
(
_schedule_broadcast
)
# copy
reg
.
register_compute
(
"copy"
,
_compute_unary
(
topi
.
identity
))
reg
.
register_pattern
(
"copy"
,
OpPattern
.
ELEM_WISE
)
reg
.
register_schedule
(
"copy"
,
_fschedule_broadcast
)
# exp
reg
.
register_compute
(
"exp"
,
_compute_unary
(
topi
.
exp
))
reg
.
register_pattern
(
"exp"
,
OpPattern
.
ELEM_WISE
)
...
...
nnvm/src/compiler/simplify_
batch_norm
.cc
→
nnvm/src/compiler/simplify_
inference
.cc
View file @
40bc10f3
...
...
@@ -22,6 +22,7 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
nnvm
::
NodeEntry
moving_mean
,
nnvm
::
NodeEntry
moving_var
,
TShape
dshape
)
{
CHECK_NE
(
dshape
.
ndim
(),
0
);
CHECK
(
attrs
.
op
);
static
const
Op
*
bn_op
=
Op
::
Get
(
"batch_norm"
);
CHECK
(
attrs
.
op
==
bn_op
);
...
...
@@ -76,13 +77,14 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
return
{
out
,
undef
,
undef
};
}
Graph
Simplify
BatchNorm
Inference
(
nnvm
::
Graph
src
)
{
Graph
SimplifyInference
(
nnvm
::
Graph
src
)
{
// Get attributes from the graph
const
IndexedGraph
&
idx
=
src
.
indexed_graph
();
const
ShapeVector
&
shape_vec
=
src
.
GetAttr
<
ShapeVector
>
(
"shape"
);
auto
transform
=
[
&
](
uint32_t
nid
,
const
Node
*
n
,
std
::
vector
<
NodeEntry
>*
ret
)
{
if
(
n
->
is_variable
())
return
false
;
static
const
Op
*
bn_op
=
Op
::
Get
(
"batch_norm"
);
static
const
Op
*
dropout_op
=
Op
::
Get
(
"dropout"
);
if
(
n
->
op
()
==
bn_op
)
{
*
ret
=
BatchNormToInferUnpack
(
n
->
attrs
,
...
...
@@ -93,6 +95,10 @@ Graph SimplifyBatchNormInference(nnvm::Graph src) {
n
->
inputs
[
4
],
shape_vec
[
idx
.
entry_id
(
nid
,
0
)]);
return
true
;
}
else
if
(
n
->
op
()
==
dropout_op
)
{
NodeEntry
undef
=
MakeNode
(
"__undef__"
,
"undef"
,
{});
*
ret
=
{
n
->
inputs
[
0
],
undef
};
return
true
;
}
else
{
return
false
;
}
...
...
@@ -100,8 +106,8 @@ Graph SimplifyBatchNormInference(nnvm::Graph src) {
return
GraphTransform
(
src
,
transform
);
}
NNVM_REGISTER_PASS
(
Simplify
BatchNorm
Inference
)
.
set_body
(
Simplify
BatchNorm
Inference
);
NNVM_REGISTER_PASS
(
SimplifyInference
)
.
set_body
(
SimplifyInference
);
}
// namespace compiler
}
// namespace nnvm
nnvm/tests/python/compiler/test_simplify_
batchnorm
.py
→
nnvm/tests/python/compiler/test_simplify_
inference
.py
View file @
40bc10f3
...
...
@@ -30,12 +30,13 @@ def test_simplify_batchnorm():
for
i
in
range
(
nstep
):
y1
=
sym
.
batch_norm
(
y1
+
1
,
gamma
,
beta
,
moving_mean
,
moving_var
,
epsilon
=
eps
,
axis
=
axis
)
y1
=
sym
.
dropout
(
y1
)
y2
=
simple_bn
(
y2
+
1
,
gamma
,
beta
,
moving_mean
,
moving_var
,
epsilon
=
eps
,
axis
=
axis
,
shape
=
ishape
[
"x"
])
g
=
nnvm
.
graph
.
create
(
y1
)
g2
=
nnvm
.
graph
.
create
(
y2
)
graph_attr
.
set_shape_inputs
(
g
,
ishape
)
g1
=
g
.
apply
(
"InferShape"
)
.
apply
(
"Simplify
BatchNorm
Inference"
)
g1
=
g
.
apply
(
"InferShape"
)
.
apply
(
"SimplifyInference"
)
# Some prints for debug
# print(g1.ir())
# assert graph equals as expected
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment