Unverified Commit 75e9f5dc by Josh Fromm Committed by GitHub

[Frontend][ONNX] LSTM Support (#4825)

* Initial version working and passing tests.

* WIP on supporting other activations.

* add support for multiple activation functions in lstm

* All tests working and code cleaned up.

* Undo import swap to avoid conflict with masahi.

* Added new tests and related bug fixes.

Co-authored-by: Matthew Brookhart <mbrookhart@octoml.ai>
parent d2cc2144
...@@ -32,6 +32,55 @@ from .common import infer_type, infer_value, infer_value_simulated, get_name ...@@ -32,6 +32,55 @@ from .common import infer_type, infer_value, infer_value_simulated, get_name
__all__ = ['from_onnx'] __all__ = ['from_onnx']
class onnx_input():
""" Dual purpose list or dictionary access object."""
def __init__(self):
self.input_keys = []
self.input_dict = {}
def __getitem__(self, item):
if isinstance(item, int):
return self.input_dict[self.input_keys[item]]
if isinstance(item, str):
if item not in self.input_keys:
return None
return self.input_dict[item]
if isinstance(item, slice):
keys = self.input_keys[item]
return [self.input_dict[key] for key in keys]
raise ValueError("Only integer, string, and slice accesses allowed.")
def __setitem__(self, item, value):
if isinstance(item, int):
self.input_dict[self.input_keys[item]] = value
elif isinstance(item, str):
if item not in self.input_dict:
self.input_keys.append(item)
self.input_dict[item] = value
else:
raise ValueError("Only integer and string indexed writes allowed.")
def keys(self):
return self.input_keys
def __len__(self):
return len(self.input_keys)
def __iter__(self):
self.n = 0
return self
def __next__(self):
if self.n < len(self.input_keys):
output = self.input_dict[self.input_keys[self.n]]
self.n += 1
return output
raise StopIteration
def get_numpy(tensor_proto): def get_numpy(tensor_proto):
"""Grab data in TensorProto and convert to numpy array.""" """Grab data in TensorProto and convert to numpy array."""
try: try:
...@@ -664,13 +713,24 @@ class Sum(OnnxOpConverter): ...@@ -664,13 +713,24 @@ class Sum(OnnxOpConverter):
return inputs[len(inputs) - 1] return inputs[len(inputs) - 1]
class Affine(OnnxOpConverter):
""" Operator converter for Affine transformation.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = _expr.const(attr.get('alpha', 1.0))
beta = _expr.const(attr.get('beta', 0.0))
return (alpha * inputs[0]) + beta
class ThresholdedRelu(OnnxOpConverter): class ThresholdedRelu(OnnxOpConverter):
""" Operator converter for ThresholdedRelu. """ Operator converter for ThresholdedRelu.
""" """
@classmethod @classmethod
def _impl_v1(cls, inputs, attr, params): def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 0.0)) alpha = float(attr.get('alpha', 1.0))
alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha)) alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha))
mask = _op.greater(inputs[0], alpha_tensor).astype("float32") mask = _op.greater(inputs[0], alpha_tensor).astype("float32")
return inputs[0] * mask return inputs[0] * mask
...@@ -893,7 +953,7 @@ class Maximum(OnnxOpConverter): ...@@ -893,7 +953,7 @@ class Maximum(OnnxOpConverter):
""" """
@classmethod @classmethod
def _impl_v1(cls, inputs, attr, params): def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, list) or len(inputs) < 2: if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs") raise ValueError("Expect minimum 2 inputs")
_max = inputs[0] _max = inputs[0]
for i in range(1, len(inputs)): for i in range(1, len(inputs)):
...@@ -905,7 +965,7 @@ class Minimum(OnnxOpConverter): ...@@ -905,7 +965,7 @@ class Minimum(OnnxOpConverter):
""" """
@classmethod @classmethod
def _impl_v1(cls, inputs, attr, params): def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, list) or len(inputs) < 2: if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs") raise ValueError("Expect minimum 2 inputs")
_min = inputs[0] _min = inputs[0]
for i in range(1, len(inputs)): for i in range(1, len(inputs)):
...@@ -917,7 +977,7 @@ class Mean(OnnxOpConverter): ...@@ -917,7 +977,7 @@ class Mean(OnnxOpConverter):
""" """
@classmethod @classmethod
def _impl_v1(cls, inputs, attr, params): def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, list) or len(inputs) < 2: if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs") raise ValueError("Expect minimum 2 inputs")
# avoid overflow # avoid overflow
concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0) concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0)
...@@ -1190,6 +1250,151 @@ class Expand(OnnxOpConverter): ...@@ -1190,6 +1250,151 @@ class Expand(OnnxOpConverter):
return _op.broadcast_to(inputs[0], shape=tuple(shape)) return _op.broadcast_to(inputs[0], shape=tuple(shape))
class LSTM(OnnxOpConverter):
""" Operator converter for LSTM.
"""
@classmethod
def _activation_helper(cls, activation, alpha, beta):
convert_map = _get_convert_map(1)
attrs = {}
if alpha is not None:
attrs['alpha'] = alpha
if beta is not None:
attrs['beta'] = beta
return lambda x: convert_map[activation.decode("utf-8")]([x], attrs, {})
@classmethod
def _activation_needs_alpha(cls, activation):
needs_alpha = [
"Affine",
"LeakyRelu",
"ThresholdedRelu",
"ScaledTanh",
"HardSigmoid",
"Elu",
]
return activation.decode("utf-8") in needs_alpha
@classmethod
def _activation_needs_beta(cls, activation):
needs_beta = [
"Affine",
"ScaledTanh",
"HardSigmoid",
]
return activation.decode("utf-8") in needs_beta
@classmethod
def _impl_v7(cls, inputs, attr, params):
# Unpack inputs, note that if optional and not provided then value will be None.
X = inputs[0]
W = inputs[1]
R = inputs[2]
B = inputs['B']
# Sequence length currently unused as it can be inferred from shapes.
#sequence_lens = inputs['sequence_lens']
h_0 = inputs['initial_h']
c_0 = inputs['initial_c']
P = inputs['P']
num_directions = infer_shape(W)[0]
W_dtype = infer_type(W).type_annotation.dtype
if num_directions != 1:
raise NotImplementedError("Bidirectional LSTMs not yet supported.")
# Remove num_directions axis from weights.
W = _op.squeeze(W, axis=[0])
R = _op.squeeze(R, axis=[0])
if B is not None:
B = _op.squeeze(B, axis=[0])
X_shape = infer_shape(X)
hidden_size = infer_shape(R)[-1]
batch_size = X_shape[1]
# Initialize state if not provided.
# Otherwise remove bidirectional axis.
if h_0 is None:
h_0 = _op.zeros((batch_size, hidden_size), W_dtype)
else:
h_0 = _op.squeeze(h_0, axis=[0])
if c_0 is None:
c_0 = _op.zeros((batch_size, hidden_size), W_dtype)
else:
c_0 = _op.squeeze(c_0, axis=[0])
if P is not None:
P = _op.squeeze(P, axis=[0])
p_i, p_o, p_f = _op.split(P, 3)
H_t = h_0
C_t = c_0
h_list = []
if 'activations' in attr:
activations = attr['activations']
if len(activations) != 3:
raise NotImplementedError("LSTM assumes 3 activation functions are provided")
alpha_loc = 0
alphas = attr.get('activation_alpha', [])
if isinstance(alphas, float):
alphas = [alphas]
beta_loc = 0
betas = attr.get('activation_beta', [])
if isinstance(betas, float):
betas = [betas]
acts = []
for i in range(3):
alpha = None
beta = None
activation = activations[i]
if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc:
alpha = alphas[alpha_loc]
alpha_loc += 1
if cls._activation_needs_beta(activation) and len(betas) > beta_loc:
beta = betas[beta_loc]
beta_loc += 1
acts.append(cls._activation_helper(activation, alpha, beta))
f_act, g_act, h_act = acts
else:
f_act = _op.sigmoid
g_act = _op.tanh
h_act = _op.tanh
X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0)
for step in X_steps:
step = _op.squeeze(step, axis=[0])
gates = _op.nn.dense(step, W) + _op.nn.dense(H_t, R)
if B is not None:
WB, RB = _op.split(B, 2)
gates += WB + RB
i, o, f, c = _op.split(gates, 4, axis=-1)
if P is not None:
i = f_act(i + p_i * C_t)
f = f_act(f + p_f * C_t)
else:
i = f_act(i)
f = f_act(f)
c = g_act(c)
C = f * C_t + i * c
if P is not None:
o = f_act(o + p_o * C)
else:
o = f_act(o)
H = o * h_act(C)
H_t = H
C_t = C
h_list.append(_op.expand_dims(H, axis=0))
# Concatenate outputs and add back in direction axis.
concatenated = _op.concatenate(h_list, 0)
output = _op.expand_dims(concatenated, axis=1)
H_t = _op.expand_dims(H_t, axis=0)
C_t = _op.expand_dims(C_t, axis=0)
return _expr.TupleWrapper(_expr.Tuple((output, H_t, C_t)), 3)
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -1203,7 +1408,7 @@ def _get_convert_map(opset): ...@@ -1203,7 +1408,7 @@ def _get_convert_map(opset):
return { return {
# defs/experimental # defs/experimental
'Identity': Renamer('copy'), 'Identity': Renamer('copy'),
# 'Affine' 'Affine': Affine.get_converter(opset),
'ThresholdedRelu': ThresholdedRelu.get_converter(opset), 'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
'ScaledTanh': ScaledTanh.get_converter(opset), 'ScaledTanh': ScaledTanh.get_converter(opset),
'ParametricSoftplus': ParametricSoftPlus.get_converter(opset), 'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
...@@ -1281,6 +1486,8 @@ def _get_convert_map(opset): ...@@ -1281,6 +1486,8 @@ def _get_convert_map(opset):
'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']), 'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
'Flatten': Flatten.get_converter(opset), 'Flatten': Flatten.get_converter(opset),
'LRN': LRN.get_converter(opset), 'LRN': LRN.get_converter(opset),
# Recurrent Layers
'LSTM': LSTM.get_converter(opset),
# defs/reduction # defs/reduction
'ReduceMax': ReduceMax.get_converter(opset), 'ReduceMax': ReduceMax.get_converter(opset),
...@@ -1414,7 +1621,11 @@ class GraphProto(object): ...@@ -1414,7 +1621,11 @@ class GraphProto(object):
for node in graph.node: for node in graph.node:
op_name = node.op_type op_name = node.op_type
attr = self._parse_attr(node.attribute) attr = self._parse_attr(node.attribute)
inputs = [self._nodes[self._renames.get(i, i)] for i in node.input] # Create and populate onnx input object.
inputs = onnx_input()
for i in node.input:
if i != '':
inputs[i] = self._nodes[self._renames.get(i, i)]
if op_name == "Constant": if op_name == "Constant":
t_proto = self._parse_attr(node.attribute)["value"] t_proto = self._parse_attr(node.attribute)["value"]
self._num_param += 1 self._num_param += 1
......
...@@ -56,6 +56,12 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output ...@@ -56,6 +56,12 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output
# set inputs # set inputs
if isinstance(input_data, list): if isinstance(input_data, list):
for i, e in enumerate(input_names): for i, e in enumerate(input_names):
# Its possible for some onnx inputs to not be needed in the tvm
# module, confirm its present before setting.
try:
m.get_input(input_names[i])
except:
continue
m.set_input(input_names[i], tvm.nd.array( m.set_input(input_names[i], tvm.nd.array(
input_data[i].astype(input_data[i].dtype))) input_data[i].astype(input_data[i].dtype)))
else: else:
...@@ -1962,6 +1968,175 @@ def test_pooling(): ...@@ -1962,6 +1968,175 @@ def test_pooling():
auto_pad='SAME_UPPER') auto_pad='SAME_UPPER')
def verify_lstm(seq_length,
batch_size,
input_size,
hidden_size,
use_bias=False,
activations=None,
alphas=None,
betas=None,
use_initial_state=False,
use_peep=False):
x_np = np.random.uniform(size=(seq_length, batch_size, input_size)).astype('float32')
w_np = np.random.uniform(size=(1, 4 * hidden_size, input_size)).astype('float32')
r_np = np.random.uniform(size=(1, 4 * hidden_size, hidden_size)).astype('float32')
input_names = ["X", "W", "R"]
input_tensors = [
helper.make_tensor_value_info("X", TensorProto.FLOAT, list(x_np.shape)),
helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_np.shape)),
helper.make_tensor_value_info("R", TensorProto.FLOAT, list(r_np.shape))
]
input_values = [x_np, w_np, r_np]
if use_bias:
b_np = np.random.uniform(size=(1, 8 * hidden_size)).astype('float32')
input_names.append("B")
input_tensors.append(
helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 8 * hidden_size]))
input_values.append(b_np)
if use_initial_state:
assert use_bias == True, "Initial states must have bias specified."
sequence_np = np.repeat(seq_length, batch_size).astype('int32')
input_names.append("sequence_lens")
input_tensors.append(helper.make_tensor_value_info("sequence_lens", TensorProto.INT32, [batch_size]))
input_values.append(sequence_np)
initial_h_np = np.random.uniform(size=(1, batch_size, hidden_size)).astype('float32')
input_names.append("initial_h")
input_tensors.append(
helper.make_tensor_value_info("initial_h", TensorProto.FLOAT,
[1, batch_size, hidden_size]))
input_values.append(initial_h_np)
initial_c_np = np.random.uniform(size=(1, batch_size, hidden_size)).astype('float32')
input_names.append("initial_c")
input_tensors.append(
helper.make_tensor_value_info("initial_c", TensorProto.FLOAT,
[1, batch_size, hidden_size]))
input_values.append(initial_c_np)
if use_peep:
assert use_initial_state == True, "Peepholes require initial state to be specified."
p_np = np.random.uniform(size=(1, 3 * hidden_size)).astype('float32')
input_names.append("P")
input_tensors.append(
helper.make_tensor_value_info("P", TensorProto.FLOAT, [1, 3 * hidden_size]))
input_values.append(p_np)
Y_shape = [seq_length, 1, batch_size, hidden_size]
Y_h_shape = [1, batch_size, hidden_size]
Y_c_shape = [1, batch_size, hidden_size]
if activations is None:
lstm_node = helper.make_node(
'LSTM', inputs=input_names, outputs=["Y", "Y_h", "Y_c"], hidden_size=hidden_size)
elif alphas is None:
lstm_node = helper.make_node(
'LSTM',
inputs=input_names,
outputs=["Y", "Y_h", "Y_c"],
hidden_size=hidden_size,
activations=activations)
else:
lstm_node = helper.make_node(
'LSTM',
inputs=input_names,
outputs=["Y", "Y_h", "Y_c"],
hidden_size=hidden_size,
activations=activations,
activation_alpha=alphas,
activation_beta=betas)
graph = helper.make_graph([lstm_node],
"lstm_test",
inputs=input_tensors,
outputs=[
helper.make_tensor_value_info("Y", TensorProto.FLOAT,
list(Y_shape)),
helper.make_tensor_value_info("Y_h", TensorProto.FLOAT,
list(Y_h_shape)),
helper.make_tensor_value_info("Y_c", TensorProto.FLOAT,
list(Y_c_shape))
])
model = helper.make_model(graph, producer_name='lstm_test')
for target, ctx in ctx_list():
onnx_out = get_onnxruntime_output(model, input_values, 'float32')
tvm_out = get_tvm_output(
model,
input_values,
target,
ctx, [Y_shape, Y_h_shape, Y_c_shape],
output_dtype=['float32', 'float32', 'float32'])
for o_out, t_out in zip(onnx_out, tvm_out):
tvm.testing.assert_allclose(o_out, t_out, rtol=5e-3, atol=5e-3)
def test_lstm():
# No bias.
verify_lstm(seq_length=2, batch_size=1, input_size=16, hidden_size=32, use_bias=False)
# large batch.
verify_lstm(seq_length=4, batch_size=8, input_size=16, hidden_size=32, use_bias=True)
# Non power of two.
verify_lstm(seq_length=3, batch_size=3, input_size=16, hidden_size=40, use_bias=True)
# Long sequence.
verify_lstm(seq_length=8, batch_size=1, input_size=16, hidden_size=32, use_bias=True)
# Large hidden.
verify_lstm(seq_length=2, batch_size=1, input_size=16, hidden_size=128, use_bias=True)
# Large input.
verify_lstm(seq_length=2, batch_size=1, input_size=64, hidden_size=32, use_bias=True)
# Different activation testing.
# Default value hardsigmoid.
verify_lstm(
seq_length=2,
batch_size=1,
input_size=16,
hidden_size=32,
use_bias=False,
activations=['HardSigmoid', 'Tanh', 'Tanh'])
# Multiple parameterized activations.
verify_lstm(
seq_length=2,
batch_size=1,
input_size=16,
hidden_size=32,
use_bias=False,
activations=['HardSigmoid', 'LeakyRelu', 'Tanh'],
alphas=[2.0, 0.5],
betas=[.3])
# All parameterized with new Affine activation.
verify_lstm(
seq_length=2,
batch_size=1,
input_size=16,
hidden_size=32,
use_bias=False,
activations=['HardSigmoid', 'LeakyRelu', 'Affine'],
alphas=[2.0, 0.5, 0.8],
betas=[.3, 0.1])
# Testing with initial state and peepholes
verify_lstm(
seq_length=2,
batch_size=1,
input_size=16,
hidden_size=32,
use_bias=True,
use_initial_state=True)
verify_lstm(
seq_length=2,
batch_size=1,
input_size=16,
hidden_size=32,
use_bias=True,
use_initial_state=True,
use_peep=True)
if __name__ == '__main__': if __name__ == '__main__':
test_flatten() test_flatten()
test_reshape() test_reshape()
...@@ -2020,3 +2195,4 @@ if __name__ == '__main__': ...@@ -2020,3 +2195,4 @@ if __name__ == '__main__':
test_convtranspose() test_convtranspose()
test_unsqueeze_constant() test_unsqueeze_constant()
test_pooling() test_pooling()
test_lstm()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment