Commit 05f4362e by MORITA Kazutaka Committed by Tianqi Chen

[NNVM][Keras] allow only tensorflow backend (#1392)

parent 0c506f9a
...@@ -489,6 +489,8 @@ def from_keras(model): ...@@ -489,6 +489,8 @@ def from_keras(model):
raise ImportError('Keras must be installed') raise ImportError('Keras must be installed')
assert isinstance(model, keras.engine.training.Model) assert isinstance(model, keras.engine.training.Model)
if keras.backend.backend() != 'tensorflow':
raise ValueError("Keras frontend currently supports tensorflow backend only.")
if keras.backend.image_data_format() != 'channels_last': if keras.backend.image_data_format() != 'channels_last':
raise ValueError("Keras frontend currently supports data_format = channels_last only.") raise ValueError("Keras frontend currently supports data_format = channels_last only.")
_check_unsupported_layers(model) _check_unsupported_layers(model)
......
...@@ -14,6 +14,9 @@ set_session(tf.Session(config=config)) ...@@ -14,6 +14,9 @@ set_session(tf.Session(config=config))
def verify_keras_frontend(keras_model): def verify_keras_frontend(keras_model):
# Keras frontend currently supports tensorflow backend only.
assert(keras.backend.backend() == 'tensorflow')
in_shapes = [] in_shapes = []
for layer in keras_model._input_layers: for layer in keras_model._input_layers:
in_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape)) in_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment