Commit 32276146 by zhuochen Committed by Tianqi Chen

fix tf.compat.v1 issue for tf verison <=1.12 (#4593)

parent e6d9f89c
...@@ -28,9 +28,13 @@ import numpy as np ...@@ -28,9 +28,13 @@ import numpy as np
# Tensorflow imports # Tensorflow imports
import tensorflow as tf import tensorflow as tf
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
from tvm.contrib.download import download_testdata from tvm.contrib.download import download_testdata
try:
tf_compat_v1 = tf.compat.v1
except ImportError:
tf_compat_v1 = tf
###################################################################### ######################################################################
# Some helper functions # Some helper functions
# --------------------- # ---------------------
...@@ -80,7 +84,7 @@ def AddShapesToGraphDef(session, out_node): ...@@ -80,7 +84,7 @@ def AddShapesToGraphDef(session, out_node):
""" """
graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( graph_def = tf_compat_v1.graph_util.convert_variables_to_constants(
session, session,
session.graph.as_graph_def(add_shapes=True), session.graph.as_graph_def(add_shapes=True),
[out_node], [out_node],
...@@ -112,13 +116,13 @@ class NodeLookup(object): ...@@ -112,13 +116,13 @@ class NodeLookup(object):
dict from integer node ID to human-readable string. dict from integer node ID to human-readable string.
""" """
if not tf.compat.v1.io.gfile.exists(uid_lookup_path): if not tf_compat_v1.gfile.Exists(uid_lookup_path):
tf.logging.fatal('File does not exist %s', uid_lookup_path) tf.logging.fatal('File does not exist %s', uid_lookup_path)
if not tf.compat.v1.io.gfile.exists(label_lookup_path): if not tf_compat_v1.gfile.Exists(label_lookup_path):
tf.logging.fatal('File does not exist %s', label_lookup_path) tf.logging.fatal('File does not exist %s', label_lookup_path)
# Loads mapping from string UID to human-readable string # Loads mapping from string UID to human-readable string
proto_as_ascii_lines = tf.compat.v1.gfile.GFile(uid_lookup_path).readlines() proto_as_ascii_lines = tf_compat_v1.gfile.GFile(uid_lookup_path).readlines()
uid_to_human = {} uid_to_human = {}
p = re.compile(r'[n\d]*[ \S,]*') p = re.compile(r'[n\d]*[ \S,]*')
for line in proto_as_ascii_lines: for line in proto_as_ascii_lines:
...@@ -129,7 +133,7 @@ class NodeLookup(object): ...@@ -129,7 +133,7 @@ class NodeLookup(object):
# Loads mapping from string UID to integer node ID. # Loads mapping from string UID to integer node ID.
node_id_to_uid = {} node_id_to_uid = {}
proto_as_ascii = tf.compat.v1.gfile.GFile(label_lookup_path).readlines() proto_as_ascii = tf_compat_v1.gfile.GFile(label_lookup_path).readlines()
for line in proto_as_ascii: for line in proto_as_ascii:
if line.startswith(' target_class:'): if line.startswith(' target_class:'):
target_class = int(line.split(': ')[1]) target_class = int(line.split(': ')[1])
...@@ -209,7 +213,7 @@ def get_workload(model_path, model_sub_path=None): ...@@ -209,7 +213,7 @@ def get_workload(model_path, model_sub_path=None):
path_model = download_testdata(model_url, model_path, module='tf') path_model = download_testdata(model_url, model_path, module='tf')
# Creates graph from saved graph_def.pb. # Creates graph from saved graph_def.pb.
with tf.compat.v1.gfile.FastGFile(path_model, 'rb') as f: with tf_compat_v1.gfile.FastGFile(path_model, 'rb') as f:
graph_def = tf.GraphDef() graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read()) graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='') graph = tf.import_graph_def(graph_def, name='')
...@@ -299,7 +303,7 @@ def _create_ptb_vocabulary(data_dir): ...@@ -299,7 +303,7 @@ def _create_ptb_vocabulary(data_dir):
file_name = 'ptb.train.txt' file_name = 'ptb.train.txt'
def _read_words(filename): def _read_words(filename):
"""Read the data for creating vocabulary""" """Read the data for creating vocabulary"""
with tf.compat.v1.gfile.GFile(filename, "r") as f: with tf_compat_v1.gfile.GFile(filename, "r") as f:
return f.read().encode("utf-8").decode("utf-8").replace("\n", "<eos>").split() return f.read().encode("utf-8").decode("utf-8").replace("\n", "<eos>").split()
def _build_vocab(filename): def _build_vocab(filename):
......
...@@ -34,6 +34,10 @@ import os.path ...@@ -34,6 +34,10 @@ import os.path
# Tensorflow imports # Tensorflow imports
import tensorflow as tf import tensorflow as tf
try:
tf_compat_v1 = tf.compat.v1
except ImportError:
tf_compat_v1 = tf
# Tensorflow utility functions # Tensorflow utility functions
import tvm.relay.testing.tf as tf_testing import tvm.relay.testing.tf as tf_testing
...@@ -89,14 +93,14 @@ label_path = download_testdata(label_map_url, label_map, module='data') ...@@ -89,14 +93,14 @@ label_path = download_testdata(label_map_url, label_map, module='data')
# ------------ # ------------
# Creates tensorflow graph definition from protobuf file. # Creates tensorflow graph definition from protobuf file.
with tf.compat.v1.gfile.GFile(model_path, 'rb') as f: with tf_compat_v1.gfile.GFile(model_path, 'rb') as f:
graph_def = tf.compat.v1.GraphDef() graph_def = tf_compat_v1.GraphDef()
graph_def.ParseFromString(f.read()) graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='') graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph. # Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def) graph_def = tf_testing.ProcessGraphDefParam(graph_def)
# Add shapes to the graph. # Add shapes to the graph.
with tf.compat.v1.Session() as sess: with tf_compat_v1.Session() as sess:
graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax') graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')
###################################################################### ######################################################################
...@@ -187,8 +191,8 @@ for node_id in top_k: ...@@ -187,8 +191,8 @@ for node_id in top_k:
def create_graph(): def create_graph():
"""Creates a graph from saved GraphDef file and returns a saver.""" """Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb. # Creates graph from saved graph_def.pb.
with tf.compat.v1.gfile.GFile(model_path, 'rb') as f: with tf_compat_v1.gfile.GFile(model_path, 'rb') as f:
graph_def = tf.compat.v1.GraphDef() graph_def = tf_compat_v1.GraphDef()
graph_def.ParseFromString(f.read()) graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='') graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph. # Call the utility to import the graph definition into default graph.
...@@ -206,14 +210,14 @@ def run_inference_on_image(image): ...@@ -206,14 +210,14 @@ def run_inference_on_image(image):
------- -------
Nothing Nothing
""" """
if not tf.compat.v1.io.gfile.exists(image): if not tf_compat_v1.gfile.Exists(image):
tf.logging.fatal('File does not exist %s', image) tf.logging.fatal('File does not exist %s', image)
image_data = tf.compat.v1.gfile.GFile(image, 'rb').read() image_data = tf_compat_v1.gfile.GFile(image, 'rb').read()
# Creates graph from saved GraphDef. # Creates graph from saved GraphDef.
create_graph() create_graph()
with tf.compat.v1.Session() as sess: with tf_compat_v1.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0') softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
predictions = sess.run(softmax_tensor, predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data}) {'DecodeJpeg/contents:0': image_data})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment