Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d5103bbc
Unverified
Commit
d5103bbc
authored
Oct 29, 2018
by
Tianqi Chen
Committed by
GitHub
Oct 29, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][PASS] FoldScaleAxis Backward (#2024)
parent
25e4dc51
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
667 additions
and
35 deletions
+667
-35
include/tvm/relay/expr_functor.h
+3
-3
python/tvm/relay/ir_pass.py
+29
-0
src/relay/ir/expr_functor.cc
+8
-4
src/relay/pass/fold_scale_axis.cc
+433
-22
src/relay/pass/pattern_util.h
+21
-2
tests/python/relay/test_pass_fold_scale_axis.py
+173
-4
No files found.
include/tvm/relay/expr_functor.h
View file @
d5103bbc
...
@@ -135,9 +135,9 @@ class ExprVisitor
...
@@ -135,9 +135,9 @@ class ExprVisitor
void
VisitExpr_
(
const
TupleGetItemNode
*
op
)
override
;
void
VisitExpr_
(
const
TupleGetItemNode
*
op
)
override
;
virtual
void
VisitType
(
const
Type
&
t
);
virtual
void
VisitType
(
const
Type
&
t
);
pr
ivate
:
pr
otected
:
//
internal visited flag.
//
Internal visiting counter
std
::
unordered_
set
<
const
Node
*>
visited
_
;
std
::
unordered_
map
<
const
Node
*
,
size_t
>
visit_counter
_
;
};
};
/*!
/*!
...
...
python/tvm/relay/ir_pass.py
View file @
d5103bbc
...
@@ -31,6 +31,29 @@ def infer_type(expr, env=None):
...
@@ -31,6 +31,29 @@ def infer_type(expr, env=None):
return
_ir_pass
.
infer_type
(
expr
,
env
)
return
_ir_pass
.
infer_type
(
expr
,
env
)
def
backward_fold_scale_axis
(
expr
):
"""Backward fold axis scaling into weights of conv2d/dense.
Parameters
----------
expr : tvm.relay.Expr
The input expression, we expect that expr's types
should be fully inferred by infer_type.
Returns
-------
folded_expr : tvm.relay.Expr
The folded expression after transformation.
Note
----
It is recommended to call backward_fold_scale_axis
before using forward_fold_scale_axis.
As backward folding targets common conv-bn pattern.
"""
return
_ir_pass
.
backward_fold_scale_axis
(
expr
)
def
forward_fold_scale_axis
(
expr
):
def
forward_fold_scale_axis
(
expr
):
"""Fold the scaling of axis into weights of conv2d/dense.
"""Fold the scaling of axis into weights of conv2d/dense.
...
@@ -44,6 +67,12 @@ def forward_fold_scale_axis(expr):
...
@@ -44,6 +67,12 @@ def forward_fold_scale_axis(expr):
-------
-------
folded_expr : tvm.relay.Expr
folded_expr : tvm.relay.Expr
The folded expression after transformation.
The folded expression after transformation.
Note
----
It is recommended to call backward_fold_scale_axis
before using forward_fold_scale_axis.
As backward folding targets common conv-bn pattern.
"""
"""
return
_ir_pass
.
forward_fold_scale_axis
(
expr
)
return
_ir_pass
.
forward_fold_scale_axis
(
expr
)
...
...
src/relay/ir/expr_functor.cc
View file @
d5103bbc
...
@@ -160,10 +160,14 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
...
@@ -160,10 +160,14 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
Type
ExprMutator
::
VisitType
(
const
Type
&
t
)
{
return
t
;
}
Type
ExprMutator
::
VisitType
(
const
Type
&
t
)
{
return
t
;
}
void
ExprVisitor
::
VisitExpr
(
const
Expr
&
expr
)
{
void
ExprVisitor
::
VisitExpr
(
const
Expr
&
expr
)
{
if
(
visited_
.
count
(
expr
.
get
()))
return
;
auto
it
=
visit_counter_
.
find
(
expr
.
get
());
using
TParent
=
ExprFunctor
<
void
(
const
Expr
&
)
>
;
if
(
it
!=
visit_counter_
.
end
())
{
TParent
::
VisitExpr
(
expr
);
++
it
->
second
;
visited_
.
insert
(
expr
.
get
());
}
else
{
using
TParent
=
ExprFunctor
<
void
(
const
Expr
&
)
>
;
TParent
::
VisitExpr
(
expr
);
visit_counter_
.
insert
({
expr
.
get
(),
1
});
}
}
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
VarNode
*
op
)
{
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
VarNode
*
op
)
{
...
...
src/relay/pass/fold_scale_axis.cc
View file @
d5103bbc
...
@@ -24,9 +24,9 @@ namespace fold_scale_axis {
...
@@ -24,9 +24,9 @@ namespace fold_scale_axis {
using
runtime
::
TypedPackedFunc
;
using
runtime
::
TypedPackedFunc
;
// FoldScaleAxis
Foward
algorithm:
// FoldScaleAxis algorithm:
//
//
// The general idea is t
hat we
transform Expr to tuple of
// The general idea is t
o
transform Expr to tuple of
// (value, axes, scale), where the final result satiesfies:
// (value, axes, scale), where the final result satiesfies:
//
//
// result = value
// result = value
...
@@ -41,9 +41,14 @@ using runtime::TypedPackedFunc;
...
@@ -41,9 +41,14 @@ using runtime::TypedPackedFunc;
// we run a backward "preparation phase", which propagates the demand
// we run a backward "preparation phase", which propagates the demand
// of the potential axes scaling back to its input.
// of the potential axes scaling back to its input.
//
//
//
The
folding process is done in two steps:
//
Forward
folding process is done in two steps:
// - Prepare phase: backward propagation of demand.
// - Prepare phase: backward propagation of demand.
// - Transform phase: forward transformation,
// - Transform phase: forward transformation,
//
// Similarly, backward folding process is done in two steps:
// - Prepare phase: forward propagation of demand.
// - Transform phase: transformation by push down the axes scale signal to inputs.
//
/*!
/*!
* \brief sorted array axis, can also be nullptr.
* \brief sorted array axis, can also be nullptr.
...
@@ -99,7 +104,7 @@ ValueType GetFunc(const OpMap<ValueType>& op_map,
...
@@ -99,7 +104,7 @@ ValueType GetFunc(const OpMap<ValueType>& op_map,
}
}
/*!
/*!
* \brief Preparation function for
for
pass scale forward.
* \brief Preparation function for pass scale forward.
* \param call The call node.
* \param call The call node.
* \param out_scale_axes Possible scaling on axes of the output.
* \param out_scale_axes Possible scaling on axes of the output.
* \return The result scaling on axes of the input.
* \return The result scaling on axes of the input.
...
@@ -144,7 +149,7 @@ using FForwardTransform = TypedPackedFunc<
...
@@ -144,7 +149,7 @@ using FForwardTransform = TypedPackedFunc<
//----------------------------------------------
//----------------------------------------------
// Generic Visitors for FScaleAxisForward
// Generic Visitors for FScaleAxisForward
//----------------------------------------------
//----------------------------------------------
class
F
ScaleAxisF
orwardPrep
:
private
ExprVisitor
{
class
ForwardPrep
:
private
ExprVisitor
{
public
:
public
:
std
::
unordered_map
<
const
Node
*
,
AxesSet
>
std
::
unordered_map
<
const
Node
*
,
AxesSet
>
Prepare
(
const
Expr
&
body
)
{
Prepare
(
const
Expr
&
body
)
{
...
@@ -255,12 +260,12 @@ class FScaleAxisForwardPrep : private ExprVisitor {
...
@@ -255,12 +260,12 @@ class FScaleAxisForwardPrep : private ExprVisitor {
}
}
};
};
class
F
ScaleAxisForwardTransform
:
private
ExprMutator
{
class
F
orwardTransformer
:
private
ExprMutator
{
public
:
public
:
// Transform expression.
// Transform expression.
Expr
Transform
(
Expr
expr
)
{
Expr
Fold
(
Expr
expr
)
{
expected_scale_axes_
=
expected_scale_axes_
=
F
ScaleAxisF
orwardPrep
().
Prepare
(
expr
);
ForwardPrep
().
Prepare
(
expr
);
return
this
->
Mutate
(
expr
);
return
this
->
Mutate
(
expr
);
}
}
...
@@ -346,13 +351,13 @@ Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) {
...
@@ -346,13 +351,13 @@ Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) {
}
}
STuple
ReluForwardTransform
(
const
Call
&
ref_call
,
STuple
ReluForwardTransform
(
const
Call
&
ref_call
,
const
AxesSet
&
expected_axes
,
const
AxesSet
&
expected_axes
,
const
Array
<
STuple
>&
sargs
)
{
const
Array
<
STuple
>&
sargs
)
{
if
(
!
sargs
[
0
]
->
axes
.
defined
())
return
STuple
();
if
(
!
sargs
[
0
]
->
axes
.
defined
())
return
STuple
();
// return transformed conv2d
// return transformed conv2d
auto
rnode
=
make_node
<
STupleNode
>
();
auto
rnode
=
make_node
<
STupleNode
>
();
rnode
->
value
=
CallNode
::
make
(
rnode
->
value
=
CallNode
::
make
(
ref_call
->
op
,
{
sargs
[
0
]
->
value
},
ref_call
->
attrs
,
{}
);
ref_call
->
op
,
{
sargs
[
0
]
->
value
},
ref_call
->
attrs
,
ref_call
->
type_args
);
rnode
->
scale
=
sargs
[
0
]
->
scale
;
rnode
->
scale
=
sargs
[
0
]
->
scale
;
rnode
->
axes
=
sargs
[
0
]
->
axes
;
rnode
->
axes
=
sargs
[
0
]
->
axes
;
return
STuple
(
rnode
);
return
STuple
(
rnode
);
...
@@ -474,8 +479,6 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
...
@@ -474,8 +479,6 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
Layout
weight_layout
(
param
->
weight_layout
);
Layout
weight_layout
(
param
->
weight_layout
);
int
c_big_axis
=
data_layout
.
indexof
(
'C'
);
int
c_big_axis
=
data_layout
.
indexof
(
'C'
);
int
c_small_axis
=
data_layout
.
indexof
(
'c'
);
int
c_small_axis
=
data_layout
.
indexof
(
'c'
);
const
auto
*
tdata
=
call
->
args
[
0
]
->
type_as
<
TensorTypeNode
>
();
CHECK
(
tdata
)
<<
"require checked type"
;
CHECK_GE
(
c_big_axis
,
0
);
CHECK_GE
(
c_big_axis
,
0
);
AxesSet
data_axes
=
NullValue
<
AxesSet
>
();
AxesSet
data_axes
=
NullValue
<
AxesSet
>
();
...
@@ -486,8 +489,7 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
...
@@ -486,8 +489,7 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
//
//
// only handle depthwise or full conv2d.
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool
is_depthwise_conv2d
=
bool
is_depthwise_conv2d
=
IsDepthwiseConv2D
(
call
,
param
,
weight_layout
);
is_const_int
(
tdata
->
shape
[
c_big_axis
],
param
->
groups
);
if
(
weight_layout
.
indexof
(
'i'
)
<
0
&&
if
(
weight_layout
.
indexof
(
'i'
)
<
0
&&
c_small_axis
<
0
&&
c_small_axis
<
0
&&
(
param
->
groups
==
1
||
is_depthwise_conv2d
))
{
(
param
->
groups
==
1
||
is_depthwise_conv2d
))
{
...
@@ -515,18 +517,24 @@ STuple Conv2DForwardTransform(const Call& ref_call,
...
@@ -515,18 +517,24 @@ STuple Conv2DForwardTransform(const Call& ref_call,
CHECK_EQ
(
weight_layout
.
indexof
(
'i'
),
-
1
);
CHECK_EQ
(
weight_layout
.
indexof
(
'i'
),
-
1
);
CHECK
(
sdata
->
axes
.
size
()
==
1
&&
CHECK
(
sdata
->
axes
.
size
()
==
1
&&
c_big_axis
==
sdata
->
axes
[
0
]
->
value
);
c_big_axis
==
sdata
->
axes
[
0
]
->
value
);
int
big_oc_axis
=
weight_layout
.
indexof
(
'O'
);
int
big_ic_axis
=
weight_layout
.
indexof
(
'I'
);
int
big_ic_axis
=
weight_layout
.
indexof
(
'I'
);
const
auto
*
tdata
=
ref_call
->
args
[
0
]
->
type_as
<
TensorTypeNode
>
();
// Check it must be depthwise or full conv2d.
// Check it must be depthwise or full conv2d.
bool
is_depthwise_conv2d
=
bool
is_depthwise_conv2d
=
IsDepthwiseConv2D
(
ref_call
,
param
,
weight_layout
);
is_const_int
(
tdata
->
shape
[
c_big_axis
],
param
->
groups
);
CHECK
(
param
->
groups
==
1
||
is_depthwise_conv2d
);
CHECK
(
param
->
groups
==
1
||
is_depthwise_conv2d
);
Expr
weight
=
sargs
[
1
]
->
value
;
// match the ic_axis
// match the ic_axis
Expr
scale
=
ExpandBiasToMatchAxis
(
if
(
is_depthwise_conv2d
)
{
sdata
->
scale
,
weight_layout
.
ndim
(),
{
big_ic_axis
});
Expr
scale
=
ExpandBiasToMatchAxis
(
Expr
weight
=
Multiply
(
sargs
[
1
]
->
value
,
scale
);
sdata
->
scale
,
weight_layout
.
ndim
(),
{
big_oc_axis
});
weight
=
Multiply
(
weight
,
scale
);
}
else
{
Expr
scale
=
ExpandBiasToMatchAxis
(
sdata
->
scale
,
weight_layout
.
ndim
(),
{
big_ic_axis
});
weight
=
Multiply
(
weight
,
scale
);
}
// return transformed conv2d
// return transformed conv2d
auto
rnode
=
make_node
<
STupleNode
>
();
auto
rnode
=
make_node
<
STupleNode
>
();
rnode
->
value
=
CallNode
::
make
(
rnode
->
value
=
CallNode
::
make
(
...
@@ -542,13 +550,416 @@ RELAY_REGISTER_OP("nn.conv2d")
...
@@ -542,13 +550,416 @@ RELAY_REGISTER_OP("nn.conv2d")
Expr
ForwardFoldScaleAxis
(
Expr
data
)
{
Expr
ForwardFoldScaleAxis
(
Expr
data
)
{
return
F
ScaleAxisForwardTransform
().
Transform
(
data
);
return
F
orwardTransformer
().
Fold
(
data
);
}
}
// Expose the FoldScaleAxisFoward
// Expose the FoldScaleAxisFoward
TVM_REGISTER_API
(
"relay._ir_pass.forward_fold_scale_axis"
)
TVM_REGISTER_API
(
"relay._ir_pass.forward_fold_scale_axis"
)
.
set_body_typed
<
Expr
(
Expr
)
>
(
ForwardFoldScaleAxis
);
.
set_body_typed
<
Expr
(
Expr
)
>
(
ForwardFoldScaleAxis
);
//----------------------------------------
// Implement backward transformations.
//----------------------------------------
class
BackwardTransformer
;
/*!
* \brief Preparation function for for pass scale backward.
* \param call The call node.
* \param in_scale_axes Allowed input scaling.
* \return The result scaling on axes of the input.
*/
using
FBackwardPrep
=
TypedPackedFunc
<
AxesSet
(
const
Call
&
call
,
const
Array
<
AxesSet
>&
in_scale_axes
)
>
;
using
FBackwardTransform
=
TypedPackedFunc
<
Expr
(
const
Call
&
call
,
const
AxesSet
&
axes
,
const
Expr
&
scale
,
const
BackwardTransformer
&
transformer
)
>
;
//----------------------------------------------
// Generic Visitors for FScaleAxisBackward
//----------------------------------------------
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std
::
unordered_map
<
const
Node
*
,
size_t
>
GetExprRefCount
(
const
Expr
&
body
)
{
class
ExprRefCounter
:
private
ExprVisitor
{
public
:
std
::
unordered_map
<
const
Node
*
,
size_t
>
Get
(
const
Expr
&
body
)
{
this
->
VisitExpr
(
body
);
return
std
::
move
(
this
->
visit_counter_
);
}
};
return
ExprRefCounter
().
Get
(
body
);
}
class
BackwardPrep
:
private
ExprVisitor
{
public
:
// The message on each node.
std
::
unordered_map
<
const
Node
*
,
AxesSet
>
Prepare
(
const
Expr
&
body
)
{
ref_counter_
=
GetExprRefCount
(
body
);
this
->
VisitExpr
(
body
);
return
std
::
move
(
message_
);
}
private
:
// The message on each node.
std
::
unordered_map
<
const
Node
*
,
AxesSet
>
message_
;
// reference counter of an internal expr
std
::
unordered_map
<
const
Node
*
,
size_t
>
ref_counter_
;
// Visit the expression.
void
VisitExpr_
(
const
CallNode
*
call
)
{
ExprVisitor
::
VisitExpr_
(
call
);
static
const
auto
&
fprep
=
Op
::
GetAttr
<
FBackwardPrep
>
(
"FScaleAxisBackwardPrep"
);
auto
f
=
GetFunc
(
fprep
,
call
->
op
);
if
(
f
==
nullptr
)
return
;
auto
rit
=
ref_counter_
.
find
(
call
);
CHECK
(
rit
!=
ref_counter_
.
end
());
// We only allow propagation of scale backward
// if the expression is only referred by a single parent.
if
(
rit
->
second
!=
1
)
return
;
Array
<
AxesSet
>
in_axes
;
for
(
Expr
arg
:
call
->
args
)
{
auto
it
=
message_
.
find
(
arg
.
get
());
if
(
it
!=
message_
.
end
())
{
in_axes
.
push_back
(
it
->
second
);
}
else
{
in_axes
.
push_back
(
NullValue
<
AxesSet
>
());
}
}
AxesSet
out_axes
=
f
(
GetRef
<
Call
>
(
call
),
in_axes
);
if
(
out_axes
.
defined
())
{
message_
[
call
]
=
out_axes
;
}
}
};
class
BackwardTransformerNode
:
public
Node
,
private
ExprMutator
{
public
:
// Run forward transform.
Expr
Fold
(
Expr
expr
)
{
expected_scale_axes_
=
BackwardPrep
().
Prepare
(
expr
);
return
this
->
Mutate
(
expr
);
}
/*!
* \brief Transform the expr to consider the scaling.
*
* \param expr The input expression.
* \param axes The axes to scale.
* \param scale The scale applied to the axes.
* \return The result of transformation.
*/
Expr
Transform
(
const
Expr
&
expr
,
AxesSet
axes
,
Expr
scale
)
{
// NOTE: the result of Transform is not memoized.
// However, in the current rule, Transform will
// only be called to expr that is referred once.
if
(
const
CallNode
*
call_node
=
expr
.
as
<
CallNode
>
())
{
return
Transform
(
call_node
,
axes
,
scale
);
}
else
{
CHECK
(
!
axes
.
defined
())
<<
"outstanding scale"
;
return
ExprMutator
::
VisitExpr
(
expr
);
}
}
/*!
* \brief Normal way of mutating call node.
* \param call_node The call node to be mutated.
* \return the result of the call Mutation.
*/
Expr
NormalCallTransform
(
const
CallNode
*
call_node
)
{
return
ExprMutator
::
VisitExpr_
(
call_node
);
}
/*!
* \brief Get the expected axes on expr.
* \param expr The expresison.
* \return The expected axes.
*/
AxesSet
GetExpectedAxes
(
const
Expr
&
expr
)
const
{
auto
it
=
expected_scale_axes_
.
find
(
expr
.
get
());
if
(
it
!=
expected_scale_axes_
.
end
())
return
it
->
second
;
return
NullValue
<
AxesSet
>
();
}
// solver is not serializable.
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{}
static
constexpr
const
char
*
_type_key
=
"relay.fold_scale_axis.FBackwardTransformer"
;
TVM_DECLARE_NODE_TYPE_INFO
(
BackwardTransformerNode
,
Node
);
private
:
// Valid axes on each node.
std
::
unordered_map
<
const
Node
*
,
AxesSet
>
expected_scale_axes_
;
// Override mutation of call.
Expr
VisitExpr_
(
const
CallNode
*
call_node
)
final
{
return
Transform
(
call_node
,
NullValue
<
AxesSet
>
(),
NullValue
<
Expr
>
());
}
// Transform of CallNode.
Expr
Transform
(
const
CallNode
*
call_node
,
AxesSet
axes
,
Expr
scale
);
};
class
BackwardTransformer
:
public
NodeRef
{
public
:
BackwardTransformer
()
{}
explicit
BackwardTransformer
(
::
tvm
::
NodePtr
<::
tvm
::
Node
>
n
)
:
NodeRef
(
n
)
{
}
BackwardTransformerNode
*
operator
->
()
const
{
return
static_cast
<
BackwardTransformerNode
*>
(
node_
.
get
());
}
using
ContainerType
=
BackwardTransformerNode
;
};
Expr
BackwardTransformerNode
::
Transform
(
const
CallNode
*
call_node
,
AxesSet
axes
,
Expr
scale
)
{
static
const
auto
&
ftransform
=
Op
::
GetAttr
<
FBackwardTransform
>
(
"FScaleAxisBackwardTransform"
);
auto
f
=
GetFunc
(
ftransform
,
call_node
->
op
);
if
(
f
!=
nullptr
)
{
return
f
(
GetRef
<
Call
>
(
call_node
),
axes
,
scale
,
GetRef
<
BackwardTransformer
>
(
this
));
}
else
{
CHECK
(
!
axes
.
defined
())
<<
"outstanding scale"
;
return
NormalCallTransform
(
call_node
);
}
}
//----------------------------------------------
// Per operator defs for FScaleAxisForward
//----------------------------------------------
// Intermediate operators
AxesSet
ReluBackwardPrep
(
const
Call
&
call
,
const
Array
<
AxesSet
>&
in_axes
)
{
return
in_axes
[
0
];
}
Expr
ReluBackwardTransform
(
const
Call
&
call
,
const
AxesSet
&
axes
,
const
Expr
&
scale
,
const
BackwardTransformer
&
transformer
)
{
if
(
!
axes
.
defined
())
{
return
transformer
->
NormalCallTransform
(
call
.
operator
->
());
}
Expr
input
=
transformer
->
Transform
(
call
->
args
[
0
],
axes
,
scale
);
return
CallNode
::
make
(
call
->
op
,
{
input
},
call
->
attrs
,
call
->
type_args
);
}
RELAY_REGISTER_OP
(
"nn.relu"
)
.
set_attr
<
FBackwardPrep
>
(
"FScaleAxisBackwardPrep"
,
ReluBackwardPrep
);
RELAY_REGISTER_OP
(
"nn.relu"
)
.
set_attr
<
FBackwardTransform
>
(
"FScaleAxisBackwardTransform"
,
ReluBackwardTransform
);
RELAY_REGISTER_OP
(
"nn.leaky_relu"
)
.
set_attr
<
FBackwardPrep
>
(
"FScaleAxisBackwardPrep"
,
ReluBackwardPrep
);
RELAY_REGISTER_OP
(
"nn.leaky_relu"
)
.
set_attr
<
FBackwardTransform
>
(
"FScaleAxisBackwardTransform"
,
ReluBackwardTransform
);
// AddSub
AxesSet
AddSubBackwardPrep
(
const
Call
&
call
,
const
Array
<
AxesSet
>&
in_axes
)
{
const
auto
*
tlhs
=
call
->
args
[
0
]
->
type_as
<
TensorTypeNode
>
();
const
auto
*
trhs
=
call
->
args
[
1
]
->
type_as
<
TensorTypeNode
>
();
AttrsEqual
equal
;
if
(
in_axes
[
0
].
defined
()
&&
MatchBroadcastToLeftAxes
(
tlhs
,
trhs
,
in_axes
[
0
]))
{
return
in_axes
[
0
];
}
else
if
(
in_axes
[
1
].
defined
()
&&
MatchBroadcastToLeftAxes
(
trhs
,
tlhs
,
in_axes
[
1
]))
{
return
in_axes
[
1
];
}
else
if
(
in_axes
[
0
].
defined
()
&&
in_axes
[
1
].
defined
()
&&
equal
(
in_axes
[
0
],
in_axes
[
1
])
&&
equal
(
tlhs
->
shape
,
trhs
->
shape
))
{
// add of two elements.
return
in_axes
[
0
];
}
else
{
return
NullValue
<
AxesSet
>
();
}
}
Expr
AddSubBackwardTransform
(
const
Call
&
call
,
const
AxesSet
&
axes
,
const
Expr
&
scale
,
const
BackwardTransformer
&
transformer
)
{
const
auto
*
tlhs
=
call
->
args
[
0
]
->
type_as
<
TensorTypeNode
>
();
const
auto
*
trhs
=
call
->
args
[
1
]
->
type_as
<
TensorTypeNode
>
();
if
(
!
axes
.
defined
())
{
return
transformer
->
NormalCallTransform
(
call
.
operator
->
());
}
AxesSet
lhs_axes
=
transformer
->
GetExpectedAxes
(
call
->
args
[
0
]);
AxesSet
rhs_axes
=
transformer
->
GetExpectedAxes
(
call
->
args
[
1
]);
AttrsEqual
equal
;
if
(
lhs_axes
.
defined
()
&&
rhs_axes
.
defined
())
{
CHECK
(
equal
(
lhs_axes
,
rhs_axes
));
CHECK
(
equal
(
axes
,
lhs_axes
));
Expr
lhs
=
transformer
->
Transform
(
call
->
args
[
0
],
axes
,
scale
);
Expr
rhs
=
transformer
->
Transform
(
call
->
args
[
1
],
axes
,
scale
);
return
CallNode
::
make
(
call
->
op
,
{
lhs
,
rhs
},
call
->
attrs
,
call
->
type_args
);
}
else
if
(
lhs_axes
.
defined
())
{
CHECK
(
equal
(
axes
,
lhs_axes
));
Expr
lhs
=
transformer
->
Transform
(
call
->
args
[
0
],
axes
,
scale
);
Expr
rhs
=
transformer
->
Transform
(
call
->
args
[
1
],
NullValue
<
AxesSet
>
(),
NullValue
<
Expr
>
());
Expr
rhs_scale
=
ExpandBiasToMatchAxis
(
scale
,
tlhs
->
shape
.
size
(),
axes
);
rhs
=
Multiply
(
rhs
,
rhs_scale
);
return
CallNode
::
make
(
call
->
op
,
{
lhs
,
rhs
},
call
->
attrs
,
call
->
type_args
);
}
else
if
(
rhs_axes
.
defined
())
{
CHECK
(
equal
(
axes
,
rhs_axes
));
Expr
lhs
=
transformer
->
Transform
(
call
->
args
[
0
],
NullValue
<
AxesSet
>
(),
NullValue
<
Expr
>
());
Expr
rhs
=
transformer
->
Transform
(
call
->
args
[
1
],
axes
,
scale
);
Expr
lhs_scale
=
ExpandBiasToMatchAxis
(
scale
,
trhs
->
shape
.
size
(),
axes
);
lhs
=
Multiply
(
lhs
,
lhs_scale
);
return
CallNode
::
make
(
call
->
op
,
{
lhs
,
rhs
},
call
->
attrs
,
call
->
type_args
);
}
else
{
LOG
(
FATAL
)
<<
"outstanding scale"
;
return
Expr
();
}
}
RELAY_REGISTER_OP
(
"add"
)
.
set_attr
<
FBackwardPrep
>
(
"FScaleAxisBackwardPrep"
,
AddSubBackwardPrep
);
RELAY_REGISTER_OP
(
"add"
)
.
set_attr
<
FBackwardTransform
>
(
"FScaleAxisBackwardTransform"
,
AddSubBackwardTransform
);
RELAY_REGISTER_OP
(
"subtract"
)
.
set_attr
<
FBackwardPrep
>
(
"FScaleAxisBackwardPrep"
,
AddSubBackwardPrep
);
RELAY_REGISTER_OP
(
"subtract"
)
.
set_attr
<
FBackwardTransform
>
(
"FScaleAxisBackwardTransform"
,
AddSubBackwardTransform
);
// Producer operators
// Multiply produces the scale-axis pair.
Expr
MultiplyBackwardTransform
(
const
Call
&
call
,
const
AxesSet
&
axes
,
const
Expr
&
scale
,
const
BackwardTransformer
&
transformer
)
{
CHECK
(
!
axes
.
defined
())
<<
"outstanding scale"
;
const
auto
*
tlhs
=
call
->
args
[
0
]
->
type_as
<
TensorTypeNode
>
();
const
auto
*
trhs
=
call
->
args
[
1
]
->
type_as
<
TensorTypeNode
>
();
AxesSet
lhs_axes
=
transformer
->
GetExpectedAxes
(
call
->
args
[
0
]);
AxesSet
rhs_axes
=
transformer
->
GetExpectedAxes
(
call
->
args
[
1
]);
if
(
lhs_axes
.
defined
())
{
// NOTE we won't recursively call mutating on scale part.
// since there won't be scale chance within scale part.
Expr
rhs
=
call
->
args
[
1
];
if
(
MatchBroadcastToLeftAxes
(
tlhs
,
trhs
,
lhs_axes
,
&
rhs
))
{
return
transformer
->
Transform
(
call
->
args
[
0
],
lhs_axes
,
rhs
);
}
}
else
if
(
rhs_axes
.
defined
())
{
Expr
lhs
=
call
->
args
[
0
];
if
(
MatchBroadcastToLeftAxes
(
trhs
,
tlhs
,
rhs_axes
,
&
lhs
))
{
return
transformer
->
Transform
(
call
->
args
[
1
],
rhs_axes
,
lhs
);
}
}
return
transformer
->
NormalCallTransform
(
call
.
operator
->
());
}
RELAY_REGISTER_OP
(
"multiply"
)
.
set_attr
<
FBackwardTransform
>
(
"FScaleAxisBackwardTransform"
,
MultiplyBackwardTransform
);
// Consumer operators
// Conv2D send out requirement of axis folding.
AxesSet
Conv2DBackwardPrep
(
const
Call
&
call
,
const
Array
<
AxesSet
>&
in_axes
)
{
const
auto
*
param
=
call
->
attrs
.
as
<
Conv2DAttrs
>
();
CHECK
(
param
!=
nullptr
);
Layout
out_layout
(
param
->
out_layout
);
if
(
!
out_layout
.
defined
())
{
out_layout
=
Layout
(
param
->
data_layout
);
}
Layout
weight_layout
(
param
->
weight_layout
);
int
c_big_axis
=
out_layout
.
indexof
(
'C'
);
int
c_small_axis
=
out_layout
.
indexof
(
'c'
);
CHECK_GE
(
c_big_axis
,
0
);
// For now, we only support simple pattern (no folded weight/data)
// More general layout can be supported under the current framework.
// By using a unified layout transformation.
// We only need to change the Prep and Mutate function.
//
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool
is_depthwise_conv2d
=
IsDepthwiseConv2D
(
call
,
param
,
weight_layout
);
if
(
weight_layout
.
indexof
(
'o'
)
<
0
&&
weight_layout
.
indexof
(
'i'
)
<
0
&&
c_small_axis
<
0
&&
(
param
->
groups
==
1
||
is_depthwise_conv2d
))
{
return
{
c_big_axis
};
}
else
{
return
NullValue
<
AxesSet
>
();
}
}
// Conv2D consumes the scale axis during transformation.
Expr
Conv2DBackwardTransform
(
const
Call
&
call
,
const
AxesSet
&
axes
,
const
Expr
&
scale
,
const
BackwardTransformer
&
transformer
)
{
if
(
!
axes
.
defined
())
{
return
transformer
->
NormalCallTransform
(
call
.
operator
->
());
}
const
auto
*
param
=
call
->
attrs
.
as
<
Conv2DAttrs
>
();
CHECK
(
param
!=
nullptr
);
Layout
out_layout
(
param
->
out_layout
);
if
(
!
out_layout
.
defined
())
{
out_layout
=
Layout
(
param
->
data_layout
);
}
Layout
weight_layout
(
param
->
weight_layout
);
int
c_big_axis
=
out_layout
.
indexof
(
'C'
);
CHECK_GE
(
c_big_axis
,
0
);
// For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout
CHECK_EQ
(
weight_layout
.
indexof
(
'o'
),
-
1
);
CHECK_EQ
(
weight_layout
.
indexof
(
'i'
),
-
1
);
CHECK
(
axes
.
size
()
==
1
&&
c_big_axis
==
axes
[
0
]
->
value
);
int
big_oc_axis
=
weight_layout
.
indexof
(
'O'
);
// Check it must be depthwise or full conv2d.
bool
is_depthwise_conv2d
=
IsDepthwiseConv2D
(
call
,
param
,
weight_layout
);
CHECK
(
param
->
groups
==
1
||
is_depthwise_conv2d
);
Expr
data
=
transformer
->
Transform
(
call
->
args
[
0
],
NullValue
<
AxesSet
>
(),
NullValue
<
Expr
>
());
Expr
weight
=
transformer
->
Transform
(
call
->
args
[
1
],
NullValue
<
AxesSet
>
(),
NullValue
<
Expr
>
());
// scale on input for deptwise.
Expr
wscale
=
ExpandBiasToMatchAxis
(
scale
,
weight_layout
.
ndim
(),
{
big_oc_axis
});
weight
=
Multiply
(
weight
,
wscale
);
return
CallNode
::
make
(
call
->
op
,
{
data
,
weight
},
call
->
attrs
,
call
->
type_args
);
}
RELAY_REGISTER_OP
(
"nn.conv2d"
)
.
set_attr
<
FBackwardPrep
>
(
"FScaleAxisBackwardPrep"
,
Conv2DBackwardPrep
);
RELAY_REGISTER_OP
(
"nn.conv2d"
)
.
set_attr
<
FBackwardTransform
>
(
"FScaleAxisBackwardTransform"
,
Conv2DBackwardTransform
);
Expr
BackwardFoldScaleAxis
(
Expr
data
)
{
return
make_node
<
BackwardTransformerNode
>
()
->
Fold
(
data
);
}
TVM_REGISTER_API
(
"relay._ir_pass.backward_fold_scale_axis"
)
.
set_body_typed
<
Expr
(
Expr
)
>
(
BackwardFoldScaleAxis
);
}
// namespace fold_scale_axis
}
// namespace fold_scale_axis
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
src/relay/pass/pattern_util.h
View file @
d5103bbc
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include <tvm/relay/op.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/transform.h>
#include "../op/nn/layout.h"
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
...
@@ -100,11 +101,31 @@ inline Expr ExpandBiasToMatchAxis(Expr bias,
...
@@ -100,11 +101,31 @@ inline Expr ExpandBiasToMatchAxis(Expr bias,
return
bias
;
return
bias
;
}
}
/*!
* \brief Check if the call is depthwise conv2d.
*
* \param call The conv2d call.
* \param param The conv2d attributes.
* \return Whether it is depthwise_conv2d.
*/
inline
bool
IsDepthwiseConv2D
(
const
Call
&
call
,
const
Conv2DAttrs
*
param
,
const
Layout
&
weight_layout
)
{
static
const
Layout
kOIHW
(
"OIHW"
);
auto
wshape
=
ConvertLayout
(
call
->
args
[
1
]
->
type_as
<
TensorTypeNode
>
()
->
shape
,
weight_layout
,
kOIHW
);
return
is_const_int
(
wshape
[
0
],
param
->
groups
)
&&
is_const_int
(
wshape
[
1
],
1
);
}
inline
Expr
Multiply
(
Expr
lhs
,
Expr
rhs
)
{
inline
Expr
Multiply
(
Expr
lhs
,
Expr
rhs
)
{
static
const
Op
&
op
=
Op
::
Get
(
"multiply"
);
static
const
Op
&
op
=
Op
::
Get
(
"multiply"
);
return
CallNode
::
make
(
op
,
{
lhs
,
rhs
},
Attrs
(),
{});
return
CallNode
::
make
(
op
,
{
lhs
,
rhs
},
Attrs
(),
{});
}
}
inline
Expr
Divide
(
Expr
lhs
,
Expr
rhs
)
{
inline
Expr
Divide
(
Expr
lhs
,
Expr
rhs
)
{
static
const
Op
&
op
=
Op
::
Get
(
"divide"
);
static
const
Op
&
op
=
Op
::
Get
(
"divide"
);
return
CallNode
::
make
(
op
,
{
lhs
,
rhs
},
Attrs
(),
{});
return
CallNode
::
make
(
op
,
{
lhs
,
rhs
},
Attrs
(),
{});
...
@@ -116,8 +137,6 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) {
...
@@ -116,8 +137,6 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) {
return
CallNode
::
make
(
op
,
{
lhs
,
rhs
},
Attrs
(),
{});
return
CallNode
::
make
(
op
,
{
lhs
,
rhs
},
Attrs
(),
{});
}
}
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
tests/python/relay/test_pass_fold_scale_axis.py
View file @
d5103bbc
...
@@ -62,14 +62,14 @@ def test_fold_fwd_dual_path():
...
@@ -62,14 +62,14 @@ def test_fold_fwd_dual_path():
channels
=
channels
,
channels
=
channels
,
kernel_size
=
(
3
,
3
),
kernel_size
=
(
3
,
3
),
data_layout
=
"NHWC"
,
data_layout
=
"NHWC"
,
weight_layout
=
"HW
OI
"
,
weight_layout
=
"HW
IO
"
,
groups
=
channels
,
groups
=
channels
,
padding
=
(
1
,
1
))
padding
=
(
1
,
1
))
y2
=
relay
.
nn
.
conv2d
(
x
,
conv_weight
,
y2
=
relay
.
nn
.
conv2d
(
x
,
conv_weight
,
channels
=
channels
,
channels
=
channels
,
kernel_size
=
(
3
,
3
),
kernel_size
=
(
3
,
3
),
data_layout
=
"NHWC"
,
data_layout
=
"NHWC"
,
weight_layout
=
"HW
OI
"
,
weight_layout
=
"HW
IO
"
,
groups
=
channels
,
groups
=
channels
,
padding
=
(
1
,
1
))
padding
=
(
1
,
1
))
z
=
relay
.
add
(
y1
,
y2
)
z
=
relay
.
add
(
y1
,
y2
)
...
@@ -85,7 +85,7 @@ def test_fold_fwd_dual_path():
...
@@ -85,7 +85,7 @@ def test_fold_fwd_dual_path():
channels
=
channels
,
channels
=
channels
,
kernel_size
=
(
3
,
3
),
kernel_size
=
(
3
,
3
),
data_layout
=
"NHWC"
,
data_layout
=
"NHWC"
,
weight_layout
=
"HW
OI
"
,
weight_layout
=
"HW
IO
"
,
groups
=
channels
,
groups
=
channels
,
padding
=
(
1
,
1
))
padding
=
(
1
,
1
))
y2
=
relay
.
nn
.
conv2d
(
x
,
y2
=
relay
.
nn
.
conv2d
(
x
,
...
@@ -93,7 +93,7 @@ def test_fold_fwd_dual_path():
...
@@ -93,7 +93,7 @@ def test_fold_fwd_dual_path():
channels
=
channels
,
channels
=
channels
,
kernel_size
=
(
3
,
3
),
kernel_size
=
(
3
,
3
),
data_layout
=
"NHWC"
,
data_layout
=
"NHWC"
,
weight_layout
=
"HW
OI
"
,
weight_layout
=
"HW
IO
"
,
groups
=
channels
,
groups
=
channels
,
padding
=
(
1
,
1
))
padding
=
(
1
,
1
))
z
=
relay
.
add
(
y1
,
y2
)
z
=
relay
.
add
(
y1
,
y2
)
...
@@ -147,7 +147,176 @@ def test_fold_fwd_fail():
...
@@ -147,7 +147,176 @@ def test_fold_fwd_fail():
check
((
2
,
11
,
10
,
4
),
4
)
check
((
2
,
11
,
10
,
4
),
4
)
def
test_fold_bwd_simple
():
"""Simple testcase."""
def
before
(
x
,
conv_weight
,
out_bias
,
out_scale
,
channels
):
args
=
[
x
,
conv_weight
,
out_bias
,
out_scale
]
out_scale
=
relay
.
expand_dims
(
out_scale
,
axis
=
1
,
num_newaxis
=
2
)
out_bias
=
relay
.
expand_dims
(
out_bias
,
axis
=
1
,
num_newaxis
=
2
)
y
=
relay
.
nn
.
conv2d
(
x
,
conv_weight
,
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y
=
relay
.
add
(
y
,
out_bias
)
y
=
relay
.
nn
.
relu
(
y
)
y
=
relay
.
multiply
(
y
,
out_scale
)
return
relay
.
Function
(
args
,
y
)
def
expected
(
x
,
conv_weight
,
out_bias
,
out_scale
,
channels
):
# use a fixed order of args so alpha equal check can pass
args
=
[
x
,
conv_weight
,
out_bias
,
out_scale
]
out_scale
=
relay
.
expand_dims
(
out_scale
,
axis
=
1
,
num_newaxis
=
2
)
out_bias
=
relay
.
expand_dims
(
out_bias
,
axis
=
1
,
num_newaxis
=
2
)
squeezed_scale
=
relay
.
squeeze
(
out_scale
,
axis
=
[
1
,
2
])
conv_weight
=
relay
.
multiply
(
conv_weight
,
relay
.
expand_dims
(
squeezed_scale
,
axis
=
1
,
num_newaxis
=
3
))
y
=
relay
.
nn
.
conv2d
(
x
,
conv_weight
,
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
out_bias
=
relay
.
multiply
(
out_bias
,
relay
.
expand_dims
(
squeezed_scale
,
axis
=
1
,
num_newaxis
=
2
))
y
=
relay
.
add
(
y
,
out_bias
)
y
=
relay
.
nn
.
relu
(
y
)
return
relay
.
Function
(
args
,
y
)
def
check
(
shape
,
channels
):
x
=
relay
.
var
(
"x"
,
shape
=
shape
)
in_channels
=
shape
[
1
]
weight
=
relay
.
var
(
"weight"
)
out_bias
=
relay
.
var
(
"out_bias"
,
shape
=
(
channels
,))
out_scale
=
relay
.
var
(
"out_scale"
,
shape
=
(
channels
,))
y1
=
before
(
x
,
weight
,
out_bias
,
out_scale
,
channels
)
y1
=
relay
.
ir_pass
.
infer_type
(
y1
)
type_dict
=
{
x
.
name_hint
:
x
.
checked_type
for
x
in
y1
.
params
}
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
y1_folded
=
relay
.
ir_pass
.
backward_fold_scale_axis
(
y1
)
y1_expected
=
expected
(
x
,
weight
,
out_bias
,
out_scale
,
channels
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
check
((
2
,
4
,
10
,
10
),
8
)
def
test_fold_bwd_dual_path
():
"""Dual path testcase."""
def
before
(
x
,
conv_weight
,
out_bias
,
out_scale
,
channels
):
args
=
[
x
,
conv_weight
,
out_bias
,
out_scale
]
out_scale
=
relay
.
expand_dims
(
out_scale
,
axis
=
1
,
num_newaxis
=
2
)
out_bias
=
relay
.
expand_dims
(
out_bias
,
axis
=
1
,
num_newaxis
=
2
)
y1
=
relay
.
nn
.
conv2d
(
x
,
conv_weight
,
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y1
=
relay
.
nn
.
relu
(
y1
)
y2
=
relay
.
nn
.
conv2d
(
x
,
conv_weight
,
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y2
=
relay
.
nn
.
relu
(
y2
)
y
=
relay
.
add
(
y1
,
y2
)
y
=
relay
.
multiply
(
y
,
out_scale
)
return
relay
.
Function
(
args
,
y
)
def
expected
(
x
,
conv_weight
,
out_bias
,
out_scale
,
channels
):
# use a fixed order of args so alpha equal check can pass
args
=
[
x
,
conv_weight
,
out_bias
,
out_scale
]
out_scale
=
relay
.
expand_dims
(
out_scale
,
axis
=
1
,
num_newaxis
=
2
)
out_bias
=
relay
.
expand_dims
(
out_bias
,
axis
=
1
,
num_newaxis
=
2
)
squeezed_scale
=
relay
.
squeeze
(
out_scale
,
axis
=
[
1
,
2
])
def
fold_conv_weight
():
return
relay
.
multiply
(
conv_weight
,
relay
.
expand_dims
(
squeezed_scale
,
axis
=
1
,
num_newaxis
=
3
))
y1
=
relay
.
nn
.
conv2d
(
x
,
fold_conv_weight
(),
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y1
=
relay
.
nn
.
relu
(
y1
)
y2
=
relay
.
nn
.
conv2d
(
x
,
fold_conv_weight
(),
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y2
=
relay
.
nn
.
relu
(
y2
)
y
=
relay
.
add
(
y1
,
y2
)
return
relay
.
Function
(
args
,
y
)
def
check
(
shape
,
channels
):
x
=
relay
.
var
(
"x"
,
shape
=
shape
)
in_channels
=
shape
[
1
]
weight
=
relay
.
var
(
"weight"
)
out_bias
=
relay
.
var
(
"out_bias"
,
shape
=
(
channels
,))
out_scale
=
relay
.
var
(
"out_scale"
,
shape
=
(
channels
,))
y1
=
before
(
x
,
weight
,
out_bias
,
out_scale
,
channels
)
y1
=
relay
.
ir_pass
.
infer_type
(
y1
)
type_dict
=
{
x
.
name_hint
:
x
.
checked_type
for
x
in
y1
.
params
}
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
y1_folded
=
relay
.
ir_pass
.
backward_fold_scale_axis
(
y1
)
y1_expected
=
expected
(
x
,
weight
,
out_bias
,
out_scale
,
channels
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
check
((
2
,
4
,
10
,
10
),
8
)
def
test_fold_bwd_fail
():
"""Dual path testcase."""
def
fail1
(
x
,
conv_weight
,
out_bias
,
out_scale
,
channels
):
args
=
[
x
,
conv_weight
,
out_bias
,
out_scale
]
out_scale
=
relay
.
expand_dims
(
out_scale
,
axis
=
1
,
num_newaxis
=
2
)
out_bias
=
relay
.
expand_dims
(
out_bias
,
axis
=
1
,
num_newaxis
=
2
)
y1
=
relay
.
nn
.
conv2d
(
x
,
conv_weight
,
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y1
=
relay
.
nn
.
relu
(
y1
)
y2
=
relay
.
nn
.
conv2d
(
x
,
conv_weight
,
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
out_layout
=
"CNHW"
)
# fold will fail because the axis from two path
# differs from each other.
y2
=
relay
.
nn
.
relu
(
y2
)
y
=
relay
.
add
(
y1
,
y2
)
y
=
relay
.
multiply
(
y
,
out_scale
)
return
relay
.
Function
(
args
,
y
)
def
fail2
(
x
,
conv_weight
,
out_bias
,
out_scale
,
channels
):
args
=
[
x
,
conv_weight
,
out_bias
,
out_scale
]
out_scale
=
relay
.
expand_dims
(
out_scale
,
axis
=
1
,
num_newaxis
=
2
)
out_bias
=
relay
.
expand_dims
(
out_bias
,
axis
=
1
,
num_newaxis
=
2
)
y1
=
relay
.
nn
.
conv2d
(
x
,
conv_weight
,
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y2
=
relay
.
nn
.
relu
(
y1
)
# fold will fail because y1 is referred also by y2
y1
=
relay
.
multiply
(
y1
,
out_scale
)
y
=
relay
.
add
(
y1
,
y2
)
return
relay
.
Function
(
args
,
y
)
def
check
(
shape
,
channels
,
fbefore
):
x
=
relay
.
var
(
"x"
,
shape
=
shape
)
in_channels
=
shape
[
1
]
weight
=
relay
.
var
(
"weight"
)
out_bias
=
relay
.
var
(
"out_bias"
,
shape
=
(
channels
,))
out_scale
=
relay
.
var
(
"out_scale"
,
shape
=
(
channels
,))
y1
=
fbefore
(
x
,
weight
,
out_bias
,
out_scale
,
channels
)
y1
=
relay
.
ir_pass
.
infer_type
(
y1
)
y1_folded
=
relay
.
ir_pass
.
backward_fold_scale_axis
(
y1
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1
)
check
((
4
,
4
,
10
,
10
),
4
,
fail1
)
check
((
4
,
4
,
10
,
10
),
4
,
fail2
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_fold_fwd_simple
()
test_fold_fwd_simple
()
test_fold_fwd_dual_path
()
test_fold_fwd_dual_path
()
test_fold_fwd_fail
()
test_fold_fwd_fail
()
test_fold_bwd_simple
()
test_fold_bwd_dual_path
()
test_fold_bwd_fail
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment