Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
32276146
Commit
32276146
authored
Dec 29, 2019
by
zhuochen
Committed by
Tianqi Chen
Dec 28, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix tf.compat.v1 issue for tf verison <=1.12 (#4593)
parent
e6d9f89c
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
16 deletions
+24
-16
python/tvm/relay/testing/tf.py
+12
-8
tutorials/frontend/from_tensorflow.py
+12
-8
No files found.
python/tvm/relay/testing/tf.py
View file @
32276146
...
@@ -28,9 +28,13 @@ import numpy as np
...
@@ -28,9 +28,13 @@ import numpy as np
# Tensorflow imports
# Tensorflow imports
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.core.framework
import
graph_pb2
from
tensorflow.core.framework
import
graph_pb2
from
tvm.contrib.download
import
download_testdata
from
tvm.contrib.download
import
download_testdata
try
:
tf_compat_v1
=
tf
.
compat
.
v1
except
ImportError
:
tf_compat_v1
=
tf
######################################################################
######################################################################
# Some helper functions
# Some helper functions
# ---------------------
# ---------------------
...
@@ -80,7 +84,7 @@ def AddShapesToGraphDef(session, out_node):
...
@@ -80,7 +84,7 @@ def AddShapesToGraphDef(session, out_node):
"""
"""
graph_def
=
tf
.
compat
.
v1
.
graph_util
.
convert_variables_to_constants
(
graph_def
=
tf
_compat_
v1
.
graph_util
.
convert_variables_to_constants
(
session
,
session
,
session
.
graph
.
as_graph_def
(
add_shapes
=
True
),
session
.
graph
.
as_graph_def
(
add_shapes
=
True
),
[
out_node
],
[
out_node
],
...
@@ -112,13 +116,13 @@ class NodeLookup(object):
...
@@ -112,13 +116,13 @@ class NodeLookup(object):
dict from integer node ID to human-readable string.
dict from integer node ID to human-readable string.
"""
"""
if
not
tf
.
compat
.
v1
.
io
.
gfile
.
e
xists
(
uid_lookup_path
):
if
not
tf
_compat_v1
.
gfile
.
E
xists
(
uid_lookup_path
):
tf
.
logging
.
fatal
(
'File does not exist
%
s'
,
uid_lookup_path
)
tf
.
logging
.
fatal
(
'File does not exist
%
s'
,
uid_lookup_path
)
if
not
tf
.
compat
.
v1
.
io
.
gfile
.
e
xists
(
label_lookup_path
):
if
not
tf
_compat_v1
.
gfile
.
E
xists
(
label_lookup_path
):
tf
.
logging
.
fatal
(
'File does not exist
%
s'
,
label_lookup_path
)
tf
.
logging
.
fatal
(
'File does not exist
%
s'
,
label_lookup_path
)
# Loads mapping from string UID to human-readable string
# Loads mapping from string UID to human-readable string
proto_as_ascii_lines
=
tf
.
compat
.
v1
.
gfile
.
GFile
(
uid_lookup_path
)
.
readlines
()
proto_as_ascii_lines
=
tf
_compat_
v1
.
gfile
.
GFile
(
uid_lookup_path
)
.
readlines
()
uid_to_human
=
{}
uid_to_human
=
{}
p
=
re
.
compile
(
r'[n\d]*[ \S,]*'
)
p
=
re
.
compile
(
r'[n\d]*[ \S,]*'
)
for
line
in
proto_as_ascii_lines
:
for
line
in
proto_as_ascii_lines
:
...
@@ -129,7 +133,7 @@ class NodeLookup(object):
...
@@ -129,7 +133,7 @@ class NodeLookup(object):
# Loads mapping from string UID to integer node ID.
# Loads mapping from string UID to integer node ID.
node_id_to_uid
=
{}
node_id_to_uid
=
{}
proto_as_ascii
=
tf
.
compat
.
v1
.
gfile
.
GFile
(
label_lookup_path
)
.
readlines
()
proto_as_ascii
=
tf
_compat_
v1
.
gfile
.
GFile
(
label_lookup_path
)
.
readlines
()
for
line
in
proto_as_ascii
:
for
line
in
proto_as_ascii
:
if
line
.
startswith
(
' target_class:'
):
if
line
.
startswith
(
' target_class:'
):
target_class
=
int
(
line
.
split
(
': '
)[
1
])
target_class
=
int
(
line
.
split
(
': '
)[
1
])
...
@@ -209,7 +213,7 @@ def get_workload(model_path, model_sub_path=None):
...
@@ -209,7 +213,7 @@ def get_workload(model_path, model_sub_path=None):
path_model
=
download_testdata
(
model_url
,
model_path
,
module
=
'tf'
)
path_model
=
download_testdata
(
model_url
,
model_path
,
module
=
'tf'
)
# Creates graph from saved graph_def.pb.
# Creates graph from saved graph_def.pb.
with
tf
.
compat
.
v1
.
gfile
.
FastGFile
(
path_model
,
'rb'
)
as
f
:
with
tf
_compat_
v1
.
gfile
.
FastGFile
(
path_model
,
'rb'
)
as
f
:
graph_def
=
tf
.
GraphDef
()
graph_def
=
tf
.
GraphDef
()
graph_def
.
ParseFromString
(
f
.
read
())
graph_def
.
ParseFromString
(
f
.
read
())
graph
=
tf
.
import_graph_def
(
graph_def
,
name
=
''
)
graph
=
tf
.
import_graph_def
(
graph_def
,
name
=
''
)
...
@@ -299,7 +303,7 @@ def _create_ptb_vocabulary(data_dir):
...
@@ -299,7 +303,7 @@ def _create_ptb_vocabulary(data_dir):
file_name
=
'ptb.train.txt'
file_name
=
'ptb.train.txt'
def
_read_words
(
filename
):
def
_read_words
(
filename
):
"""Read the data for creating vocabulary"""
"""Read the data for creating vocabulary"""
with
tf
.
compat
.
v1
.
gfile
.
GFile
(
filename
,
"r"
)
as
f
:
with
tf
_compat_
v1
.
gfile
.
GFile
(
filename
,
"r"
)
as
f
:
return
f
.
read
()
.
encode
(
"utf-8"
)
.
decode
(
"utf-8"
)
.
replace
(
"
\n
"
,
"<eos>"
)
.
split
()
return
f
.
read
()
.
encode
(
"utf-8"
)
.
decode
(
"utf-8"
)
.
replace
(
"
\n
"
,
"<eos>"
)
.
split
()
def
_build_vocab
(
filename
):
def
_build_vocab
(
filename
):
...
...
tutorials/frontend/from_tensorflow.py
View file @
32276146
...
@@ -34,6 +34,10 @@ import os.path
...
@@ -34,6 +34,10 @@ import os.path
# Tensorflow imports
# Tensorflow imports
import
tensorflow
as
tf
import
tensorflow
as
tf
try
:
tf_compat_v1
=
tf
.
compat
.
v1
except
ImportError
:
tf_compat_v1
=
tf
# Tensorflow utility functions
# Tensorflow utility functions
import
tvm.relay.testing.tf
as
tf_testing
import
tvm.relay.testing.tf
as
tf_testing
...
@@ -89,14 +93,14 @@ label_path = download_testdata(label_map_url, label_map, module='data')
...
@@ -89,14 +93,14 @@ label_path = download_testdata(label_map_url, label_map, module='data')
# ------------
# ------------
# Creates tensorflow graph definition from protobuf file.
# Creates tensorflow graph definition from protobuf file.
with
tf
.
compat
.
v1
.
gfile
.
GFile
(
model_path
,
'rb'
)
as
f
:
with
tf
_compat_
v1
.
gfile
.
GFile
(
model_path
,
'rb'
)
as
f
:
graph_def
=
tf
.
compat
.
v1
.
GraphDef
()
graph_def
=
tf
_compat_
v1
.
GraphDef
()
graph_def
.
ParseFromString
(
f
.
read
())
graph_def
.
ParseFromString
(
f
.
read
())
graph
=
tf
.
import_graph_def
(
graph_def
,
name
=
''
)
graph
=
tf
.
import_graph_def
(
graph_def
,
name
=
''
)
# Call the utility to import the graph definition into default graph.
# Call the utility to import the graph definition into default graph.
graph_def
=
tf_testing
.
ProcessGraphDefParam
(
graph_def
)
graph_def
=
tf_testing
.
ProcessGraphDefParam
(
graph_def
)
# Add shapes to the graph.
# Add shapes to the graph.
with
tf
.
compat
.
v1
.
Session
()
as
sess
:
with
tf
_compat_
v1
.
Session
()
as
sess
:
graph_def
=
tf_testing
.
AddShapesToGraphDef
(
sess
,
'softmax'
)
graph_def
=
tf_testing
.
AddShapesToGraphDef
(
sess
,
'softmax'
)
######################################################################
######################################################################
...
@@ -187,8 +191,8 @@ for node_id in top_k:
...
@@ -187,8 +191,8 @@ for node_id in top_k:
def
create_graph
():
def
create_graph
():
"""Creates a graph from saved GraphDef file and returns a saver."""
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
# Creates graph from saved graph_def.pb.
with
tf
.
compat
.
v1
.
gfile
.
GFile
(
model_path
,
'rb'
)
as
f
:
with
tf
_compat_
v1
.
gfile
.
GFile
(
model_path
,
'rb'
)
as
f
:
graph_def
=
tf
.
compat
.
v1
.
GraphDef
()
graph_def
=
tf
_compat_
v1
.
GraphDef
()
graph_def
.
ParseFromString
(
f
.
read
())
graph_def
.
ParseFromString
(
f
.
read
())
graph
=
tf
.
import_graph_def
(
graph_def
,
name
=
''
)
graph
=
tf
.
import_graph_def
(
graph_def
,
name
=
''
)
# Call the utility to import the graph definition into default graph.
# Call the utility to import the graph definition into default graph.
...
@@ -206,14 +210,14 @@ def run_inference_on_image(image):
...
@@ -206,14 +210,14 @@ def run_inference_on_image(image):
-------
-------
Nothing
Nothing
"""
"""
if
not
tf
.
compat
.
v1
.
io
.
gfile
.
e
xists
(
image
):
if
not
tf
_compat_v1
.
gfile
.
E
xists
(
image
):
tf
.
logging
.
fatal
(
'File does not exist
%
s'
,
image
)
tf
.
logging
.
fatal
(
'File does not exist
%
s'
,
image
)
image_data
=
tf
.
compat
.
v1
.
gfile
.
GFile
(
image
,
'rb'
)
.
read
()
image_data
=
tf
_compat_
v1
.
gfile
.
GFile
(
image
,
'rb'
)
.
read
()
# Creates graph from saved GraphDef.
# Creates graph from saved GraphDef.
create_graph
()
create_graph
()
with
tf
.
compat
.
v1
.
Session
()
as
sess
:
with
tf
_compat_
v1
.
Session
()
as
sess
:
softmax_tensor
=
sess
.
graph
.
get_tensor_by_name
(
'softmax:0'
)
softmax_tensor
=
sess
.
graph
.
get_tensor_by_name
(
'softmax:0'
)
predictions
=
sess
.
run
(
softmax_tensor
,
predictions
=
sess
.
run
(
softmax_tensor
,
{
'DecodeJpeg/contents:0'
:
image_data
})
{
'DecodeJpeg/contents:0'
:
image_data
})
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment