Commit bea0b00f by Albin Joy Committed by Tianqi Chen

[NNVM][TENSORFLOW]Fix lstm testcase to support get_output without size input (#1731)

* [NNVM][TENSORFLOW]Fix lstm testcase issue to support get_output without size input

* removed redundant

* Enabled inceptionV1 testcase
parent 2475556a
......@@ -26,7 +26,7 @@ import nnvm.testing.tf
#######################################################################
# Generic run functions for TVM & tensorflow
# ------------------------------------------
def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype, target='llvm'):
def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'):
""" Generic function to compile on nnvm and execute on tvm """
layout = None
......@@ -62,10 +62,10 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype,
# execute
m.run()
# get outputs
if isinstance(output_shape, list) and isinstance(output_dtype, list):
if num_output > 1:
tvm_output_list = []
for i, s in enumerate(output_shape):
tvm_output = m.get_output(i, tvm.nd.empty((s), output_dtype[i]))
for i in range(0, num_output):
tvm_output = m.get_output(i)
tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list
else:
......@@ -119,8 +119,7 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
if no_gpu and device == 'cuda':
continue
tvm_output = run_tvm_graph(final_graph_def, in_data,
in_node, tf_output.shape, tf_output.dtype, target=device)
tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device)
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
sess.close()
......@@ -572,14 +571,12 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
graph_def, tf_out = _get_tensorflow_output()
tvm_output = run_tvm_graph(graph_def, [input_data, in_state_c, in_state_h],
['root/Placeholder', 'root/lstm_cell/LSTMBlockCell_c',
'root/lstm_cell/LSTMBlockCell_h'],
[tf_out[0].shape, (2, batch_size, num_hidden)],
[tf_out[0].dtype, tf_out[1].dtype])
'root/lstm_cell/LSTMBlockCell_h'], num_output=2)
assert isinstance(tvm_output, list)
out = tvm_output[0]
out_state = tvm_output[1]
out_state_tup = np.split(out_state, indices_or_sections=2, axis=0)
out_state_tup = np.split(out_state, indices_or_sections=2, axis=1)
out_state_c = np.reshape(out_state_tup[0], (batch_size, num_hidden))
out_state_h = np.reshape(out_state_tup[1], (batch_size, num_hidden))
tvm_out = [out, out_state_c, out_state_h]
......@@ -587,7 +584,6 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
def test_forward_lstm():
'''test LSTM block cell'''
return
_test_lstm_cell(1, 2, 1, 0.0, 'float32')
......@@ -656,7 +652,7 @@ def test_forward_inception_v3():
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0')
tvm_output = run_tvm_graph(graph_def, data, 'input', tf_output.shape, 'float32')
tvm_output = run_tvm_graph(graph_def, data, 'input')
np.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5)
#######################################################################
......@@ -692,7 +688,7 @@ def test_forward_inception_v1():
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0')
tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents', tf_output.shape, 'float32')
tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents')
np.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5)
#######################################################################
......@@ -710,7 +706,7 @@ def test_forward_mobilenet():
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input', tf_output.shape, 'float32')
tvm_output = run_tvm_graph(graph_def, data, 'input')
np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)
#######################################################################
......@@ -1029,7 +1025,7 @@ if __name__ == '__main__':
test_forward_ptb()
# RNN
#test_forward_lstm()
test_forward_lstm()
# Elementwise
test_forward_ceil()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment