Unverified Commit cf0a7e28 by zhengdi Committed by GitHub

[FRONTEND][TENSORFLOW] support multiply outputs (#4980)

* [FRONTEND][TENSORFLOW] support multiply outputs

* [TENSORFLOW][TEST] add tf_testing.AddShapesToGraphDef test

* update frontend test

* retrigger CI
parent ba477865
...@@ -66,6 +66,11 @@ def ProcessGraphDefParam(graph_def): ...@@ -66,6 +66,11 @@ def ProcessGraphDefParam(graph_def):
return graph_def return graph_def
def convert_to_list(x):
if not isinstance(x, list):
x = [x]
return x
def AddShapesToGraphDef(session, out_node): def AddShapesToGraphDef(session, out_node):
""" Add shapes attribute to nodes of the graph. """ Add shapes attribute to nodes of the graph.
Input graph here is the default graph in context. Input graph here is the default graph in context.
...@@ -74,7 +79,7 @@ def AddShapesToGraphDef(session, out_node): ...@@ -74,7 +79,7 @@ def AddShapesToGraphDef(session, out_node):
---------- ----------
session : tf.Session session : tf.Session
Tensorflow session Tensorflow session
out_node : String out_node : String or List
Final output node of the graph. Final output node of the graph.
Returns Returns
...@@ -87,7 +92,7 @@ def AddShapesToGraphDef(session, out_node): ...@@ -87,7 +92,7 @@ def AddShapesToGraphDef(session, out_node):
graph_def = tf_compat_v1.graph_util.convert_variables_to_constants( graph_def = tf_compat_v1.graph_util.convert_variables_to_constants(
session, session,
session.graph.as_graph_def(add_shapes=True), session.graph.as_graph_def(add_shapes=True),
[out_node], convert_to_list(out_node),
) )
return graph_def return graph_def
......
...@@ -171,11 +171,7 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, ...@@ -171,11 +171,7 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
with tf.Session() as sess: with tf.Session() as sess:
if init_global_variables: if init_global_variables:
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
final_graph_def = tf.graph_util.convert_variables_to_constants( final_graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
sess,
sess.graph.as_graph_def(add_shapes=True),
out_node,
)
tf_output = run_tf_graph(sess, in_data, in_name, out_name) tf_output = run_tf_graph(sess, in_data, in_name, out_name)
for device in ["llvm", "cuda"]: for device in ["llvm", "cuda"]:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment