Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
a53d8d01
Commit
a53d8d01
authored
Apr 01, 2018
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] Enhance scale fold axis (#424)
parent
89c124bc
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
389 additions
and
87 deletions
+389
-87
nnvm/src/compiler/fold_scale_axis.cc
+318
-77
nnvm/src/pass/plan_memory.cc
+1
-1
nnvm/tests/python/compiler/test_fold_axis.py
+70
-9
No files found.
nnvm/src/compiler/fold_scale_axis.cc
View file @
a53d8d01
...
...
@@ -18,12 +18,10 @@ namespace compiler {
enum
FoldScaleKind
{
// No folding is applied
kNone
,
// The folding decision is pending
// The folding decision is pending
, we can fold on a state.
kPending
,
// The original operator that contains the scale.
kProvider
,
// Pass through the scale to parent/child to the first axis.
kPassTroughFirst
,
// The final conumer of axis scale using multiply
// Likely be a conv or dense operator.
kMulConsumer
,
...
...
@@ -31,21 +29,23 @@ enum FoldScaleKind {
kDivConsumer
};
// Input fold information
struct
FoldScaleInput
{
uint32_t
index
;
int
axis
;
};
// The entry of folding chains on which
// we should perform folding on
struct
FoldChainEntry
{
struct
FoldChainInfo
{
// Entry kind
FoldScaleKind
kind
{
kNone
};
// The output axis to be folded
int
axis
{
0
};
// Source node in the fold chain
int
source
{
0
};
};
// The entry of folding chains on which
// we should perform folding on
struct
FoldChainEntry
{
// Fold information
FoldChainInfo
info
;
// Number of outgoing fork count
// in forward propagation.
int
fork_count
{
0
};
// Following field only used by provider.
// The input index
int
fold_input_index
{
1
};
...
...
@@ -55,12 +55,26 @@ struct FoldChainEntry {
// Try to pass axis scaling to backward,
// Given that we we know the status of current fold axis.
// return whether the forward signal is consumed.
using
FScaleAxisBackward
=
std
::
function
<
FoldScaleKind
(
const
NodeAttrs
&
attrs
,
int
axis
,
bool
(
const
NodeAttrs
&
attrs
,
const
std
::
vector
<
TShape
>&
in_shape
,
const
std
::
vector
<
TShape
>&
out_shape
,
const
FoldChainInfo
&
out_info
,
std
::
vector
<
FoldChainInfo
>*
in_info
)
>
;
// Try to pass axis scaling to forward,
// Given that we we know the status of one of its input to be pending
// also update other input info
// return whether the forward signal is consumed.
using
FScaleAxisForward
=
std
::
function
<
bool
(
const
NodeAttrs
&
attrs
,
const
std
::
vector
<
TShape
>&
in_shape
,
const
std
::
vector
<
TShape
>&
out_shape
,
std
::
vector
<
std
::
pair
<
uint32_t
,
int
>
>*
in_axis
)
>
;
std
::
vector
<
FoldChainInfo
>*
in_info
,
FoldChainInfo
*
out_info
)
>
;
// Detect if there is a scaling axis happening
bool
DetectScaleAxis
(
const
IndexedGraph
&
idx
,
...
...
@@ -99,15 +113,19 @@ bool DetectScaleAxis(const IndexedGraph& idx,
}
else
{
return
false
;
}
e
.
axis
=
axis
.
first
;
e
.
kind
=
kPending
;
e
.
source
=
nid
;
e
.
info
.
axis
=
axis
.
first
;
e
.
info
.
kind
=
kPending
;
e
.
info
.
source
=
nid
;
e
.
fork_count
=
1
;
// In the backward message passing
// We need to eagerly pass it to the input
// In the forward message passing
// we will "pull" the message from input.
if
(
!
is_forward
)
{
// pass message to another input
FoldChainEntry
&
enext
=
(
*
chain
)[
b
.
node_id
];
enext
.
axis
=
e
.
axis
;
enext
.
kind
=
kPending
;
enext
.
source
=
nid
;
enext
.
info
.
axis
=
e
.
info
.
axis
;
enext
.
info
.
kind
=
kPending
;
enext
.
info
.
source
=
nid
;
}
return
true
;
}
...
...
@@ -119,12 +137,16 @@ Graph FoldScaleAxis(Graph src) {
// Operator pattern
static
auto
&
fbackward
=
nnvm
::
Op
::
GetAttr
<
FScaleAxisBackward
>
(
"FScaleAxisBackward"
);
static
auto
&
fforward
=
nnvm
::
Op
::
GetAttr
<
FScaleAxisForward
>
(
"FScaleAxisForward"
);
const
IndexedGraph
&
idx
=
src
.
indexed_graph
();
const
ShapeVector
&
shape_vec
=
src
.
GetAttr
<
ShapeVector
>
(
"shape"
);
std
::
vector
<
uint32_t
>
ref_count
=
GetNodeRefCounts
(
idx
);
std
::
vector
<
FoldChainEntry
>
bwd_chain
(
idx
.
num_nodes
());
std
::
vector
<
FoldChainEntry
>
fwd_chain
(
idx
.
num_nodes
());
// shape hint for the inference.
std
::
vector
<
TShape
>
in_shape
,
out_shape
;
// perform backward folding.
for
(
uint32_t
i
=
idx
.
num_nodes
();
i
!=
0
;
--
i
)
{
uint32_t
nid
=
i
-
1
;
...
...
@@ -132,9 +154,10 @@ Graph FoldScaleAxis(Graph src) {
if
(
inode
.
source
->
is_variable
())
continue
;
if
(
DetectScaleAxis
(
idx
,
nid
,
shape_vec
,
ref_count
,
false
,
&
bwd_chain
))
continue
;
if
(
bwd_chain
[
nid
].
kind
!=
kPending
)
continue
;
if
(
bwd_chain
[
nid
].
info
.
kind
!=
kPending
)
continue
;
// if referred by multiple node, cannot do propagation
if
(
ref_count
[
nid
]
!=
1
||
!
fbackward
.
count
(
inode
.
source
->
op
()))
{
bwd_chain
[
nid
].
kind
=
kNone
;
continue
;
bwd_chain
[
nid
].
info
.
kind
=
kNone
;
continue
;
}
// get input shape and output shape.
in_shape
.
clear
();
out_shape
.
clear
();
...
...
@@ -144,58 +167,151 @@ Graph FoldScaleAxis(Graph src) {
for
(
uint32_t
i
=
0
;
i
<
inode
.
source
->
num_outputs
();
++
i
)
{
out_shape
.
push_back
(
shape_vec
[
idx
.
entry_id
(
nid
,
i
)]);
}
std
::
vector
<
std
::
pair
<
uint32_t
,
int
>
>
in_axis
;
FoldScaleKind
kind
=
fbackward
[
inode
.
source
->
op
()](
inode
.
source
->
attrs
,
bwd_chain
[
nid
].
axis
,
in_shape
,
out_shape
,
&
in_axis
);
bwd_chain
[
nid
].
kind
=
kind
;
if
(
kind
==
kNone
)
continue
;
CHECK_GE
(
in_axis
.
size
(),
1U
);
CHECK
(
kind
==
kPassTroughFirst
||
kind
==
kMulConsumer
);
std
::
vector
<
FoldChainInfo
>
in_info
(
in_shape
.
size
(),
FoldChainInfo
());
bool
consumed
=
fbackward
[
inode
.
source
->
op
()](
inode
.
source
->
attrs
,
in_shape
,
out_shape
,
bwd_chain
[
nid
].
info
,
&
in_info
);
CHECK_EQ
(
in_info
.
size
(),
in_shape
.
size
());
// propagate back.
bool
can_prop
=
true
;
for
(
size_t
i
=
0
;
i
<
in_
axis
.
size
();
++
i
)
{
const
IndexedGraph
::
NodeEntry
&
e
=
inode
.
inputs
[
i
n_axis
[
0
].
first
];
for
(
size_t
i
=
0
;
i
<
in_
info
.
size
();
++
i
)
{
const
IndexedGraph
::
NodeEntry
&
e
=
inode
.
inputs
[
i
];
if
(
ref_count
[
e
.
node_id
]
!=
1
||
idx
[
e
.
node_id
].
source
->
num_outputs
()
!=
1
)
{
can_prop
=
false
;
break
;
}
}
if
(
!
can_prop
)
continue
;
for
(
size_t
i
=
0
;
i
<
in_axis
.
size
();
++
i
)
{
const
IndexedGraph
::
NodeEntry
&
e
=
inode
.
inputs
[
in_axis
[
i
].
first
];
if
(
kind
==
kPassTroughFirst
&&
i
==
0
)
{
bwd_chain
[
e
.
node_id
].
kind
=
kPending
;
for
(
size_t
i
=
0
;
i
<
in_info
.
size
();
++
i
)
{
const
IndexedGraph
::
NodeEntry
&
e
=
inode
.
inputs
[
i
];
bwd_chain
[
e
.
node_id
].
info
=
in_info
[
i
];
}
// mark consumed by making the source as provider.
if
(
consumed
)
{
bwd_chain
[
bwd_chain
[
nid
].
info
.
source
].
info
.
kind
=
kProvider
;
}
}
// perform forward folding.
for
(
uint32_t
nid
=
0
;
nid
<
idx
.
num_nodes
();
++
nid
)
{
const
auto
&
inode
=
idx
[
nid
];
if
(
inode
.
source
->
is_variable
())
continue
;
// skip scales that are already folded in backward.
if
(
bwd_chain
[
nid
].
info
.
kind
==
kProvider
)
continue
;
if
(
DetectScaleAxis
(
idx
,
nid
,
shape_vec
,
ref_count
,
true
,
&
fwd_chain
))
continue
;
if
(
inode
.
source
->
num_outputs
()
!=
1
)
continue
;
// Do state update
// get input shape and output shape.
std
::
vector
<
FoldChainInfo
>
in_info
;
FoldChainInfo
out_info
;
int
num_inpending
=
0
;
in_shape
.
clear
();
out_shape
.
clear
();
for
(
const
IndexedGraph
::
NodeEntry
&
e
:
inode
.
inputs
)
{
in_shape
.
push_back
(
shape_vec
[
idx
.
entry_id
(
e
)]);
// input information
in_info
.
push_back
(
fwd_chain
[
e
.
node_id
].
info
);
if
(
fwd_chain
[
e
.
node_id
].
info
.
kind
==
kPending
)
{
++
num_inpending
;
}
}
for
(
uint32_t
i
=
0
;
i
<
inode
.
source
->
num_outputs
();
++
i
)
{
out_shape
.
push_back
(
shape_vec
[
idx
.
entry_id
(
nid
,
i
)]);
}
if
(
num_inpending
!=
1
||
!
fforward
.
count
(
inode
.
source
->
op
()))
continue
;
bool
consumed
=
fforward
[
inode
.
source
->
op
()](
inode
.
source
->
attrs
,
in_shape
,
out_shape
,
&
in_info
,
&
out_info
);
// update input info
for
(
size_t
i
=
0
;
i
<
in_info
.
size
();
++
i
)
{
fwd_chain
[
inode
.
inputs
[
i
].
node_id
].
info
=
in_info
[
i
];
}
if
(
consumed
)
{
fwd_chain
[
nid
].
info
=
out_info
;
for
(
size_t
i
=
0
;
i
<
in_info
.
size
();
++
i
)
{
if
(
in_info
[
i
].
kind
==
kPending
)
{
if
(
--
fwd_chain
[
in_info
[
i
].
source
].
fork_count
==
0
)
{
fwd_chain
[
in_info
[
i
].
source
].
info
.
kind
=
kProvider
;
}
}
}
}
else
{
bwd_chain
[
nid
].
kind
=
kNone
;
bwd_chain
[
e
.
node_id
].
kind
=
kMulConsumer
;
// can propagate condition
if
(
inode
.
source
->
num_outputs
()
==
1
)
{
fwd_chain
[
nid
].
info
=
out_info
;
if
(
out_info
.
kind
==
kPending
)
{
// When there is multiple reference to input
// every path have to be consumed
fwd_chain
[
out_info
.
source
].
fork_count
+=
ref_count
[
nid
]
-
1
;
}
bwd_chain
[
e
.
node_id
].
axis
=
in_axis
[
i
].
second
;
bwd_chain
[
e
.
node_id
].
source
=
bwd_chain
[
nid
].
source
;
}
if
(
kind
==
kMulConsumer
)
{
bwd_chain
[
bwd_chain
[
nid
].
source
].
kind
=
kProvider
;
}
}
auto
transform
=
[
&
](
uint32_t
nid
,
const
NodePtr
&
n
,
std
::
vector
<
NodeEntry
>*
ret
)
{
NodeEntry
rvalue
=
NodeEntry
{
n
,
0
,
0
};
{
// Backward chain
const
FoldChainEntry
&
e
=
bwd_chain
[
nid
];
if
(
e
.
kind
==
kMulConsumer
&&
bwd_chain
[
e
.
source
].
kind
==
kProvider
)
{
const
FoldChainEntry
&
se
=
bwd_chain
[
e
.
source
];
if
(
e
.
info
.
kind
==
kMulConsumer
&&
bwd_chain
[
e
.
info
.
source
].
info
.
kind
==
kProvider
)
{
const
FoldChainEntry
&
se
=
bwd_chain
[
e
.
info
.
source
];
CHECK_EQ
(
n
->
num_outputs
(),
1
);
NodeEntry
scale
=
ExpandBiasToMatchAxis
(
se
.
scale_entry
,
shape_vec
[
idx
.
entry_id
(
nid
,
0
)].
ndim
(),
shape_vec
[
idx
.
entry_id
(
se
.
scale_entry
)].
ndim
(),
e
.
axis
);
*
ret
=
{
MakeNode
(
"broadcast_mul"
,
n
->
attrs
.
name
+
"_sc"
,
{
NodeEntry
{
n
,
0
,
0
},
scale
})};
return
true
;
}
else
if
(
e
.
kind
==
kProvider
)
{
*
ret
=
{
n
->
inputs
[
e
.
fold_input_index
]};
return
true
;
}
else
{
e
.
info
.
axis
);
rvalue
=
MakeNode
(
"broadcast_mul"
,
n
->
attrs
.
name
+
"_sc"
,
{
rvalue
,
scale
});
}
else
if
(
e
.
info
.
kind
==
kProvider
)
{
rvalue
=
n
->
inputs
[
e
.
fold_input_index
];
}
}
// Note that the value might get transformed twice if it
// folds value from both fwd and backward chain.
{
// forward chain
const
FoldChainEntry
&
e
=
fwd_chain
[
nid
];
if
(
e
.
info
.
kind
==
kMulConsumer
&&
fwd_chain
[
e
.
info
.
source
].
info
.
kind
==
kProvider
)
{
const
FoldChainEntry
&
se
=
fwd_chain
[
e
.
info
.
source
];
CHECK_EQ
(
n
->
num_outputs
(),
1
);
NodeEntry
scale
=
ExpandBiasToMatchAxis
(
se
.
scale_entry
,
shape_vec
[
idx
.
entry_id
(
nid
,
0
)].
ndim
(),
shape_vec
[
idx
.
entry_id
(
se
.
scale_entry
)].
ndim
(),
e
.
info
.
axis
);
rvalue
=
MakeNode
(
"broadcast_mul"
,
n
->
attrs
.
name
+
"_sc"
,
{
rvalue
,
scale
});
}
else
if
(
e
.
info
.
kind
==
kDivConsumer
&&
fwd_chain
[
e
.
info
.
source
].
info
.
kind
==
kProvider
)
{
const
FoldChainEntry
&
se
=
fwd_chain
[
e
.
info
.
source
];
CHECK_EQ
(
n
->
num_outputs
(),
1
);
NodeEntry
scale
=
ExpandBiasToMatchAxis
(
se
.
scale_entry
,
shape_vec
[
idx
.
entry_id
(
nid
,
0
)].
ndim
(),
shape_vec
[
idx
.
entry_id
(
se
.
scale_entry
)].
ndim
(),
e
.
info
.
axis
);
rvalue
=
MakeNode
(
"broadcast_div"
,
n
->
attrs
.
name
+
"_sc"
,
{
rvalue
,
scale
});
}
else
if
(
e
.
info
.
kind
==
kProvider
)
{
rvalue
=
n
->
inputs
[
e
.
fold_input_index
];
}
}
if
(
rvalue
.
node
==
n
)
{
return
false
;
}
else
{
*
ret
=
{
rvalue
};
return
true
;
}
};
return
GraphTransform
(
src
,
transform
);
...
...
@@ -205,14 +321,24 @@ NNVM_REGISTER_PASS(FoldScaleAxis)
.
set_body
(
FoldScaleAxis
);
// property registration.
FoldScaleKind
ReluScaleAxisBackward
(
bool
ReluScaleAxisBackward
(
const
NodeAttrs
&
attrs
,
const
std
::
vector
<
TShape
>&
in_shape
,
const
std
::
vector
<
TShape
>&
out_shape
,
const
FoldChainInfo
&
out_info
,
std
::
vector
<
FoldChainInfo
>*
in_axis
)
{
(
*
in_axis
)[
0
]
=
out_info
;
return
false
;
}
bool
ReluScaleAxisForward
(
const
NodeAttrs
&
attrs
,
int
axis
,
const
std
::
vector
<
TShape
>&
in_shape
,
const
std
::
vector
<
TShape
>&
out_shape
,
std
::
vector
<
std
::
pair
<
uint32_t
,
int
>
>*
in_axis
)
{
in_axis
->
emplace_back
(
0
,
axis
);
return
kPassTroughFirst
;
std
::
vector
<
FoldChainInfo
>*
in_info
,
FoldChainInfo
*
out_info
)
{
*
out_info
=
(
*
in_info
)[
0
];
return
false
;
}
NNVM_REGISTER_OP
(
relu
)
...
...
@@ -221,21 +347,102 @@ NNVM_REGISTER_OP(relu)
NNVM_REGISTER_OP
(
leaky_relu
)
.
set_attr
<
FScaleAxisBackward
>
(
"FScaleAxisBackward"
,
ReluScaleAxisBackward
);
FoldScaleKind
BroadcastAddSubScaleAxisBackward
(
NNVM_REGISTER_OP
(
relu
)
.
set_attr
<
FScaleAxisForward
>
(
"FScaleAxisForward"
,
ReluScaleAxisForward
);
NNVM_REGISTER_OP
(
leaky_relu
)
.
set_attr
<
FScaleAxisForward
>
(
"FScaleAxisForward"
,
ReluScaleAxisForward
);
// property registration.
bool
Pool2DBackward
(
const
NodeAttrs
&
attrs
,
const
std
::
vector
<
TShape
>&
in_shape
,
const
std
::
vector
<
TShape
>&
out_shape
,
const
FoldChainInfo
&
out_info
,
std
::
vector
<
FoldChainInfo
>*
in_axis
)
{
using
top
::
Pool2DParam
;
const
Pool2DParam
&
param
=
nnvm
::
get
<
Pool2DParam
>
(
attrs
.
parsed
);
if
(
out_info
.
axis
==
1
&&
param
.
layout
==
top
::
kNCHW
)
{
(
*
in_axis
)[
0
]
=
out_info
;
}
return
false
;
}
bool
Pool2DForward
(
const
NodeAttrs
&
attrs
,
int
axis
,
const
std
::
vector
<
TShape
>&
in_shape
,
const
std
::
vector
<
TShape
>&
out_shape
,
std
::
vector
<
std
::
pair
<
uint32_t
,
int
>
>*
in_axis
)
{
std
::
vector
<
FoldChainInfo
>*
in_info
,
FoldChainInfo
*
out_info
)
{
using
top
::
Pool2DParam
;
const
Pool2DParam
&
param
=
nnvm
::
get
<
Pool2DParam
>
(
attrs
.
parsed
);
if
((
*
in_info
)[
0
].
axis
==
1
&&
param
.
layout
==
top
::
kNCHW
)
{
*
out_info
=
(
*
in_info
)[
0
];
}
return
false
;
}
NNVM_REGISTER_OP
(
max_pool2d
)
.
set_attr
<
FScaleAxisBackward
>
(
"FScaleAxisBackward"
,
Pool2DBackward
);
NNVM_REGISTER_OP
(
avg_pool2d
)
.
set_attr
<
FScaleAxisBackward
>
(
"FScaleAxisBackward"
,
Pool2DBackward
);
NNVM_REGISTER_OP
(
max_pool2d
)
.
set_attr
<
FScaleAxisForward
>
(
"FScaleAxisForward"
,
Pool2DForward
);
NNVM_REGISTER_OP
(
avg_pool2d
)
.
set_attr
<
FScaleAxisForward
>
(
"FScaleAxisForward"
,
Pool2DForward
);
bool
BroadcastAddSubScaleAxisBackward
(
const
NodeAttrs
&
attrs
,
const
std
::
vector
<
TShape
>&
in_shape
,
const
std
::
vector
<
TShape
>&
out_shape
,
const
FoldChainInfo
&
out_info
,
std
::
vector
<
FoldChainInfo
>*
in_axis
)
{
if
(
out_info
.
kind
!=
kPending
)
return
false
;
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
std
::
pair
<
int
,
int
>
m
=
MatchBroadcast1DAxis
(
out_shape
[
0
],
in_shape
[
i
]);
if
(
m
.
second
!=
-
1
&&
in_shape
[
1
-
i
]
==
out_shape
[
0
])
{
in_axis
->
emplace_back
(
i
,
axis
);
in_axis
->
emplace_back
(
1
-
i
,
m
.
second
);
return
kPassTroughFirst
;
std
::
pair
<
int
,
int
>
m
=
MatchBroadcast1DAxis
(
out_shape
[
0
],
in_shape
[
1
-
i
]);
if
(
m
.
second
!=
-
1
&&
in_shape
[
i
]
==
out_shape
[
0
]
&&
m
.
first
==
out_info
.
axis
)
{
(
*
in_axis
)[
i
].
kind
=
kPending
;
(
*
in_axis
)[
i
].
axis
=
out_info
.
axis
;
(
*
in_axis
)[
i
].
source
=
out_info
.
source
;
(
*
in_axis
)[
1
-
i
].
kind
=
kMulConsumer
;
(
*
in_axis
)[
1
-
i
].
axis
=
m
.
second
;
(
*
in_axis
)[
1
-
i
].
source
=
out_info
.
source
;
return
false
;
}
}
return
kNone
;
return
false
;
}
bool
BroadcastAddSubScaleAxisForward
(
const
NodeAttrs
&
attrs
,
const
std
::
vector
<
TShape
>&
in_shape
,
const
std
::
vector
<
TShape
>&
out_shape
,
std
::
vector
<
FoldChainInfo
>*
in_info
,
FoldChainInfo
*
out_info
)
{
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
if
((
*
in_info
)[
i
].
kind
==
kPending
)
{
std
::
pair
<
int
,
int
>
m
=
MatchBroadcast1DAxis
(
out_shape
[
0
],
in_shape
[
1
-
i
]);
if
(
m
.
second
!=
-
1
&&
in_shape
[
i
]
==
out_shape
[
0
]
&&
m
.
first
==
(
*
in_info
)[
i
].
axis
)
{
out_info
->
kind
=
kPending
;
out_info
->
axis
=
m
.
first
;
out_info
->
source
=
(
*
in_info
)[
i
].
source
;
(
*
in_info
)[
1
-
i
].
kind
=
kDivConsumer
;
(
*
in_info
)[
1
-
i
].
axis
=
m
.
second
;
(
*
in_info
)[
1
-
i
].
source
=
(
*
in_info
)[
i
].
source
;
return
false
;
}
}
}
return
false
;
}
NNVM_REGISTER_OP
(
broadcast_add
)
...
...
@@ -244,28 +451,62 @@ NNVM_REGISTER_OP(broadcast_add)
NNVM_REGISTER_OP
(
broadcast_sub
)
.
set_attr
<
FScaleAxisBackward
>
(
"FScaleAxisBackward"
,
BroadcastAddSubScaleAxisBackward
);
FoldScaleKind
Conv2DScaleAxisBackward
(
NNVM_REGISTER_OP
(
broadcast_add
)
.
set_attr
<
FScaleAxisForward
>
(
"FScaleAxisForward"
,
BroadcastAddSubScaleAxisForward
);
NNVM_REGISTER_OP
(
broadcast_sub
)
.
set_attr
<
FScaleAxisForward
>
(
"FScaleAxisForward"
,
BroadcastAddSubScaleAxisForward
);
bool
Conv2DScaleAxisBackward
(
const
NodeAttrs
&
attrs
,
int
axis
,
const
std
::
vector
<
TShape
>&
in_shape
,
const
std
::
vector
<
TShape
>&
out_shape
,
std
::
vector
<
std
::
pair
<
uint32_t
,
int
>
>*
in_axis
)
{
const
FoldChainInfo
&
out_info
,
std
::
vector
<
FoldChainInfo
>*
in_axis
)
{
using
top
::
Conv2DParam
;
const
Conv2DParam
&
param
=
nnvm
::
get
<
Conv2DParam
>
(
attrs
.
parsed
);
if
(
out_info
.
kind
!=
kPending
)
return
false
;
// only optimize for nchw for now
if
(
param
.
layout
==
top
::
kNCHW
)
{
in_axis
->
emplace_back
(
1
,
0
);
if
(
param
.
layout
==
top
::
kNCHW
&&
out_info
.
axis
==
1
)
{
(
*
in_axis
)[
1
].
kind
=
kMulConsumer
;
(
*
in_axis
)[
1
].
axis
=
0
;
(
*
in_axis
)[
1
].
source
=
out_info
.
source
;
if
(
param
.
use_bias
)
{
in_axis
->
emplace_back
(
2
,
0
);
(
*
in_axis
)[
2
].
kind
=
kMulConsumer
;
(
*
in_axis
)[
2
].
axis
=
0
;
(
*
in_axis
)[
2
].
source
=
out_info
.
source
;
}
return
true
;
}
else
{
return
false
;
}
return
kMulConsumer
;
}
bool
Conv2DScaleAxisForward
(
const
NodeAttrs
&
attrs
,
const
std
::
vector
<
TShape
>&
in_shape
,
const
std
::
vector
<
TShape
>&
out_shape
,
std
::
vector
<
FoldChainInfo
>*
in_info
,
FoldChainInfo
*
out_info
)
{
using
top
::
Conv2DParam
;
const
Conv2DParam
&
param
=
nnvm
::
get
<
Conv2DParam
>
(
attrs
.
parsed
);
if
((
*
in_info
)[
0
].
kind
!=
kPending
)
return
false
;
// only optimize for nchw for now
if
(
param
.
layout
==
top
::
kNCHW
&&
(
*
in_info
)[
0
].
axis
==
1
)
{
(
*
in_info
)[
1
].
kind
=
kMulConsumer
;
(
*
in_info
)[
1
].
axis
=
1
;
(
*
in_info
)[
1
].
source
=
(
*
in_info
)[
0
].
source
;
return
true
;
}
else
{
return
kNon
e
;
return
fals
e
;
}
}
NNVM_REGISTER_OP
(
conv2d
)
.
set_attr
<
FScaleAxisBackward
>
(
"FScaleAxisBackward"
,
Conv2DScaleAxisBackward
);
NNVM_REGISTER_OP
(
conv2d
)
.
set_attr
<
FScaleAxisForward
>
(
"FScaleAxisForward"
,
Conv2DScaleAxisForward
);
}
// namespace compiler
}
// namespace nnvm
nnvm/src/pass/plan_memory.cc
View file @
a53d8d01
...
...
@@ -196,7 +196,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx,
if
(
taken
[
kv
.
first
]
==
false
&&
sid_out
==
GraphAllocator
::
kBadStorageID
&&
sid_in
>=
0
&&
(
storage_ref_count
[
sid_in
]
==
1
&&
!
ignore_all_inputs
||
identity
[
ipair
])
&&
(
(
storage_ref_count
[
sid_in
]
==
1
&&
!
ignore_all_inputs
)
||
identity
[
ipair
])
&&
entry_ref_count
[
eid_out
]
>
0
&&
shape_vec
[
eid_out
].
Size
()
==
shape_vec
[
eid_in
].
Size
()
&&
dtype_vec
[
eid_out
]
==
dtype_vec
[
eid_in
])
{
...
...
nnvm/tests/python/compiler/test_fold_axis.py
View file @
a53d8d01
"""Unittest cases for fold_axis"""
import
nnvm
import
nnvm.testing.resnet
import
numpy
as
np
from
nnvm
import
symbol
as
sym
from
nnvm.compiler
import
graph_util
,
graph_attr
def
test_fold_axis_conv
():
def
before
(
x
,
conv_weight
,
conv_bias
,
scale
,
channels
):
def
before
(
x
,
conv_weight
,
conv_bias
,
in_scale
,
out_scale
,
channels
):
x
=
x
*
sym
.
expand_dims
(
in_scale
,
axis
=
1
,
num_newaxis
=
2
)
y
=
sym
.
conv2d
(
x
,
conv_weight
,
conv_bias
,
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
name
=
"conv"
)
y
=
sym
.
relu
(
y
)
y
=
y
*
sym
.
expand_dims
(
scale
,
axis
=
1
,
num_newaxis
=
2
)
y
=
y
*
sym
.
expand_dims
(
out_
scale
,
axis
=
1
,
num_newaxis
=
2
)
return
y
def
expected
(
x
,
conv_weight
,
conv_bias
,
scale
,
channels
):
conv_weight
=
conv_weight
*
sym
.
expand_dims
(
scale
,
axis
=
1
,
num_newaxis
=
3
)
conv_bias
=
conv_bias
*
scale
def
expected
(
x
,
conv_weight
,
conv_bias
,
in_scale
,
out_scale
,
channels
):
conv_weight
=
conv_weight
*
sym
.
expand_dims
(
out_scale
,
axis
=
1
,
num_newaxis
=
3
)
conv_weight
=
conv_weight
*
sym
.
expand_dims
(
in_scale
,
axis
=
1
,
num_newaxis
=
2
)
conv_bias
=
conv_bias
*
out_scale
y
=
sym
.
conv2d
(
x
,
conv_weight
,
conv_bias
,
...
...
@@ -32,10 +36,11 @@ def test_fold_axis_conv():
x
=
sym
.
Variable
(
"x"
)
+
1
weight
=
sym
.
Variable
(
"weight"
)
bias
=
sym
.
Variable
(
"bias"
)
scale
=
sym
.
Variable
(
"scale"
)
y1
=
before
(
x
,
weight
,
bias
,
scale
,
channels
)
y2
=
expected
(
x
,
weight
,
bias
,
scale
,
channels
)
ishape
=
{
"x"
:
shape
,
"scale"
:
(
channels
,)}
in_scale
=
sym
.
Variable
(
"in_scale"
)
out_scale
=
sym
.
Variable
(
"out_scale"
)
y1
=
before
(
x
,
weight
,
bias
,
in_scale
,
out_scale
,
channels
)
y2
=
expected
(
x
,
weight
,
bias
,
in_scale
,
out_scale
,
channels
)
ishape
=
{
"x"
:
shape
,
"out_scale"
:
(
channels
,),
"in_scale"
:
(
shape
[
1
],)}
g1
=
nnvm
.
graph
.
create
(
y1
)
g2
=
nnvm
.
graph
.
create
(
y2
)
graph_attr
.
set_shape_inputs
(
g1
,
ishape
)
...
...
@@ -45,5 +50,61 @@ def test_fold_axis_conv():
check
((
2
,
4
,
10
,
10
),
2
)
def
test_fold_fail
():
def
before
(
x
,
scale
,
channels
):
y
=
sym
.
conv2d
(
x
,
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
name
=
"conv"
)
y
=
y
*
sym
.
expand_dims
(
scale
,
axis
=
1
,
num_newaxis
=
1
)
return
y
# Before simplify
def
check
(
shape
,
channels
):
x
=
sym
.
Variable
(
"x"
)
bias
=
sym
.
Variable
(
"bias"
)
scale
=
sym
.
Variable
(
"scale"
)
y1
=
before
(
x
,
scale
,
channels
)
ishape
=
{
"x"
:
shape
,
"scale"
:
(
channels
,),
"bias"
:
(
channels
,)}
g1
=
nnvm
.
graph
.
create
(
y1
)
graph_attr
.
set_shape_inputs
(
g1
,
ishape
)
g2
=
g1
.
apply
(
"InferShape"
)
.
apply
(
"FoldScaleAxis"
)
# assert graph equals as expected
graph_util
.
check_graph_equal
(
g1
,
g2
)
check
((
2
,
10
,
10
,
10
),
10
)
def
test_fold_resnet
():
batch_size
=
1
num_classes
=
1000
image_shape
=
(
3
,
224
,
224
)
data_shape
=
(
batch_size
,)
+
image_shape
net
,
params
=
nnvm
.
testing
.
resnet
.
get_workload
(
batch_size
=
1
,
image_shape
=
image_shape
)
ishape
=
{
"data"
:
data_shape
}
graph
=
nnvm
.
graph
.
create
(
net
)
data
=
np
.
random
.
uniform
(
size
=
data_shape
)
.
astype
(
"float32"
)
# Initial pass do shape type inference
shape
,
_
=
graph_util
.
infer_shape
(
graph
,
**
ishape
)
ishape
.
update
(
zip
(
graph
.
index
.
input_names
,
shape
))
def
run_prune
(
graph
,
params
,
opt_level
):
# Apply optimization
with
nnvm
.
compiler
.
build_config
(
opt_level
=
0
):
graph
=
nnvm
.
compiler
.
optimize
(
graph
,
ishape
)
graph
,
params
=
nnvm
.
compiler
.
build_module
.
precompute_prune
(
graph
,
params
)
params
[
"data"
]
=
data
return
nnvm
.
compiler
.
build_module
.
_run_graph
(
graph
,
params
)
x
=
run_prune
(
graph
,
params
,
0
)
y
=
run_prune
(
graph
,
params
,
3
)
np
.
testing
.
assert_allclose
(
y
[
0
]
.
asnumpy
(),
x
[
0
]
.
asnumpy
())
if
__name__
==
"__main__"
:
test_fold_resnet
()
test_fold_axis_conv
()
test_fold_fail
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment