Commit c1376a40 by lixiaoquan Committed by Tianqi Chen

[TensorFlow] Fix a bug output index is ignored (#3631)

Enhance test to cover this case
parent f1ede9a9
......@@ -2031,15 +2031,6 @@ class GraphProto(object):
# Pass the target layout
attr["_target_layout"] = layout
#ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
#the digit has to be ignored.
if ":" in node.input[0]:
in_name, _ = node.input[0].split(':')
node.input[0] = in_name
# Fill shapes for all inputs in a list
inputs = []
for i in node.input:
......
......@@ -836,9 +836,10 @@ def _test_split(in_shape, axis, num_or_size_splits, dtype):
in_data = tf.placeholder(dtype, in_shape, name="in_data")
num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list)\
else num_or_size_splits
tf.split(in_data, num_or_size_splits, axis=axis)
split = tf.split(in_data, num_or_size_splits, axis=axis)
relu = [tf.nn.relu(i) for i in split]
compare_tf_with_tvm([np_data], ['in_data:0'], [f'split:{n}' for n in range(num_split)])
compare_tf_with_tvm([np_data], ['in_data:0'], [n.name for n in relu])
# and now test together with concat
tf.reset_default_graph()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment