Commit dfe4c466 by Nick Hynes Committed by MORITA Kazutaka

[Relay] Allow converting keras.layers.Sequential (#2842)

* Allow converting keras.layers.Sequential

* Use existing new_var function

* Only update expr when missing

* Add test
parent 9ace1cb6
...@@ -255,7 +255,8 @@ class ExprTable(object): ...@@ -255,7 +255,8 @@ class ExprTable(object):
def set_expr(self, name, expr): def set_expr(self, name, expr):
assert isinstance(expr, _expr.Expr) assert isinstance(expr, _expr.Expr)
self.exprs[name] = expr if name not in self.exprs:
self.exprs[name] = expr
def set_padding(self, paddings): def set_padding(self, paddings):
self.paddings = paddings self.paddings = paddings
......
...@@ -7,7 +7,7 @@ from .. import ir_pass ...@@ -7,7 +7,7 @@ from .. import ir_pass
from .. import expr as _expr from .. import expr as _expr
from .. import op as _op from .. import op as _op
from ... import nd as _nd from ... import nd as _nd
from .common import ExprTable from .common import ExprTable, new_var
__all__ = ['from_keras'] __all__ = ['from_keras']
...@@ -661,12 +661,15 @@ def from_keras(model, shape=None): ...@@ -661,12 +661,15 @@ def from_keras(model, shape=None):
raise ValueError("Keras frontend currently supports data_format = channels_last only.") raise ValueError("Keras frontend currently supports data_format = channels_last only.")
_check_unsupported_layers(model) _check_unsupported_layers(model)
def _convert_input_layer(keras_layer):
input_name = keras_layer.name
input_shape = shape[input_name] if shape is not None and input_name in shape else None
etab.set_expr(input_name, new_var(input_name, shape=input_shape))
etab = ExprTable() etab = ExprTable()
for keras_layer in model.layers: for keras_layer in model.layers:
if isinstance(keras_layer, keras.engine.InputLayer): if isinstance(keras_layer, keras.engine.InputLayer):
input_name = keras_layer.name _convert_input_layer(keras_layer)
input_shape = shape[input_name] if shape is not None and input_name in shape else None
etab.set_expr(input_name, _expr.var(input_name, shape=input_shape))
else: else:
inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \ inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \
else keras_layer._inbound_nodes if hasattr(keras_layer, '_inbound_nodes') \ else keras_layer._inbound_nodes if hasattr(keras_layer, '_inbound_nodes') \
...@@ -690,6 +693,7 @@ def from_keras(model, shape=None): ...@@ -690,6 +693,7 @@ def from_keras(model, shape=None):
for n_idx, t_idx, inbound_layer in zip_node: for n_idx, t_idx, inbound_layer in zip_node:
if isinstance(inbound_layer, keras.engine.InputLayer): if isinstance(inbound_layer, keras.engine.InputLayer):
expr_name = inbound_layer.name expr_name = inbound_layer.name
_convert_input_layer(inbound_layer)
else: else:
expr_name = inbound_layer.name + ':' + str(n_idx) + ':' + str(t_idx) expr_name = inbound_layer.name + ':' + str(n_idx) + ':' + str(t_idx)
expr = etab.get_expr(expr_name) expr = etab.get_expr(expr_name)
......
...@@ -106,6 +106,17 @@ def test_forward_dense(): ...@@ -106,6 +106,17 @@ def test_forward_dense():
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_sequential():
keras_model = keras.models.Sequential([
keras.layers.Dense(16, input_dim=32, activation='relu'),
keras.layers.Dropout(0.5),
keras.layers.Dense(8, activation='relu'),
keras.layers.Dropout(0.5),
keras.layers.Dense(1, activation='sigmoid')
])
verify_keras_frontend(keras_model)
def test_forward_pool(): def test_forward_pool():
data = keras.layers.Input(shape=(32,32,1)) data = keras.layers.Input(shape=(32,32,1))
# maxpool # maxpool
...@@ -244,6 +255,7 @@ if __name__ == '__main__': ...@@ -244,6 +255,7 @@ if __name__ == '__main__':
test_forward_merge() test_forward_merge()
test_forward_activations() test_forward_activations()
test_forward_dense() test_forward_dense()
test_forward_sequential()
test_forward_pool() test_forward_pool()
test_forward_conv() test_forward_conv()
test_forward_upsample(interpolation='nearest') test_forward_upsample(interpolation='nearest')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment