Commit 77869913 by Ashutosh Parkhi Committed by Tianqi Chen

tensorflow frontend supports user given outputs (#1913)

parent 2c231b5a
......@@ -1039,7 +1039,7 @@ class GraphProto(object):
self._num_param = 0
self._num_rnn_layer = False
def from_tensorflow(self, graph, layout="NHWC", shape=None):
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef.
Follow the tensorflow graph definition to parse and convert it to NNVM.
......@@ -1086,6 +1086,7 @@ class GraphProto(object):
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))
final_op = None
# Parse the nodes to re-create TF graph using Symbol API of NNVM
for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction.
......@@ -1165,6 +1166,7 @@ class GraphProto(object):
# Assuming only one output.
self._nodes[node.name] = op
final_op = op
# Infer shapes if passed explicitely
node_output = self._nodes[node.name]
......@@ -1175,13 +1177,16 @@ class GraphProto(object):
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
self._output_shapes[node.name] = out_shapes
# Assume the final node is the output node
out = node_output
out = []
if outputs is None:
out.append(final_op)
else:
out = [self._nodes[out_name] for out_name in outputs]
#Add the RNN outputs also with 'head' nodes of the nnvm graph
if self._num_rnn_layer:
out_rnn = _sym.concatenate(*self._out_rnn, axis=0)
out = [out, out_rnn]
out.append(out_rnn)
if isinstance(out, list):
out = _sym.Group(out)
......@@ -1378,7 +1383,7 @@ class GraphProto(object):
return inputs
def from_tensorflow(graph, layout="NHWC", shape=None):
def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
""" Load tensorflow graph which is a python tensorflow graph object into nnvm graph.
The companion parameters will be handled automatically.
......@@ -1396,5 +1401,5 @@ def from_tensorflow(graph, layout="NHWC", shape=None):
Dict of converted parameters stored in tvm.ndarray format
"""
g = GraphProto()
sym, params = g.from_tensorflow(graph, layout, shape)
sym, params = g.from_tensorflow(graph, layout, shape, outputs)
return sym, params
......@@ -26,8 +26,15 @@ import nnvm.testing.tf
#######################################################################
# Generic run functions for TVM & tensorflow
# ------------------------------------------
def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'):
def convert_to_list(x):
if not isinstance(x, list):
x = [x]
return x
def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm', out_names=None):
""" Generic function to compile on nnvm and execute on tvm """
input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node)
layout = None
if target == "cuda":
......@@ -44,7 +51,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'
shape_dict = {input_node: input_data.shape}
dtype_dict = {input_node: input_data.dtype}
sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict)
sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict, outputs=out_names)
graph, lib, params = nnvm.compiler.build(sym, target=target, target_host=target_host, shape=shape_dict,
dtype=dtype_dict, params=params)
......@@ -52,37 +59,34 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
if isinstance(input_data, list):
for i, e in enumerate(input_node):
m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
else:
m.set_input(input_node, tvm.nd.array(input_data.astype(input_data.dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
if num_output > 1:
assert out_names is None or num_output == len(out_names),"out_names: {} num_output: {}".format(
out_names, num_output)
tvm_output_list = []
for i in range(0, num_output):
tvm_output = m.get_output(i)
tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list
else:
tvm_output = m.get_output(0)
return tvm_output.asnumpy()
def run_tf_graph(sess, input_data, input_node, output_node):
""" Generic function to execute tensorflow """
input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node)
output_node = convert_to_list(output_node)
tensor = sess.graph.get_tensor_by_name(output_node)
tensor = [0] * len(output_node)
for i in range(len(output_node)):
tensor[i] = sess.graph.get_tensor_by_name(output_node[i])
if isinstance(input_data, list):
input_dict = {}
for i, e in enumerate(input_node):
input_dict[e] = input_data[i]
else:
input_dict = {input_node: input_data}
output_data = sess.run(tensor, input_dict)
return output_data
......@@ -91,14 +95,16 @@ def run_tf_graph(sess, input_data, input_node, output_node):
def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, no_gpu=False):
"""Generic function to generate and compare tensorflow and TVM output"""
out_node = out_name.split(':')[0] if ":" in out_name else out_name
out_name = convert_to_list(out_name)
out_node = [0]*len(out_name)
for i in range(len(out_name)):
out_node[i] = out_name[i].split(':')[0] if ":" in out_name[i] else out_name[i]
if isinstance(in_name, list):
in_data = convert_to_list(in_data)
in_name = convert_to_list(in_name)
in_node = [0]*len(in_name)
for i in range(len(in_name)):
in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i]
else:
in_node = in_name.split(':')[0] if ":" in in_name else in_name
with tf.Session() as sess:
if init_global_variables:
......@@ -106,9 +112,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
[out_node],
out_node,
)
tf_output = run_tf_graph(sess, in_data, in_name, out_name)
for device in ["llvm", "cuda"]:
......@@ -120,7 +125,10 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
continue
tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device)
tvm.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
# since the names from tensorflow and nnvm runs are not exactly same,
# first len(tf_output) will be compared
for i in range(len(tf_output)):
tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
sess.close()
......@@ -260,6 +268,7 @@ def test_forward_reshape():
_test_reshape(np.arange(6), [-1])
#######################################################################
#######################################################################
# Squeeze
# -------
......@@ -509,6 +518,35 @@ def test_forward_multi_input():
['in1:0', 'in2:0', 'in3:0', 'in4:0'], 'out:0')
#######################################################################
# Multi Output to Graph
# ---------------------
def test_forward_multi_output():
with tf.Graph().as_default():
in1 = tf.placeholder(tf.int32, shape=[3, 3], name='in1')
in2 = tf.placeholder(tf.int32, shape=[3, 3], name='in2')
in3 = tf.placeholder(tf.int32, shape=[3, 3], name='in3')
in4 = tf.placeholder(tf.int32, shape=[3, 3], name='in4')
out1 = tf.add(in1, in2, name='out1')
out2 = tf.subtract(in3, in4, name='out2')
in_data = np.arange(9, dtype='int32').reshape([3, 3])
in_data = [in_data] * 4
in_name = ['in1:0', 'in2:0', 'in3:0', 'in4:0']
out_name = ['out1:0', 'out2:0']
out_node = [out.strip(':0') for out in out_name]
in_node = [inp.strip(':0') for inp in in_name]
with tf.Session() as sess:
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess, sess.graph.as_graph_def(add_shapes=True), out_node,)
tf_output = run_tf_graph(sess, in_data, in_name, out_name)
tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target='llvm',
out_names=out_node, num_output=2)
for i in range(len(tf_output)):
tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
#######################################################################
# Resize Bilinear
# ---------------
......@@ -580,7 +618,7 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
out_state_c = np.reshape(out_state_tup[0], (batch_size, num_hidden))
out_state_h = np.reshape(out_state_tup[1], (batch_size, num_hidden))
tvm_out = [out, out_state_c, out_state_h]
tvm.testing.assert_allclose(tf_out, tvm_out, rtol=1e-3, atol=1e-3)
tvm.testing.assert_allclose(tf_out[0], tvm_out[0], rtol=1e-3, atol=1e-3)
def test_forward_lstm():
'''test LSTM block cell'''
......@@ -653,7 +691,7 @@ def test_forward_inception_v3():
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0')
tvm_output = run_tvm_graph(graph_def, data, 'input')
tvm.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5)
tvm.testing.assert_allclose(tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5)
#######################################################################
# Inception V1
......@@ -689,7 +727,7 @@ def test_forward_inception_v1():
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0')
tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents')
tvm.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5)
tvm.testing.assert_allclose(tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5)
#######################################################################
# Mobilenet
......@@ -712,7 +750,7 @@ def test_forward_mobilenet():
graph_def = nnvm.testing.tf.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
#######################################################################
# ResnetV2
......@@ -731,7 +769,7 @@ def test_forward_resnetv2():
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', tf_output.shape, 'float32')
tvm.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
#######################################################################
# PTB
......@@ -797,6 +835,7 @@ def test_forward_ptb():
state_output = model.get_output(1, tvm.nd.empty(out_state_shape,
"float32")).asnumpy()
sample = nnvm.testing.tf.pick_from_weight(tvm_output[0])
return sample, state_output
for x in data:
......@@ -942,7 +981,7 @@ def test_forward_leaky_relu():
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
tf.nn.leaky_relu(in1, alpha=0.4)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'LeakyRelu:0')
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'LeakyRelu/mul:0')
def test_forward_elu():
ishape = (1, 3, 10, 10)
......@@ -1042,6 +1081,7 @@ if __name__ == '__main__':
# General
test_forward_multi_input()
test_forward_multi_output()
test_forward_variable()
# End to End
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment