Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
87a37684
Commit
87a37684
authored
Dec 18, 2018
by
Tianqi Chen
Committed by
ziheng
Dec 18, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] Avoid recursion in FoldScaleAxis (#2299)
* [PASS] Avoid recursion in FoldScaleAxis * remove GetForwardScale
parent
e9e12f03
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
91 additions
and
113 deletions
+91
-113
src/relay/pass/fold_scale_axis.cc
+12
-58
src/relay/pass/pass_util.h
+9
-0
src/relay/pass/pattern_util.h
+0
-51
src/relay/pass/util.cc
+62
-0
tests/python/relay/test_pass_fold_scale_axis.py
+8
-4
No files found.
src/relay/pass/fold_scale_axis.cc
View file @
87a37684
...
...
@@ -246,44 +246,9 @@ class ForwardPrep : private ExprVisitor {
// Per operator defs for FScaleAxisForward
//----------------------------------------------
// Helper functions
Expr
GetForwardScale
(
const
Expr
&
expr
,
AxesSet
out
)
{
static
const
Op
&
multiply
=
Op
::
Get
(
"multiply"
);
static
const
auto
&
fprep
=
Op
::
GetAttr
<
FForwardPrep
>
(
"FScaleAxisForwardPrep"
);
const
CallNode
*
call
=
expr
.
as
<
CallNode
>
();
if
(
!
call
)
return
NullValue
<
Expr
>
();
auto
f
=
fprep
.
get
(
call
->
op
,
nullptr
);
if
(
call
->
op
.
same_as
(
multiply
))
{
const
auto
*
tlhs
=
call
->
args
[
0
]
->
type_as
<
TensorTypeNode
>
();
const
auto
*
trhs
=
call
->
args
[
1
]
->
type_as
<
TensorTypeNode
>
();
if
(
MatchBroadcastToLeftAxes
(
tlhs
,
trhs
,
out
))
{
return
call
->
args
[
1
];
}
else
if
(
MatchBroadcastToLeftAxes
(
trhs
,
tlhs
,
out
))
{
return
call
->
args
[
0
];
}
else
{
return
NullValue
<
Expr
>
();
}
}
else
if
(
f
!=
nullptr
)
{
Array
<
AxesSet
>
in_axes
=
f
(
GetRef
<
Call
>
(
call
),
out
);
for
(
size_t
i
=
0
;
i
<
call
->
args
.
size
();
i
++
)
{
auto
scale
=
GetForwardScale
(
call
->
args
[
i
],
in_axes
[
i
]);
if
(
scale
.
defined
())
{
return
scale
;
}
}
}
return
NullValue
<
Expr
>
();
}
// Intermediate operators
Array
<
AxesSet
>
ReluForwardPrep
(
const
Call
&
call
,
AxesSet
out
)
{
Expr
scale
=
GetForwardScale
(
call
->
args
[
0
],
out
);
if
(
IsPositiveConstant
(
scale
))
{
return
{
out
};
}
return
{
NullValue
<
AxesSet
>
()};
}
Expr
ReluForwardRewrite
(
const
Call
&
ref_call
,
...
...
@@ -391,16 +356,21 @@ Expr MultiplyForwardRewrite(const Call& ref_call,
Expr
lhs
=
new_args
[
0
];
Expr
rhs
=
new_args
[
1
];
auto
rnode
=
make_node
<
ScaledExprNode
>
();
if
(
MatchBroadcastToLeftAxes
(
tlhs
,
trhs
,
expected_out_axes
,
&
rhs
))
{
if
(
MatchBroadcastToLeftAxes
(
tlhs
,
trhs
,
expected_out_axes
,
&
rhs
)
&&
IsAllPositiveConstant
(
rhs
))
{
rnode
->
value
=
lhs
;
rnode
->
scale
=
rhs
;
rnode
->
axes
=
expected_out_axes
;
}
else
if
(
MatchBroadcastToLeftAxes
(
trhs
,
tlhs
,
expected_out_axes
,
&
lhs
))
{
return
Expr
(
rnode
);
}
else
if
(
MatchBroadcastToLeftAxes
(
trhs
,
tlhs
,
expected_out_axes
,
&
lhs
)
&&
IsAllPositiveConstant
(
lhs
))
{
rnode
->
value
=
rhs
;
rnode
->
scale
=
lhs
;
rnode
->
axes
=
expected_out_axes
;
}
return
Expr
(
rnode
);
}
else
{
return
Expr
();
}
}
RELAY_REGISTER_OP
(
"multiply"
)
...
...
@@ -790,22 +760,6 @@ RELAY_REGISTER_OP("subtract")
RELAY_REGISTER_OP
(
"subtract"
)
.
set_attr
<
FBackwardTransform
>
(
"FScaleAxisBackwardTransform"
,
AddSubBackwardTransform
);
// Find relu in the backward path between multiply and conv2d
bool
FindBackwardRelu
(
const
Expr
&
expr
)
{
const
CallNode
*
call
=
expr
.
as
<
CallNode
>
();
static
const
Op
&
conv2d
=
Op
::
Get
(
"nn.conv2d"
);
static
const
Op
&
relu
=
Op
::
Get
(
"nn.relu"
);
if
(
!
call
)
return
false
;
if
(
call
->
op
.
same_as
(
relu
))
return
true
;
if
(
call
->
op
.
same_as
(
conv2d
))
return
false
;
for
(
size_t
i
=
0
;
i
<
call
->
args
.
size
();
i
++
)
{
if
(
FindBackwardRelu
(
call
->
args
[
i
]))
return
true
;
}
return
false
;
}
// Producer operators
// Multiply produces the scale-axis pair.
Expr
MultiplyBackwardTransform
(
const
Call
&
call
,
...
...
@@ -821,16 +775,16 @@ Expr MultiplyBackwardTransform(const Call& call,
// NOTE we won't recursively call mutating on scale part.
// since there won't be scale chance within scale part.
Expr
rhs
=
call
->
args
[
1
];
// Only propagate positive scaling.
if
(
MatchBroadcastToLeftAxes
(
tlhs
,
trhs
,
lhs_axes
,
&
rhs
)
&&
(
!
FindBackwardRelu
(
call
->
args
[
0
])
||
IsPositiveConstant
(
call
->
args
[
1
])))
{
IsAllPositiveConstant
(
rhs
))
{
return
transformer
->
Transform
(
call
->
args
[
0
],
lhs_axes
,
rhs
);
}
}
else
if
(
rhs_axes
.
defined
()
&&
rhs_axes
.
size
()
!=
0
)
{
// Only propagate positive scaling.
Expr
lhs
=
call
->
args
[
0
];
if
(
MatchBroadcastToLeftAxes
(
trhs
,
tlhs
,
rhs_axes
,
&
lhs
)
&&
(
!
FindBackwardRelu
(
call
->
args
[
1
])
||
IsPositiveConstant
(
call
->
args
[
0
])))
{
IsAllPositiveConstant
(
lhs
))
{
return
transformer
->
Transform
(
call
->
args
[
1
],
rhs_axes
,
lhs
);
}
}
...
...
src/relay/pass/pass_util.h
View file @
87a37684
...
...
@@ -22,6 +22,15 @@ namespace relay {
std
::
unordered_map
<
const
Node
*
,
size_t
>
GetExprRefCount
(
const
Expr
&
body
);
/*!
* \brief Check if expr is positive constant.
* \param expr The expression to be checked.
* \return Whether all elements of expr is positive constant.
*/
bool
IsAllPositiveConstant
(
const
Expr
&
expr
);
/*!
* \brief Substitute var with subst.
* \param type The type to be substituted.
...
...
src/relay/pass/pattern_util.h
View file @
87a37684
...
...
@@ -190,57 +190,6 @@ Expr MakeConcatenate(Expr data, int axis);
Expr
MakeStridedSlice
(
Expr
data
,
Array
<
Integer
>
begin
,
Array
<
Integer
>
end
,
Array
<
Integer
>
strides
);
template
<
typename
T
>
bool
IsNDArrayAllGreaterEqual
(
const
runtime
::
NDArray
&
tensor
,
T
value
)
{
CHECK_EQ
(
tensor
->
ctx
.
device_type
,
kDLCPU
);
CHECK
(
tensor
->
strides
==
nullptr
);
CHECK_EQ
(
tensor
->
byte_offset
,
0
);
const
T
*
data
=
static_cast
<
const
T
*>
(
tensor
->
data
);
int64_t
num_elems
=
1
;
for
(
int
i
=
0
;
i
<
tensor
->
ndim
;
++
i
)
{
num_elems
*=
tensor
->
shape
[
i
];
}
for
(
int64_t
i
=
0
;
i
<
num_elems
;
i
++
)
{
if
(
*
data
<
value
)
{
return
false
;
}
data
++
;
}
return
true
;
}
inline
bool
IsPositiveConstant
(
const
Expr
&
expr
)
{
const
auto
*
constant
=
expr
.
as
<
ConstantNode
>
();
if
(
!
constant
)
return
false
;
const
auto
&
tensor
=
constant
->
data
;
const
auto
&
dtype
=
tensor
->
dtype
;
if
(
dtype
.
lanes
!=
1
)
{
// pass
}
else
if
(
dtype
.
code
==
kDLFloat
&&
dtype
.
bits
==
32
)
{
return
IsNDArrayAllGreaterEqual
<
float
>
(
tensor
,
0
);
}
else
if
(
dtype
.
code
==
kDLFloat
&&
dtype
.
bits
==
64
)
{
return
IsNDArrayAllGreaterEqual
<
double
>
(
tensor
,
0
);
}
else
if
(
dtype
.
code
==
kDLInt
&&
dtype
.
bits
==
8
)
{
return
IsNDArrayAllGreaterEqual
<
int8_t
>
(
tensor
,
0
);
}
else
if
(
dtype
.
code
==
kDLInt
&&
dtype
.
bits
==
32
)
{
return
IsNDArrayAllGreaterEqual
<
int32_t
>
(
tensor
,
0
);
}
else
if
(
dtype
.
code
==
kDLUInt
&&
dtype
.
bits
==
8
)
{
return
IsNDArrayAllGreaterEqual
<
uint8_t
>
(
tensor
,
0
);
}
else
if
(
dtype
.
code
==
kDLUInt
&&
dtype
.
bits
==
32
)
{
return
IsNDArrayAllGreaterEqual
<
uint32_t
>
(
tensor
,
0
);
}
LOG
(
WARNING
)
<<
"Unsupported data type (code = "
<<
dtype
.
code
<<
", bits = "
<<
dtype
.
bits
<<
", lanes = "
<<
dtype
.
lanes
<<
")"
;
return
false
;
}
}
// namespace relay
}
// namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
src/relay/pass/util.cc
View file @
87a37684
...
...
@@ -146,5 +146,67 @@ GetExprRefCount(const Expr& body) {
return
ExprRefCounter
().
Get
(
body
);
}
template
<
typename
T
>
bool
IsNDArrayAllGreaterEqual
(
const
runtime
::
NDArray
&
tensor
,
T
value
)
{
CHECK_EQ
(
tensor
->
ctx
.
device_type
,
kDLCPU
);
CHECK
(
tensor
->
strides
==
nullptr
);
CHECK_EQ
(
tensor
->
byte_offset
,
0
);
const
T
*
data
=
static_cast
<
const
T
*>
(
tensor
->
data
);
int64_t
num_elems
=
1
;
for
(
int
i
=
0
;
i
<
tensor
->
ndim
;
++
i
)
{
num_elems
*=
tensor
->
shape
[
i
];
}
for
(
int64_t
i
=
0
;
i
<
num_elems
;
i
++
)
{
if
(
*
data
<
value
)
{
return
false
;
}
data
++
;
}
return
true
;
}
bool
IsAllPositiveConstant
(
const
Expr
&
expr
)
{
// peel through a few common transform ops.
static
const
auto
&
expand_dims
=
Op
::
Get
(
"expand_dims"
);
static
const
auto
&
reshape
=
Op
::
Get
(
"reshape"
);
static
const
auto
&
transpose
=
Op
::
Get
(
"transpose"
);
static
const
auto
&
squeeze
=
Op
::
Get
(
"squeeze"
);
if
(
const
auto
*
constant
=
expr
.
as
<
ConstantNode
>
())
{
const
auto
&
tensor
=
constant
->
data
;
const
auto
&
dtype
=
tensor
->
dtype
;
if
(
dtype
.
lanes
!=
1
)
{
return
false
;
}
else
if
(
dtype
.
code
==
kDLFloat
&&
dtype
.
bits
==
32
)
{
return
IsNDArrayAllGreaterEqual
<
float
>
(
tensor
,
0
);
}
else
if
(
dtype
.
code
==
kDLFloat
&&
dtype
.
bits
==
64
)
{
return
IsNDArrayAllGreaterEqual
<
double
>
(
tensor
,
0
);
}
else
if
(
dtype
.
code
==
kDLInt
&&
dtype
.
bits
==
8
)
{
return
IsNDArrayAllGreaterEqual
<
int8_t
>
(
tensor
,
0
);
}
else
if
(
dtype
.
code
==
kDLInt
&&
dtype
.
bits
==
32
)
{
return
IsNDArrayAllGreaterEqual
<
int32_t
>
(
tensor
,
0
);
}
else
if
(
dtype
.
code
==
kDLUInt
&&
dtype
.
bits
==
8
)
{
return
IsNDArrayAllGreaterEqual
<
uint8_t
>
(
tensor
,
0
);
}
else
if
(
dtype
.
code
==
kDLUInt
&&
dtype
.
bits
==
32
)
{
return
IsNDArrayAllGreaterEqual
<
uint32_t
>
(
tensor
,
0
);
}
else
{
return
false
;
}
}
else
if
(
const
auto
*
op
=
expr
.
as
<
CallNode
>
())
{
// tail recursion.
if
(
op
->
op
.
same_as
(
expand_dims
)
||
op
->
op
.
same_as
(
reshape
)
||
op
->
op
.
same_as
(
transpose
)
||
op
->
op
.
same_as
(
squeeze
))
{
return
IsAllPositiveConstant
(
op
->
args
[
0
]);
}
else
{
return
false
;
}
}
else
{
return
false
;
}
}
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_pass_fold_scale_axis.py
View file @
87a37684
from
tvm
import
relay
import
numpy
as
np
def
_get_positive_scale
(
size
):
return
np
.
random
.
uniform
(
0.5
,
1
,
size
=
size
)
.
astype
(
'float32'
)
def
test_fold_fwd_simple
():
"""Simple testcase."""
...
...
@@ -14,6 +17,7 @@ def test_fold_fwd_simple():
channels
=
channels
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
return
relay
.
Function
(
args
,
y
)
def
expected
(
x
,
conv_weight
,
in_bias
,
in_scale
,
channels
):
...
...
@@ -37,14 +41,14 @@ def test_fold_fwd_simple():
in_channels
=
shape
[
1
]
weight
=
relay
.
var
(
"weight"
)
in_bias
=
relay
.
var
(
"in_bias"
,
shape
=
(
in_channels
,))
in_scale
=
relay
.
const
(
np
.
random
.
uniform
(
size
=
(
in_channels
,
1
,
1
))
.
astype
(
'float32'
))
in_scale
=
relay
.
const
(
_get_positive_scale
((
in_channels
,
1
,
1
)))
y1
=
before
(
x
,
weight
,
in_bias
,
in_scale
,
channels
)
y1
=
relay
.
ir_pass
.
infer_type
(
y1
)
type_dict
=
{
x
.
name_hint
:
x
.
checked_type
for
x
in
y1
.
params
}
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
y1_folded
=
relay
.
ir_pass
.
forward_fold_scale_axis
(
y1
)
y1_expected
=
expected
(
x
,
weight
,
in_bias
,
in_scale
,
channels
)
y1_folded
=
relay
.
ir_pass
.
infer_type
(
y1_folded
)
y1_expected
=
relay
.
ir_pass
.
infer_type
(
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
...
...
@@ -107,7 +111,7 @@ def test_fold_fwd_dual_path():
assert
in_channels
==
channels
weight
=
relay
.
var
(
"weight"
)
in_bias
=
relay
.
var
(
"in_bias"
,
shape
=
(
in_channels
,))
in_scale
=
relay
.
const
(
np
.
random
.
uniform
(
size
=
(
in_channels
,))
.
astype
(
"float32"
))
in_scale
=
relay
.
const
(
_get_positive_scale
(
in_channels
,
))
y1
=
before
(
x
,
weight
,
in_bias
,
in_scale
,
channels
)
y1
=
relay
.
ir_pass
.
infer_type
(
y1
)
y1_folded
=
relay
.
ir_pass
.
forward_fold_scale_axis
(
y1
)
...
...
@@ -141,7 +145,7 @@ def test_fold_fwd_fail():
assert
in_channels
==
channels
weight
=
relay
.
var
(
"weight"
)
in_bias
=
relay
.
var
(
"in_bias"
,
shape
=
(
in_channels
,))
in_scale
=
relay
.
const
(
np
.
random
.
uniform
(
size
=
(
in_channels
,))
.
astype
(
"float32"
))
in_scale
=
relay
.
const
(
_get_positive_scale
(
size
=
(
in_channels
,)
))
y1
=
before
(
x
,
weight
,
in_bias
,
in_scale
,
channels
)
y1
=
relay
.
ir_pass
.
infer_type
(
y1
)
y1_folded
=
relay
.
ir_pass
.
forward_fold_scale_axis
(
y1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment