Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
629a293a
Unverified
Commit
629a293a
authored
Nov 14, 2018
by
Tianqi Chen
Committed by
GitHub
Nov 14, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][PASS] FuseOps, fix input fusion rule for conv2d (#2110)
parent
b2521604
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
5 deletions
+19
-5
src/relay/pass/fuse_ops.cc
+11
-4
tests/python/relay/test_pass_fuse_ops.py
+8
-1
No files found.
src/relay/pass/fuse_ops.cc
View file @
629a293a
...
@@ -464,14 +464,15 @@ class GraphPartitioner {
...
@@ -464,14 +464,15 @@ class GraphPartitioner {
return
true
;
return
true
;
}
}
/*!
/*!
* \brief Check all the node between src and sink satisfies fcond.
* \brief Check all the node and edge pattern
* between src and sink satisfies fcond.
*
*
* src
and sink are
not checked.
* src
is
not checked.
*
*
* \param src The source node.
* \param src The source node.
* \param sink The termination node.
* \param sink The termination node.
* \param fcond The condition to be checked.
* \param fcond The condition to be checked.
* \tparam F the condition function
.
* \tparam F the condition function
, with signature
* \note sink must be a post-dominator of src.
* \note sink must be a post-dominator of src.
*/
*/
template
<
typename
F
>
template
<
typename
F
>
...
@@ -596,7 +597,12 @@ class GraphPartitioner {
...
@@ -596,7 +597,12 @@ class GraphPartitioner {
}
}
}
}
}
else
if
(
group_node
->
pattern
<=
kBroadcast
)
{
}
else
if
(
group_node
->
pattern
<=
kBroadcast
)
{
// The fuse can be executed if all the intermediate ops are still broadcast.
// Pre-condition: can only be fused to parent which is injective or reduction.
if
(
dom_node
->
parent
!=
nullptr
&&
(
dom_node
->
pattern
<=
kInjective
||
dom_node
->
pattern
==
kCommReduce
))
{
// Check if all the intermediate ops are still broadcast.
// The final terminal node can already be fused to a OutEWiseFusable group.
auto
fcond
=
[](
OpPatternKind
kind
,
bool
is_sink
)
{
auto
fcond
=
[](
OpPatternKind
kind
,
bool
is_sink
)
{
if
(
!
is_sink
)
{
if
(
!
is_sink
)
{
return
kind
<=
kBroadcast
;
return
kind
<=
kBroadcast
;
...
@@ -609,6 +615,7 @@ class GraphPartitioner {
...
@@ -609,6 +615,7 @@ class GraphPartitioner {
if
(
CheckPath
(
graph_node
,
dom_node
->
parent
->
gnode
,
fcond
))
{
if
(
CheckPath
(
graph_node
,
dom_node
->
parent
->
gnode
,
fcond
))
{
CommitFuse
(
graph_node
,
dom_node
->
parent
->
gnode
);
CommitFuse
(
graph_node
,
dom_node
->
parent
->
gnode
);
}
}
}
}
else
if
(
group_node
->
pattern
==
kInjective
)
{
}
else
if
(
group_node
->
pattern
==
kInjective
)
{
// defer injective fusion to second phase.
// defer injective fusion to second phase.
// so conv2d always finishes fusing.
// so conv2d always finishes fusing.
...
...
tests/python/relay/test_pass_fuse_ops.py
View file @
629a293a
...
@@ -29,10 +29,12 @@ def test_fuse_simple():
...
@@ -29,10 +29,12 @@ def test_fuse_simple():
def
test_conv2d_fuse
():
def
test_conv2d_fuse
():
"""Test fusion case of conv2d"""
"""Test fusion case of conv2d"""
def
before
(
dshape
):
def
before
(
dshape
):
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
x
=
relay
.
add
(
x
,
relay
.
const
(
1
,
"float32"
))
y
=
relay
.
nn
.
conv2d
(
x
,
relay
.
var
(
"w1"
),
y
=
relay
.
nn
.
conv2d
(
x
,
relay
.
var
(
"w1"
),
kernel_size
=
(
3
,
3
),
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
padding
=
(
1
,
1
),
...
@@ -54,6 +56,10 @@ def test_conv2d_fuse():
...
@@ -54,6 +56,10 @@ def test_conv2d_fuse():
return
relay
.
Function
(
relay
.
ir_pass
.
free_vars
(
z
),
z
)
return
relay
.
Function
(
relay
.
ir_pass
.
free_vars
(
z
),
z
)
def
expected
(
dshape
):
def
expected
(
dshape
):
# segment 0
x
=
relay
.
var
(
"p0"
,
shape
=
dshape
)
y
=
relay
.
add
(
x
,
relay
.
const
(
1
,
"float32"
))
f0
=
relay
.
Function
([
x
],
y
)
# segment 1
# segment 1
x
=
relay
.
var
(
"p0"
,
shape
=
dshape
)
x
=
relay
.
var
(
"p0"
,
shape
=
dshape
)
w
=
relay
.
var
(
"p1"
)
w
=
relay
.
var
(
"p1"
)
...
@@ -84,7 +90,8 @@ def test_conv2d_fuse():
...
@@ -84,7 +90,8 @@ def test_conv2d_fuse():
f3
=
relay
.
Function
([
x
,
w
,
offset
],
z3
)
f3
=
relay
.
Function
([
x
,
w
,
offset
],
z3
)
# compose
# compose
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
y
=
relay
.
Call
(
f1
,
[
x
,
relay
.
var
(
"w1"
)])
y
=
relay
.
Call
(
f0
,
[
x
])
y
=
relay
.
Call
(
f1
,
[
y
,
relay
.
var
(
"w1"
)])
z2
=
relay
.
Call
(
f2
,
[
y
,
relay
.
var
(
"w3"
)])
z2
=
relay
.
Call
(
f2
,
[
y
,
relay
.
var
(
"w3"
)])
z3
=
relay
.
Call
(
f3
,
[
y
,
relay
.
var
(
"w2"
),
z2
])
z3
=
relay
.
Call
(
f3
,
[
y
,
relay
.
var
(
"w2"
),
z2
])
z
=
z3
z
=
z3
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment