Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
93949456
Commit
93949456
authored
Dec 05, 2018
by
Wuwei Lin
Committed by
Tianqi Chen
Dec 04, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][PASS] Check Positiveness in FoldScaleAxis (#2220)
parent
166936cd
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
112 additions
and
5 deletions
+112
-5
python/tvm/relay/build_module.py
+3
-2
src/relay/pass/fold_scale_axis.cc
+58
-3
src/relay/pass/pattern_util.h
+51
-0
tests/python/relay/test_pass_fold_scale_axis.py
+0
-0
No files found.
python/tvm/relay/build_module.py
View file @
93949456
...
...
@@ -150,13 +150,14 @@ def optimize(func, params=None):
func
=
ir_pass
.
infer_type
(
func
)
func
=
ir_pass
.
combine_parallel_conv2d
(
func
)
if
cfg
.
pass_enabled
(
"FoldConstant"
):
func
=
ir_pass
.
fold_constant
(
func
)
if
cfg
.
pass_enabled
(
"FoldScaleAxis"
):
func
=
ir_pass
.
infer_type
(
func
)
func
=
ir_pass
.
backward_fold_scale_axis
(
func
)
func
=
ir_pass
.
infer_type
(
func
)
func
=
ir_pass
.
forward_fold_scale_axis
(
func
)
if
cfg
.
pass_enabled
(
"FoldConstant"
):
func
=
ir_pass
.
fold_constant
(
func
)
if
cfg
.
pass_enabled
(
"AlterOpLayout"
):
...
...
src/relay/pass/fold_scale_axis.cc
View file @
93949456
...
...
@@ -246,9 +246,44 @@ class ForwardPrep : private ExprVisitor {
// Per operator defs for FScaleAxisForward
//----------------------------------------------
// Helper functions
Expr
GetForwardScale
(
const
Expr
&
expr
,
AxesSet
out
)
{
static
const
Op
&
multiply
=
Op
::
Get
(
"multiply"
);
static
const
auto
&
fprep
=
Op
::
GetAttr
<
FForwardPrep
>
(
"FScaleAxisForwardPrep"
);
const
CallNode
*
call
=
expr
.
as
<
CallNode
>
();
if
(
!
call
)
return
NullValue
<
Expr
>
();
auto
f
=
fprep
.
get
(
call
->
op
,
nullptr
);
if
(
call
->
op
.
same_as
(
multiply
))
{
const
auto
*
tlhs
=
call
->
args
[
0
]
->
type_as
<
TensorTypeNode
>
();
const
auto
*
trhs
=
call
->
args
[
1
]
->
type_as
<
TensorTypeNode
>
();
if
(
MatchBroadcastToLeftAxes
(
tlhs
,
trhs
,
out
))
{
return
call
->
args
[
1
];
}
else
if
(
MatchBroadcastToLeftAxes
(
trhs
,
tlhs
,
out
))
{
return
call
->
args
[
0
];
}
else
{
return
NullValue
<
Expr
>
();
}
}
else
if
(
f
!=
nullptr
)
{
Array
<
AxesSet
>
in_axes
=
f
(
GetRef
<
Call
>
(
call
),
out
);
for
(
size_t
i
=
0
;
i
<
call
->
args
.
size
();
i
++
)
{
auto
scale
=
GetForwardScale
(
call
->
args
[
i
],
in_axes
[
i
]);
if
(
scale
.
defined
())
{
return
scale
;
}
}
}
return
NullValue
<
Expr
>
();
}
// Intermediate operators
Array
<
AxesSet
>
ReluForwardPrep
(
const
Call
&
call
,
AxesSet
out
)
{
return
{
out
};
Expr
scale
=
GetForwardScale
(
call
->
args
[
0
],
out
);
if
(
IsPositiveConstant
(
scale
))
{
return
{
out
};
}
return
{
NullValue
<
AxesSet
>
()};
}
Expr
ReluForwardRewrite
(
const
Call
&
ref_call
,
...
...
@@ -755,6 +790,22 @@ RELAY_REGISTER_OP("subtract")
RELAY_REGISTER_OP
(
"subtract"
)
.
set_attr
<
FBackwardTransform
>
(
"FScaleAxisBackwardTransform"
,
AddSubBackwardTransform
);
// Find relu in the backward path between multiply and conv2d
bool
FindBackwardRelu
(
const
Expr
&
expr
)
{
const
CallNode
*
call
=
expr
.
as
<
CallNode
>
();
static
const
Op
&
conv2d
=
Op
::
Get
(
"nn.conv2d"
);
static
const
Op
&
relu
=
Op
::
Get
(
"nn.relu"
);
if
(
!
call
)
return
false
;
if
(
call
->
op
.
same_as
(
relu
))
return
true
;
if
(
call
->
op
.
same_as
(
conv2d
))
return
false
;
for
(
size_t
i
=
0
;
i
<
call
->
args
.
size
();
i
++
)
{
if
(
FindBackwardRelu
(
call
->
args
[
i
]))
return
true
;
}
return
false
;
}
// Producer operators
// Multiply produces the scale-axis pair.
Expr
MultiplyBackwardTransform
(
const
Call
&
call
,
...
...
@@ -770,12 +821,16 @@ Expr MultiplyBackwardTransform(const Call& call,
// NOTE we won't recursively call mutating on scale part.
// since there won't be scale chance within scale part.
Expr
rhs
=
call
->
args
[
1
];
if
(
MatchBroadcastToLeftAxes
(
tlhs
,
trhs
,
lhs_axes
,
&
rhs
))
{
if
(
MatchBroadcastToLeftAxes
(
tlhs
,
trhs
,
lhs_axes
,
&
rhs
)
&&
(
!
FindBackwardRelu
(
call
->
args
[
0
])
||
IsPositiveConstant
(
call
->
args
[
1
])))
{
return
transformer
->
Transform
(
call
->
args
[
0
],
lhs_axes
,
rhs
);
}
}
else
if
(
rhs_axes
.
defined
()
&&
rhs_axes
.
size
()
!=
0
)
{
Expr
lhs
=
call
->
args
[
0
];
if
(
MatchBroadcastToLeftAxes
(
trhs
,
tlhs
,
rhs_axes
,
&
lhs
))
{
if
(
MatchBroadcastToLeftAxes
(
trhs
,
tlhs
,
rhs_axes
,
&
lhs
)
&&
(
!
FindBackwardRelu
(
call
->
args
[
1
])
||
IsPositiveConstant
(
call
->
args
[
0
])))
{
return
transformer
->
Transform
(
call
->
args
[
1
],
rhs_axes
,
lhs
);
}
}
...
...
src/relay/pass/pattern_util.h
View file @
93949456
...
...
@@ -190,6 +190,57 @@ Expr MakeConcatenate(Expr data, int axis);
Expr
MakeStridedSlice
(
Expr
data
,
Array
<
Integer
>
begin
,
Array
<
Integer
>
end
,
Array
<
Integer
>
strides
);
template
<
typename
T
>
bool
IsNDArrayAllGreaterEqual
(
const
runtime
::
NDArray
&
tensor
,
T
value
)
{
CHECK_EQ
(
tensor
->
ctx
.
device_type
,
kDLCPU
);
CHECK
(
tensor
->
strides
==
nullptr
);
CHECK_EQ
(
tensor
->
byte_offset
,
0
);
const
T
*
data
=
static_cast
<
const
T
*>
(
tensor
->
data
);
int64_t
num_elems
=
1
;
for
(
int
i
=
0
;
i
<
tensor
->
ndim
;
++
i
)
{
num_elems
*=
tensor
->
shape
[
i
];
}
for
(
int64_t
i
=
0
;
i
<
num_elems
;
i
++
)
{
if
(
*
data
<
value
)
{
return
false
;
}
data
++
;
}
return
true
;
}
inline
bool
IsPositiveConstant
(
const
Expr
&
expr
)
{
const
auto
*
constant
=
expr
.
as
<
ConstantNode
>
();
if
(
!
constant
)
return
false
;
const
auto
&
tensor
=
constant
->
data
;
const
auto
&
dtype
=
tensor
->
dtype
;
if
(
dtype
.
lanes
!=
1
)
{
// pass
}
else
if
(
dtype
.
code
==
kDLFloat
&&
dtype
.
bits
==
32
)
{
return
IsNDArrayAllGreaterEqual
<
float
>
(
tensor
,
0
);
}
else
if
(
dtype
.
code
==
kDLFloat
&&
dtype
.
bits
==
64
)
{
return
IsNDArrayAllGreaterEqual
<
double
>
(
tensor
,
0
);
}
else
if
(
dtype
.
code
==
kDLInt
&&
dtype
.
bits
==
8
)
{
return
IsNDArrayAllGreaterEqual
<
int8_t
>
(
tensor
,
0
);
}
else
if
(
dtype
.
code
==
kDLInt
&&
dtype
.
bits
==
32
)
{
return
IsNDArrayAllGreaterEqual
<
int32_t
>
(
tensor
,
0
);
}
else
if
(
dtype
.
code
==
kDLUInt
&&
dtype
.
bits
==
8
)
{
return
IsNDArrayAllGreaterEqual
<
uint8_t
>
(
tensor
,
0
);
}
else
if
(
dtype
.
code
==
kDLUInt
&&
dtype
.
bits
==
32
)
{
return
IsNDArrayAllGreaterEqual
<
uint32_t
>
(
tensor
,
0
);
}
LOG
(
WARNING
)
<<
"Unsupported data type (code = "
<<
dtype
.
code
<<
", bits = "
<<
dtype
.
bits
<<
", lanes = "
<<
dtype
.
lanes
<<
")"
;
return
false
;
}
}
// namespace relay
}
// namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
tests/python/relay/test_pass_fold_scale_axis.py
View file @
93949456
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment