Unverified Commit cf0a7e28 by zhengdi Committed by GitHub

[FRONTEND][TENSORFLOW] support multiply outputs (#4980)

* [FRONTEND][TENSORFLOW] support multiply outputs

* [TENSORFLOW][TEST] add tf_testing.AddShapesToGraphDef test

* update frontend test

* retrigger CI
parent ba477865
......@@ -66,6 +66,11 @@ def ProcessGraphDefParam(graph_def):
return graph_def
def convert_to_list(x):
if not isinstance(x, list):
x = [x]
return x
def AddShapesToGraphDef(session, out_node):
""" Add shapes attribute to nodes of the graph.
Input graph here is the default graph in context.
......@@ -74,7 +79,7 @@ def AddShapesToGraphDef(session, out_node):
----------
session : tf.Session
Tensorflow session
out_node : String
out_node : String or List
Final output node of the graph.
Returns
......@@ -87,7 +92,7 @@ def AddShapesToGraphDef(session, out_node):
graph_def = tf_compat_v1.graph_util.convert_variables_to_constants(
session,
session.graph.as_graph_def(add_shapes=True),
[out_node],
convert_to_list(out_node),
)
return graph_def
......
......@@ -171,11 +171,7 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
with tf.Session() as sess:
if init_global_variables:
sess.run(variables.global_variables_initializer())
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
out_node,
)
final_graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, in_data, in_name, out_name)
for device in ["llvm", "cuda"]:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment