Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
7b59b8ef
Commit
7b59b8ef
authored
Aug 04, 2018
by
Siva
Committed by
Tianqi Chen
Aug 04, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NNVM][TENSORFLOW] Cleanup redundant code. (#1551)
parent
136061dc
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
83 additions
and
157 deletions
+83
-157
nnvm/python/nnvm/frontend/tensorflow.py
+83
-157
No files found.
nnvm/python/nnvm/frontend/tensorflow.py
View file @
7b59b8ef
...
...
@@ -168,81 +168,7 @@ def _pooling(name):
custom_check
=
_dimension_constraint
())(
inputs
,
attr
)
return
_impl
def
_conv
():
def
_impl
(
inputs
,
attr
,
params
):
attr
[
'data_format'
]
=
attr
[
'data_format'
]
.
decode
(
"utf-8"
)
# Extract kernel shape from params
conv_param_weights
=
params
[
inputs
[
1
]
.
list_output_names
()[
0
]]
if
attr
[
'data_format'
]
==
'NHWC'
:
attr
[
'kernel_shape'
]
=
(
conv_param_weights
.
shape
[
0
],
conv_param_weights
.
shape
[
1
])
attr
[
'channels'
]
=
conv_param_weights
.
shape
[
3
]
if
'dilations'
in
attr
:
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
0
],
attr
[
'dilations'
][
1
])
elif
attr
[
'data_format'
]
==
'NCHW'
:
attr
[
'kernel_shape'
]
=
(
conv_param_weights
.
shape
[
2
],
conv_param_weights
.
shape
[
3
])
attr
[
'channels'
]
=
conv_param_weights
.
shape
[
1
]
if
'dilations'
in
attr
:
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
2
],
attr
[
'dilations'
][
3
])
else
:
raise
TypeError
(
"Unsupported data format type : {}"
.
format
(
attr
[
'data_format'
]))
# Fix strides
attr
[
'strides'
]
=
(
attr
[
'strides'
][
1
],
attr
[
'strides'
][
2
])
# Fix padding
input_shapes
=
attr
[
'_input_shapes'
][
inputs
[
0
]]
attr
[
'padding'
]
=
attr
[
'padding'
]
.
decode
(
"utf-8"
)
if
attr
[
'padding'
]
==
'VALID'
:
attr
[
'padding'
]
=
[
0
,
0
]
elif
attr
[
'padding'
]
==
'SAME'
:
stride_h
,
stride_w
=
attr
[
'strides'
]
kernel_h
,
kernel_w
=
attr
[
'kernel_shape'
]
if
attr
[
'data_format'
]
==
'NHWC'
:
in_h
=
input_shapes
[
0
][
1
]
in_w
=
input_shapes
[
0
][
2
]
else
:
in_h
=
input_shapes
[
0
][
2
]
in_w
=
input_shapes
[
0
][
3
]
pad_v
=
_get_pad_pair
(
in_h
,
kernel_h
,
stride_h
)
pad_h
=
_get_pad_pair
(
in_w
,
kernel_w
,
stride_w
)
if
attr
[
'data_format'
]
==
'NHWC'
:
inputs
[
0
]
=
_sym
.
pad
(
data
=
inputs
[
0
],
pad_width
=
((
0
,
0
),
(
pad_v
[
0
],
pad_v
[
1
]),
(
pad_h
[
0
],
pad_h
[
1
]),
(
0
,
0
)))
else
:
inputs
[
0
]
=
_sym
.
pad
(
data
=
inputs
[
0
],
pad_width
=
((
0
,
0
),
(
0
,
0
),
(
pad_v
[
0
],
pad_v
[
1
]),
(
pad_h
[
0
],
pad_h
[
1
])))
attr
[
'padding'
]
=
[
0
,
0
]
else
:
raise
TypeError
(
"Unsupported padding type : {}"
.
format
(
attr
[
'padding'
]))
if
'kernel_layout'
not
in
attr
:
attr
[
'kernel_layout'
]
=
'HWIO'
if
attr
[
'data_format'
]
==
'NHWC'
else
'OIHW'
return
AttrCvt
(
op_name
=
_dimension_picker
(
'conv'
),
transforms
=
{
'kernel_shape'
:
'kernel_size'
,
'data_format'
:
'layout'
,
'dilations'
:
(
'dilation'
,
(
0
,
0
)),
'group'
:
(
'groups'
,
1
)},
extras
=
{
'use_bias'
:
len
(
inputs
)
==
3
},
custom_check
=
_dimension_constraint
())(
inputs
,
attr
)
return
_impl
def
_depthwise_conv
():
def
_conv
(
opname
):
def
_impl
(
inputs
,
attr
,
params
):
attr
[
'data_format'
]
=
attr
[
'data_format'
]
.
decode
(
"utf-8"
)
input_shapes
=
attr
[
'_input_shapes'
][
inputs
[
0
]]
...
...
@@ -253,24 +179,33 @@ def _depthwise_conv():
if
attr
[
'data_format'
]
==
'NHWC'
:
kernel_h
,
kernel_w
,
_
,
depth_mult
=
conv_param_weights
.
shape
attr
[
'kernel_shape'
]
=
(
conv_param_weights
.
shape
[
0
],
conv_param_weights
.
shape
[
1
])
attr
[
'channels'
]
=
input_shapes
[
0
][
3
]
*
depth_mult
if
opname
==
'conv'
:
attr
[
'channels'
]
=
conv_param_weights
.
shape
[
3
]
else
:
attr
[
'channels'
]
=
input_shapes
[
0
][
3
]
*
depth_mult
if
'dilations'
in
attr
:
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
0
],
attr
[
'dilations'
][
1
])
elif
attr
[
'data_format'
]
==
'NCHW'
:
depth_mult
,
_
,
kernel_h
,
kernel_w
=
conv_param_weights
.
shape
attr
[
'kernel_shape'
]
=
(
conv_param_weights
.
shape
[
2
],
conv_param_weights
.
shape
[
3
])
attr
[
'channels'
]
=
input_shapes
[
0
][
1
]
*
depth_mult
if
opname
==
'conv'
:
attr
[
'channels'
]
=
conv_param_weights
.
shape
[
1
]
else
:
attr
[
'channels'
]
=
input_shapes
[
0
][
1
]
*
depth_mult
if
'dilations'
in
attr
:
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
2
],
attr
[
'dilations'
][
3
])
else
:
raise
TypeError
(
"Unsupported data format type : {}"
.
format
(
attr
[
'data_format'
]))
if
opname
==
'depthwise'
:
attr
[
'groups'
]
=
attr
[
'channels'
]
# Fix strides
attr
[
'strides'
]
=
(
attr
[
'strides'
][
1
],
attr
[
'strides'
][
2
])
# Fix groups
attr
[
'groups'
]
=
attr
[
'channels'
]
# Fix padding
attr
[
'padding'
]
=
attr
[
'padding'
]
.
decode
(
"utf-8"
)
...
...
@@ -308,7 +243,10 @@ def _depthwise_conv():
raise
TypeError
(
"Unsupported padding type : {}"
.
format
(
attr
[
'padding'
]))
if
'kernel_layout'
not
in
attr
:
attr
[
'kernel_layout'
]
=
'HWOI'
if
attr
[
'data_format'
]
==
'NHWC'
else
'OIHW'
if
opname
==
'conv'
:
attr
[
'kernel_layout'
]
=
'HWIO'
if
attr
[
'data_format'
]
==
'NHWC'
else
'OIHW'
else
:
attr
[
'kernel_layout'
]
=
'HWOI'
if
attr
[
'data_format'
]
==
'NHWC'
else
'OIHW'
return
AttrCvt
(
op_name
=
_dimension_picker
(
'conv'
),
...
...
@@ -687,7 +625,7 @@ _convert_map = {
'CheckNumerics'
:
_check_numerics
(),
'Concat'
:
_concat
(),
'ConcatV2'
:
_concatV2
(),
'Conv2D'
:
_conv
(),
'Conv2D'
:
_conv
(
'conv'
),
'DecodeJpeg'
:
_decode_image
(),
'ExpandDims'
:
_expand_dims
(),
'Identity'
:
_identity
(),
...
...
@@ -704,7 +642,7 @@ _convert_map = {
'Squeeze'
:
_squeeze
(),
'FusedBatchNorm'
:
_fused_batch_norm
(),
'Relu6'
:
_relu6
(),
'DepthwiseConv2dNative'
:
_
depthwise_conv
(
),
'DepthwiseConv2dNative'
:
_
conv
(
'depthwise'
),
'Shape'
:
_shape
(),
'Sigmoid'
:
AttrCvt
(
'sigmoid'
),
'Fill'
:
_fill
(),
...
...
@@ -895,28 +833,6 @@ class RecurrentNetworks(object):
params
,
num_layers
)
return
sym
def
_parse_import_prerequisites
(
graph
):
""" Calculate the named preconditions from TensorFlow `graph`.
Return prerequisites for parsing:
a. Set of operator names which don't have their mapping in TVM, i.e.
which are not supported
"""
missing_operators
=
set
()
for
node
in
graph
.
node
:
if
node
.
op
==
"Placeholder"
:
pass
elif
node
.
op
==
"Const"
:
pass
else
:
if
any
([
node
.
op
in
t
for
t
in
[
_identity_list
,
_convert_map
,
_convert_map_rnn
]]):
pass
else
:
missing_operators
.
add
(
node
.
op
)
return
missing_operators
class
GraphProto
(
object
):
""" A helper class for handling nnvm graph copying from Tensorflow GraphDef.
Definition:
...
...
@@ -925,12 +841,8 @@ class GraphProto(object):
def
__init__
(
self
):
self
.
_nodes
=
{}
self
.
_params
=
{}
self
.
_renames
=
{}
self
.
_replacements
=
{}
self
.
_output_shapes
=
{}
self
.
_num_input
=
0
self
.
_num_param
=
0
self
.
_input_node
=
''
self
.
_num_rnn_layer
=
False
def
from_tensorflow
(
self
,
graph
):
...
...
@@ -969,7 +881,7 @@ class GraphProto(object):
raise
ImportError
(
"Unable to import tensorflow which is required {}"
.
format
(
e
))
missing_operators
=
_parse_import_prerequisites
(
graph
)
missing_operators
=
self
.
_parse_import_prerequisites
(
graph
)
if
missing_operators
:
raise
NotImplementedError
(
\
...
...
@@ -979,58 +891,42 @@ class GraphProto(object):
for
node
in
graph
.
node
:
# Tensorflow doesn't have seperate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes
=
{}
attr
=
self
.
_parse_attr
(
node
.
attr
)
#Variable converted to Const will not have only value attr
if
'value'
in
attr
and
node
.
op
==
'Const'
:
tensor_value
=
attr
[
'value'
]
self
.
_output_shapes
[
node
.
name
]
=
\
[
tensor_util
.
TensorShapeProtoToList
(
\
tensor_value
.
tensor_shape
)]
elif
'_output_shapes'
in
attr
:
self
.
_output_shapes
[
node
.
name
]
=
\
[
tensor_util
.
TensorShapeProtoToList
(
shape
)
\
for
shape
in
attr
[
'_output_shapes'
]]
else
:
raise
NotImplementedError
(
\
"Please freeze the graph with add_shapes=True"
)
if
node
.
op
==
"Placeholder"
:
self
.
_
input_node
=
node
.
name
self
.
_num_input
+=
1
self
.
_
nodes
[
node
.
name
]
=
_sym
.
Variable
(
name
=
node
.
name
,
shape
=
self
.
_output_shapes
[
node
.
name
][
0
])
try
:
self
.
_output_shapes
[
node
.
name
]
=
\
[
tensor_util
.
TensorShapeProtoToList
(
shape
)
\
for
shape
in
self
.
_parse_attr
(
node
.
attr
)[
'_output_shapes'
]]
self
.
_nodes
[
node
.
name
]
=
_sym
.
Variable
(
name
=
node
.
name
,
shape
=
self
.
_output_shapes
[
node
.
name
][
0
])
input_shapes
[
self
.
_nodes
[
node
.
name
]]
=
self
.
_output_shapes
[
node
.
name
]
except
KeyError
:
raise
NotImplementedError
(
\
"Please freeze the graph with add_shapes=True"
)
#input_shapes[self._nodes[node.name]] = self._output_shapes[node.name]
elif
node
.
op
==
"Const"
:
if
self
.
_input_node
==
''
:
self
.
_input_node
=
node
.
name
self
.
_num_input
+=
1
self
.
_nodes
[
node
.
name
]
=
_sym
.
Variable
(
name
=
node
.
name
)
else
:
# Rest all nodes are Param nodes, lets parse
self
.
_num_param
+=
1
for
key
,
value
in
node
.
attr
.
items
():
self
.
_parse_param
(
key
,
value
,
node
.
name
)
if
node
.
name
not
in
self
.
_nodes
:
raise
NotImplementedError
(
\
"Const {} couldn't be converted to Param."
.
format
(
node
.
name
))
attr
=
self
.
_parse_attr
(
node
.
attr
)
#Variable converted to Const will not have only value attr
if
'value'
in
attr
:
tensor_value
=
attr
[
'value'
]
self
.
_output_shapes
[
node
.
name
]
=
\
[
tensor_util
.
TensorShapeProtoToList
(
\
tensor_value
.
tensor_shape
)]
elif
'_output_shapes'
in
attr
:
self
.
_output_shapes
[
node
.
name
]
=
\
[
tensor_util
.
TensorShapeProtoToList
(
shape
)
\
for
shape
in
self
.
_parse_attr
(
node
.
attr
)[
'_output_shapes'
]]
else
:
# All Const nodes are Param nodes, lets parse
self
.
_num_param
+=
1
for
key
,
value
in
node
.
attr
.
items
():
self
.
_parse_param
(
key
,
value
,
node
.
name
)
if
node
.
name
not
in
self
.
_nodes
:
raise
NotImplementedError
(
\
"
Please freeze the graph with add_shapes=True"
)
else
:
"
Const {} couldn't be converted to Param."
.
format
(
node
.
name
)
)
attr
=
self
.
_parse_attr
(
node
.
attr
)
try
:
self
.
_output_shapes
[
node
.
name
]
=
\
[
tensor_util
.
TensorShapeProtoToList
(
shape
)
\
for
shape
in
attr
[
'_output_shapes'
]]
except
KeyError
:
raise
NotImplementedError
(
\
"Please freeze the graph with add_shapes=True"
)
else
:
# Pass the parsed shapes instead
attr
[
"_output_shapes"
]
=
self
.
_output_shapes
[
node
.
name
]
...
...
@@ -1045,11 +941,12 @@ class GraphProto(object):
if
":"
in
node
.
input
[
0
]:
in_name
,
_
=
node
.
input
[
0
]
.
split
(
':'
)
node
.
input
[
0
]
=
in_name
# Fill shapes for all inputs in a list
try
:
inputs
=
[
self
.
_nodes
[
i
]
for
i
in
node
.
input
]
for
i
in
node
.
input
:
if
i
not
in
self
.
_params
:
input_shapes
[
self
.
_nodes
[
i
]]
=
self
.
_output_shapes
[
i
]
input_shapes
[
self
.
_nodes
[
i
]]
=
self
.
_output_shapes
[
i
]
attr
[
'_input_shapes'
]
=
input_shapes
except
KeyError
:
# TODO: Need to find clean way to handle '^CheckNumerics'
...
...
@@ -1061,6 +958,7 @@ class GraphProto(object):
# Assuming only one output.
self
.
_nodes
[
node
.
name
]
=
op
node_output
=
op
# Assume the final node is the output node
out
=
node_output
...
...
@@ -1068,11 +966,32 @@ class GraphProto(object):
if
self
.
_num_rnn_layer
:
out_rnn
=
_sym
.
concatenate
(
*
self
.
_out_rnn
,
axis
=
0
)
out
=
[
out
,
out_rnn
]
if
isinstance
(
out
,
list
):
out
=
_sym
.
Group
(
out
)
return
out
,
self
.
_params
def
_parse_import_prerequisites
(
self
,
graph
):
""" Calculate the named preconditions from TensorFlow `graph`.
Return prerequisites for parsing:
a. Set of operator names which don't have their mapping in TVM, i.e.
which are not supported
"""
missing_operators
=
set
()
for
node
in
graph
.
node
:
if
node
.
op
==
"Placeholder"
:
pass
elif
node
.
op
==
"Const"
:
pass
else
:
if
any
([
node
.
op
in
t
for
t
in
[
_identity_list
,
_convert_map
,
_convert_map_rnn
]]):
pass
else
:
missing_operators
.
add
(
node
.
op
)
return
missing_operators
def
_parse_param
(
self
,
key
,
value
,
name
):
try
:
from
tensorflow.python.framework
import
tensor_util
...
...
@@ -1082,6 +1001,13 @@ class GraphProto(object):
if
key
==
'value'
:
np_array
=
tensor_util
.
MakeNdarray
(
value
.
tensor
)
if
np_array
.
dtype
==
np
.
dtype
(
object
):
# Object types are generally tensorflow DT_STRING (DecodeJpeg op).
# Just leave it as placeholder.
self
.
_nodes
[
name
]
=
_sym
.
Variable
(
name
=
name
)
return
array_ndim
=
len
(
np_array
.
shape
)
if
array_ndim
==
0
:
new_array
=
np
.
empty
([
1
],
dtype
=
np_array
.
dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment