Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
7b59b8ef
Commit
7b59b8ef
authored
Aug 04, 2018
by
Siva
Committed by
Tianqi Chen
Aug 04, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NNVM][TENSORFLOW] Cleanup redundant code. (#1551)
parent
136061dc
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
70 additions
and
144 deletions
+70
-144
nnvm/python/nnvm/frontend/tensorflow.py
+70
-144
No files found.
nnvm/python/nnvm/frontend/tensorflow.py
View file @
7b59b8ef
...
@@ -168,81 +168,7 @@ def _pooling(name):
...
@@ -168,81 +168,7 @@ def _pooling(name):
custom_check
=
_dimension_constraint
())(
inputs
,
attr
)
custom_check
=
_dimension_constraint
())(
inputs
,
attr
)
return
_impl
return
_impl
def
_conv
():
def
_conv
(
opname
):
def
_impl
(
inputs
,
attr
,
params
):
attr
[
'data_format'
]
=
attr
[
'data_format'
]
.
decode
(
"utf-8"
)
# Extract kernel shape from params
conv_param_weights
=
params
[
inputs
[
1
]
.
list_output_names
()[
0
]]
if
attr
[
'data_format'
]
==
'NHWC'
:
attr
[
'kernel_shape'
]
=
(
conv_param_weights
.
shape
[
0
],
conv_param_weights
.
shape
[
1
])
attr
[
'channels'
]
=
conv_param_weights
.
shape
[
3
]
if
'dilations'
in
attr
:
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
0
],
attr
[
'dilations'
][
1
])
elif
attr
[
'data_format'
]
==
'NCHW'
:
attr
[
'kernel_shape'
]
=
(
conv_param_weights
.
shape
[
2
],
conv_param_weights
.
shape
[
3
])
attr
[
'channels'
]
=
conv_param_weights
.
shape
[
1
]
if
'dilations'
in
attr
:
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
2
],
attr
[
'dilations'
][
3
])
else
:
raise
TypeError
(
"Unsupported data format type : {}"
.
format
(
attr
[
'data_format'
]))
# Fix strides
attr
[
'strides'
]
=
(
attr
[
'strides'
][
1
],
attr
[
'strides'
][
2
])
# Fix padding
input_shapes
=
attr
[
'_input_shapes'
][
inputs
[
0
]]
attr
[
'padding'
]
=
attr
[
'padding'
]
.
decode
(
"utf-8"
)
if
attr
[
'padding'
]
==
'VALID'
:
attr
[
'padding'
]
=
[
0
,
0
]
elif
attr
[
'padding'
]
==
'SAME'
:
stride_h
,
stride_w
=
attr
[
'strides'
]
kernel_h
,
kernel_w
=
attr
[
'kernel_shape'
]
if
attr
[
'data_format'
]
==
'NHWC'
:
in_h
=
input_shapes
[
0
][
1
]
in_w
=
input_shapes
[
0
][
2
]
else
:
in_h
=
input_shapes
[
0
][
2
]
in_w
=
input_shapes
[
0
][
3
]
pad_v
=
_get_pad_pair
(
in_h
,
kernel_h
,
stride_h
)
pad_h
=
_get_pad_pair
(
in_w
,
kernel_w
,
stride_w
)
if
attr
[
'data_format'
]
==
'NHWC'
:
inputs
[
0
]
=
_sym
.
pad
(
data
=
inputs
[
0
],
pad_width
=
((
0
,
0
),
(
pad_v
[
0
],
pad_v
[
1
]),
(
pad_h
[
0
],
pad_h
[
1
]),
(
0
,
0
)))
else
:
inputs
[
0
]
=
_sym
.
pad
(
data
=
inputs
[
0
],
pad_width
=
((
0
,
0
),
(
0
,
0
),
(
pad_v
[
0
],
pad_v
[
1
]),
(
pad_h
[
0
],
pad_h
[
1
])))
attr
[
'padding'
]
=
[
0
,
0
]
else
:
raise
TypeError
(
"Unsupported padding type : {}"
.
format
(
attr
[
'padding'
]))
if
'kernel_layout'
not
in
attr
:
attr
[
'kernel_layout'
]
=
'HWIO'
if
attr
[
'data_format'
]
==
'NHWC'
else
'OIHW'
return
AttrCvt
(
op_name
=
_dimension_picker
(
'conv'
),
transforms
=
{
'kernel_shape'
:
'kernel_size'
,
'data_format'
:
'layout'
,
'dilations'
:
(
'dilation'
,
(
0
,
0
)),
'group'
:
(
'groups'
,
1
)},
extras
=
{
'use_bias'
:
len
(
inputs
)
==
3
},
custom_check
=
_dimension_constraint
())(
inputs
,
attr
)
return
_impl
def
_depthwise_conv
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
attr
[
'data_format'
]
=
attr
[
'data_format'
]
.
decode
(
"utf-8"
)
attr
[
'data_format'
]
=
attr
[
'data_format'
]
.
decode
(
"utf-8"
)
input_shapes
=
attr
[
'_input_shapes'
][
inputs
[
0
]]
input_shapes
=
attr
[
'_input_shapes'
][
inputs
[
0
]]
...
@@ -253,24 +179,33 @@ def _depthwise_conv():
...
@@ -253,24 +179,33 @@ def _depthwise_conv():
if
attr
[
'data_format'
]
==
'NHWC'
:
if
attr
[
'data_format'
]
==
'NHWC'
:
kernel_h
,
kernel_w
,
_
,
depth_mult
=
conv_param_weights
.
shape
kernel_h
,
kernel_w
,
_
,
depth_mult
=
conv_param_weights
.
shape
attr
[
'kernel_shape'
]
=
(
conv_param_weights
.
shape
[
0
],
conv_param_weights
.
shape
[
1
])
attr
[
'kernel_shape'
]
=
(
conv_param_weights
.
shape
[
0
],
conv_param_weights
.
shape
[
1
])
if
opname
==
'conv'
:
attr
[
'channels'
]
=
conv_param_weights
.
shape
[
3
]
else
:
attr
[
'channels'
]
=
input_shapes
[
0
][
3
]
*
depth_mult
attr
[
'channels'
]
=
input_shapes
[
0
][
3
]
*
depth_mult
if
'dilations'
in
attr
:
if
'dilations'
in
attr
:
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
0
],
attr
[
'dilations'
][
1
])
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
0
],
attr
[
'dilations'
][
1
])
elif
attr
[
'data_format'
]
==
'NCHW'
:
elif
attr
[
'data_format'
]
==
'NCHW'
:
depth_mult
,
_
,
kernel_h
,
kernel_w
=
conv_param_weights
.
shape
depth_mult
,
_
,
kernel_h
,
kernel_w
=
conv_param_weights
.
shape
attr
[
'kernel_shape'
]
=
(
conv_param_weights
.
shape
[
2
],
conv_param_weights
.
shape
[
3
])
attr
[
'kernel_shape'
]
=
(
conv_param_weights
.
shape
[
2
],
conv_param_weights
.
shape
[
3
])
if
opname
==
'conv'
:
attr
[
'channels'
]
=
conv_param_weights
.
shape
[
1
]
else
:
attr
[
'channels'
]
=
input_shapes
[
0
][
1
]
*
depth_mult
attr
[
'channels'
]
=
input_shapes
[
0
][
1
]
*
depth_mult
if
'dilations'
in
attr
:
if
'dilations'
in
attr
:
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
2
],
attr
[
'dilations'
][
3
])
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
2
],
attr
[
'dilations'
][
3
])
else
:
else
:
raise
TypeError
(
"Unsupported data format type : {}"
.
format
(
attr
[
'data_format'
]))
raise
TypeError
(
"Unsupported data format type : {}"
.
format
(
attr
[
'data_format'
]))
# Fix strides
attr
[
'strides'
]
=
(
attr
[
'strides'
][
1
],
attr
[
'strides'
][
2
])
# Fix groups
if
opname
==
'depthwise'
:
attr
[
'groups'
]
=
attr
[
'channels'
]
attr
[
'groups'
]
=
attr
[
'channels'
]
# Fix strides
attr
[
'strides'
]
=
(
attr
[
'strides'
][
1
],
attr
[
'strides'
][
2
])
# Fix padding
# Fix padding
attr
[
'padding'
]
=
attr
[
'padding'
]
.
decode
(
"utf-8"
)
attr
[
'padding'
]
=
attr
[
'padding'
]
.
decode
(
"utf-8"
)
...
@@ -308,6 +243,9 @@ def _depthwise_conv():
...
@@ -308,6 +243,9 @@ def _depthwise_conv():
raise
TypeError
(
"Unsupported padding type : {}"
.
format
(
attr
[
'padding'
]))
raise
TypeError
(
"Unsupported padding type : {}"
.
format
(
attr
[
'padding'
]))
if
'kernel_layout'
not
in
attr
:
if
'kernel_layout'
not
in
attr
:
if
opname
==
'conv'
:
attr
[
'kernel_layout'
]
=
'HWIO'
if
attr
[
'data_format'
]
==
'NHWC'
else
'OIHW'
else
:
attr
[
'kernel_layout'
]
=
'HWOI'
if
attr
[
'data_format'
]
==
'NHWC'
else
'OIHW'
attr
[
'kernel_layout'
]
=
'HWOI'
if
attr
[
'data_format'
]
==
'NHWC'
else
'OIHW'
return
AttrCvt
(
return
AttrCvt
(
...
@@ -687,7 +625,7 @@ _convert_map = {
...
@@ -687,7 +625,7 @@ _convert_map = {
'CheckNumerics'
:
_check_numerics
(),
'CheckNumerics'
:
_check_numerics
(),
'Concat'
:
_concat
(),
'Concat'
:
_concat
(),
'ConcatV2'
:
_concatV2
(),
'ConcatV2'
:
_concatV2
(),
'Conv2D'
:
_conv
(),
'Conv2D'
:
_conv
(
'conv'
),
'DecodeJpeg'
:
_decode_image
(),
'DecodeJpeg'
:
_decode_image
(),
'ExpandDims'
:
_expand_dims
(),
'ExpandDims'
:
_expand_dims
(),
'Identity'
:
_identity
(),
'Identity'
:
_identity
(),
...
@@ -704,7 +642,7 @@ _convert_map = {
...
@@ -704,7 +642,7 @@ _convert_map = {
'Squeeze'
:
_squeeze
(),
'Squeeze'
:
_squeeze
(),
'FusedBatchNorm'
:
_fused_batch_norm
(),
'FusedBatchNorm'
:
_fused_batch_norm
(),
'Relu6'
:
_relu6
(),
'Relu6'
:
_relu6
(),
'DepthwiseConv2dNative'
:
_
depthwise_conv
(
),
'DepthwiseConv2dNative'
:
_
conv
(
'depthwise'
),
'Shape'
:
_shape
(),
'Shape'
:
_shape
(),
'Sigmoid'
:
AttrCvt
(
'sigmoid'
),
'Sigmoid'
:
AttrCvt
(
'sigmoid'
),
'Fill'
:
_fill
(),
'Fill'
:
_fill
(),
...
@@ -895,28 +833,6 @@ class RecurrentNetworks(object):
...
@@ -895,28 +833,6 @@ class RecurrentNetworks(object):
params
,
num_layers
)
params
,
num_layers
)
return
sym
return
sym
def
_parse_import_prerequisites
(
graph
):
""" Calculate the named preconditions from TensorFlow `graph`.
Return prerequisites for parsing:
a. Set of operator names which don't have their mapping in TVM, i.e.
which are not supported
"""
missing_operators
=
set
()
for
node
in
graph
.
node
:
if
node
.
op
==
"Placeholder"
:
pass
elif
node
.
op
==
"Const"
:
pass
else
:
if
any
([
node
.
op
in
t
for
t
in
[
_identity_list
,
_convert_map
,
_convert_map_rnn
]]):
pass
else
:
missing_operators
.
add
(
node
.
op
)
return
missing_operators
class
GraphProto
(
object
):
class
GraphProto
(
object
):
""" A helper class for handling nnvm graph copying from Tensorflow GraphDef.
""" A helper class for handling nnvm graph copying from Tensorflow GraphDef.
Definition:
Definition:
...
@@ -925,12 +841,8 @@ class GraphProto(object):
...
@@ -925,12 +841,8 @@ class GraphProto(object):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_nodes
=
{}
self
.
_nodes
=
{}
self
.
_params
=
{}
self
.
_params
=
{}
self
.
_renames
=
{}
self
.
_replacements
=
{}
self
.
_output_shapes
=
{}
self
.
_output_shapes
=
{}
self
.
_num_input
=
0
self
.
_num_param
=
0
self
.
_num_param
=
0
self
.
_input_node
=
''
self
.
_num_rnn_layer
=
False
self
.
_num_rnn_layer
=
False
def
from_tensorflow
(
self
,
graph
):
def
from_tensorflow
(
self
,
graph
):
...
@@ -969,7 +881,7 @@ class GraphProto(object):
...
@@ -969,7 +881,7 @@ class GraphProto(object):
raise
ImportError
(
raise
ImportError
(
"Unable to import tensorflow which is required {}"
.
format
(
e
))
"Unable to import tensorflow which is required {}"
.
format
(
e
))
missing_operators
=
_parse_import_prerequisites
(
graph
)
missing_operators
=
self
.
_parse_import_prerequisites
(
graph
)
if
missing_operators
:
if
missing_operators
:
raise
NotImplementedError
(
\
raise
NotImplementedError
(
\
...
@@ -979,37 +891,13 @@ class GraphProto(object):
...
@@ -979,37 +891,13 @@ class GraphProto(object):
for
node
in
graph
.
node
:
for
node
in
graph
.
node
:
# Tensorflow doesn't have seperate list for params extraction.
# Tensorflow doesn't have seperate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes
=
{}
input_shapes
=
{}
if
node
.
op
==
"Placeholder"
:
self
.
_input_node
=
node
.
name
self
.
_num_input
+=
1
try
:
self
.
_output_shapes
[
node
.
name
]
=
\
[
tensor_util
.
TensorShapeProtoToList
(
shape
)
\
for
shape
in
self
.
_parse_attr
(
node
.
attr
)[
'_output_shapes'
]]
self
.
_nodes
[
node
.
name
]
=
_sym
.
Variable
(
name
=
node
.
name
,
shape
=
self
.
_output_shapes
[
node
.
name
][
0
])
input_shapes
[
self
.
_nodes
[
node
.
name
]]
=
self
.
_output_shapes
[
node
.
name
]
except
KeyError
:
raise
NotImplementedError
(
\
"Please freeze the graph with add_shapes=True"
)
elif
node
.
op
==
"Const"
:
if
self
.
_input_node
==
''
:
self
.
_input_node
=
node
.
name
self
.
_num_input
+=
1
self
.
_nodes
[
node
.
name
]
=
_sym
.
Variable
(
name
=
node
.
name
)
else
:
# Rest all nodes are Param nodes, lets parse
self
.
_num_param
+=
1
for
key
,
value
in
node
.
attr
.
items
():
self
.
_parse_param
(
key
,
value
,
node
.
name
)
if
node
.
name
not
in
self
.
_nodes
:
raise
NotImplementedError
(
\
"Const {} couldn't be converted to Param."
.
format
(
node
.
name
))
attr
=
self
.
_parse_attr
(
node
.
attr
)
attr
=
self
.
_parse_attr
(
node
.
attr
)
#Variable converted to Const will not have only value attr
#Variable converted to Const will not have only value attr
if
'value'
in
attr
:
if
'value'
in
attr
and
node
.
op
==
'Const'
:
tensor_value
=
attr
[
'value'
]
tensor_value
=
attr
[
'value'
]
self
.
_output_shapes
[
node
.
name
]
=
\
self
.
_output_shapes
[
node
.
name
]
=
\
[
tensor_util
.
TensorShapeProtoToList
(
\
[
tensor_util
.
TensorShapeProtoToList
(
\
...
@@ -1017,20 +905,28 @@ class GraphProto(object):
...
@@ -1017,20 +905,28 @@ class GraphProto(object):
elif
'_output_shapes'
in
attr
:
elif
'_output_shapes'
in
attr
:
self
.
_output_shapes
[
node
.
name
]
=
\
self
.
_output_shapes
[
node
.
name
]
=
\
[
tensor_util
.
TensorShapeProtoToList
(
shape
)
\
[
tensor_util
.
TensorShapeProtoToList
(
shape
)
\
for
shape
in
self
.
_parse_attr
(
node
.
attr
)
[
'_output_shapes'
]]
for
shape
in
attr
[
'_output_shapes'
]]
else
:
else
:
raise
NotImplementedError
(
\
raise
NotImplementedError
(
\
"Please freeze the graph with add_shapes=True"
)
"Please freeze the graph with add_shapes=True"
)
else
:
attr
=
self
.
_parse_attr
(
node
.
attr
)
if
node
.
op
==
"Placeholder"
:
try
:
self
.
_nodes
[
node
.
name
]
=
_sym
.
Variable
(
name
=
node
.
name
,
self
.
_output_shapes
[
node
.
name
]
=
\
shape
=
self
.
_output_shapes
[
node
.
name
][
0
])
[
tensor_util
.
TensorShapeProtoToList
(
shape
)
\
for
shape
in
attr
[
'_output_shapes'
]]
#input_shapes[self._nodes[node.name]] = self._output_shapes[node.name]
except
KeyError
:
elif
node
.
op
==
"Const"
:
# All Const nodes are Param nodes, lets parse
self
.
_num_param
+=
1
for
key
,
value
in
node
.
attr
.
items
():
self
.
_parse_param
(
key
,
value
,
node
.
name
)
if
node
.
name
not
in
self
.
_nodes
:
raise
NotImplementedError
(
\
raise
NotImplementedError
(
\
"
Please freeze the graph with add_shapes=True"
)
"
Const {} couldn't be converted to Param."
.
format
(
node
.
name
)
)
attr
=
self
.
_parse_attr
(
node
.
attr
)
else
:
# Pass the parsed shapes instead
# Pass the parsed shapes instead
attr
[
"_output_shapes"
]
=
self
.
_output_shapes
[
node
.
name
]
attr
[
"_output_shapes"
]
=
self
.
_output_shapes
[
node
.
name
]
...
@@ -1045,10 +941,11 @@ class GraphProto(object):
...
@@ -1045,10 +941,11 @@ class GraphProto(object):
if
":"
in
node
.
input
[
0
]:
if
":"
in
node
.
input
[
0
]:
in_name
,
_
=
node
.
input
[
0
]
.
split
(
':'
)
in_name
,
_
=
node
.
input
[
0
]
.
split
(
':'
)
node
.
input
[
0
]
=
in_name
node
.
input
[
0
]
=
in_name
# Fill shapes for all inputs in a list
try
:
try
:
inputs
=
[
self
.
_nodes
[
i
]
for
i
in
node
.
input
]
inputs
=
[
self
.
_nodes
[
i
]
for
i
in
node
.
input
]
for
i
in
node
.
input
:
for
i
in
node
.
input
:
if
i
not
in
self
.
_params
:
input_shapes
[
self
.
_nodes
[
i
]]
=
self
.
_output_shapes
[
i
]
input_shapes
[
self
.
_nodes
[
i
]]
=
self
.
_output_shapes
[
i
]
attr
[
'_input_shapes'
]
=
input_shapes
attr
[
'_input_shapes'
]
=
input_shapes
except
KeyError
:
except
KeyError
:
...
@@ -1061,6 +958,7 @@ class GraphProto(object):
...
@@ -1061,6 +958,7 @@ class GraphProto(object):
# Assuming only one output.
# Assuming only one output.
self
.
_nodes
[
node
.
name
]
=
op
self
.
_nodes
[
node
.
name
]
=
op
node_output
=
op
node_output
=
op
# Assume the final node is the output node
# Assume the final node is the output node
out
=
node_output
out
=
node_output
...
@@ -1068,11 +966,32 @@ class GraphProto(object):
...
@@ -1068,11 +966,32 @@ class GraphProto(object):
if
self
.
_num_rnn_layer
:
if
self
.
_num_rnn_layer
:
out_rnn
=
_sym
.
concatenate
(
*
self
.
_out_rnn
,
axis
=
0
)
out_rnn
=
_sym
.
concatenate
(
*
self
.
_out_rnn
,
axis
=
0
)
out
=
[
out
,
out_rnn
]
out
=
[
out
,
out_rnn
]
if
isinstance
(
out
,
list
):
if
isinstance
(
out
,
list
):
out
=
_sym
.
Group
(
out
)
out
=
_sym
.
Group
(
out
)
return
out
,
self
.
_params
return
out
,
self
.
_params
def
_parse_import_prerequisites
(
self
,
graph
):
""" Calculate the named preconditions from TensorFlow `graph`.
Return prerequisites for parsing:
a. Set of operator names which don't have their mapping in TVM, i.e.
which are not supported
"""
missing_operators
=
set
()
for
node
in
graph
.
node
:
if
node
.
op
==
"Placeholder"
:
pass
elif
node
.
op
==
"Const"
:
pass
else
:
if
any
([
node
.
op
in
t
for
t
in
[
_identity_list
,
_convert_map
,
_convert_map_rnn
]]):
pass
else
:
missing_operators
.
add
(
node
.
op
)
return
missing_operators
def
_parse_param
(
self
,
key
,
value
,
name
):
def
_parse_param
(
self
,
key
,
value
,
name
):
try
:
try
:
from
tensorflow.python.framework
import
tensor_util
from
tensorflow.python.framework
import
tensor_util
...
@@ -1082,6 +1001,13 @@ class GraphProto(object):
...
@@ -1082,6 +1001,13 @@ class GraphProto(object):
if
key
==
'value'
:
if
key
==
'value'
:
np_array
=
tensor_util
.
MakeNdarray
(
value
.
tensor
)
np_array
=
tensor_util
.
MakeNdarray
(
value
.
tensor
)
if
np_array
.
dtype
==
np
.
dtype
(
object
):
# Object types are generally tensorflow DT_STRING (DecodeJpeg op).
# Just leave it as placeholder.
self
.
_nodes
[
name
]
=
_sym
.
Variable
(
name
=
name
)
return
array_ndim
=
len
(
np_array
.
shape
)
array_ndim
=
len
(
np_array
.
shape
)
if
array_ndim
==
0
:
if
array_ndim
==
0
:
new_array
=
np
.
empty
([
1
],
dtype
=
np_array
.
dtype
)
new_array
=
np
.
empty
([
1
],
dtype
=
np_array
.
dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment