Commit f713aa9c by Siju Committed by Tianqi Chen

[NNVM][KERAS]LSTMCell support (#1686)

parent fb570e5a
......@@ -395,6 +395,40 @@ def _convert_reshape(insym, keras_layer, _):
shape = (-1, ch) + keras_layer.target_shape[:-1]
return _sym.reshape(insym, shape=shape)
def _convert_lstm(insym, keras_layer, symtab):
_check_data_format(keras_layer)
if not isinstance(insym, list):
buffer = np.zeros((1, keras_layer.units), 'float32')
c_sym = symtab.new_const(buffer)
h_sym = symtab.new_const(buffer)
insym = [insym, h_sym, c_sym]
in_data = insym[0]
in_state_h = insym[1]
in_state_c = insym[2]
weightList = keras_layer.get_weights()
kernel_wt = symtab.new_const(weightList[0].transpose([1, 0]))
recurrent_wt = symtab.new_const(weightList[1].transpose([1, 0]))
in_bias = symtab.new_const(weightList[2])
units = list(weightList[0].shape)[1]
in_data = _sym.flatten(in_data)
ixh1 = _sym.dense(in_data, kernel_wt, use_bias=False, units=units)
ixh2 = _sym.dense(in_state_h, recurrent_wt, in_bias, use_bias=True, units=units)
gate = ixh1 + ixh2
gates = _sym.split(gate, indices_or_sections=4, axis=1)
in_gate = _sym.sigmoid(gates[0])
in_transform = _sym.sigmoid(gates[1])
next_c = in_transform * in_state_c + in_gate * _sym.tanh(gates[2])
out_gate = _sym.sigmoid(gates[3])
next_h = out_gate * _sym.tanh(next_c)
out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
out = _sym.reshape(next_h, shape=out_shape)
return [out, next_h, next_c]
def _default_skip(insym, keras_layer, _): # pylint: disable=unused-argument
"""Layers that can be skipped because they are train time only."""
......@@ -442,7 +476,7 @@ _convert_map = {
# 'Conv1D' : _convert_convolution1d,
# 'GRU' : _convert_gru,
# 'LSTM' : _convert_lstm,
'LSTM' : _convert_lstm,
# 'SimpleRNN' : _convert_simple_rnn,
# 'Bidirectional' : _convert_bidirectional,
# 'TimeDistributed' : _default_skip,
......@@ -466,6 +500,11 @@ def _check_unsupported_layers(model):
if type(layer).__name__ not in _convert_map:
raise ValueError("Keras layer {} not supported.".format(type(layer).__name__))
def _as_list(arr):
"""Force being a list, ignore if already is."""
if isinstance(arr, list):
return arr
return [arr]
def keras_op_to_nnvm(insym, keras_layer, outname, symtab):
"""Convert keras layer to nnvm symbol, and update symtab.
......@@ -486,9 +525,12 @@ def keras_op_to_nnvm(insym, keras_layer, outname, symtab):
"""
if type(keras_layer).__name__ not in _convert_map:
raise NotImplementedError("{} is not supported".format((type(keras_layer).__name__)))
ret = _convert_map[type(keras_layer).__name__](insym, keras_layer, symtab)
symtab.set_var(outname, ret)
outs = _convert_map[type(keras_layer).__name__](insym, keras_layer, symtab)
outs = _as_list(outs)
for t_idx, out in enumerate(outs):
name = outname + ":" + str(t_idx)
symtab.set_var(name, out)
def from_keras(model):
"""Convert keras model to NNVM format.
......@@ -529,7 +571,7 @@ def from_keras(model):
if inbound_nodes is None:
raise TypeError("Unknown layer type or unsupported Keras version : {}"
.format(keras_layer))
for my_idx, node in enumerate(inbound_nodes):
for node_idx, node in enumerate(inbound_nodes):
insym = []
# Since Keras allows creating multiple layers from the same name instance,
......@@ -537,17 +579,25 @@ def from_keras(model):
# The one exception is InputLayer. Changing input variable names after conversion
# would confuse users, so we should keep them as far as possible. Fortunately,
# they are named uniquely to input_1, input_2, input_3 ... by default.
for pred_idx, pred in zip(node.node_indices, node.inbound_layers):
if isinstance(pred, keras.engine.InputLayer):
sym = symtab.get_var(pred.name, must_contain=True)
zip_node = zip(node.node_indices, node.tensor_indices, node.inbound_layers)
for n_idx, t_idx, layer in zip_node:
if isinstance(layer, keras.engine.InputLayer):
sym = symtab.get_var(layer.name, must_contain=True)
else:
sym = symtab.get_var(pred.name + ':' + str(pred_idx), must_contain=True)
sym_name = layer.name + ':' + str(n_idx) + ':' + str(t_idx)
sym = symtab.get_var(sym_name, must_contain=True)
insym.append(sym)
if len(insym) == 1:
insym = insym[0]
keras_op_to_nnvm(insym, keras_layer, keras_layer.name + ':' + str(my_idx), symtab)
keras_op_to_nnvm(insym, keras_layer, keras_layer.name + ':' + str(node_idx), symtab)
#model._output_coordinates contains out_node(oc[0]), node_index(oc[1]) and tensor index(oc[2])
#Get all output nodes in symtab using the name made from above values. The out symbols
#were added to symtab in keras_op_to_nnvm using this name. For multiple outputs, make a list
#with these output symbols and Group them.
outsym = [symtab.get_var(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2]))
for oc in model._output_coordinates]
outsym = [symtab.get_var(layer.name + ':0') for layer in model._output_layers]
tvmparams = {k:tvm.nd.array(np.array(v, dtype=np.float32)) for k, v in symtab.params.items()}
return _sym.Group(outsym), tvmparams
......@@ -13,16 +13,22 @@ config.gpu_options.per_process_gpu_memory_fraction = 0.5
set_session(tf.Session(config=config))
def verify_keras_frontend(keras_model):
def verify_keras_frontend(keras_model, need_transpose=True):
# Keras frontend currently supports tensorflow backend only.
assert(keras.backend.backend() == 'tensorflow')
in_shapes = []
for layer in keras_model._input_layers:
in_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape))
#keras_model._output_coordinates contains the output_node, node_index and tensor_index
#get the outshapes from combining output node and tensor index
out_shapes = []
for layer in keras_model._output_layers:
out_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.output.shape))
for layer, node_index, tensor_index in keras_model._output_coordinates:
layer_out = layer.output
if isinstance(layer.output, list):#if multiple outputs are there
layer_out = layer.output[tensor_index]
out_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer_out.shape))
def get_keras_output(xs, dtype='float32'):
return keras_model.predict(xs)
......@@ -46,14 +52,13 @@ def verify_keras_frontend(keras_model):
keras_out = get_keras_output(xs)
for target, ctx in ctx_list():
tvm_out = get_tvm_output([x.transpose([0,3,1,2]) for x in xs], target, ctx)
tvm_out = get_tvm_output([x.transpose([0,3,1,2]) for x in xs ] if need_transpose else xs, target, ctx)
if isinstance (keras_out, list):
for kout, tout in zip(keras_out, tvm_out):
np.testing.assert_allclose(kout, tout.reshape(kout.shape), rtol=1e-5, atol=1e-5)
else:
np.testing.assert_allclose(keras_out, tvm_out.reshape(keras_out.shape), rtol=1e-5, atol=1e-5)
def test_forward_elemwise_add():
r = []
data = keras.layers.Input(shape=(32,32,3))
......@@ -231,6 +236,33 @@ def test_forward_reuse_layers():
keras_model = keras.models.Model(data, z)
verify_keras_frontend(keras_model)
def _test_LSTM(inputs, hidden, return_state=True):
data = keras.layers.Input(shape=(1, inputs))
lstm_out = keras.layers.LSTM(hidden,
return_state=return_state,
recurrent_activation='sigmoid',
activation='tanh')
x = lstm_out(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)
def _test_LSTM_MultiLayer(inputs, hidden):
inputs = keras.layers.Input(shape=(1, inputs))
layer = keras.layers.LSTM(hidden, return_state=True, return_sequences=True,
recurrent_activation='sigmoid',
activation='tanh')
outputs = layer(inputs)
output, state = outputs[0], outputs[1:]
output = keras.layers.LSTM(hidden, recurrent_activation='sigmoid',
activation='tanh')(output, initial_state=state)
keras_model = keras.models.Model(inputs, output)
verify_keras_frontend(keras_model, need_transpose=False)
def test_forward_LSTM():
_test_LSTM(8, 8, return_state=True)
_test_LSTM(4, 4, return_state=False)
_test_LSTM_MultiLayer(4, 4)
if __name__ == '__main__':
test_forward_elemwise_add()
......@@ -249,3 +281,4 @@ if __name__ == '__main__':
test_forward_multi_inputs()
test_forward_multi_outputs()
test_forward_reuse_layers()
test_forward_LSTM()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment