Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
51e2e31f
Commit
51e2e31f
authored
Apr 20, 2019
by
Yong Wu
Committed by
MORITA Kazutaka
Apr 21, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Frontend][TF] Fix Placeholder issue (#2834)
* [Frontend][TF] Fix Placeholder issue * Add test cases
parent
7e34988e
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
71 additions
and
29 deletions
+71
-29
nnvm/python/nnvm/frontend/tensorflow.py
+12
-13
nnvm/tests/python/frontend/tensorflow/test_forward.py
+24
-0
python/tvm/relay/frontend/tensorflow.py
+13
-16
tests/python/frontend/tensorflow/test_forward.py
+22
-0
No files found.
nnvm/python/nnvm/frontend/tensorflow.py
View file @
51e2e31f
...
@@ -126,7 +126,7 @@ def _argx(func, func_name):
...
@@ -126,7 +126,7 @@ def _argx(func, func_name):
def
_elemwise
(
name
):
def
_elemwise
(
name
):
def
_impl
(
inputs
,
attr
,
*
args
):
def
_impl
(
inputs
,
attr
,
*
args
):
assert
len
(
inputs
)
==
2
,
"
Math op take 2 inputs, {} given"
.
format
(
len
(
inputs
))
assert
len
(
inputs
)
==
2
,
"
{} take 2 inputs, {} given"
.
format
(
name
,
len
(
inputs
))
op_name
=
_math_name_picker
(
name
)(
attr
)
op_name
=
_math_name_picker
(
name
)(
attr
)
return
get_nnvm_op
(
op_name
)(
*
inputs
)
return
get_nnvm_op
(
op_name
)(
*
inputs
)
return
_impl
return
_impl
...
@@ -1217,9 +1217,10 @@ class GraphProto(object):
...
@@ -1217,9 +1217,10 @@ class GraphProto(object):
for
node
in
graph
.
node
:
for
node
in
graph
.
node
:
if
node
.
op
==
'Placeholder'
:
if
node
.
op
==
'Placeholder'
:
# Give priority to user argument.
if
shape
and
node
.
name
in
shape
:
if
shape
and
node
.
name
in
shape
:
self
.
_input_shapes
[
node
.
name
]
=
list
(
shape
[
node
.
name
])
self
.
_input_shapes
[
node
.
name
]
=
list
(
shape
[
node
.
name
])
continue
else
:
self
.
_input_shapes
[
node
.
name
]
=
\
self
.
_input_shapes
[
node
.
name
]
=
\
tensor_util
.
TensorShapeProtoToList
(
node
.
attr
[
'shape'
]
.
shape
)
tensor_util
.
TensorShapeProtoToList
(
node
.
attr
[
'shape'
]
.
shape
)
for
idx
,
dim
in
enumerate
(
self
.
_input_shapes
[
node
.
name
]):
for
idx
,
dim
in
enumerate
(
self
.
_input_shapes
[
node
.
name
]):
...
@@ -1228,6 +1229,13 @@ class GraphProto(object):
...
@@ -1228,6 +1229,13 @@ class GraphProto(object):
warnings
.
warn
(
"Use 1 instead of -1 in shape of operator
%
s."
warnings
.
warn
(
"Use 1 instead of -1 in shape of operator
%
s."
%
node
.
name
)
%
node
.
name
)
self
.
_nodes
[
node
.
name
]
=
_sym
.
Variable
(
name
=
node
.
name
,
shape
=
self
.
_input_shapes
[
node
.
name
])
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
self
.
_outputs_are_0d
[
node
.
name
]
=
[
\
not
tshape
if
isinstance
(
tshape
,
list
)
else
False
\
for
tshape
in
self
.
_output_shapes
[
node
.
name
]]
# Ignore user's input shape for Non placeholder
# Ignore user's input shape for Non placeholder
elif
node
.
op
==
'Const'
:
elif
node
.
op
==
'Const'
:
tensor_value
=
node
.
attr
[
'value'
]
.
tensor
tensor_value
=
node
.
attr
[
'value'
]
.
tensor
...
@@ -1250,11 +1258,6 @@ class GraphProto(object):
...
@@ -1250,11 +1258,6 @@ class GraphProto(object):
# Variable converted to Const will not have only value attr
# Variable converted to Const will not have only value attr
if
'value'
in
attr
and
node
.
op
==
'Const'
:
if
'value'
in
attr
and
node
.
op
==
'Const'
:
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
elif
shape
and
node
.
name
in
shape
:
# Give priority to user argument.
self
.
_output_shapes
[
node
.
name
]
=
[
shape
[
node
.
name
]]
elif
node
.
op
==
'Placeholder'
:
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
elif
'_output_shapes'
in
attr
:
elif
'_output_shapes'
in
attr
:
self
.
_output_shapes
[
node
.
name
]
=
\
self
.
_output_shapes
[
node
.
name
]
=
\
[
tensor_util
.
TensorShapeProtoToList
(
tshape
)
\
[
tensor_util
.
TensorShapeProtoToList
(
tshape
)
\
...
@@ -1269,11 +1272,7 @@ class GraphProto(object):
...
@@ -1269,11 +1272,7 @@ class GraphProto(object):
not
tshape
if
isinstance
(
tshape
,
list
)
else
False
\
not
tshape
if
isinstance
(
tshape
,
list
)
else
False
\
for
tshape
in
self
.
_output_shapes
[
node
.
name
]]
for
tshape
in
self
.
_output_shapes
[
node
.
name
]]
if
node
.
op
==
"Placeholder"
:
if
node
.
op
==
"Const"
:
self
.
_nodes
[
node
.
name
]
=
_sym
.
Variable
(
name
=
node
.
name
,
shape
=
self
.
_input_shapes
[
node
.
name
])
elif
node
.
op
==
"Const"
:
# All Const nodes are Param nodes, lets parse
# All Const nodes are Param nodes, lets parse
self
.
_num_param
+=
1
self
.
_num_param
+=
1
for
key
,
value
in
node
.
attr
.
items
():
for
key
,
value
in
node
.
attr
.
items
():
...
@@ -1284,7 +1283,7 @@ class GraphProto(object):
...
@@ -1284,7 +1283,7 @@ class GraphProto(object):
attr
=
self
.
_parse_attr
(
node
.
attr
)
attr
=
self
.
_parse_attr
(
node
.
attr
)
el
se
:
el
if
node
.
op
!=
"Placeholder"
:
# Pass the parsed shapes instead
# Pass the parsed shapes instead
attr
[
"_output_shapes"
]
=
output_shapes
=
self
.
_output_shapes
[
node
.
name
]
attr
[
"_output_shapes"
]
=
output_shapes
=
self
.
_output_shapes
[
node
.
name
]
...
...
nnvm/tests/python/frontend/tensorflow/test_forward.py
View file @
51e2e31f
...
@@ -941,6 +941,29 @@ def test_forward_resnetv2():
...
@@ -941,6 +941,29 @@ def test_forward_resnetv2():
tvm
.
testing
.
assert_allclose
(
np
.
squeeze
(
tvm_output
[
0
]),
np
.
squeeze
(
tf_output
[
0
]),
rtol
=
1e-5
,
atol
=
1e-5
)
tvm
.
testing
.
assert_allclose
(
np
.
squeeze
(
tvm_output
[
0
]),
np
.
squeeze
(
tf_output
[
0
]),
rtol
=
1e-5
,
atol
=
1e-5
)
#######################################################################
#######################################################################
# Placeholder
# -----------
def
test_forward_placeholder
():
'''test a simple pb with Placeholder node in the end of GraphDef'''
with
tf
.
Graph
()
.
as_default
():
graph_def
=
tf_testing
.
get_workload
(
"Custom/placeholder.pb"
)
# Call the utility to import the graph definition into default graph.
graph_def
=
tf_testing
.
ProcessGraphDefParam
(
graph_def
)
data
=
np
.
random
.
uniform
(
size
=
(
1
,
224
,
224
,
3
))
.
astype
(
'float32'
)
out_node
=
'mul'
with
tf
.
Session
()
as
sess
:
# Add shapes to the graph.
graph_def
=
tf_testing
.
AddShapesToGraphDef
(
sess
,
out_node
)
tf_output
=
run_tf_graph
(
sess
,
data
,
'Placeholder:0'
,
out_node
+
':0'
)
tvm_output
=
run_tvm_graph
(
graph_def
,
data
,
'Placeholder'
)
print
(
"tf_output is {}
\n
tvm_output is {}"
.
format
(
tf_output
,
tvm_output
))
tvm
.
testing
.
assert_allclose
(
np
.
squeeze
(
tvm_output
[
0
]),
np
.
squeeze
(
tf_output
[
0
]),
rtol
=
1e-5
,
atol
=
1e-5
)
#######################################################################
# PTB
# PTB
# ---
# ---
dir
(
tf
.
contrib
)
dir
(
tf
.
contrib
)
...
@@ -1261,6 +1284,7 @@ if __name__ == '__main__':
...
@@ -1261,6 +1284,7 @@ if __name__ == '__main__':
test_forward_inception_v1
()
test_forward_inception_v1
()
test_forward_mobilenet
()
test_forward_mobilenet
()
test_forward_resnetv2
()
test_forward_resnetv2
()
test_forward_placeholder
()
test_forward_ptb
()
test_forward_ptb
()
# RNN
# RNN
...
...
python/tvm/relay/frontend/tensorflow.py
View file @
51e2e31f
...
@@ -239,7 +239,7 @@ def _argx(func, func_name):
...
@@ -239,7 +239,7 @@ def _argx(func, func_name):
def
_elemwise
(
name
):
def
_elemwise
(
name
):
def
_impl
(
inputs
,
attr
,
*
args
):
def
_impl
(
inputs
,
attr
,
*
args
):
assert
len
(
inputs
)
==
2
,
"
Math op take 2 inputs, {} given"
.
format
(
len
(
inputs
))
assert
len
(
inputs
)
==
2
,
"
{} take 2 inputs, {} given"
.
format
(
name
,
len
(
inputs
))
return
_get_relay_op
(
name
)(
*
inputs
)
return
_get_relay_op
(
name
)(
*
inputs
)
return
_impl
return
_impl
...
@@ -1704,9 +1704,10 @@ class GraphProto(object):
...
@@ -1704,9 +1704,10 @@ class GraphProto(object):
node_name_prefix
=
node
.
name
.
rsplit
(
'/'
,
1
)[
0
]
node_name_prefix
=
node
.
name
.
rsplit
(
'/'
,
1
)[
0
]
control_flow_node_map
[
node_name_prefix
]
.
add
(
node
.
op
)
control_flow_node_map
[
node_name_prefix
]
.
add
(
node
.
op
)
if
node
.
op
==
'Placeholder'
:
if
node
.
op
==
'Placeholder'
:
# Give priority to user argument.
if
shape
and
node
.
name
in
shape
:
if
shape
and
node
.
name
in
shape
:
self
.
_input_shapes
[
node
.
name
]
=
list
(
shape
[
node
.
name
])
self
.
_input_shapes
[
node
.
name
]
=
list
(
shape
[
node
.
name
])
continue
else
:
self
.
_input_shapes
[
node
.
name
]
=
\
self
.
_input_shapes
[
node
.
name
]
=
\
tensor_util
.
TensorShapeProtoToList
(
node
.
attr
[
'shape'
]
.
shape
)
tensor_util
.
TensorShapeProtoToList
(
node
.
attr
[
'shape'
]
.
shape
)
for
idx
,
dim
in
enumerate
(
self
.
_input_shapes
[
node
.
name
]):
for
idx
,
dim
in
enumerate
(
self
.
_input_shapes
[
node
.
name
]):
...
@@ -1715,6 +1716,12 @@ class GraphProto(object):
...
@@ -1715,6 +1716,12 @@ class GraphProto(object):
warnings
.
warn
(
"Use 1 instead of -1 in shape of operator
%
s."
warnings
.
warn
(
"Use 1 instead of -1 in shape of operator
%
s."
%
node
.
name
)
%
node
.
name
)
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
attr
=
self
.
_parse_attr
(
node
.
attr
)
self
.
_nodes
[
node
.
name
]
=
[
_expr
.
var
(
node
.
name
,
shape
=
self
.
_input_shapes
[
node
.
name
],
dtype
=
attr
[
'dtype'
]
.
name
)]
# Ignore user's input shape for Non placeholder
# Ignore user's input shape for Non placeholder
elif
node
.
op
==
'Const'
:
elif
node
.
op
==
'Const'
:
tensor_value
=
node
.
attr
[
'value'
]
.
tensor
tensor_value
=
node
.
attr
[
'value'
]
.
tensor
...
@@ -1736,11 +1743,6 @@ class GraphProto(object):
...
@@ -1736,11 +1743,6 @@ class GraphProto(object):
# Variable converted to Const will not have only value attr
# Variable converted to Const will not have only value attr
if
'value'
in
attr
and
node
.
op
==
'Const'
:
if
'value'
in
attr
and
node
.
op
==
'Const'
:
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
elif
shape
and
node
.
name
in
shape
:
# Give priority to user argument.
self
.
_output_shapes
[
node
.
name
]
=
[
shape
[
node
.
name
]]
elif
node
.
op
==
'Placeholder'
:
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
elif
'_output_shapes'
in
attr
:
elif
'_output_shapes'
in
attr
:
self
.
_output_shapes
[
node
.
name
]
=
\
self
.
_output_shapes
[
node
.
name
]
=
\
[
tensor_util
.
TensorShapeProtoToList
(
tshape
)
\
[
tensor_util
.
TensorShapeProtoToList
(
tshape
)
\
...
@@ -1755,13 +1757,7 @@ class GraphProto(object):
...
@@ -1755,13 +1757,7 @@ class GraphProto(object):
not
shape
if
isinstance
(
tshape
,
list
)
else
False
\
not
shape
if
isinstance
(
tshape
,
list
)
else
False
\
for
tshape
in
self
.
_output_shapes
[
node
.
name
]]
for
tshape
in
self
.
_output_shapes
[
node
.
name
]]
if
node
.
op
==
"Placeholder"
:
if
node
.
op
==
"Const"
:
self
.
_output_shapes
[
node
.
name
]
=
[
self
.
_input_shapes
[
node
.
name
]]
self
.
_nodes
[
node
.
name
]
=
[
_expr
.
var
(
node
.
name
,
shape
=
self
.
_input_shapes
[
node
.
name
],
dtype
=
attr
[
'dtype'
]
.
name
)]
elif
node
.
op
==
"Const"
:
# All Const nodes are Param nodes, lets parse
# All Const nodes are Param nodes, lets parse
self
.
_num_param
+=
1
self
.
_num_param
+=
1
for
key
,
value
in
node
.
attr
.
items
():
for
key
,
value
in
node
.
attr
.
items
():
...
@@ -1772,7 +1768,7 @@ class GraphProto(object):
...
@@ -1772,7 +1768,7 @@ class GraphProto(object):
attr
=
self
.
_parse_attr
(
node
.
attr
)
attr
=
self
.
_parse_attr
(
node
.
attr
)
el
se
:
el
if
node
.
op
!=
"Placeholder"
:
# Pass the parsed shapes instead
# Pass the parsed shapes instead
attr
[
"_output_shapes"
]
=
output_shapes
=
self
.
_output_shapes
[
node
.
name
]
attr
[
"_output_shapes"
]
=
output_shapes
=
self
.
_output_shapes
[
node
.
name
]
...
@@ -1816,7 +1812,8 @@ class GraphProto(object):
...
@@ -1816,7 +1812,8 @@ class GraphProto(object):
input_shapes
[
in_sym
[
0
]]
=
input_shape
input_shapes
[
in_sym
[
0
]]
=
input_shape
# This means the node is 1d in Relay and 0d in TF.
# This means the node is 1d in Relay and 0d in TF.
# See `_expand_dims_0d_aware`.
# See `_expand_dims_0d_aware`.
if
self
.
_outputs_are_0d
[
node_name
][
tensor_slot
]
and
input_shape
:
if
node_name
in
self
.
_outputs_are_0d
\
and
self
.
_outputs_are_0d
[
node_name
][
tensor_slot
]
and
input_shape
:
input_0d_mismatch
.
add
(
in_sym
[
0
])
input_0d_mismatch
.
add
(
in_sym
[
0
])
attr
[
'_input_shapes'
]
=
input_shapes
attr
[
'_input_shapes'
]
=
input_shapes
...
...
tests/python/frontend/tensorflow/test_forward.py
View file @
51e2e31f
...
@@ -1134,6 +1134,27 @@ def test_forward_resnetv2():
...
@@ -1134,6 +1134,27 @@ def test_forward_resnetv2():
tvm
.
testing
.
assert_allclose
(
np
.
squeeze
(
tvm_output
[
0
]),
np
.
squeeze
(
tf_output
[
0
]),
rtol
=
1e-5
,
atol
=
1e-5
)
tvm
.
testing
.
assert_allclose
(
np
.
squeeze
(
tvm_output
[
0
]),
np
.
squeeze
(
tf_output
[
0
]),
rtol
=
1e-5
,
atol
=
1e-5
)
#######################################################################
#######################################################################
# Placeholder
# -----------
def
test_forward_placeholder
():
'''test a simple pb with Placeholder node in the end of GraphDef'''
with
tf
.
Graph
()
.
as_default
():
graph_def
=
tf_testing
.
get_workload
(
"Custom/placeholder.pb"
)
# Call the utility to import the graph definition into default graph.
graph_def
=
tf_testing
.
ProcessGraphDefParam
(
graph_def
)
data
=
np
.
random
.
uniform
(
size
=
(
1
,
224
,
224
,
3
))
.
astype
(
'float32'
)
out_node
=
'mul'
with
tf
.
Session
()
as
sess
:
# Add shapes to the graph.
graph_def
=
tf_testing
.
AddShapesToGraphDef
(
sess
,
out_node
)
tf_output
=
run_tf_graph
(
sess
,
data
,
'Placeholder:0'
,
out_node
+
':0'
)
tvm_output
=
run_tvm_graph
(
graph_def
,
data
,
'Placeholder'
)
print
(
"tf_output is {}
\n
tvm_output is {}"
.
format
(
tf_output
,
tvm_output
))
tvm
.
testing
.
assert_allclose
(
np
.
squeeze
(
tvm_output
[
0
]),
np
.
squeeze
(
tf_output
[
0
]),
rtol
=
1e-5
,
atol
=
1e-5
)
#######################################################################
# PTB
# PTB
# ---
# ---
dir
(
tf
.
contrib
)
dir
(
tf
.
contrib
)
...
@@ -1514,6 +1535,7 @@ if __name__ == '__main__':
...
@@ -1514,6 +1535,7 @@ if __name__ == '__main__':
test_forward_inception_v1
()
test_forward_inception_v1
()
test_forward_mobilenet
()
test_forward_mobilenet
()
test_forward_resnetv2
()
test_forward_resnetv2
()
test_forward_placeholder
()
test_forward_ptb
()
test_forward_ptb
()
# RNN
# RNN
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment