Commit f347b525 by Yong Wu Committed by Yizhi Liu

Get tags of saved model automatically

Remove exception trail in tf parser error message

Fix lint

Fix comments
parent 916576c0
......@@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs
from __future__ import print_function
import warnings
# Numpy support
import numpy as np
......@@ -303,7 +304,8 @@ def _conv(opname):
def _decode_image():
def _impl(inputs, attr, params):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
print("DecodeJpeg: It's a pass through, please handle preprocessing before input")
warnings.warn("DecodeJpeg: It's a pass through, "
"please handle preprocessing before input")
return inputs[0]
return _impl
......@@ -938,8 +940,6 @@ _convert_map = {
'Split' : _split(False),
'SplitV' : _split(True),
'Unpack' : _unpack(),
'QueueDequeueManyV2' : _undef(),
'FIFOQueueV2' : _undef(),
}
# _convert_map_rnn defines maps of rnn operator name to
......@@ -1184,42 +1184,57 @@ class GraphProto(object):
if missing_operators:
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))
for node in graph.node:
if node.op == 'Placeholder':
self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
self._input_shapes[node.name][0] = 1
if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name])
continue
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
for idx, dim in enumerate(self._input_shapes[node.name]):
if dim < 0:
self._input_shapes[node.name][idx] = 1
warnings.warn("Use 1 instead of -1 in shape of operator %s."
% node.name)
# Ignore user's input shape for Non placeholder
elif node.op == 'Const':
tensor_value = node.attr['value'].tensor
self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
if shape and node.name in shape:
warnings.warn("Ignore the passed shape. "
"Shape in graphdef will be used for operator %s." % node.name)
final_op = None
# Parse the nodes to re-create TF graph using Symbol API of NNVM
for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction.
# Tensorflow doesn't have separate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes = {}
input_0d_mismatch = set()
attr = self._parse_attr(node.attr)
#Variable converted to Const will not have only value attr
# Variable converted to Const will not have only value attr
if 'value' in attr and node.op == 'Const':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif node.op == 'Placeholder':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif shape and node.name in shape:
# Give priority to user argument.
self._output_shapes[node.name] = [shape[node.name]]
elif node.op == 'Placeholder':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif '_output_shapes' in attr:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(tshape) \
for tshape in attr['_output_shapes']]
elif shape:
else:
# Keep the list indexable to avoid key error.
# Actual value will be filled after node creation.
# Will infer shapes if the graph is not frozen with add_shapes=True
self._output_shapes[node.name] = [None]
else:
self._output_shapes[node.name] = None
self._outputs_are_0d[node.name] = [ \
not tshape if isinstance(tshape, list) else False \
for tshape in self._output_shapes[node.name]]
......@@ -1241,7 +1256,7 @@ class GraphProto(object):
else:
# Pass the parsed shapes instead
output_shapes = self._output_shapes[node.name]
attr["_output_shapes"] = output_shapes = self._output_shapes[node.name]
# Pass the node name too in attr
attr["_node_name"] = node.name
......@@ -1282,7 +1297,7 @@ class GraphProto(object):
inputs = self._fix_extranodes(node.op, attr, inputs)
op = self._convert_operator(node.op, inputs, attr, graph)
# Check is op is converted to param
# Check if op is converted to param
if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op)
op = _sym.Variable(name=node.name,
......@@ -1291,19 +1306,25 @@ class GraphProto(object):
# Assuming only one output.
self._nodes[node.name] = op
final_op = op
# Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]:
g = _graph.create(final_op)
self._output_shapes[node.name] = \
list(graph_util.infer_shape(g, **self._input_shapes))[-1]
if self._output_shapes[node.name] and shape and node.name in shape:
assert self._output_shapes[node.name] == list(shape[node.name])
# Infer shapes if passed explicitely
node_output = self._nodes[node.name]
if shape:
if shape and (not self._output_shapes[node.name][0]
or -1 in self._output_shapes[node.name][0]):
g = _graph.create(node_output)
shape_dict = {k: v.shape for k, v in self._params.items()}
shape_dict.update(shape)
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
self._output_shapes[node.name] = out_shapes
elif output_shapes == None:
g = _graph.create(node_output)
self._output_shapes[node.name] = list(graph_util.infer_shape(g, **self._input_shapes))[-1]
else:
self._output_shapes[node.name] = output_shapes
out = []
if outputs is None:
......
......@@ -2,32 +2,13 @@
from __future__ import absolute_import as _abs
from __future__ import print_function
import os
try:
from tensorflow.core.framework import graph_pb2
except ImportError as e:
from nnvm.frontend.protobuf import graph_pb2
try:
from tempfile import TemporaryDirectory
except ImportError:
import tempfile
import shutil
class TemporaryDirectory(object):
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name
def __exit__(self, exc, value, tb):
shutil.rmtree(self.name)
from tensorflow.core.framework import graph_pb2
from tvm.contrib import util
class TFParser(object):
"""A Wrapper to handle tensorflow models parsing
Works w/o installing tensorflow,
Protocol Buffer is needed
TensorFlow is needed
```
parser = TfParser(model_dir)
graph = parser.parse()
......@@ -39,7 +20,7 @@ class TFParser(object):
"""
def __init__(self, model_dir):
self._tmp_dir = TemporaryDirectory()
self._tmp_dir = util.tempdir()
self._model_dir = model_dir
self._graph = graph_pb2.GraphDef()
......@@ -51,21 +32,6 @@ class TFParser(object):
"""Get Graph"""
return self._graph
def _output_graph(self):
import logging
logging.basicConfig(level=logging.DEBUG)
for node in self._get_graph().node:
logging.info("Name: {}".format(node.name))
logging.info("\top: {}".format(node.op))
for input in node.input:
logging.info("\t\tinput: {}".format(input))
logging.info("\t\tdevice: {}".format(node.device))
logging.info("\t\tAttrValue: ")
for key in node.attr.keys():
logging.info("\t\t\tkey: {} => value: {}"
.format(key, node.attr[key]))
logging.info(node.attr['shape'].shape)
def _load_pb_file(self):
"""Load single pb file"""
graph = self._get_graph()
......@@ -73,19 +39,30 @@ class TFParser(object):
graph.ParseFromString(f.read())
return graph
def _get_output_names(self, model_path):
def _get_tag_set(self):
"""Return the tag set of saved model, multiple metagraphs are not supported"""
try:
from tensorflow.contrib.saved_model.python.saved_model import reader
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import saved_model.reader which is "
"required to get tag set from saved model.")
tag_sets = reader.get_saved_model_tag_sets(self._model_dir)
return tag_sets[0]
def _get_output_names(self):
"""Return the concatenated output names"""
try:
import tensorflow as tf
except ImportError as e:
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model. {}".format(e))
"required to restore from saved model.")
tags = self._get_tag_set()
with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(sess,
[tf.saved_model.tag_constants.SERVING],
model_path)
tags,
self._model_dir)
output_names = set()
for k in meta_graph_def.signature_def.keys():
outputs_tensor_info = meta_graph_def.signature_def[k].outputs
......@@ -97,19 +74,18 @@ class TFParser(object):
def _load_saved_model(self):
"""Load the tensorflow saved model."""
try:
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import graph_util
except ImportError as e:
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model. {}".format(e))
"required to restore from saved model.")
saved_model_dir = self._model_dir
output_graph_filename = os.path.join(self._tmp_dir.name, "neo_frozen_model.pb")
output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb")
input_saved_model_dir = saved_model_dir
output_node_names = self._get_output_names(self._model_dir)
output_node_names = self._get_output_names()
input_binary = False
input_saver_def_path = False
......@@ -119,7 +95,7 @@ class TFParser(object):
input_meta_graph = False
checkpoint_path = None
input_graph_filename = None
saved_model_tags = tf.saved_model.tag_constants.SERVING
saved_model_tags = ",".join(self._get_tag_set())
freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
......@@ -145,6 +121,7 @@ class TFParser(object):
file.
"""
graph = None
if os.path.isdir(self._model_dir):
ckpt = os.path.join(self._model_dir, "checkpoint")
if not os.path.isfile(ckpt):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment