Unverified Commit 75e9f5dc by Josh Fromm Committed by GitHub

[Frontend][ONNX] LSTM Support (#4825)

* Initial version working and passing tests.

* WIP on supporting other activations.

* add support for multiple activation functions in lstm

* All tests working and code cleaned up.

* Undo import swap to avoid conflict with masahi.

* Added new tests and related bug fixes.

Co-authored-by: Matthew Brookhart <mbrookhart@octoml.ai>
parent d2cc2144
......@@ -56,6 +56,12 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output
# set inputs
if isinstance(input_data, list):
for i, e in enumerate(input_names):
# Its possible for some onnx inputs to not be needed in the tvm
# module, confirm its present before setting.
try:
m.get_input(input_names[i])
except:
continue
m.set_input(input_names[i], tvm.nd.array(
input_data[i].astype(input_data[i].dtype)))
else:
......@@ -1962,6 +1968,175 @@ def test_pooling():
auto_pad='SAME_UPPER')
def verify_lstm(seq_length,
batch_size,
input_size,
hidden_size,
use_bias=False,
activations=None,
alphas=None,
betas=None,
use_initial_state=False,
use_peep=False):
x_np = np.random.uniform(size=(seq_length, batch_size, input_size)).astype('float32')
w_np = np.random.uniform(size=(1, 4 * hidden_size, input_size)).astype('float32')
r_np = np.random.uniform(size=(1, 4 * hidden_size, hidden_size)).astype('float32')
input_names = ["X", "W", "R"]
input_tensors = [
helper.make_tensor_value_info("X", TensorProto.FLOAT, list(x_np.shape)),
helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_np.shape)),
helper.make_tensor_value_info("R", TensorProto.FLOAT, list(r_np.shape))
]
input_values = [x_np, w_np, r_np]
if use_bias:
b_np = np.random.uniform(size=(1, 8 * hidden_size)).astype('float32')
input_names.append("B")
input_tensors.append(
helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 8 * hidden_size]))
input_values.append(b_np)
if use_initial_state:
assert use_bias == True, "Initial states must have bias specified."
sequence_np = np.repeat(seq_length, batch_size).astype('int32')
input_names.append("sequence_lens")
input_tensors.append(helper.make_tensor_value_info("sequence_lens", TensorProto.INT32, [batch_size]))
input_values.append(sequence_np)
initial_h_np = np.random.uniform(size=(1, batch_size, hidden_size)).astype('float32')
input_names.append("initial_h")
input_tensors.append(
helper.make_tensor_value_info("initial_h", TensorProto.FLOAT,
[1, batch_size, hidden_size]))
input_values.append(initial_h_np)
initial_c_np = np.random.uniform(size=(1, batch_size, hidden_size)).astype('float32')
input_names.append("initial_c")
input_tensors.append(
helper.make_tensor_value_info("initial_c", TensorProto.FLOAT,
[1, batch_size, hidden_size]))
input_values.append(initial_c_np)
if use_peep:
assert use_initial_state == True, "Peepholes require initial state to be specified."
p_np = np.random.uniform(size=(1, 3 * hidden_size)).astype('float32')
input_names.append("P")
input_tensors.append(
helper.make_tensor_value_info("P", TensorProto.FLOAT, [1, 3 * hidden_size]))
input_values.append(p_np)
Y_shape = [seq_length, 1, batch_size, hidden_size]
Y_h_shape = [1, batch_size, hidden_size]
Y_c_shape = [1, batch_size, hidden_size]
if activations is None:
lstm_node = helper.make_node(
'LSTM', inputs=input_names, outputs=["Y", "Y_h", "Y_c"], hidden_size=hidden_size)
elif alphas is None:
lstm_node = helper.make_node(
'LSTM',
inputs=input_names,
outputs=["Y", "Y_h", "Y_c"],
hidden_size=hidden_size,
activations=activations)
else:
lstm_node = helper.make_node(
'LSTM',
inputs=input_names,
outputs=["Y", "Y_h", "Y_c"],
hidden_size=hidden_size,
activations=activations,
activation_alpha=alphas,
activation_beta=betas)
graph = helper.make_graph([lstm_node],
"lstm_test",
inputs=input_tensors,
outputs=[
helper.make_tensor_value_info("Y", TensorProto.FLOAT,
list(Y_shape)),
helper.make_tensor_value_info("Y_h", TensorProto.FLOAT,
list(Y_h_shape)),
helper.make_tensor_value_info("Y_c", TensorProto.FLOAT,
list(Y_c_shape))
])
model = helper.make_model(graph, producer_name='lstm_test')
for target, ctx in ctx_list():
onnx_out = get_onnxruntime_output(model, input_values, 'float32')
tvm_out = get_tvm_output(
model,
input_values,
target,
ctx, [Y_shape, Y_h_shape, Y_c_shape],
output_dtype=['float32', 'float32', 'float32'])
for o_out, t_out in zip(onnx_out, tvm_out):
tvm.testing.assert_allclose(o_out, t_out, rtol=5e-3, atol=5e-3)
def test_lstm():
# No bias.
verify_lstm(seq_length=2, batch_size=1, input_size=16, hidden_size=32, use_bias=False)
# large batch.
verify_lstm(seq_length=4, batch_size=8, input_size=16, hidden_size=32, use_bias=True)
# Non power of two.
verify_lstm(seq_length=3, batch_size=3, input_size=16, hidden_size=40, use_bias=True)
# Long sequence.
verify_lstm(seq_length=8, batch_size=1, input_size=16, hidden_size=32, use_bias=True)
# Large hidden.
verify_lstm(seq_length=2, batch_size=1, input_size=16, hidden_size=128, use_bias=True)
# Large input.
verify_lstm(seq_length=2, batch_size=1, input_size=64, hidden_size=32, use_bias=True)
# Different activation testing.
# Default value hardsigmoid.
verify_lstm(
seq_length=2,
batch_size=1,
input_size=16,
hidden_size=32,
use_bias=False,
activations=['HardSigmoid', 'Tanh', 'Tanh'])
# Multiple parameterized activations.
verify_lstm(
seq_length=2,
batch_size=1,
input_size=16,
hidden_size=32,
use_bias=False,
activations=['HardSigmoid', 'LeakyRelu', 'Tanh'],
alphas=[2.0, 0.5],
betas=[.3])
# All parameterized with new Affine activation.
verify_lstm(
seq_length=2,
batch_size=1,
input_size=16,
hidden_size=32,
use_bias=False,
activations=['HardSigmoid', 'LeakyRelu', 'Affine'],
alphas=[2.0, 0.5, 0.8],
betas=[.3, 0.1])
# Testing with initial state and peepholes
verify_lstm(
seq_length=2,
batch_size=1,
input_size=16,
hidden_size=32,
use_bias=True,
use_initial_state=True)
verify_lstm(
seq_length=2,
batch_size=1,
input_size=16,
hidden_size=32,
use_bias=True,
use_initial_state=True,
use_peep=True)
if __name__ == '__main__':
test_flatten()
test_reshape()
......@@ -2020,3 +2195,4 @@ if __name__ == '__main__':
test_convtranspose()
test_unsqueeze_constant()
test_pooling()
test_lstm()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment