Commit bea0b00f by Albin Joy Committed by Tianqi Chen

[NNVM][TENSORFLOW]Fix lstm testcase to support get_output without size input (#1731)

* [NNVM][TENSORFLOW]Fix lstm testcase issue to support get_output without size input

* removed redundant

* Enabled inceptionV1 testcase
parent 2475556a
...@@ -26,7 +26,7 @@ import nnvm.testing.tf ...@@ -26,7 +26,7 @@ import nnvm.testing.tf
####################################################################### #######################################################################
# Generic run functions for TVM & tensorflow # Generic run functions for TVM & tensorflow
# ------------------------------------------ # ------------------------------------------
def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype, target='llvm'): def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'):
""" Generic function to compile on nnvm and execute on tvm """ """ Generic function to compile on nnvm and execute on tvm """
layout = None layout = None
...@@ -62,10 +62,10 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype, ...@@ -62,10 +62,10 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype,
# execute # execute
m.run() m.run()
# get outputs # get outputs
if isinstance(output_shape, list) and isinstance(output_dtype, list): if num_output > 1:
tvm_output_list = [] tvm_output_list = []
for i, s in enumerate(output_shape): for i in range(0, num_output):
tvm_output = m.get_output(i, tvm.nd.empty((s), output_dtype[i])) tvm_output = m.get_output(i)
tvm_output_list.append(tvm_output.asnumpy()) tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list return tvm_output_list
else: else:
...@@ -119,8 +119,7 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, ...@@ -119,8 +119,7 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
if no_gpu and device == 'cuda': if no_gpu and device == 'cuda':
continue continue
tvm_output = run_tvm_graph(final_graph_def, in_data, tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device)
in_node, tf_output.shape, tf_output.dtype, target=device)
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5) np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
sess.close() sess.close()
...@@ -572,14 +571,12 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype): ...@@ -572,14 +571,12 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
graph_def, tf_out = _get_tensorflow_output() graph_def, tf_out = _get_tensorflow_output()
tvm_output = run_tvm_graph(graph_def, [input_data, in_state_c, in_state_h], tvm_output = run_tvm_graph(graph_def, [input_data, in_state_c, in_state_h],
['root/Placeholder', 'root/lstm_cell/LSTMBlockCell_c', ['root/Placeholder', 'root/lstm_cell/LSTMBlockCell_c',
'root/lstm_cell/LSTMBlockCell_h'], 'root/lstm_cell/LSTMBlockCell_h'], num_output=2)
[tf_out[0].shape, (2, batch_size, num_hidden)],
[tf_out[0].dtype, tf_out[1].dtype])
assert isinstance(tvm_output, list) assert isinstance(tvm_output, list)
out = tvm_output[0] out = tvm_output[0]
out_state = tvm_output[1] out_state = tvm_output[1]
out_state_tup = np.split(out_state, indices_or_sections=2, axis=0) out_state_tup = np.split(out_state, indices_or_sections=2, axis=1)
out_state_c = np.reshape(out_state_tup[0], (batch_size, num_hidden)) out_state_c = np.reshape(out_state_tup[0], (batch_size, num_hidden))
out_state_h = np.reshape(out_state_tup[1], (batch_size, num_hidden)) out_state_h = np.reshape(out_state_tup[1], (batch_size, num_hidden))
tvm_out = [out, out_state_c, out_state_h] tvm_out = [out, out_state_c, out_state_h]
...@@ -587,7 +584,6 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype): ...@@ -587,7 +584,6 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
def test_forward_lstm(): def test_forward_lstm():
'''test LSTM block cell''' '''test LSTM block cell'''
return
_test_lstm_cell(1, 2, 1, 0.0, 'float32') _test_lstm_cell(1, 2, 1, 0.0, 'float32')
...@@ -656,7 +652,7 @@ def test_forward_inception_v3(): ...@@ -656,7 +652,7 @@ def test_forward_inception_v3():
with tf.Session() as sess: with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0') tf_output = run_tf_graph(sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0')
tvm_output = run_tvm_graph(graph_def, data, 'input', tf_output.shape, 'float32') tvm_output = run_tvm_graph(graph_def, data, 'input')
np.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5) np.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
...@@ -692,7 +688,7 @@ def test_forward_inception_v1(): ...@@ -692,7 +688,7 @@ def test_forward_inception_v1():
with tf.Session() as sess: with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0') tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0')
tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents', tf_output.shape, 'float32') tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents')
np.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5) np.testing.assert_allclose(tf_output, tvm_output, rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
...@@ -710,7 +706,7 @@ def test_forward_mobilenet(): ...@@ -710,7 +706,7 @@ def test_forward_mobilenet():
with tf.Session() as sess: with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0') tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input', tf_output.shape, 'float32') tvm_output = run_tvm_graph(graph_def, data, 'input')
np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5) np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
...@@ -1029,7 +1025,7 @@ if __name__ == '__main__': ...@@ -1029,7 +1025,7 @@ if __name__ == '__main__':
test_forward_ptb() test_forward_ptb()
# RNN # RNN
#test_forward_lstm() test_forward_lstm()
# Elementwise # Elementwise
test_forward_ceil() test_forward_ceil()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment