Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
10b7757a
Commit
10b7757a
authored
Jul 05, 2018
by
Albin Joy
Committed by
Tianqi Chen
Jul 04, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NNVM][TENSORFLOW] Fixed variable ops shape parsing issue (#1381)
parent
2fa0eca1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
4 deletions
+52
-4
nnvm/python/nnvm/frontend/tensorflow.py
+11
-4
nnvm/tests/python/frontend/tensorflow/test_forward.py
+41
-0
No files found.
nnvm/python/nnvm/frontend/tensorflow.py
View file @
10b7757a
...
@@ -593,11 +593,18 @@ class GraphProto(object):
...
@@ -593,11 +593,18 @@ class GraphProto(object):
raise
NotImplementedError
(
\
raise
NotImplementedError
(
\
"Const {} couldn't be converted to Param."
.
format
(
node
.
name
))
"Const {} couldn't be converted to Param."
.
format
(
node
.
name
))
try
:
attr
=
self
.
_parse_attr
(
node
.
attr
)
#Variable converted to Const will not have only value attr
if
'value'
in
attr
:
tensor_value
=
attr
[
'value'
]
self
.
_output_shapes
[
node
.
name
]
=
\
self
.
_output_shapes
[
node
.
name
]
=
\
[
tensor_util
.
TensorShapeProtoToList
(
shape
)
\
[
tensor_util
.
TensorShapeProtoToList
(
\
for
shape
in
self
.
_parse_attr
(
node
.
attr
)[
'_output_shapes'
]]
tensor_value
.
tensor_shape
)]
except
KeyError
:
elif
'_output_shapes'
in
attr
:
self
.
_output_shapes
[
node
.
name
]
=
\
[
tensor_util
.
TensorShapeProtoToList
(
shape
)
\
for
shape
in
self
.
_parse_attr
(
node
.
attr
)[
'_output_shapes'
]]
else
:
raise
NotImplementedError
(
\
raise
NotImplementedError
(
\
"Please freeze the graph with add_shapes=True"
)
"Please freeze the graph with add_shapes=True"
)
else
:
else
:
...
...
nnvm/tests/python/frontend/tensorflow/test_forward.py
View file @
10b7757a
...
@@ -14,6 +14,8 @@ from tensorflow.python.ops import nn_ops
...
@@ -14,6 +14,8 @@ from tensorflow.python.ops import nn_ops
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.ops
import
gen_array_ops
from
tensorflow.python.ops
import
gen_array_ops
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.ops
import
variable_scope
from
tensorflow.python.ops
import
variables
from
tensorflow.core.framework
import
graph_pb2
from
tensorflow.core.framework
import
graph_pb2
import
nnvm.testing.tf
import
nnvm.testing.tf
...
@@ -393,6 +395,44 @@ def test_forward_sigmoid():
...
@@ -393,6 +395,44 @@ def test_forward_sigmoid():
_test_sigmoid
(
np
.
random
.
uniform
(
size
=
(
3
,
4
,
4
,
3
))
.
astype
(
'float32'
))
_test_sigmoid
(
np
.
random
.
uniform
(
size
=
(
3
,
4
,
4
,
3
))
.
astype
(
'float32'
))
#######################################################################
# Variable
# --------
def
_test_variable
(
data
):
tf
.
reset_default_graph
()
input_op
=
array_ops
.
placeholder
(
shape
=
data
.
shape
,
dtype
=
data
.
dtype
)
input_tensor
=
array_ops
.
reshape
(
input_op
,
data
.
shape
)
size
=
input_tensor
.
shape
.
dims
[
1
]
with
variable_scope
.
variable_scope
(
"linear"
,
reuse
=
None
):
w
=
variable_scope
.
get_variable
(
"w"
,
shape
=
[
size
,
size
],
dtype
=
input_tensor
.
dtype
)
# pylint: disable=unused-variable
output_op
=
math_ops
.
matmul
(
input_tensor
,
w
)
# pylint: enable=unused-variable
with
tf
.
Session
()
as
sess
:
sess
.
run
(
variables
.
global_variables_initializer
())
final_graph_def
=
tf
.
graph_util
.
convert_variables_to_constants
(
sess
,
sess
.
graph
.
as_graph_def
(
add_shapes
=
True
),
[
'MatMul'
],
)
tf_output
=
run_tf_graph
(
sess
,
data
,
'Placeholder:0'
,
'MatMul:0'
)
tvm_output
=
run_tvm_graph
(
final_graph_def
,
data
,
"Placeholder"
,
tf_output
.
shape
,
data
.
dtype
)
np
.
testing
.
assert_allclose
(
tf_output
,
tvm_output
,
atol
=
1e-5
,
rtol
=
1e-5
)
sess
.
close
()
def
test_forward_variable
():
"""Variable type op test"""
_test_variable
(
np
.
random
.
uniform
(
size
=
(
32
,
100
))
.
astype
(
'float32'
))
#######################################################################
#######################################################################
# Multi Input to graph
# Multi Input to graph
# --------------------
# --------------------
...
@@ -503,3 +543,4 @@ if __name__ == '__main__':
...
@@ -503,3 +543,4 @@ if __name__ == '__main__':
test_forward_inception_v3
()
test_forward_inception_v3
()
test_forward_inception_v1
()
test_forward_inception_v1
()
test_forward_mobilenet
()
test_forward_mobilenet
()
test_forward_variable
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment