Commit dfe4c466 by Nick Hynes Committed by MORITA Kazutaka

[Relay] Allow converting keras.layers.Sequential (#2842)

* Allow converting keras.layers.Sequential

* Use existing new_var function

* Only update expr when missing

* Add test
parent 9ace1cb6
......@@ -255,7 +255,8 @@ class ExprTable(object):
def set_expr(self, name, expr):
assert isinstance(expr, _expr.Expr)
self.exprs[name] = expr
if name not in self.exprs:
self.exprs[name] = expr
def set_padding(self, paddings):
self.paddings = paddings
......
......@@ -7,7 +7,7 @@ from .. import ir_pass
from .. import expr as _expr
from .. import op as _op
from ... import nd as _nd
from .common import ExprTable
from .common import ExprTable, new_var
__all__ = ['from_keras']
......@@ -661,12 +661,15 @@ def from_keras(model, shape=None):
raise ValueError("Keras frontend currently supports data_format = channels_last only.")
_check_unsupported_layers(model)
def _convert_input_layer(keras_layer):
input_name = keras_layer.name
input_shape = shape[input_name] if shape is not None and input_name in shape else None
etab.set_expr(input_name, new_var(input_name, shape=input_shape))
etab = ExprTable()
for keras_layer in model.layers:
if isinstance(keras_layer, keras.engine.InputLayer):
input_name = keras_layer.name
input_shape = shape[input_name] if shape is not None and input_name in shape else None
etab.set_expr(input_name, _expr.var(input_name, shape=input_shape))
_convert_input_layer(keras_layer)
else:
inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \
else keras_layer._inbound_nodes if hasattr(keras_layer, '_inbound_nodes') \
......@@ -690,6 +693,7 @@ def from_keras(model, shape=None):
for n_idx, t_idx, inbound_layer in zip_node:
if isinstance(inbound_layer, keras.engine.InputLayer):
expr_name = inbound_layer.name
_convert_input_layer(inbound_layer)
else:
expr_name = inbound_layer.name + ':' + str(n_idx) + ':' + str(t_idx)
expr = etab.get_expr(expr_name)
......
......@@ -106,6 +106,17 @@ def test_forward_dense():
verify_keras_frontend(keras_model)
def test_forward_sequential():
keras_model = keras.models.Sequential([
keras.layers.Dense(16, input_dim=32, activation='relu'),
keras.layers.Dropout(0.5),
keras.layers.Dense(8, activation='relu'),
keras.layers.Dropout(0.5),
keras.layers.Dense(1, activation='sigmoid')
])
verify_keras_frontend(keras_model)
def test_forward_pool():
data = keras.layers.Input(shape=(32,32,1))
# maxpool
......@@ -244,6 +255,7 @@ if __name__ == '__main__':
test_forward_merge()
test_forward_activations()
test_forward_dense()
test_forward_sequential()
test_forward_pool()
test_forward_conv()
test_forward_upsample(interpolation='nearest')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment