Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
ca352770
Commit
ca352770
authored
Sep 05, 2019
by
雾雨魔理沙
Committed by
masahi
Sep 06, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] Fix operator fusion for multiple output (#3871)
* save * add test * refactor * fix indent * save * refactor
parent
57cd27f1
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
83 additions
and
77 deletions
+83
-77
src/relay/ir/pretty_printer.cc
+11
-7
src/relay/pass/fuse_ops.cc
+59
-70
tests/python/relay/test_pass_fuse_ops.py
+13
-0
No files found.
src/relay/ir/pretty_printer.cc
View file @
ca352770
...
...
@@ -304,14 +304,16 @@ class PrettyPrinter :
* \return The corresponding name.
*/
Doc
AllocTypeVar
(
const
TypeVar
&
var
)
{
if
(
memo_type_
.
count
(
var
))
{
Doc
val
=
memo_type_
[
var
];
val
<<
"-malformed-ir"
;
return
val
;
}
std
::
string
name
=
var
->
var
->
name_hint
;
if
(
name
.
length
()
==
0
||
!
std
::
isalpha
(
name
[
0
]))
{
name
=
"t"
+
name
;
}
Doc
val
=
GetUniqueName
(
"%"
+
name
);
if
(
memo_type_
.
count
(
var
))
{
val
<<
"-malformed-ir"
;
}
memo_type_
[
var
]
=
val
;
if
(
var
->
kind
!=
kType
)
{
val
<<
": "
<<
Print
(
var
->
kind
);
...
...
@@ -325,16 +327,18 @@ class PrettyPrinter :
* \return The corresponding name.
*/
Doc
AllocVar
(
const
Var
&
var
)
{
// still print if ir is malformed, but show the error.
if
(
memo_
.
count
(
var
))
{
Doc
val
=
memo_
[
var
];
val
<<
"-malformed-ir"
;
return
val
;
}
std
::
string
name
=
var
->
name_hint
();
// always make sure first name is alpha
if
(
name
.
length
()
==
0
||
!
std
::
isalpha
(
name
[
0
]))
{
name
=
"v"
+
name
;
}
Doc
val
=
GetUniqueName
(
"%"
+
name
);
// still print if ir is malformed, but show the error.
if
(
memo_
.
count
(
var
))
{
val
<<
"-malformed-ir"
;
}
memo_
[
var
]
=
val
;
if
(
var
->
type_annotation
.
defined
())
{
val
<<
": "
<<
Print
(
var
->
type_annotation
);
...
...
src/relay/pass/fuse_ops.cc
View file @
ca352770
...
...
@@ -18,7 +18,7 @@
*/
/*!
* Copyright (c) 201
8
by Contributors
* Copyright (c) 201
9
by Contributors
*
* \file src/tvm/relay/pass/fuse_ops.cc
*
...
...
@@ -247,11 +247,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
node
->
pattern
=
op_pattern
;
this
->
Update
(
call
->
op
,
nullptr
,
kOpaque
);
const
auto
*
rtype
=
call
->
checked_type
().
as
<
TensorTypeNode
>
();
// pass the
message
back to all the children it references.
// pass the
analysis
back to all the children it references.
for
(
size_t
i
=
0
;
i
<
call
->
args
.
size
();
++
i
)
{
const
auto
*
arg_type
=
call
->
args
[
i
]
->
checked_type
().
as
<
TensorTypeNode
>
();
// specifically check if result type
// specifically check if result type
is the same as arguments type
OpPatternKind
edge_pattern
=
op_pattern
;
if
(
edge_pattern
==
kBroadcast
&&
arg_type
!=
nullptr
&&
...
...
@@ -403,12 +403,12 @@ class DominatorTree {
return
rhs
;
}
/*!
* \brief Find the least common a
cen
stor of the two nodes.
* \brief Find the least common a
nce
stor of the two nodes.
* \param lhs The left node.
* \param rhs The right node.
* \param edge_pattern
* The combined edge pattern across all the parents.
* \return The least common ancestor of th
w
two.
* \return The least common ancestor of th
e
two.
*/
static
Node
*
LeastCommonAncestor
(
Node
*
lhs
,
...
...
@@ -436,17 +436,43 @@ class DominatorTree {
}
return
lhs
;
}
};
DominatorTree
DominatorTree
::
PostDom
(
common
::
Arena
*
arena
,
const
IndexedForwardGraph
&
graph
)
{
DominatorTree
tree
;
tree
.
nodes
.
resize
(
graph
.
post_dfs_order
.
size
(),
nullptr
);
// reverse topo order
for
(
size_t
i
=
graph
.
post_dfs_order
.
size
();
i
!=
0
;
--
i
)
{
size_t
index
=
i
-
1
;
/*!
* \brief Find the least common ancestor of a list of nodes.
* \param nodes the nodes.
* \param edge_pattern
* The combined edge pattern across all the parents.
* \return The least common ancestor of all nodes.
*/
Node
*
LeastCommonAncestor
(
const
LinkedList
<
IndexedForwardGraph
::
Edge
>&
input_nodes
,
OpPatternKind
*
edge_pattern
)
{
auto
link
=
input_nodes
.
head
;
if
(
link
==
nullptr
)
{
return
nullptr
;
}
auto
get_node
=
[
&
](
const
IndexedForwardGraph
::
Edge
&
edge
)
{
size_t
oindex
=
edge
.
node
->
index
;
CHECK_LT
(
oindex
,
nodes
.
size
());
Node
*
onode
=
nodes
[
oindex
];
CHECK
(
onode
!=
nullptr
);
return
onode
;
};
Node
*
parent
=
get_node
(
link
->
value
);
*
edge_pattern
=
CombinePattern
(
*
edge_pattern
,
link
->
value
.
pattern
);
link
=
link
->
next
;
for
(;
link
!=
nullptr
;
link
=
link
->
next
)
{
parent
=
LeastCommonAncestor
(
parent
,
get_node
(
link
->
value
),
edge_pattern
);
*
edge_pattern
=
CombinePattern
(
*
edge_pattern
,
link
->
value
.
pattern
);
}
return
parent
;
}
/*!
* \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node.
* \param arena The Arena.
* \param gnode An IndexedForwardGraph Node.
* \return The DominatorTree Node.
*/
Node
*
GetNode
(
common
::
Arena
*
arena
,
IndexedForwardGraph
::
Node
*
gnode
)
{
Node
*
tnode
=
arena
->
make
<
Node
>
();
auto
*
gnode
=
graph
.
post_dfs_order
[
index
];
tnode
->
gnode
=
gnode
;
if
(
gnode
->
extern_ref
)
{
tnode
->
depth
=
1
;
...
...
@@ -455,24 +481,24 @@ DominatorTree DominatorTree::PostDom(common::Arena* arena,
}
else
{
// find the LCAs of all outputs.
OpPatternKind
pattern
=
kElemWise
;
Node
*
parent
=
nullptr
;
for
(
auto
link
=
gnode
->
outputs
.
head
;
link
!=
nullptr
;
link
=
link
->
next
)
{
size_t
oindex
=
link
->
value
.
node
->
index
;
CHECK_LT
(
oindex
,
tree
.
nodes
.
size
());
Node
*
onode
=
tree
.
nodes
[
oindex
];
CHECK
(
onode
!=
nullptr
);
if
(
parent
!=
nullptr
)
{
parent
=
LeastCommonAncestor
(
parent
,
onode
,
&
pattern
);
}
else
{
parent
=
onode
;
}
pattern
=
CombinePattern
(
pattern
,
link
->
value
.
pattern
);
}
Node
*
parent
=
LeastCommonAncestor
(
gnode
->
outputs
,
&
pattern
);
tnode
->
depth
=
parent
?
parent
->
depth
+
1
:
1
;
tnode
->
parent
=
parent
;
tnode
->
pattern
=
pattern
;
}
tree
.
nodes
[
index
]
=
tnode
;
return
tnode
;
}
};
DominatorTree
DominatorTree
::
PostDom
(
common
::
Arena
*
arena
,
const
IndexedForwardGraph
&
graph
)
{
DominatorTree
tree
;
tree
.
nodes
.
resize
(
graph
.
post_dfs_order
.
size
(),
nullptr
);
// reverse topo order
for
(
size_t
i
=
graph
.
post_dfs_order
.
size
();
i
!=
0
;
--
i
)
{
size_t
index
=
i
-
1
;
tree
.
nodes
[
index
]
=
tree
.
GetNode
(
arena
,
graph
.
post_dfs_order
[
index
]);
}
return
tree
;
}
...
...
@@ -614,7 +640,7 @@ class GraphPartitioner {
// merge the current group to the parent if possible.
MergeFromTo
(
gnode
,
target
);
for
(
auto
link
=
src
->
outputs
.
head
;
link
!=
nullptr
;
link
=
link
->
next
)
{
CommitFuse_
(
link
->
value
.
node
,
sink
,
target
);
;
CommitFuse_
(
link
->
value
.
node
,
sink
,
target
);
}
}
/*!
...
...
@@ -851,7 +877,7 @@ class FuseMutator : private ExprMutator {
Expr
VisitExpr_
(
const
TupleNode
*
tuple
)
{
auto
*
ret_group
=
gmap_
.
at
(
tuple
)
->
FindRoot
();
if
(
ret_group
==
gmap_
.
at
(
tuple
)
)
{
if
(
ret_group
->
root_ref
==
tuple
)
{
return
ExprMutator
::
VisitExpr_
(
tuple
);
}
// This tuple is an intermediate node in the group
...
...
@@ -863,7 +889,7 @@ class FuseMutator : private ExprMutator {
auto
*
ret_group
=
gmap_
.
at
(
tuple_get
)
->
FindRoot
();
auto
new_tuple
=
GetNewArguments
({
tuple_get
->
tuple
},
ret_group
)[
0
];
auto
new_node
=
TupleGetItemNode
::
make
(
new_tuple
,
tuple_get
->
index
);
if
(
ret_group
==
gmap_
.
at
(
tuple_get
)
)
{
if
(
ret_group
->
root_ref
==
tuple_get
)
{
if
(
gmap_
.
at
(
tuple_get
->
tuple
.
get
())
->
FindRoot
()
!=
ret_group
)
{
// Isolated. This case occurs when tuple is created by an Opaque op
// e.g. multibox_transform_loc
...
...
@@ -922,45 +948,8 @@ class FuseMutator : private ExprMutator {
}
};
// Temporary solution, should be handled by implementing a "FunctionPass"
// which applies fusion to each function.
struct
GlobalVarLiveness
:
ExprVisitor
{
Module
module
;
std
::
set
<
GlobalVar
>
visited
;
explicit
GlobalVarLiveness
(
const
Module
&
mod
)
:
module
(
mod
),
visited
()
{}
void
VisitExpr_
(
const
GlobalVarNode
*
gvar_node
)
{
auto
gvar
=
GetRef
<
GlobalVar
>
(
gvar_node
);
if
(
visited
.
find
(
gvar
)
==
visited
.
end
())
{
visited
.
insert
(
gvar
);
this
->
VisitExpr
(
this
->
module
->
Lookup
(
gvar
));
}
}
};
std
::
set
<
GlobalVar
>
LiveGlobals
(
const
Module
&
mod
,
const
Expr
&
expr
)
{
auto
gvl
=
GlobalVarLiveness
(
mod
);
gvl
.
VisitExpr
(
expr
);
return
gvl
.
visited
;
}
Expr
FuseOps
(
const
Expr
&
expr
,
int
fuse_opt_level
,
const
Module
&
module
)
{
// First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// new operators.
if
(
!
module
.
defined
())
{
return
FuseMutator
().
Transform
(
expr
,
fuse_opt_level
);
}
else
{
auto
lgvs
=
LiveGlobals
(
module
,
expr
);
for
(
auto
lv
:
lgvs
)
{
auto
body
=
module
->
Lookup
(
lv
);
auto
e
=
FuseMutator
().
Transform
(
body
,
fuse_opt_level
);
module
->
Add
(
lv
,
Downcast
<
Function
>
(
e
),
true
);
}
return
FuseMutator
().
Transform
(
expr
,
fuse_opt_level
);
}
return
FuseMutator
().
Transform
(
expr
,
fuse_opt_level
);
}
namespace
transform
{
...
...
tests/python/relay/test_pass_fuse_ops.py
View file @
ca352770
...
...
@@ -541,6 +541,18 @@ def test_immutable():
assert
relay
.
analysis
.
alpha_equal
(
new_mod
,
expected
())
def
test_split
():
"""Test that the result is well formed."""
x
=
relay
.
var
(
"x"
,
shape
=
(
6
,
9
))
y
=
relay
.
split
(
x
,
3
)
.
astuple
()
a
=
relay
.
TupleGetItem
(
y
,
0
)
b
=
relay
.
TupleGetItem
(
y
,
1
)
c
=
relay
.
TupleGetItem
(
y
,
2
)
mod
=
relay
.
module
.
Module
()
mod
[
"main"
]
=
relay
.
Function
([
x
],
a
+
relay
.
RefRead
(
relay
.
RefCreate
(
b
))
+
c
)
mod
=
transform
.
FuseOps
()(
mod
)
if
__name__
==
"__main__"
:
test_fuse_simple
()
test_conv2d_fuse
()
...
...
@@ -555,3 +567,4 @@ if __name__ == "__main__":
test_inception_like
()
test_fuse_parallel_injective
()
test_immutable
()
test_split
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment