Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
00d509d4
Commit
00d509d4
authored
Dec 13, 2018
by
Alexey Romanov
Committed by
Tianqi Chen
Dec 13, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[FRONTEND][TENSORFLOW] Support Unstack and Split (#2105)
parent
4bbf96e4
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
136 additions
and
63 deletions
+136
-63
nnvm/python/nnvm/frontend/tensorflow.py
+88
-18
nnvm/tests/python/frontend/tensorflow/test_forward.py
+48
-45
No files found.
nnvm/python/nnvm/frontend/tensorflow.py
View file @
00d509d4
...
...
@@ -36,6 +36,7 @@ class AttrCvt(object):
self
.
_ignores
.
append
(
'_node_name'
)
self
.
_ignores
.
append
(
'is_training'
)
self
.
_ignores
.
append
(
'_target_layout'
)
self
.
_ignores
.
append
(
'_input_0d_mismatch'
)
# Retain the names
try
:
attrs
[
'name'
]
=
attrs
[
'_node_name'
]
...
...
@@ -319,8 +320,7 @@ def _expand_dims():
dim_input
=
inputs
.
pop
(
1
)
axis
=
params
[
dim_input
.
list_output_names
()[
0
]]
params
.
pop
(
dim_input
.
list_output_names
()[
0
])
return
AttrCvt
(
op_name
=
"expand_dims"
,
ignores
=
[
'Tdim'
],
extras
=
{
'axis'
:
axis
.
asnumpy
()[
0
]})(
inputs
,
attr
)
return
_expand_dims_0d_aware
(
inputs
[
0
],
attr
,
axis
=
axis
.
asnumpy
()[
0
])
return
_impl
def
_resize_bilinear
():
...
...
@@ -383,7 +383,7 @@ def _concat():
def
_pack
():
def
_impl
(
inputs
,
attr
,
params
):
axis
=
int
(
attr
[
"axis"
])
inputs_reshaped
=
[
_
sym
.
expand_dims
(
i
,
axis
=
axis
,
num_newaxis
=
1
)
for
i
in
inputs
]
inputs_reshaped
=
[
_
expand_dims_0d_aware
(
i
,
attr
,
axis
=
axis
,
num_newaxis
=
1
)
for
i
in
inputs
]
return
_sym
.
concatenate
(
*
inputs_reshaped
,
axis
=
axis
,
name
=
attr
[
"_node_name"
])
return
_impl
...
...
@@ -787,15 +787,64 @@ def _broadcast(name):
)(
inputs
,
attr
)
return
_impl
def
_split
():
def
_split
(
has_size_vector
):
# TF documentation https://www.tensorflow.org/api_docs/python/tf/split
def
_impl
(
inputs
,
attr
,
params
):
axis
=
params
.
pop
(
inputs
[
0
]
.
list_output_names
()[
0
])
return
AttrCvt
(
op_name
=
"split"
,
ignores
=
[
'T'
],
transforms
=
{
'num_split'
:
'indices_or_sections'
},
extras
=
{
'axis'
:
axis
.
asnumpy
()[
0
]})(
inputs
[
1
],
attr
)
try
:
# order and number of inputs are different:
# if has_size_vector:
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split-v
# else:
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split
# in addition, `axis` and `num_or_size_splits` can be tensors in TensorFlow,
# we can only support constants
if
has_size_vector
:
input_node_index
=
0
input_axis_index
=
2
size_splits_input_name
=
inputs
[
1
]
.
list_output_names
()[
0
]
size_splits
=
params
[
size_splits_input_name
]
.
asnumpy
()
section_beginnings
=
np
.
cumsum
(
size_splits
)[:
-
1
]
indices_or_sections
=
tuple
(
section_beginnings
)
else
:
input_node_index
=
1
input_axis_index
=
0
indices_or_sections
=
attr
[
'num_split'
]
input_node
=
inputs
[
input_node_index
]
axis_input_name
=
inputs
[
input_axis_index
]
.
list_output_names
()[
0
]
axis_input_value
=
params
[
axis_input_name
]
.
asnumpy
()[
0
]
except
(
IndexError
,
KeyError
):
raise
TypeError
(
\
"Unsupported argument for split: `axis` and `num_or_size_splits` "
\
"should be constants"
)
return
_sym
.
split
(
input_node
,
indices_or_sections
=
indices_or_sections
,
axis
=
axis_input_value
)
return
_impl
def
_unpack
():
def
_impl
(
inputs
,
attr
,
params
):
input_node
=
inputs
[
0
]
axis
=
attr
[
'axis'
]
input_shape
=
attr
[
'_input_shapes'
][
input_node
][
0
]
axis_length
=
input_shape
[
axis
]
if
axis_length
<
0
:
raise
TypeError
(
"Unstack with unknown axis length"
)
splitted
=
_sym
.
split
(
input_node
,
indices_or_sections
=
axis_length
,
axis
=
axis
,
name
=
attr
.
get
(
'_node_name'
,
'unstack'
))
return
_sym
.
Group
([
_sym
.
squeeze
(
split_item
,
axis
=
axis
)
for
split_item
in
splitted
])
return
_impl
def
_expand_dims_0d_aware
(
data
,
attr
,
axis
,
num_newaxis
=
1
):
if
data
in
attr
[
'_input_0d_mismatch'
]:
return
data
if
num_newaxis
==
1
else
\
_sym
.
expand_dims
(
data
,
axis
=
axis
,
num_newaxis
=
num_newaxis
-
1
)
return
_sym
.
expand_dims
(
data
,
axis
=
axis
,
num_newaxis
=
num_newaxis
)
# compatible operators that do NOT require any conversion.
_identity_list
=
[]
...
...
@@ -863,7 +912,9 @@ _convert_map = {
'GreaterEqual'
:
_broadcast
(
'greater_equal'
),
'Equal'
:
_broadcast
(
'equal'
),
'NotEqual'
:
_broadcast
(
'not_equal'
),
'Split'
:
_split
(),
'Split'
:
_split
(
False
),
'SplitV'
:
_split
(
True
),
'Unpack'
:
_unpack
(),
}
# _convert_map_rnn defines maps of rnn operator name to
...
...
@@ -1059,6 +1110,7 @@ class GraphProto(object):
self
.
_output_shapes
=
{}
self
.
_num_param
=
0
self
.
_num_rnn_layer
=
False
self
.
_outputs_are_0d
=
{}
def
from_tensorflow
(
self
,
graph
,
layout
=
"NHWC"
,
shape
=
None
,
outputs
=
None
):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef.
...
...
@@ -1114,6 +1166,7 @@ class GraphProto(object):
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes
=
{}
input_0d_mismatch
=
set
()
attr
=
self
.
_parse_attr
(
node
.
attr
)
#Variable converted to Const will not have only value attr
...
...
@@ -1133,6 +1186,9 @@ class GraphProto(object):
else
:
raise
NotImplementedError
(
\
"Please freeze the graph with add_shapes=True"
)
self
.
_outputs_are_0d
[
node
.
name
]
=
[
\
not
shape
if
isinstance
(
shape
,
list
)
else
False
\
for
shape
in
self
.
_output_shapes
[
node
.
name
]]
if
node
.
op
==
"Placeholder"
:
self
.
_nodes
[
node
.
name
]
=
_sym
.
Variable
(
name
=
node
.
name
,
...
...
@@ -1162,11 +1218,13 @@ class GraphProto(object):
# Fill shapes for all inputs in a list
inputs
=
[]
for
i
in
node
.
input
:
#ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
#the digit has to be ignored.
# Some TensorFlow operators internally maintain execution layers
# and their output name includes the layer number along with
# graph node name. E.g. the node name is 'Model/RNN/cell_0/RnnCell', but the
# output tensor name is 'Model/RNN/cell_0/RnnCell:0'. In this case,
# the number has to be ignored for single-output nodes.
# On the other hand, for multi-output nodes the number is the output index,
# and the lack of the number implies 0.
tensor_name
=
i
.
split
(
':'
)
node_name
=
tensor_name
[
0
]
if
node_name
in
self
.
_nodes
:
...
...
@@ -1174,12 +1232,18 @@ class GraphProto(object):
if
len
(
in_sym
.
list_output_names
())
>
1
:
tensor_slot
=
int
(
tensor_name
[
1
])
if
len
(
tensor_name
)
>
1
else
0
in_sym
=
in_sym
[
tensor_slot
]
input_shape
=
(
self
.
_output_shapes
[
node_name
])
[
tensor_slot
]
input_shape
=
self
.
_output_shapes
[
node_name
]
[
tensor_slot
]
else
:
tensor_slot
=
0
input_shape
=
self
.
_output_shapes
[
node_name
][
0
]
inputs
.
append
(
in_sym
)
input_shapes
[
in_sym
]
=
[
input_shape
]
# This means the node is 1d in NNVM and 0d in TF.
# See `_expand_dims_0d_aware`.
if
self
.
_outputs_are_0d
[
node_name
][
tensor_slot
]
and
input_shape
:
input_0d_mismatch
.
add
(
in_sym
)
attr
[
'_input_shapes'
]
=
input_shapes
attr
[
'_input_0d_mismatch'
]
=
input_0d_mismatch
inputs
=
self
.
_fix_extranodes
(
node
.
op
,
attr
,
inputs
)
op
=
self
.
_convert_operator
(
node
.
op
,
inputs
,
attr
,
graph
)
...
...
@@ -1207,7 +1271,13 @@ class GraphProto(object):
if
outputs
is
None
:
out
.
append
(
final_op
)
else
:
out
=
[
self
.
_nodes
[
out_name
]
for
out_name
in
outputs
]
for
out_name
in
outputs
:
if
":"
in
out_name
:
out_name
,
out_num
=
out_name
.
split
(
":"
)
out_num
=
int
(
out_num
)
out
.
append
(
self
.
_nodes
[
out_name
][
out_num
])
else
:
out
.
append
(
self
.
_nodes
[
out_name
])
#Add the RNN outputs also with 'head' nodes of the nnvm graph
if
self
.
_num_rnn_layer
:
...
...
@@ -1215,7 +1285,7 @@ class GraphProto(object):
out
.
append
(
out_rnn
)
if
isinstance
(
out
,
list
):
out
=
_sym
.
Group
(
out
)
out
=
_sym
.
Group
(
out
)
if
len
(
out
)
>
1
else
out
[
0
]
return
out
,
self
.
_params
...
...
nnvm/tests/python/frontend/tensorflow/test_forward.py
View file @
00d509d4
...
...
@@ -124,7 +124,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
if
no_gpu
and
device
==
'cuda'
:
continue
tvm_output
=
run_tvm_graph
(
final_graph_def
,
in_data
,
in_node
,
target
=
device
)
tvm_output
=
run_tvm_graph
(
final_graph_def
,
in_data
,
in_node
,
num_output
=
len
(
out_node
),
target
=
device
,
out_names
=
out_name
)
# since the names from tensorflow and nnvm runs are not exactly same,
# first len(tf_output) will be compared
for
i
in
range
(
len
(
tf_output
)):
...
...
@@ -506,14 +507,24 @@ def test_forward_gather():
# Split
# -----
def
_test_split
(
in_shape
,
axis
,
num_split
,
dtype
):
def
_test_split
(
in_shape
,
axis
,
num_or_size_splits
,
dtype
):
np_data
=
np
.
random
.
uniform
(
-
5
,
5
,
size
=
in_shape
)
.
astype
(
dtype
)
""" One iteration of a Split """
tf
.
reset_default_graph
()
in_data
=
tf
.
placeholder
(
dtype
,
in_shape
,
name
=
"in_data"
)
num_split
=
len
(
num_or_size_splits
)
if
isinstance
(
num_or_size_splits
,
list
)
else
num_or_size_splits
tf
.
split
(
in_data
,
num_or_size_splits
,
axis
=
axis
)
with
tf
.
Graph
()
.
as_default
():
in_data
=
tf
.
placeholder
(
dtype
,
in_shape
,
name
=
"in_data"
)
tf
.
split
(
in_data
,
num_split
,
axis
)
np_data
=
np
.
random
.
uniform
(
size
=
in_shape
)
.
astype
(
dtype
)
compare_tf_with_tvm
(
np_data
,
'in_data:0'
,
'split:0'
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
[
f
'split:{n}'
for
n
in
range
(
num_split
)])
# and now test together with concat
tf
.
reset_default_graph
()
in_data
=
tf
.
placeholder
(
dtype
,
in_shape
,
name
=
"in_data"
)
splitted
=
tf
.
split
(
in_data
,
num_or_size_splits
,
axis
=
axis
)
tf
.
concat
(
splitted
,
axis
)
compare_tf_with_tvm
([
np_data
],
'in_data:0'
,
'concat:0'
)
def
test_forward_split
():
'''test split layer'''
...
...
@@ -523,11 +534,11 @@ def test_forward_split():
_test_split
((
6
,),
0
,
3
,
'float32'
)
# rank 2
_test_split
((
6
,
2
),
0
,
3
,
'float32'
)
_test_split
((
2
,
6
),
1
,
3
,
'float32'
)
_test_split
((
2
,
6
),
1
,
6
,
'float32'
)
# rank 3
_test_split
((
6
,
2
,
4
),
0
,
3
,
'floa
t32'
)
_test_split
((
6
,
2
,
4
),
0
,
2
,
'in
t32'
)
_test_split
((
2
,
6
,
4
),
1
,
3
,
'float32'
)
_test_split
((
2
,
4
,
6
),
2
,
3
,
'float32'
)
_test_split
((
2
,
4
,
6
),
2
,
1
,
'float32'
)
# rank 4
_test_split
((
6
,
1
,
3
,
5
),
0
,
3
,
'float32'
)
_test_split
((
1
,
6
,
3
,
5
),
1
,
3
,
'float32'
)
...
...
@@ -538,45 +549,37 @@ def test_forward_split():
_test_split
((
1
,
6
,
3
,
5
),
-
3
,
3
,
'float32'
)
_test_split
((
1
,
3
,
6
,
5
),
-
2
,
3
,
'float32'
)
_test_split
((
1
,
3
,
5
,
6
),
-
1
,
3
,
'float32'
)
# size_splits list
_test_split
((
6
,),
0
,
[
1
,
2
,
3
],
'int32'
)
_test_split
((
3
,
6
,
4
),
-
2
,
[
1
,
4
,
1
],
'float32'
)
#######################################################################
#
Split followed by concat
# -------
-----------------
#
Unstack
# -------
def
_test_
split_concat
(
in_shape
,
axis
,
num_split
,
dtype
):
""" One iteration of a split_concat pair"""
def
_test_
unstack
(
ip_shape
,
axis
,
dtype
):
np_data
=
np
.
random
.
uniform
(
-
5
,
5
,
size
=
ip_shape
)
.
astype
(
dtype
)
with
tf
.
Graph
()
.
as_default
():
in_data
=
tf
.
placeholder
(
dtype
,
in_shape
,
name
=
"in_data"
)
splitted
=
tf
.
split
(
in_data
,
num_split
,
axis
)
tf
.
concat
(
splitted
,
axis
)
np_data
=
np
.
random
.
uniform
(
size
=
in_shape
)
.
astype
(
dtype
)
compare_tf_with_tvm
(
np_data
,
'in_data:0'
,
'concat:0'
)
def
test_forward_split_concat
():
'''test split followed by concat layers'''
# rank 1
_test_split_concat
((
3
,),
0
,
1
,
'float32'
)
_test_split_concat
((
3
,),
0
,
3
,
'float32'
)
_test_split_concat
((
6
,),
0
,
3
,
'float32'
)
# rank 2
_test_split_concat
((
6
,
2
),
0
,
3
,
'float32'
)
_test_split_concat
((
2
,
6
),
1
,
3
,
'float32'
)
# rank 3
_test_split_concat
((
6
,
2
,
4
),
0
,
3
,
'float32'
)
_test_split_concat
((
2
,
6
,
4
),
1
,
3
,
'float32'
)
_test_split_concat
((
2
,
4
,
6
),
2
,
3
,
'float32'
)
# rank 4
_test_split
((
6
,
1
,
3
,
5
),
0
,
3
,
'float32'
)
_test_split
((
1
,
6
,
3
,
5
),
1
,
3
,
'float32'
)
_test_split
((
1
,
3
,
6
,
5
),
2
,
3
,
'float32'
)
_test_split
((
1
,
3
,
5
,
6
),
3
,
3
,
'float32'
)
# split along negative axis
_test_split
((
6
,
1
,
3
,
5
),
-
4
,
3
,
'float32'
)
_test_split
((
1
,
6
,
3
,
5
),
-
3
,
3
,
'float32'
)
_test_split
((
1
,
3
,
6
,
5
),
-
2
,
3
,
'float32'
)
_test_split
((
1
,
3
,
5
,
6
),
-
1
,
3
,
'float32'
)
tf
.
reset_default_graph
()
in_data
=
tf
.
placeholder
(
dtype
,
ip_shape
,
name
=
"in_data"
)
tf
.
unstack
(
in_data
,
axis
=
axis
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
[
f
'unstack:{n}'
for
n
in
range
(
ip_shape
[
axis
])])
tf
.
reset_default_graph
()
in_data
=
tf
.
placeholder
(
dtype
,
ip_shape
,
name
=
"in_data"
)
tf
.
stack
(
tf
.
unstack
(
in_data
,
axis
=
axis
),
axis
=
axis
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'stack:0'
)
def
test_forward_unstack
():
'''test unstack layer'''
_test_unstack
((
6
,),
0
,
'int32'
)
_test_unstack
((
2
,
6
),
1
,
'float64'
)
# negative axis
_test_unstack
((
1
,
4
),
-
1
,
'int32'
)
_test_unstack
((
3
,
6
,
4
),
-
2
,
'float32'
)
#######################################################################
...
...
@@ -1139,7 +1142,7 @@ if __name__ == '__main__':
test_forward_gather
()
test_forward_stridedslice
()
test_forward_split
()
test_forward_
split_concat
()
test_forward_
unstack
()
# Activations
test_forward_sigmoid
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment