Commit f347b525 by Yong Wu Committed by Yizhi Liu

Get tags of saved model automatically

Remove exception trail in tf parser error message

Fix lint

Fix comments
parent 916576c0
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from __future__ import print_function from __future__ import print_function
import warnings
# Numpy support # Numpy support
import numpy as np import numpy as np
...@@ -303,7 +304,8 @@ def _conv(opname): ...@@ -303,7 +304,8 @@ def _conv(opname):
def _decode_image(): def _decode_image():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer. # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
print("DecodeJpeg: It's a pass through, please handle preprocessing before input") warnings.warn("DecodeJpeg: It's a pass through, "
"please handle preprocessing before input")
return inputs[0] return inputs[0]
return _impl return _impl
...@@ -938,8 +940,6 @@ _convert_map = { ...@@ -938,8 +940,6 @@ _convert_map = {
'Split' : _split(False), 'Split' : _split(False),
'SplitV' : _split(True), 'SplitV' : _split(True),
'Unpack' : _unpack(), 'Unpack' : _unpack(),
'QueueDequeueManyV2' : _undef(),
'FIFOQueueV2' : _undef(),
} }
# _convert_map_rnn defines maps of rnn operator name to # _convert_map_rnn defines maps of rnn operator name to
...@@ -1184,42 +1184,57 @@ class GraphProto(object): ...@@ -1184,42 +1184,57 @@ class GraphProto(object):
if missing_operators: if missing_operators:
raise NotImplementedError( \ raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators)) "The following operators are not implemented: {}".format(missing_operators))
for node in graph.node: for node in graph.node:
if node.op == 'Placeholder': if node.op == 'Placeholder':
self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(node.attr['shape'].shape) if shape and node.name in shape:
self._input_shapes[node.name][0] = 1 self._input_shapes[node.name] = list(shape[node.name])
continue
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
for idx, dim in enumerate(self._input_shapes[node.name]):
if dim < 0:
self._input_shapes[node.name][idx] = 1
warnings.warn("Use 1 instead of -1 in shape of operator %s."
% node.name)
# Ignore user's input shape for Non placeholder
elif node.op == 'Const': elif node.op == 'Const':
tensor_value = node.attr['value'].tensor tensor_value = node.attr['value'].tensor
self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape) self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
if shape and node.name in shape:
warnings.warn("Ignore the passed shape. "
"Shape in graphdef will be used for operator %s." % node.name)
final_op = None final_op = None
# Parse the nodes to re-create TF graph using Symbol API of NNVM # Parse the nodes to re-create TF graph using Symbol API of NNVM
for node in graph.node: for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction. # Tensorflow doesn't have separate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict. # Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes = {} input_shapes = {}
input_0d_mismatch = set() input_0d_mismatch = set()
attr = self._parse_attr(node.attr) attr = self._parse_attr(node.attr)
#Variable converted to Const will not have only value attr # Variable converted to Const will not have only value attr
if 'value' in attr and node.op == 'Const': if 'value' in attr and node.op == 'Const':
self._output_shapes[node.name] = [self._input_shapes[node.name]] self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif node.op == 'Placeholder':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif shape and node.name in shape: elif shape and node.name in shape:
# Give priority to user argument. # Give priority to user argument.
self._output_shapes[node.name] = [shape[node.name]] self._output_shapes[node.name] = [shape[node.name]]
elif node.op == 'Placeholder':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif '_output_shapes' in attr: elif '_output_shapes' in attr:
self._output_shapes[node.name] = \ self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(tshape) \ [tensor_util.TensorShapeProtoToList(tshape) \
for tshape in attr['_output_shapes']] for tshape in attr['_output_shapes']]
elif shape: else:
# Keep the list indexable to avoid key error. # Keep the list indexable to avoid key error.
# Actual value will be filled after node creation. # Actual value will be filled after node creation.
# Will infer shapes if the graph is not frozen with add_shapes=True
self._output_shapes[node.name] = [None] self._output_shapes[node.name] = [None]
else:
self._output_shapes[node.name] = None
self._outputs_are_0d[node.name] = [ \ self._outputs_are_0d[node.name] = [ \
not tshape if isinstance(tshape, list) else False \ not tshape if isinstance(tshape, list) else False \
for tshape in self._output_shapes[node.name]] for tshape in self._output_shapes[node.name]]
...@@ -1241,7 +1256,7 @@ class GraphProto(object): ...@@ -1241,7 +1256,7 @@ class GraphProto(object):
else: else:
# Pass the parsed shapes instead # Pass the parsed shapes instead
output_shapes = self._output_shapes[node.name] attr["_output_shapes"] = output_shapes = self._output_shapes[node.name]
# Pass the node name too in attr # Pass the node name too in attr
attr["_node_name"] = node.name attr["_node_name"] = node.name
...@@ -1282,7 +1297,7 @@ class GraphProto(object): ...@@ -1282,7 +1297,7 @@ class GraphProto(object):
inputs = self._fix_extranodes(node.op, attr, inputs) inputs = self._fix_extranodes(node.op, attr, inputs)
op = self._convert_operator(node.op, inputs, attr, graph) op = self._convert_operator(node.op, inputs, attr, graph)
# Check is op is converted to param # Check if op is converted to param
if isinstance(op, np.ndarray): if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op) self._params[node.name] = tvm.nd.array(op)
op = _sym.Variable(name=node.name, op = _sym.Variable(name=node.name,
...@@ -1291,19 +1306,25 @@ class GraphProto(object): ...@@ -1291,19 +1306,25 @@ class GraphProto(object):
# Assuming only one output. # Assuming only one output.
self._nodes[node.name] = op self._nodes[node.name] = op
final_op = op final_op = op
# Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]:
g = _graph.create(final_op)
self._output_shapes[node.name] = \
list(graph_util.infer_shape(g, **self._input_shapes))[-1]
if self._output_shapes[node.name] and shape and node.name in shape:
assert self._output_shapes[node.name] == list(shape[node.name])
# Infer shapes if passed explicitely # Infer shapes if passed explicitely
node_output = self._nodes[node.name] node_output = self._nodes[node.name]
if shape: if shape and (not self._output_shapes[node.name][0]
or -1 in self._output_shapes[node.name][0]):
g = _graph.create(node_output) g = _graph.create(node_output)
shape_dict = {k: v.shape for k, v in self._params.items()} shape_dict = {k: v.shape for k, v in self._params.items()}
shape_dict.update(shape) shape_dict.update(shape)
_, out_shapes = graph_util.infer_shape(g, **shape_dict) _, out_shapes = graph_util.infer_shape(g, **shape_dict)
self._output_shapes[node.name] = out_shapes self._output_shapes[node.name] = out_shapes
elif output_shapes == None:
g = _graph.create(node_output)
self._output_shapes[node.name] = list(graph_util.infer_shape(g, **self._input_shapes))[-1]
else:
self._output_shapes[node.name] = output_shapes
out = [] out = []
if outputs is None: if outputs is None:
......
...@@ -2,32 +2,13 @@ ...@@ -2,32 +2,13 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from __future__ import print_function from __future__ import print_function
import os import os
from tensorflow.core.framework import graph_pb2
try: from tvm.contrib import util
from tensorflow.core.framework import graph_pb2
except ImportError as e:
from nnvm.frontend.protobuf import graph_pb2
try:
from tempfile import TemporaryDirectory
except ImportError:
import tempfile
import shutil
class TemporaryDirectory(object):
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name
def __exit__(self, exc, value, tb):
shutil.rmtree(self.name)
class TFParser(object): class TFParser(object):
"""A Wrapper to handle tensorflow models parsing """A Wrapper to handle tensorflow models parsing
Works w/o installing tensorflow, TensorFlow is needed
Protocol Buffer is needed
``` ```
parser = TfParser(model_dir) parser = TfParser(model_dir)
graph = parser.parse() graph = parser.parse()
...@@ -39,7 +20,7 @@ class TFParser(object): ...@@ -39,7 +20,7 @@ class TFParser(object):
""" """
def __init__(self, model_dir): def __init__(self, model_dir):
self._tmp_dir = TemporaryDirectory() self._tmp_dir = util.tempdir()
self._model_dir = model_dir self._model_dir = model_dir
self._graph = graph_pb2.GraphDef() self._graph = graph_pb2.GraphDef()
...@@ -51,21 +32,6 @@ class TFParser(object): ...@@ -51,21 +32,6 @@ class TFParser(object):
"""Get Graph""" """Get Graph"""
return self._graph return self._graph
def _output_graph(self):
import logging
logging.basicConfig(level=logging.DEBUG)
for node in self._get_graph().node:
logging.info("Name: {}".format(node.name))
logging.info("\top: {}".format(node.op))
for input in node.input:
logging.info("\t\tinput: {}".format(input))
logging.info("\t\tdevice: {}".format(node.device))
logging.info("\t\tAttrValue: ")
for key in node.attr.keys():
logging.info("\t\t\tkey: {} => value: {}"
.format(key, node.attr[key]))
logging.info(node.attr['shape'].shape)
def _load_pb_file(self): def _load_pb_file(self):
"""Load single pb file""" """Load single pb file"""
graph = self._get_graph() graph = self._get_graph()
...@@ -73,19 +39,30 @@ class TFParser(object): ...@@ -73,19 +39,30 @@ class TFParser(object):
graph.ParseFromString(f.read()) graph.ParseFromString(f.read())
return graph return graph
def _get_output_names(self, model_path): def _get_tag_set(self):
"""Return the tag set of saved model, multiple metagraphs are not supported"""
try:
from tensorflow.contrib.saved_model.python.saved_model import reader
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import saved_model.reader which is "
"required to get tag set from saved model.")
tag_sets = reader.get_saved_model_tag_sets(self._model_dir)
return tag_sets[0]
def _get_output_names(self):
"""Return the concatenated output names""" """Return the concatenated output names"""
try: try:
import tensorflow as tf import tensorflow as tf
except ImportError as e: except ImportError:
raise ImportError( raise ImportError(
"InputConfiguration: Unable to import tensorflow which is " "InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model. {}".format(e)) "required to restore from saved model.")
tags = self._get_tag_set()
with tf.Session() as sess: with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(sess, meta_graph_def = tf.saved_model.loader.load(sess,
[tf.saved_model.tag_constants.SERVING], tags,
model_path) self._model_dir)
output_names = set() output_names = set()
for k in meta_graph_def.signature_def.keys(): for k in meta_graph_def.signature_def.keys():
outputs_tensor_info = meta_graph_def.signature_def[k].outputs outputs_tensor_info = meta_graph_def.signature_def[k].outputs
...@@ -97,19 +74,18 @@ class TFParser(object): ...@@ -97,19 +74,18 @@ class TFParser(object):
def _load_saved_model(self): def _load_saved_model(self):
"""Load the tensorflow saved model.""" """Load the tensorflow saved model."""
try: try:
import tensorflow as tf
from tensorflow.python.tools import freeze_graph from tensorflow.python.tools import freeze_graph
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import graph_util from tensorflow.python.framework import graph_util
except ImportError as e: except ImportError:
raise ImportError( raise ImportError(
"InputConfiguration: Unable to import tensorflow which is " "InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model. {}".format(e)) "required to restore from saved model.")
saved_model_dir = self._model_dir saved_model_dir = self._model_dir
output_graph_filename = os.path.join(self._tmp_dir.name, "neo_frozen_model.pb") output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb")
input_saved_model_dir = saved_model_dir input_saved_model_dir = saved_model_dir
output_node_names = self._get_output_names(self._model_dir) output_node_names = self._get_output_names()
input_binary = False input_binary = False
input_saver_def_path = False input_saver_def_path = False
...@@ -119,7 +95,7 @@ class TFParser(object): ...@@ -119,7 +95,7 @@ class TFParser(object):
input_meta_graph = False input_meta_graph = False
checkpoint_path = None checkpoint_path = None
input_graph_filename = None input_graph_filename = None
saved_model_tags = tf.saved_model.tag_constants.SERVING saved_model_tags = ",".join(self._get_tag_set())
freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path, freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
input_binary, checkpoint_path, output_node_names, input_binary, checkpoint_path, output_node_names,
...@@ -145,6 +121,7 @@ class TFParser(object): ...@@ -145,6 +121,7 @@ class TFParser(object):
file. file.
""" """
graph = None graph = None
if os.path.isdir(self._model_dir): if os.path.isdir(self._model_dir):
ckpt = os.path.join(self._model_dir, "checkpoint") ckpt = os.path.join(self._model_dir, "checkpoint")
if not os.path.isfile(ckpt): if not os.path.isfile(ckpt):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment