Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
1ed28aeb
Commit
1ed28aeb
authored
Aug 09, 2018
by
masahi
Committed by
Tianqi Chen
Aug 08, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NNVM] Enhance operator fusion for more element wise patterns (#1548)
parent
0241fdc5
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
152 additions
and
8 deletions
+152
-8
nnvm/src/compiler/graph_fuse.cc
+97
-0
nnvm/tests/python/compiler/test_op_fusion.py
+43
-1
topi/python/topi/arm_cpu/conv2d.py
+1
-4
topi/python/topi/util.py
+11
-3
No files found.
nnvm/src/compiler/graph_fuse.cc
View file @
1ed28aeb
...
...
@@ -161,6 +161,103 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
}
}
}
/*
Above algorithm will not fuse a node whose output is fed to more than one
child node. This is because in general, it does not make sense to fuse multiple
children branches with their parent, as in the following example.
conv2d
/ | \
/ | \
op op op
| | |
| | |
However, when all children branches meet at a certain node, there is a possibility for
further operator fusion. For example, all nodes in the following subgraph can be fused
into a single node, if three 'in-between' nodes and the bottom node are all element wise
operation.
conv2d
/ | \
/ | \
op op op
\ | /
\ | /
elemwise add
|
This pattern is not uncommon. For example, it arises when conv2d op is followed by exponential
linear unit. If bias add and batch normalization are also present, they can be fused as well.
In fact, above fusion algorithm already fuses three in-between nodes and the element wise
add node in the figure above. The following code fuses the conv2d node with the already
fused children nodes. The following patterns are supported.
* Any number of child nodes from the top node
* The path from the top node to bottom node can contain any number of element wise ops.
The only restriction is that in-between nodes cannot have more than one child.
The overview of the algorithm below is as follows:
1. Check if all children nodes are fused into a single op by the existing fusion algorithm
2. Fuse the parent node to children nodes, and update its group id to be the children's group id
3. If the parent node originally belongs to another group (for example, conv + batch norm),
propagate the new group id to a grand parent and upward
*/
if
(
opt_level
>=
1
)
{
std
::
vector
<
std
::
vector
<
uint32_t
>
>
children_group_ids
(
idx
.
num_nodes
());
std
::
vector
<
std
::
vector
<
uint32_t
>
>
node_ids_per_group
(
idx
.
num_nodes
());
for
(
uint32_t
nid
=
idx
.
num_nodes
()
-
1
;
nid
!=
0
;
--
nid
)
{
const
auto
&
inode
=
idx
[
nid
];
if
(
inode
.
source
->
is_variable
())
continue
;
CHECK_NE
(
group_vec
[
nid
],
-
1
);
node_ids_per_group
[
group_vec
[
nid
]].
push_back
(
nid
);
if
(
inode
.
inputs
.
size
()
!=
1
)
continue
;
const
uint32_t
parent_nid
=
inode
.
inputs
[
0
].
node_id
;
// if parent node has more than one child, record each child's group id.
if
(
ref_count
[
parent_nid
]
>
1
)
children_group_ids
[
parent_nid
].
push_back
(
group_vec
[
nid
]);
}
std
::
vector
<
int
>
new_group_id
(
idx
.
num_nodes
(),
-
1
);
for
(
uint32_t
nid
=
idx
.
num_nodes
()
-
1
;
nid
!=
0
;
--
nid
)
{
if
(
new_group_id
[
group_vec
[
nid
]]
!=
-
1
)
{
// propagate new group id from child
group_vec
[
nid
]
=
new_group_id
[
group_vec
[
nid
]];
}
TOpPattern
pt
=
op_pattern
.
get
(
idx
[
nid
].
source
->
op
(),
kOpaque
);
if
(
pt
==
kOpaque
)
continue
;
const
auto
&
group_ids
=
children_group_ids
[
nid
];
if
(
group_ids
.
size
()
<=
1
)
continue
;
const
uint32_t
child_group_id
=
group_ids
[
0
];
const
auto
&
children_node_ids
=
node_ids_per_group
[
child_group_id
];
auto
is_same_group_id
=
[
child_group_id
](
uint32_t
id
)
{
return
id
==
child_group_id
;
};
auto
is_fusible_pattern
=
[
&
idx
](
uint32_t
child_nid
)
{
TOpPattern
child_pt
=
op_pattern
.
get
(
idx
[
child_nid
].
source
->
op
(),
kOpaque
);
return
child_pt
<=
kBroadcast
;
};
// fuse this node with children if
// all children belong to the same group and
// all nodes in the group are element wise or broadcast op.
const
bool
can_be_fused
=
std
::
all_of
(
group_ids
.
begin
(),
group_ids
.
end
(),
is_same_group_id
)
&&
std
::
all_of
(
children_node_ids
.
begin
(),
children_node_ids
.
end
(),
is_fusible_pattern
);
if
(
can_be_fused
)
{
new_group_id
[
group_vec
[
nid
]]
=
child_group_id
;
group_vec
[
nid
]
=
child_group_id
;
for
(
uint32_t
nid2
:
node_ids_per_group
[
child_group_id
])
{
pattern_vec
[
nid2
]
=
pattern_vec
[
nid
];
master_vec
[
nid2
]
=
master_vec
[
nid
];
}
}
}
}
g
.
attrs
[
"group_root"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
group_vec
));
g
.
attrs
[
"group_master"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
master_vec
));
g
.
attrs
[
"pattern"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
pattern_vec
));
...
...
nnvm/tests/python/compiler/test_op_fusion.py
View file @
1ed28aeb
...
...
@@ -5,7 +5,7 @@ import topi.testing
from
tvm.contrib
import
graph_runtime
from
nnvm
import
symbol
as
sym
from
nnvm.compiler
import
graph_util
,
graph_attr
from
nnvm.testing
import
ctx_list
from
nnvm.testing
import
ctx_list
,
utils
def
test_ewise_injective
():
x
=
sym
.
Variable
(
"x"
)
...
...
@@ -77,7 +77,49 @@ def test_injective_reduce_injective():
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
c_np
,
rtol
=
1e-5
)
def
build_and_run
(
sym
,
params
,
data
,
out_shape
,
target
,
ctx
,
opt_level
=
2
):
with
nnvm
.
compiler
.
build_config
(
opt_level
=
opt_level
):
graph
,
lib
,
params
=
nnvm
.
compiler
.
build
(
sym
,
target
,
shape
=
{
"data"
:
data
.
shape
},
params
=
params
)
module
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
module
.
set_input
(
**
params
)
module
.
set_input
(
"data"
,
data
)
module
.
run
()
out
=
module
.
get_output
(
0
,
tvm
.
nd
.
empty
(
out_shape
))
return
out
.
asnumpy
(),
graph
def
test_fuse_conv2d_elu
():
def
elu
(
data
):
return
-
0.5
*
sym
.
relu
(
1
-
sym
.
exp
(
data
))
+
sym
.
relu
(
data
)
def
get_sym
(
out_channel
):
data
=
sym
.
Variable
(
name
=
"data"
)
data
=
sym
.
conv2d
(
data
=
data
,
kernel_size
=
(
3
,
3
),
channels
=
out_channel
,
padding
=
(
1
,
1
),
layout
=
"NCHW"
,
kernel_layout
=
"OIHW"
,
use_bias
=
True
)
data
=
sym
.
batch_norm
(
data
)
data
=
elu
(
data
)
return
data
in_channel
=
8
out_channel
=
16
size
=
64
dshape
=
(
1
,
in_channel
,
size
,
size
)
oshape
=
(
1
,
out_channel
,
size
,
size
)
data
=
np
.
random
.
uniform
(
-
1
,
1
,
dshape
)
.
astype
(
np
.
float32
)
for
target
,
ctx
in
ctx_list
():
sym1
=
get_sym
(
out_channel
)
sym2
=
get_sym
(
out_channel
)
_
,
params1
=
utils
.
create_workload
(
sym1
,
1
,
dshape
[
1
:],
seed
=
0
)
_
,
params2
=
utils
.
create_workload
(
sym2
,
1
,
dshape
[
1
:],
seed
=
0
)
output1
,
g1
=
build_and_run
(
sym1
,
params1
,
data
,
oshape
,
target
,
ctx
,
opt_level
=
2
)
output2
,
g2
=
build_and_run
(
sym2
,
params2
,
data
,
oshape
,
target
,
ctx
,
opt_level
=
0
)
np
.
testing
.
assert_allclose
(
output1
,
output2
,
rtol
=
1e-5
,
atol
=
1e-5
)
# data, conv weight, bias, batch norm gamma, batch norm beta, conv op
assert
g1
.
index
.
num_nodes
==
6
if
__name__
==
"__main__"
:
test_injective_reduce_injective
()
test_ewise_injective
()
test_conv_ewise_injective
()
test_fuse_conv2d_elu
()
topi/python/topi/arm_cpu/conv2d.py
View file @
1ed28aeb
...
...
@@ -39,11 +39,10 @@ def decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype):
def
schedule_conv2d_nchw_arm_cpu
(
cfg
,
outs
):
"""TOPI schedule callback"""
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
_callback
(
op
):
# schedule conv2d
if
'spatial_conv_output'
in
op
.
tag
and
op
not
in
scheduled_ops
:
if
'spatial_conv_output'
in
op
.
tag
:
output
=
op
.
output
(
0
)
conv
=
op
.
input_tensors
[
0
]
...
...
@@ -65,8 +64,6 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
output
=
op
.
output
(
0
)
_schedule_winograd
(
cfg
,
s
,
output
,
outs
[
0
])
scheduled_ops
.
append
(
op
)
traverse_inline
(
s
,
outs
[
0
]
.
op
,
_callback
)
return
s
...
...
topi/python/topi/util.py
View file @
1ed28aeb
...
...
@@ -5,26 +5,34 @@ import tvm
from
.
import
tag
def
traverse_inline
(
s
,
op
,
callback
):
def
traverse_inline
(
s
,
final_
op
,
callback
):
"""Traverse computation graph and do auto inline
Parameters
----------
s: schedule
The schedule
op: Operation
final_
op: Operation
The final output operator.
callback: callable
The callback function on each op
"""
visited
=
set
()
def
_traverse
(
op
):
if
op
in
visited
:
return
visited
.
add
(
op
)
if
tag
.
is_injective
(
op
.
tag
):
if
op
not
in
s
.
outputs
:
s
[
op
]
.
compute_inline
()
for
tensor
in
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
traverse_inline
(
s
,
tensor
.
op
,
callback
)
_traverse
(
tensor
.
op
)
callback
(
op
)
_traverse
(
final_op
)
def
prod
(
x
):
"""Get the product of every items in the tuple.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment