Commit 6292c781 by Siju Committed by Tianqi Chen

[NNVM]Keras SimpleRnn and GRU support (#1729)

parent 7c3ec7df
......@@ -28,6 +28,10 @@ def _get_elu(insym, alpha):
"""
return -alpha * _sym.relu(1 - _sym.exp(insym)) + _sym.relu(insym)
def _convert_recurrent_activation(insym, keras_layer):
act_type = keras_layer.recurrent_activation.__name__
return _convert_activation(insym, act_type, None)
def _convert_activation(insym, keras_layer, _):
if isinstance(keras_layer, str):
act_type = keras_layer
......@@ -420,16 +424,96 @@ def _convert_lstm(insym, keras_layer, symtab):
ixh2 = _sym.dense(in_state_h, recurrent_wt, in_bias, use_bias=True, units=units)
gate = ixh1 + ixh2
gates = _sym.split(gate, indices_or_sections=4, axis=1)
in_gate = _sym.sigmoid(gates[0])
in_transform = _sym.sigmoid(gates[1])
next_c = in_transform * in_state_c + in_gate * _sym.tanh(gates[2])
out_gate = _sym.sigmoid(gates[3])
next_h = out_gate * _sym.tanh(next_c)
in_gate = _convert_recurrent_activation(gates[0], keras_layer)
in_transform = _convert_recurrent_activation(gates[1], keras_layer)
next_c = in_transform * in_state_c + in_gate * _convert_activation(gates[2], keras_layer, None)
out_gate = _convert_recurrent_activation(gates[3], keras_layer)
next_h = out_gate * _convert_activation(next_c, keras_layer, None)
out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
out = _sym.reshape(next_h, shape=out_shape)
return [out, next_h, next_c]
def _convert_simple_rnn(insym, keras_layer, symtab):
_check_data_format(keras_layer)
if not isinstance(insym, list):
buffer = np.zeros((1, keras_layer.units), 'float32')
prev_sym = symtab.new_const(buffer)
insym = [insym, prev_sym]
in_data = insym[0]
prev_sym = insym[1]
weightList = keras_layer.get_weights()
kernel_wt = symtab.new_const(weightList[0].transpose([1, 0]))
recurrent_wt = symtab.new_const(weightList[1].transpose([1, 0]))
in_bias = symtab.new_const(weightList[2])
units = list(weightList[0].shape)[1]
in_data = _sym.flatten(in_data)
ixh = _sym.dense(in_data, kernel_wt, in_bias, use_bias=True, units=units)
prev_sym = _sym.flatten(prev_sym)
ixh2 = _sym.dense(prev_sym, recurrent_wt, use_bias=False, units=units)
output = ixh + ixh2
output = _convert_activation(output, keras_layer, None)
out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
output = _sym.reshape(output, shape=out_shape)
return [output, output]
def _convert_gru(insym, keras_layer, symtab):
_check_data_format(keras_layer)
if not isinstance(insym, list):
buffer = np.zeros((1, keras_layer.units), 'float32')
h_tm1 = symtab.new_const(buffer)
insym = [insym, h_tm1]
in_data = insym[0]
h_tm1_sym = insym[1]
weightList = keras_layer.get_weights()
kernel_wt = symtab.new_const(weightList[0].transpose([1, 0]))
recurrent_wt = symtab.new_const(weightList[1].transpose([1, 0]))
in_bias = symtab.new_const(weightList[2])
units = list(weightList[0].shape)[1]
in_data = _sym.flatten(in_data)
matrix_x = _sym.dense(in_data, kernel_wt, in_bias, use_bias=True, units=units)
# inputs projected by all gate matrices at once
split_indices = [keras_layer.units, 2 * keras_layer.units]
gates = _sym.split(matrix_x, indices_or_sections=split_indices, axis=1)
x_z = gates[0]
x_r = gates[1]
x_h = gates[2]
# hidden state projected separately for update/reset and new
units = 2 * keras_layer.units
split_indices = [units]
rec_wts = _sym.split(recurrent_wt, indices_or_sections=split_indices, axis=0)
h_tm1_sym = _sym.flatten(h_tm1_sym)
matrix_inner = _sym.dense(h_tm1_sym, rec_wts[0], use_bias=False, units=units)
split_indices = [keras_layer.units]
recurrent = _sym.split(matrix_inner, indices_or_sections=split_indices, axis=1)
recurrent_z = recurrent[0]
recurrent_r = recurrent[1]
rec_act_z = _convert_recurrent_activation(x_z + recurrent_z, keras_layer)
rec_act_r = _convert_recurrent_activation(x_r + recurrent_r, keras_layer)
units = keras_layer.units
recurrent_h = _sym.dense(rec_act_r * h_tm1_sym, rec_wts[1], use_bias=False, units=units)
act_hh = _convert_activation(x_h + recurrent_h, keras_layer, None)
# previous and candidate state mixed by update gate
output = rec_act_z * h_tm1_sym + (1 - rec_act_z) * act_hh
out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
output = _sym.reshape(output, shape=out_shape)
return [output, output]
def _default_skip(insym, keras_layer, _): # pylint: disable=unused-argument
"""Layers that can be skipped because they are train time only."""
return insym
......@@ -475,9 +559,9 @@ _convert_map = {
# 'UpSampling3D' : _convert_upsample,
# 'Conv1D' : _convert_convolution1d,
# 'GRU' : _convert_gru,
'SimpleRNN' : _convert_simple_rnn,
'LSTM' : _convert_lstm,
# 'SimpleRNN' : _convert_simple_rnn,
'GRU' : _convert_gru,
# 'Bidirectional' : _convert_bidirectional,
# 'TimeDistributed' : _default_skip,
......
......@@ -254,6 +254,58 @@ def test_forward_LSTM():
_test_LSTM(4, 4, return_state=False)
_test_LSTM_MultiLayer(4, 4)
def _test_RNN(inputs, units):
data = keras.layers.Input(shape=(1, inputs))
rnn_out = keras.layers.SimpleRNN(units, return_state=True,
activation='tanh')
x = rnn_out(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)
def _test_RNN_MultiLayer(inputs, units):
inputs = keras.layers.Input(shape=(1, inputs))
layer = keras.layers.SimpleRNN(units, return_state=True, return_sequences=True,
activation='tanh')
outputs = layer(inputs)
output, state = outputs[0], outputs[1:]
output = keras.layers.SimpleRNN(units, activation='tanh')(output, initial_state=state)
keras_model = keras.models.Model(inputs, output)
verify_keras_frontend(keras_model, need_transpose=False)
def test_forward_RNN():
_test_RNN(2, 4)
_test_RNN(4, 3)
_test_RNN_MultiLayer(4, 12)
def _test_GRU(inputs, units):
data = keras.layers.Input(shape=(1, inputs))
gru_out = keras.layers.GRU(units,
return_state=True,
recurrent_activation='sigmoid',
activation='tanh')
x = gru_out(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)
def _test_GRU_MultiLayer(inputs, units):
inputs = keras.layers.Input(shape=(1, inputs))
layer = keras.layers.GRU(units,
return_state=True,
return_sequences=True,
recurrent_activation='sigmoid',
activation='tanh')
outputs = layer(inputs)
output, state = outputs[0], outputs[1:]
output = keras.layers.GRU(units, recurrent_activation='sigmoid',
activation='tanh')(output, initial_state=state)
keras_model = keras.models.Model(inputs, output)
verify_keras_frontend(keras_model, need_transpose=False)
def test_forward_GRU():
_test_GRU(2, 4)
_test_GRU(4, 3)
_test_GRU_MultiLayer(4, 4)
if __name__ == '__main__':
test_forward_elemwise_add()
test_forward_activations()
......@@ -272,3 +324,5 @@ if __name__ == '__main__':
test_forward_multi_outputs()
test_forward_reuse_layers()
test_forward_LSTM()
test_forward_RNN()
test_forward_GRU()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment