Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
cf3f5bce
Commit
cf3f5bce
authored
Dec 02, 2018
by
Wuwei Lin
Committed by
Tianqi Chen
Dec 01, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][PASS] Memorize FoldScaleAxis backward transform result (#2214)
parent
1a9df7be
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
96 additions
and
8 deletions
+96
-8
src/relay/pass/fold_scale_axis.cc
+20
-8
tests/python/relay/test_pass_fold_scale_axis.py
+76
-0
No files found.
src/relay/pass/fold_scale_axis.cc
View file @
cf3f5bce
...
@@ -556,9 +556,7 @@ class BackwardTransformerNode :
...
@@ -556,9 +556,7 @@ class BackwardTransformerNode :
* \return The result of transformation.
* \return The result of transformation.
*/
*/
Expr
Transform
(
const
Expr
&
expr
,
AxesSet
axes
,
Expr
scale
)
{
Expr
Transform
(
const
Expr
&
expr
,
AxesSet
axes
,
Expr
scale
)
{
// NOTE: the result of Transform is not memoized.
// NOTE: the result of Transform is memoized.
// However, in the current rule, Transform will
// only be called to expr that is referred once.
if
(
const
CallNode
*
call_node
=
expr
.
as
<
CallNode
>
())
{
if
(
const
CallNode
*
call_node
=
expr
.
as
<
CallNode
>
())
{
return
Transform
(
call_node
,
axes
,
scale
);
return
Transform
(
call_node
,
axes
,
scale
);
}
else
{
}
else
{
...
@@ -572,7 +570,14 @@ class BackwardTransformerNode :
...
@@ -572,7 +570,14 @@ class BackwardTransformerNode :
* \return the result of the call Mutation.
* \return the result of the call Mutation.
*/
*/
Expr
NormalCallTransform
(
const
CallNode
*
call_node
)
{
Expr
NormalCallTransform
(
const
CallNode
*
call_node
)
{
return
ExprMutator
::
VisitExpr_
(
call_node
);
const
Call
call
=
GetRef
<
Call
>
(
call_node
);
const
auto
it
=
memo_
.
find
(
call
);
if
(
it
!=
memo_
.
end
())
{
return
it
->
second
;
}
Expr
new_expr
=
ExprMutator
::
VisitExpr_
(
call_node
);
memo_
[
call
]
=
new_expr
;
return
new_expr
;
}
}
/*!
/*!
* \brief Get the expected axes on expr.
* \brief Get the expected axes on expr.
...
@@ -620,10 +625,17 @@ Expr BackwardTransformerNode::Transform(
...
@@ -620,10 +625,17 @@ Expr BackwardTransformerNode::Transform(
Op
::
GetAttr
<
FBackwardTransform
>
(
"FScaleAxisBackwardTransform"
);
Op
::
GetAttr
<
FBackwardTransform
>
(
"FScaleAxisBackwardTransform"
);
auto
f
=
ftransform
.
get
(
call_node
->
op
,
nullptr
);
auto
f
=
ftransform
.
get
(
call_node
->
op
,
nullptr
);
if
(
f
!=
nullptr
)
{
if
(
f
!=
nullptr
)
{
return
f
(
GetRef
<
Call
>
(
call_node
),
const
Call
call
=
GetRef
<
Call
>
(
call_node
);
axes
,
const
auto
it
=
memo_
.
find
(
call
);
scale
,
if
(
it
!=
memo_
.
end
())
{
GetRef
<
BackwardTransformer
>
(
this
));
return
it
->
second
;
}
Expr
new_expr
=
f
(
GetRef
<
Call
>
(
call_node
),
axes
,
scale
,
GetRef
<
BackwardTransformer
>
(
this
));
memo_
[
call
]
=
new_expr
;
return
new_expr
;
}
else
{
}
else
{
CHECK
(
!
axes
.
defined
())
<<
"outstanding scale"
;
CHECK
(
!
axes
.
defined
())
<<
"outstanding scale"
;
return
NormalCallTransform
(
call_node
);
return
NormalCallTransform
(
call_node
);
...
...
tests/python/relay/test_pass_fold_scale_axis.py
View file @
cf3f5bce
...
@@ -268,6 +268,81 @@ def test_fold_bwd_dual_path():
...
@@ -268,6 +268,81 @@ def test_fold_bwd_dual_path():
check
((
2
,
4
,
10
,
10
),
8
)
check
((
2
,
4
,
10
,
10
),
8
)
def
test_fold_bwd_dual_consumer
():
def
before
(
x
,
conv_weight
,
out_bias
,
out_scale
,
channels
):
args
=
[
x
,
conv_weight
,
out_bias
,
out_scale
]
out_scale
=
relay
.
expand_dims
(
out_scale
,
axis
=
1
,
num_newaxis
=
2
)
y0
=
relay
.
nn
.
conv2d
(
x
,
conv_weight
,
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y0
=
relay
.
multiply
(
y0
,
out_scale
)
y0
=
relay
.
nn
.
relu
(
y0
)
y1
=
relay
.
nn
.
conv2d
(
y0
,
conv_weight
,
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y1
=
relay
.
multiply
(
y1
,
out_scale
)
y1
=
relay
.
nn
.
relu
(
y1
)
y2
=
relay
.
nn
.
conv2d
(
y0
,
conv_weight
,
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y2
=
relay
.
multiply
(
y2
,
out_scale
)
y2
=
relay
.
nn
.
relu
(
y2
)
y
=
relay
.
add
(
y1
,
y2
)
return
relay
.
Function
(
args
,
y
)
def
expected
(
x
,
conv_weight
,
out_bias
,
out_scale
,
channels
):
# use a fixed order of args so alpha equal check can pass
args
=
[
x
,
conv_weight
,
out_bias
,
out_scale
]
out_scale
=
relay
.
expand_dims
(
out_scale
,
axis
=
1
,
num_newaxis
=
2
)
def
fold_conv_weight
():
squeezed_scale
=
relay
.
squeeze
(
out_scale
,
axis
=
[
1
,
2
])
return
relay
.
multiply
(
conv_weight
,
relay
.
expand_dims
(
squeezed_scale
,
axis
=
1
,
num_newaxis
=
3
))
y0
=
relay
.
nn
.
conv2d
(
x
,
fold_conv_weight
(),
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y0
=
relay
.
nn
.
relu
(
y0
)
y1
=
relay
.
nn
.
conv2d
(
y0
,
fold_conv_weight
(),
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y1
=
relay
.
nn
.
relu
(
y1
)
y2
=
relay
.
nn
.
conv2d
(
y0
,
fold_conv_weight
(),
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y2
=
relay
.
nn
.
relu
(
y2
)
y
=
relay
.
add
(
y1
,
y2
)
return
relay
.
Function
(
args
,
y
)
def
check
(
shape
,
channels
):
x
=
relay
.
var
(
"x"
,
shape
=
shape
)
in_channels
=
shape
[
1
]
weight
=
relay
.
var
(
"weight"
)
out_bias
=
relay
.
var
(
"out_bias"
,
shape
=
(
channels
,))
out_scale
=
relay
.
var
(
"out_scale"
,
shape
=
(
channels
,))
y1
=
before
(
x
,
weight
,
out_bias
,
out_scale
,
channels
)
y1
=
relay
.
ir_pass
.
infer_type
(
y1
)
type_dict
=
{
x
.
name_hint
:
x
.
checked_type
for
x
in
y1
.
params
}
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
y1_folded
=
relay
.
ir_pass
.
backward_fold_scale_axis
(
y1
)
y1_expected
=
expected
(
x
,
weight
,
out_bias
,
out_scale
,
channels
)
y1_folded
=
relay
.
ir_pass
.
infer_type
(
y1_folded
)
y1_expected
=
relay
.
ir_pass
.
infer_type
(
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
check
((
2
,
4
,
10
,
10
),
4
)
def
test_fold_bwd_fail
():
def
test_fold_bwd_fail
():
"""Dual path testcase."""
"""Dual path testcase."""
def
fail1
(
x
,
conv_weight
,
out_bias
,
out_scale
,
channels
):
def
fail1
(
x
,
conv_weight
,
out_bias
,
out_scale
,
channels
):
...
@@ -327,4 +402,5 @@ if __name__ == "__main__":
...
@@ -327,4 +402,5 @@ if __name__ == "__main__":
test_fold_fwd_fail
()
test_fold_fwd_fail
()
test_fold_bwd_simple
()
test_fold_bwd_simple
()
test_fold_bwd_dual_path
()
test_fold_bwd_dual_path
()
test_fold_bwd_dual_consumer
()
test_fold_bwd_fail
()
test_fold_bwd_fail
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment