Commit dd9589ec by Siva Committed by Tianqi Chen

[FRONTEND][TENSORFLOW] Helper function to add shapes into the graph. Use tmp…

[FRONTEND][TENSORFLOW] Helper function to add shapes into the graph. Use tmp folder for model files and clean it. (#1697)
parent 3d0be3b0
......@@ -8,6 +8,7 @@ import re
import os.path
import collections
import numpy as np
from tvm.contrib import util
# Tensorflow imports
import tensorflow as tf
......@@ -43,6 +44,31 @@ def ProcessGraphDefParam(graph_def):
raise TypeError('graph_def must be a GraphDef proto.')
return graph_def
def AddShapesToGraphDef(out_node):
""" Add shapes attribute to nodes of the graph.
Input graph here is the default graph in context.
Parameters
----------
out_node: String
Final output node of the graph.
Returns
-------
graph_def : Obj
tensorflow graph definition with shapes attribute added to nodes.
"""
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
[out_node],
)
return graph_def
class NodeLookup(object):
"""Converts integer node ID's to human readable labels."""
......@@ -128,13 +154,18 @@ def get_workload(model_path):
model_url = os.path.join(repo_base, model_path)
from mxnet.gluon.utils import download
download(model_url, model_name)
temp = util.tempdir()
path_model = temp.relpath(model_name)
download(model_url, path_model)
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
with tf.gfile.FastGFile(path_model, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
temp.remove()
return graph_def
#######################################################################
......
......@@ -62,7 +62,6 @@ download(model_url, model_name)
download(map_proto_url, map_proto)
download(lable_map_url, lable_map)
######################################################################
# Import model
# ------------
......@@ -74,7 +73,8 @@ with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
# Add shapes to the graph.
graph_def = nnvm.testing.tf.AddShapesToGraphDef('softmax')
######################################################################
# Decode image
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment