Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
ca352770
Commit
ca352770
authored
Sep 05, 2019
by
雾雨魔理沙
Committed by
masahi
Sep 06, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] Fix operator fusion for multiple output (#3871)
* save * add test * refactor * fix indent * save * refactor
parent
57cd27f1
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
83 additions
and
77 deletions
+83
-77
src/relay/ir/pretty_printer.cc
+11
-7
src/relay/pass/fuse_ops.cc
+59
-70
tests/python/relay/test_pass_fuse_ops.py
+13
-0
No files found.
src/relay/ir/pretty_printer.cc
View file @
ca352770
...
@@ -304,14 +304,16 @@ class PrettyPrinter :
...
@@ -304,14 +304,16 @@ class PrettyPrinter :
* \return The corresponding name.
* \return The corresponding name.
*/
*/
Doc
AllocTypeVar
(
const
TypeVar
&
var
)
{
Doc
AllocTypeVar
(
const
TypeVar
&
var
)
{
if
(
memo_type_
.
count
(
var
))
{
Doc
val
=
memo_type_
[
var
];
val
<<
"-malformed-ir"
;
return
val
;
}
std
::
string
name
=
var
->
var
->
name_hint
;
std
::
string
name
=
var
->
var
->
name_hint
;
if
(
name
.
length
()
==
0
||
!
std
::
isalpha
(
name
[
0
]))
{
if
(
name
.
length
()
==
0
||
!
std
::
isalpha
(
name
[
0
]))
{
name
=
"t"
+
name
;
name
=
"t"
+
name
;
}
}
Doc
val
=
GetUniqueName
(
"%"
+
name
);
Doc
val
=
GetUniqueName
(
"%"
+
name
);
if
(
memo_type_
.
count
(
var
))
{
val
<<
"-malformed-ir"
;
}
memo_type_
[
var
]
=
val
;
memo_type_
[
var
]
=
val
;
if
(
var
->
kind
!=
kType
)
{
if
(
var
->
kind
!=
kType
)
{
val
<<
": "
<<
Print
(
var
->
kind
);
val
<<
": "
<<
Print
(
var
->
kind
);
...
@@ -325,16 +327,18 @@ class PrettyPrinter :
...
@@ -325,16 +327,18 @@ class PrettyPrinter :
* \return The corresponding name.
* \return The corresponding name.
*/
*/
Doc
AllocVar
(
const
Var
&
var
)
{
Doc
AllocVar
(
const
Var
&
var
)
{
// still print if ir is malformed, but show the error.
if
(
memo_
.
count
(
var
))
{
Doc
val
=
memo_
[
var
];
val
<<
"-malformed-ir"
;
return
val
;
}
std
::
string
name
=
var
->
name_hint
();
std
::
string
name
=
var
->
name_hint
();
// always make sure first name is alpha
// always make sure first name is alpha
if
(
name
.
length
()
==
0
||
!
std
::
isalpha
(
name
[
0
]))
{
if
(
name
.
length
()
==
0
||
!
std
::
isalpha
(
name
[
0
]))
{
name
=
"v"
+
name
;
name
=
"v"
+
name
;
}
}
Doc
val
=
GetUniqueName
(
"%"
+
name
);
Doc
val
=
GetUniqueName
(
"%"
+
name
);
// still print if ir is malformed, but show the error.
if
(
memo_
.
count
(
var
))
{
val
<<
"-malformed-ir"
;
}
memo_
[
var
]
=
val
;
memo_
[
var
]
=
val
;
if
(
var
->
type_annotation
.
defined
())
{
if
(
var
->
type_annotation
.
defined
())
{
val
<<
": "
<<
Print
(
var
->
type_annotation
);
val
<<
": "
<<
Print
(
var
->
type_annotation
);
...
...
src/relay/pass/fuse_ops.cc
View file @
ca352770
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
*/
*/
/*!
/*!
* Copyright (c) 201
8
by Contributors
* Copyright (c) 201
9
by Contributors
*
*
* \file src/tvm/relay/pass/fuse_ops.cc
* \file src/tvm/relay/pass/fuse_ops.cc
*
*
...
@@ -247,11 +247,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
...
@@ -247,11 +247,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
node
->
pattern
=
op_pattern
;
node
->
pattern
=
op_pattern
;
this
->
Update
(
call
->
op
,
nullptr
,
kOpaque
);
this
->
Update
(
call
->
op
,
nullptr
,
kOpaque
);
const
auto
*
rtype
=
call
->
checked_type
().
as
<
TensorTypeNode
>
();
const
auto
*
rtype
=
call
->
checked_type
().
as
<
TensorTypeNode
>
();
// pass the
message
back to all the children it references.
// pass the
analysis
back to all the children it references.
for
(
size_t
i
=
0
;
i
<
call
->
args
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
call
->
args
.
size
();
++
i
)
{
const
auto
*
arg_type
=
const
auto
*
arg_type
=
call
->
args
[
i
]
->
checked_type
().
as
<
TensorTypeNode
>
();
call
->
args
[
i
]
->
checked_type
().
as
<
TensorTypeNode
>
();
// specifically check if result type
// specifically check if result type
is the same as arguments type
OpPatternKind
edge_pattern
=
op_pattern
;
OpPatternKind
edge_pattern
=
op_pattern
;
if
(
edge_pattern
==
kBroadcast
&&
if
(
edge_pattern
==
kBroadcast
&&
arg_type
!=
nullptr
&&
arg_type
!=
nullptr
&&
...
@@ -403,12 +403,12 @@ class DominatorTree {
...
@@ -403,12 +403,12 @@ class DominatorTree {
return
rhs
;
return
rhs
;
}
}
/*!
/*!
* \brief Find the least common a
cen
stor of the two nodes.
* \brief Find the least common a
nce
stor of the two nodes.
* \param lhs The left node.
* \param lhs The left node.
* \param rhs The right node.
* \param rhs The right node.
* \param edge_pattern
* \param edge_pattern
* The combined edge pattern across all the parents.
* The combined edge pattern across all the parents.
* \return The least common ancestor of th
w
two.
* \return The least common ancestor of th
e
two.
*/
*/
static
Node
*
LeastCommonAncestor
(
static
Node
*
LeastCommonAncestor
(
Node
*
lhs
,
Node
*
lhs
,
...
@@ -436,17 +436,43 @@ class DominatorTree {
...
@@ -436,17 +436,43 @@ class DominatorTree {
}
}
return
lhs
;
return
lhs
;
}
}
};
/*!
* \brief Find the least common ancestor of a list of nodes.
DominatorTree
DominatorTree
::
PostDom
(
common
::
Arena
*
arena
,
* \param nodes the nodes.
const
IndexedForwardGraph
&
graph
)
{
* \param edge_pattern
DominatorTree
tree
;
* The combined edge pattern across all the parents.
tree
.
nodes
.
resize
(
graph
.
post_dfs_order
.
size
(),
nullptr
);
* \return The least common ancestor of all nodes.
// reverse topo order
*/
for
(
size_t
i
=
graph
.
post_dfs_order
.
size
();
i
!=
0
;
--
i
)
{
Node
*
LeastCommonAncestor
(
const
LinkedList
<
IndexedForwardGraph
::
Edge
>&
input_nodes
,
size_t
index
=
i
-
1
;
OpPatternKind
*
edge_pattern
)
{
auto
link
=
input_nodes
.
head
;
if
(
link
==
nullptr
)
{
return
nullptr
;
}
auto
get_node
=
[
&
](
const
IndexedForwardGraph
::
Edge
&
edge
)
{
size_t
oindex
=
edge
.
node
->
index
;
CHECK_LT
(
oindex
,
nodes
.
size
());
Node
*
onode
=
nodes
[
oindex
];
CHECK
(
onode
!=
nullptr
);
return
onode
;
};
Node
*
parent
=
get_node
(
link
->
value
);
*
edge_pattern
=
CombinePattern
(
*
edge_pattern
,
link
->
value
.
pattern
);
link
=
link
->
next
;
for
(;
link
!=
nullptr
;
link
=
link
->
next
)
{
parent
=
LeastCommonAncestor
(
parent
,
get_node
(
link
->
value
),
edge_pattern
);
*
edge_pattern
=
CombinePattern
(
*
edge_pattern
,
link
->
value
.
pattern
);
}
return
parent
;
}
/*!
* \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node.
* \param arena The Arena.
* \param gnode An IndexedForwardGraph Node.
* \return The DominatorTree Node.
*/
Node
*
GetNode
(
common
::
Arena
*
arena
,
IndexedForwardGraph
::
Node
*
gnode
)
{
Node
*
tnode
=
arena
->
make
<
Node
>
();
Node
*
tnode
=
arena
->
make
<
Node
>
();
auto
*
gnode
=
graph
.
post_dfs_order
[
index
];
tnode
->
gnode
=
gnode
;
tnode
->
gnode
=
gnode
;
if
(
gnode
->
extern_ref
)
{
if
(
gnode
->
extern_ref
)
{
tnode
->
depth
=
1
;
tnode
->
depth
=
1
;
...
@@ -455,24 +481,24 @@ DominatorTree DominatorTree::PostDom(common::Arena* arena,
...
@@ -455,24 +481,24 @@ DominatorTree DominatorTree::PostDom(common::Arena* arena,
}
else
{
}
else
{
// find the LCAs of all outputs.
// find the LCAs of all outputs.
OpPatternKind
pattern
=
kElemWise
;
OpPatternKind
pattern
=
kElemWise
;
Node
*
parent
=
nullptr
;
Node
*
parent
=
LeastCommonAncestor
(
gnode
->
outputs
,
&
pattern
);
for
(
auto
link
=
gnode
->
outputs
.
head
;
link
!=
nullptr
;
link
=
link
->
next
)
{
size_t
oindex
=
link
->
value
.
node
->
index
;
CHECK_LT
(
oindex
,
tree
.
nodes
.
size
());
Node
*
onode
=
tree
.
nodes
[
oindex
];
CHECK
(
onode
!=
nullptr
);
if
(
parent
!=
nullptr
)
{
parent
=
LeastCommonAncestor
(
parent
,
onode
,
&
pattern
);
}
else
{
parent
=
onode
;
}
pattern
=
CombinePattern
(
pattern
,
link
->
value
.
pattern
);
}
tnode
->
depth
=
parent
?
parent
->
depth
+
1
:
1
;
tnode
->
depth
=
parent
?
parent
->
depth
+
1
:
1
;
tnode
->
parent
=
parent
;
tnode
->
parent
=
parent
;
tnode
->
pattern
=
pattern
;
tnode
->
pattern
=
pattern
;
}
}
tree
.
nodes
[
index
]
=
tnode
;
return
tnode
;
}
};
DominatorTree
DominatorTree
::
PostDom
(
common
::
Arena
*
arena
,
const
IndexedForwardGraph
&
graph
)
{
DominatorTree
tree
;
tree
.
nodes
.
resize
(
graph
.
post_dfs_order
.
size
(),
nullptr
);
// reverse topo order
for
(
size_t
i
=
graph
.
post_dfs_order
.
size
();
i
!=
0
;
--
i
)
{
size_t
index
=
i
-
1
;
tree
.
nodes
[
index
]
=
tree
.
GetNode
(
arena
,
graph
.
post_dfs_order
[
index
]);
}
}
return
tree
;
return
tree
;
}
}
...
@@ -614,7 +640,7 @@ class GraphPartitioner {
...
@@ -614,7 +640,7 @@ class GraphPartitioner {
// merge the current group to the parent if possible.
// merge the current group to the parent if possible.
MergeFromTo
(
gnode
,
target
);
MergeFromTo
(
gnode
,
target
);
for
(
auto
link
=
src
->
outputs
.
head
;
link
!=
nullptr
;
link
=
link
->
next
)
{
for
(
auto
link
=
src
->
outputs
.
head
;
link
!=
nullptr
;
link
=
link
->
next
)
{
CommitFuse_
(
link
->
value
.
node
,
sink
,
target
);
;
CommitFuse_
(
link
->
value
.
node
,
sink
,
target
);
}
}
}
}
/*!
/*!
...
@@ -851,7 +877,7 @@ class FuseMutator : private ExprMutator {
...
@@ -851,7 +877,7 @@ class FuseMutator : private ExprMutator {
Expr
VisitExpr_
(
const
TupleNode
*
tuple
)
{
Expr
VisitExpr_
(
const
TupleNode
*
tuple
)
{
auto
*
ret_group
=
gmap_
.
at
(
tuple
)
->
FindRoot
();
auto
*
ret_group
=
gmap_
.
at
(
tuple
)
->
FindRoot
();
if
(
ret_group
==
gmap_
.
at
(
tuple
)
)
{
if
(
ret_group
->
root_ref
==
tuple
)
{
return
ExprMutator
::
VisitExpr_
(
tuple
);
return
ExprMutator
::
VisitExpr_
(
tuple
);
}
}
// This tuple is an intermediate node in the group
// This tuple is an intermediate node in the group
...
@@ -863,7 +889,7 @@ class FuseMutator : private ExprMutator {
...
@@ -863,7 +889,7 @@ class FuseMutator : private ExprMutator {
auto
*
ret_group
=
gmap_
.
at
(
tuple_get
)
->
FindRoot
();
auto
*
ret_group
=
gmap_
.
at
(
tuple_get
)
->
FindRoot
();
auto
new_tuple
=
GetNewArguments
({
tuple_get
->
tuple
},
ret_group
)[
0
];
auto
new_tuple
=
GetNewArguments
({
tuple_get
->
tuple
},
ret_group
)[
0
];
auto
new_node
=
TupleGetItemNode
::
make
(
new_tuple
,
tuple_get
->
index
);
auto
new_node
=
TupleGetItemNode
::
make
(
new_tuple
,
tuple_get
->
index
);
if
(
ret_group
==
gmap_
.
at
(
tuple_get
)
)
{
if
(
ret_group
->
root_ref
==
tuple_get
)
{
if
(
gmap_
.
at
(
tuple_get
->
tuple
.
get
())
->
FindRoot
()
!=
ret_group
)
{
if
(
gmap_
.
at
(
tuple_get
->
tuple
.
get
())
->
FindRoot
()
!=
ret_group
)
{
// Isolated. This case occurs when tuple is created by an Opaque op
// Isolated. This case occurs when tuple is created by an Opaque op
// e.g. multibox_transform_loc
// e.g. multibox_transform_loc
...
@@ -922,45 +948,8 @@ class FuseMutator : private ExprMutator {
...
@@ -922,45 +948,8 @@ class FuseMutator : private ExprMutator {
}
}
};
};
// Temporary solution, should be handled by implementing a "FunctionPass"
// which applies fusion to each function.
struct
GlobalVarLiveness
:
ExprVisitor
{
Module
module
;
std
::
set
<
GlobalVar
>
visited
;
explicit
GlobalVarLiveness
(
const
Module
&
mod
)
:
module
(
mod
),
visited
()
{}
void
VisitExpr_
(
const
GlobalVarNode
*
gvar_node
)
{
auto
gvar
=
GetRef
<
GlobalVar
>
(
gvar_node
);
if
(
visited
.
find
(
gvar
)
==
visited
.
end
())
{
visited
.
insert
(
gvar
);
this
->
VisitExpr
(
this
->
module
->
Lookup
(
gvar
));
}
}
};
std
::
set
<
GlobalVar
>
LiveGlobals
(
const
Module
&
mod
,
const
Expr
&
expr
)
{
auto
gvl
=
GlobalVarLiveness
(
mod
);
gvl
.
VisitExpr
(
expr
);
return
gvl
.
visited
;
}
Expr
FuseOps
(
const
Expr
&
expr
,
int
fuse_opt_level
,
const
Module
&
module
)
{
Expr
FuseOps
(
const
Expr
&
expr
,
int
fuse_opt_level
,
const
Module
&
module
)
{
// First we convert all chains of fusable ops into
return
FuseMutator
().
Transform
(
expr
,
fuse_opt_level
);
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// new operators.
if
(
!
module
.
defined
())
{
return
FuseMutator
().
Transform
(
expr
,
fuse_opt_level
);
}
else
{
auto
lgvs
=
LiveGlobals
(
module
,
expr
);
for
(
auto
lv
:
lgvs
)
{
auto
body
=
module
->
Lookup
(
lv
);
auto
e
=
FuseMutator
().
Transform
(
body
,
fuse_opt_level
);
module
->
Add
(
lv
,
Downcast
<
Function
>
(
e
),
true
);
}
return
FuseMutator
().
Transform
(
expr
,
fuse_opt_level
);
}
}
}
namespace
transform
{
namespace
transform
{
...
...
tests/python/relay/test_pass_fuse_ops.py
View file @
ca352770
...
@@ -541,6 +541,18 @@ def test_immutable():
...
@@ -541,6 +541,18 @@ def test_immutable():
assert
relay
.
analysis
.
alpha_equal
(
new_mod
,
expected
())
assert
relay
.
analysis
.
alpha_equal
(
new_mod
,
expected
())
def
test_split
():
"""Test that the result is well formed."""
x
=
relay
.
var
(
"x"
,
shape
=
(
6
,
9
))
y
=
relay
.
split
(
x
,
3
)
.
astuple
()
a
=
relay
.
TupleGetItem
(
y
,
0
)
b
=
relay
.
TupleGetItem
(
y
,
1
)
c
=
relay
.
TupleGetItem
(
y
,
2
)
mod
=
relay
.
module
.
Module
()
mod
[
"main"
]
=
relay
.
Function
([
x
],
a
+
relay
.
RefRead
(
relay
.
RefCreate
(
b
))
+
c
)
mod
=
transform
.
FuseOps
()(
mod
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_fuse_simple
()
test_fuse_simple
()
test_conv2d_fuse
()
test_conv2d_fuse
()
...
@@ -555,3 +567,4 @@ if __name__ == "__main__":
...
@@ -555,3 +567,4 @@ if __name__ == "__main__":
test_inception_like
()
test_inception_like
()
test_fuse_parallel_injective
()
test_fuse_parallel_injective
()
test_immutable
()
test_immutable
()
test_split
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment