Commit dc30880b by Leandro Nunes Committed by Tianqi Chen

[FRONTEND][Keras] Add support for tf.Keras networks in Relay Keras frontend (#4630)

* Make Relay Keras frontend support networks created using
   Tensorflow (1.13) Keras implementation (tf.Keras)
 * Modify Keras frontend tests to run from a class rather than a
   function based script
 * Adjust Keras frontend tests to run with both 'Keras' and 'tf.Keras'
 * Change "TestKeras.test_forward_merge" to validate instances by
   class name rather than instance type
parent 3b67e8a8
...@@ -660,6 +660,9 @@ _convert_map = { ...@@ -660,6 +660,9 @@ _convert_map = {
'Concatenate' : _convert_concat, 'Concatenate' : _convert_concat,
'BatchNormalization' : _convert_batchnorm, 'BatchNormalization' : _convert_batchnorm,
# Specific tf.Keras terminology for batch normalization
'BatchNormalizationV1' : _convert_batchnorm,
'Add' : _convert_merge, 'Add' : _convert_merge,
'Subtract' : _convert_merge, 'Subtract' : _convert_merge,
'Multiply' : _convert_merge, 'Multiply' : _convert_merge,
...@@ -742,7 +745,7 @@ def from_keras(model, shape=None): ...@@ -742,7 +745,7 @@ def from_keras(model, shape=None):
Parameters Parameters
---------- ----------
model : keras.engine.training.Model model : keras.engine.training.Model or tensorflow.keras.models.Model
The keras model to be converted. The keras model to be converted.
shape: dict of str to int list/tuple shape: dict of str to int list/tuple
...@@ -756,25 +759,42 @@ def from_keras(model, shape=None): ...@@ -756,25 +759,42 @@ def from_keras(model, shape=None):
params : dict of str to tvm.NDArray params : dict of str to tvm.NDArray
The parameter dict to be used by Relay. The parameter dict to be used by Relay.
""" """
def _check_model_is_tf_keras():
return type(model).__module__.startswith("tensorflow.python.keras")
def _convert_input_layer(keras_layer):
input_name = keras_layer.name
input_shape = shape[input_name] if shape is not None and input_name in shape else None
etab.set_expr(input_name, new_var(input_name, shape=input_shape))
is_tf_keras = _check_model_is_tf_keras()
if not is_tf_keras:
# Importing from Keras
try: try:
import keras import keras
except ImportError: except ImportError:
raise ImportError('Keras must be installed') raise ImportError("Keras must be installed")
assert isinstance(model, keras.engine.training.Model)
if keras.backend.backend() != 'tensorflow': if keras.backend.backend() != 'tensorflow':
raise ValueError("Keras frontend currently supports tensorflow backend only.") raise ValueError("Keras frontend currently supports tensorflow backend only.")
if keras.backend.image_data_format() != 'channels_last': if keras.backend.image_data_format() != 'channels_last':
raise ValueError("Keras frontend currently supports data_format = channels_last only.") raise ValueError("Keras frontend currently supports data_format = channels_last only.")
_check_unsupported_layers(model) expected_model_class = keras.engine.training.Model
input_layer_class = keras.engine.InputLayer
else:
# Importing from Tensorflow Keras (tf.keras)
try:
from tensorflow import keras as tf_keras
except ImportError:
raise ImportError("Tensorflow must be installed")
expected_model_class = tf_keras.models.Model
input_layer_class = tf_keras.layers.InputLayer
def _convert_input_layer(keras_layer): assert isinstance(model, expected_model_class)
input_name = keras_layer.name
input_shape = shape[input_name] if shape is not None and input_name in shape else None
etab.set_expr(input_name, new_var(input_name, shape=input_shape))
etab = ExprTable() etab = ExprTable()
for keras_layer in model.layers: for keras_layer in model.layers:
if isinstance(keras_layer, keras.engine.InputLayer): if isinstance(keras_layer, input_layer_class):
_convert_input_layer(keras_layer) _convert_input_layer(keras_layer)
else: else:
inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \ inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \
...@@ -784,10 +804,13 @@ def from_keras(model, shape=None): ...@@ -784,10 +804,13 @@ def from_keras(model, shape=None):
raise TypeError("Unknown layer type or unsupported Keras version : {}" raise TypeError("Unknown layer type or unsupported Keras version : {}"
.format(keras_layer)) .format(keras_layer))
for node_idx, node in enumerate(inbound_nodes): for node_idx, node in enumerate(inbound_nodes):
# If some nodes in imported model is not relevant to the current model, # If some nodes in imported model are not relevant to the current model,
# skip such layers. model._network_nodes contains keys of all nodes relevant # skip such layers.
# to the current model. # - In Keras, model._network_nodes contains keys of all nodes relevant to the
if not model._node_key(keras_layer, node_idx) in model._network_nodes: # current model;
# - In tf.Keras, this is already done as part of tensorflow.keras.network.get_config
if not is_tf_keras and \
not model._node_key(keras_layer, node_idx) in model._network_nodes:
continue continue
inexpr = [] inexpr = []
# Since Keras allows creating multiple layers from the same name instance, # Since Keras allows creating multiple layers from the same name instance,
...@@ -797,7 +820,7 @@ def from_keras(model, shape=None): ...@@ -797,7 +820,7 @@ def from_keras(model, shape=None):
# they are named uniquely to input_1, input_2, input_3... by default. # they are named uniquely to input_1, input_2, input_3... by default.
zip_node = zip(node.node_indices, node.tensor_indices, node.inbound_layers) zip_node = zip(node.node_indices, node.tensor_indices, node.inbound_layers)
for n_idx, t_idx, inbound_layer in zip_node: for n_idx, t_idx, inbound_layer in zip_node:
if isinstance(inbound_layer, keras.engine.InputLayer): if isinstance(inbound_layer, input_layer_class):
expr_name = inbound_layer.name expr_name = inbound_layer.name
_convert_input_layer(inbound_layer) _convert_input_layer(inbound_layer)
else: else:
......
...@@ -23,12 +23,35 @@ import keras ...@@ -23,12 +23,35 @@ import keras
# prevent Keras from using up all gpu memory # prevent Keras from using up all gpu memory
import tensorflow as tf import tensorflow as tf
from tensorflow import keras as tf_keras
from keras.backend.tensorflow_backend import set_session from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto() config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.5 config.gpu_options.per_process_gpu_memory_fraction = 0.5
set_session(tf.Session(config=config)) set_session(tf.Session(config=config))
def pytest_generate_tests(metafunc):
# This function generates the list of tests for pytest, based
# on scenatios that will change the parameters in which the
# tests use to run.
# https://docs.pytest.org/en/latest/example/parametrize.html
idlist = []
argvalues = []
for scenario in metafunc.cls.scenarios:
idlist.append(scenario[0])
items = scenario[1].items()
argnames = [x[0] for x in items]
argvalues.append([x[1] for x in items])
metafunc.parametrize(argnames, argvalues, ids=idlist, scope="class")
# Scenarios:
# - classic keras, using keras from "import keras"
# - tensorflow keras, using keras from "from tensorflow import keras as tf_keras"
using_classic_keras = ("keras", {"keras": keras})
using_tensorflow_keras = ("tf_keras", {"keras": tf_keras})
def verify_keras_frontend(keras_model, need_transpose=True): def verify_keras_frontend(keras_model, need_transpose=True):
# Keras frontend currently supports tensorflow backend only. # Keras frontend currently supports tensorflow backend only.
assert(keras.backend.backend() == 'tensorflow') assert(keras.backend.backend() == 'tensorflow')
...@@ -72,7 +95,10 @@ def verify_keras_frontend(keras_model, need_transpose=True): ...@@ -72,7 +95,10 @@ def verify_keras_frontend(keras_model, need_transpose=True):
tvm.testing.assert_allclose(kout, tout, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(kout, tout, rtol=1e-5, atol=1e-5)
def test_forward_merge(): class TestKeras:
scenarios = [using_classic_keras, using_tensorflow_keras]
def test_forward_merge(self, keras):
data = keras.layers.Input(shape=(32, 32, 3)) data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data) x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
y = keras.layers.Conv2D(8, (3, 3), padding="same")(x) y = keras.layers.Conv2D(8, (3, 3), padding="same")(x)
...@@ -84,14 +110,15 @@ def test_forward_merge(): ...@@ -84,14 +110,15 @@ def test_forward_merge():
keras.layers.Average(), keras.layers.Average(),
keras.layers.Concatenate()] keras.layers.Concatenate()]
for merge_func in merge_funcs: for merge_func in merge_funcs:
if isinstance(merge_func, (keras.layers.merge.Subtract, keras.layers.merge.Dot)): class_name = type(merge_func).__name__
if class_name in ('Subtract', 'Dot'):
out = merge_func([x, y]) out = merge_func([x, y])
else: else:
out = merge_func([x, y, z]) out = merge_func([x, y, z])
keras_model = keras.models.Model(data, out) keras_model = keras.models.Model(data, out)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_merge_dot(): def test_forward_merge_dot(self, keras):
data1 = keras.layers.Input(shape=(2, 2)) data1 = keras.layers.Input(shape=(2, 2))
data2 = keras.layers.Input(shape=(2, 2)) data2 = keras.layers.Input(shape=(2, 2))
merge_funcs = [keras.layers.Dot(axes=[1, 2]), merge_funcs = [keras.layers.Dot(axes=[1, 2]),
...@@ -105,7 +132,7 @@ def test_forward_merge_dot(): ...@@ -105,7 +132,7 @@ def test_forward_merge_dot():
keras_model = keras.models.Model([data1, data2], out) keras_model = keras.models.Model([data1, data2], out)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_activations(): def test_forward_activations(self, keras):
data = keras.layers.Input(shape=(32, 32, 3)) data = keras.layers.Input(shape=(32, 32, 3))
act_funcs = [keras.layers.Activation('softmax'), act_funcs = [keras.layers.Activation('softmax'),
keras.layers.Softmax(), keras.layers.Softmax(),
...@@ -138,7 +165,7 @@ def test_forward_activations(): ...@@ -138,7 +165,7 @@ def test_forward_activations():
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_dense(): def test_forward_dense(self, keras):
data = keras.layers.Input(shape=(32, 32, 1)) data = keras.layers.Input(shape=(32, 32, 1))
x = keras.layers.Flatten()(data) x = keras.layers.Flatten()(data)
x = keras.layers.Dropout(0.5)(x) x = keras.layers.Dropout(0.5)(x)
...@@ -146,13 +173,13 @@ def test_forward_dense(): ...@@ -146,13 +173,13 @@ def test_forward_dense():
keras_model = keras.models.Model(data, x) keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_permute(): def test_forward_permute(self, keras):
data = keras.layers.Input(shape=(2, 3, 4)) data = keras.layers.Input(shape=(2, 3, 4))
x = keras.layers.Permute([2, 3, 1])(data) x = keras.layers.Permute([2, 3, 1])(data)
keras_model = keras.models.Model(data, x) keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False) verify_keras_frontend(keras_model, need_transpose=False)
def test_forward_sequential(): def test_forward_sequential(self, keras):
keras_model = keras.models.Sequential([ keras_model = keras.models.Sequential([
keras.layers.Dense(16, input_dim=32, activation='relu'), keras.layers.Dense(16, input_dim=32, activation='relu'),
keras.layers.Dropout(0.5), keras.layers.Dropout(0.5),
...@@ -163,7 +190,7 @@ def test_forward_sequential(): ...@@ -163,7 +190,7 @@ def test_forward_sequential():
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_pool(): def test_forward_pool(self, keras):
data = keras.layers.Input(shape=(32, 32, 1)) data = keras.layers.Input(shape=(32, 32, 1))
# maxpool # maxpool
x = keras.layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same')(data) x = keras.layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same')(data)
...@@ -175,7 +202,7 @@ def test_forward_pool(): ...@@ -175,7 +202,7 @@ def test_forward_pool():
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_conv(): def test_forward_conv(self, keras):
data = keras.layers.Input(shape=(32, 32, 3)) data = keras.layers.Input(shape=(32, 32, 3))
conv_funcs = [keras.layers.Conv2D(filters=10, kernel_size=(3, 3), conv_funcs = [keras.layers.Conv2D(filters=10, kernel_size=(3, 3),
strides=(2, 2), padding='same'), strides=(2, 2), padding='same'),
...@@ -190,7 +217,7 @@ def test_forward_conv(): ...@@ -190,7 +217,7 @@ def test_forward_conv():
keras_model = keras.models.Model(data, x) keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_batch_norm(): def test_forward_batch_norm(self, keras):
data = keras.layers.Input(shape=(32, 32, 3)) data = keras.layers.Input(shape=(32, 32, 3))
batch_norm_funcs = [keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001, batch_norm_funcs = [keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
center=True, scale=False, center=True, scale=False,
...@@ -221,14 +248,14 @@ def test_forward_batch_norm(): ...@@ -221,14 +248,14 @@ def test_forward_batch_norm():
keras_model = keras.models.Model(data, x) keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_upsample(interpolation='nearest'): def test_forward_upsample(self, keras, interpolation='nearest'):
data = keras.layers.Input(shape=(32, 32, 3)) data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.UpSampling2D(size=(3, 3), interpolation=interpolation)(data) x = keras.layers.UpSampling2D(size=(3, 3), interpolation=interpolation)(data)
keras_model = keras.models.Model(data, x) keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_reshape(): def test_forward_reshape(self, keras):
# input_shape len is 3, target_shape len is 3 # input_shape len is 3, target_shape len is 3
data = keras.layers.Input(shape=(32, 32, 3)) data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Reshape(target_shape=(16, 64, 3))(data) x = keras.layers.Reshape(target_shape=(16, 64, 3))(data)
...@@ -261,7 +288,7 @@ def test_forward_reshape(): ...@@ -261,7 +288,7 @@ def test_forward_reshape():
verify_keras_frontend(keras_model, need_transpose=False) verify_keras_frontend(keras_model, need_transpose=False)
def test_forward_crop(): def test_forward_crop(self, keras):
data = keras.layers.Input(shape=(32, 32, 3)) data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Cropping2D(cropping=((1, 1), (1, 1)))(data) x = keras.layers.Cropping2D(cropping=((1, 1), (1, 1)))(data)
x = keras.layers.Cropping2D(cropping=(1, 1))(x) x = keras.layers.Cropping2D(cropping=(1, 1))(x)
...@@ -274,7 +301,7 @@ def test_forward_crop(): ...@@ -274,7 +301,7 @@ def test_forward_crop():
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_multi_inputs(): def test_forward_multi_inputs(self, keras):
data1 = keras.layers.Input(shape=(32, 32, 3)) data1 = keras.layers.Input(shape=(32, 32, 3))
data2 = keras.layers.Input(shape=(32, 32, 3)) data2 = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data1) x = keras.layers.Conv2D(8, (3, 3), padding="same")(data1)
...@@ -285,7 +312,7 @@ def test_forward_multi_inputs(): ...@@ -285,7 +312,7 @@ def test_forward_multi_inputs():
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_multi_outputs(): def test_forward_multi_outputs(self, keras):
data = keras.layers.Input(shape=(32, 32, 3)) data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data) x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
x = keras.layers.GlobalAveragePooling2D()(x) x = keras.layers.GlobalAveragePooling2D()(x)
...@@ -295,7 +322,7 @@ def test_forward_multi_outputs(): ...@@ -295,7 +322,7 @@ def test_forward_multi_outputs():
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_reuse_layers(): def test_forward_reuse_layers(self, keras):
# reuse conv2d # reuse conv2d
data = keras.layers.Input(shape=(32, 32, 3)) data = keras.layers.Input(shape=(32, 32, 3))
conv2d = keras.layers.Conv2D(8, (3, 3), padding="same") conv2d = keras.layers.Conv2D(8, (3, 3), padding="same")
...@@ -316,7 +343,7 @@ def test_forward_reuse_layers(): ...@@ -316,7 +343,7 @@ def test_forward_reuse_layers():
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_rnn(): def test_forward_rnn(self,keras):
data = keras.layers.Input(shape=(1, 32)) data = keras.layers.Input(shape=(1, 32))
rnn_funcs = [keras.layers.LSTM(units=16, return_state=False, rnn_funcs = [keras.layers.LSTM(units=16, return_state=False,
recurrent_activation='sigmoid', activation='tanh'), recurrent_activation='sigmoid', activation='tanh'),
...@@ -330,49 +357,51 @@ def test_forward_rnn(): ...@@ -330,49 +357,51 @@ def test_forward_rnn():
verify_keras_frontend(keras_model, need_transpose=False) verify_keras_frontend(keras_model, need_transpose=False)
def test_forward_vgg16(): def test_forward_vgg16(self, keras):
keras_model = keras.applications.VGG16(include_top=True, weights='imagenet', keras_model = keras.applications.VGG16(include_top=True, weights='imagenet',
input_shape=(224, 224, 3), classes=1000) input_shape=(224, 224, 3), classes=1000)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_xception(): def test_forward_xception(self, keras):
keras_model = keras.applications.Xception(include_top=True, weights='imagenet', keras_model = keras.applications.Xception(include_top=True, weights='imagenet',
input_shape=(299, 299, 3), classes=1000) input_shape=(299, 299, 3), classes=1000)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_resnet50(): def test_forward_resnet50(self, keras):
keras_model = keras.applications.ResNet50(include_top=True, weights='imagenet', keras_model = keras.applications.ResNet50(include_top=True, weights='imagenet',
input_shape=(224, 224, 3), classes=1000) input_shape=(224, 224, 3), classes=1000)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_mobilenet(): def test_forward_mobilenet(self, keras):
keras_model = keras.applications.MobileNet(include_top=True, weights='imagenet', keras_model = keras.applications.MobileNet(include_top=True, weights='imagenet',
input_shape=(224, 224, 3), classes=1000) input_shape=(224, 224, 3), classes=1000)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
if __name__ == '__main__': if __name__ == '__main__':
test_forward_merge() for k in [keras, tf_keras]:
test_forward_merge_dot() sut = TestKeras()
test_forward_activations() sut.test_forward_merge_dot(keras=k)
test_forward_dense() sut.test_forward_merge(keras=k)
test_forward_permute() sut.test_forward_activations(keras=k)
test_forward_sequential() sut.test_forward_dense(keras=k)
test_forward_pool() sut.test_forward_permute(keras=k)
test_forward_conv() sut.test_forward_sequential(keras=k)
test_forward_batch_norm() sut.test_forward_pool(keras=k)
test_forward_upsample(interpolation='nearest') sut.test_forward_conv(keras=k)
test_forward_upsample(interpolation='bilinear') sut.test_forward_batch_norm(keras=k)
test_forward_reshape() sut.test_forward_upsample(keras=k, interpolation='nearest')
test_forward_crop() sut.test_forward_upsample(keras=k, interpolation='bilinear')
test_forward_multi_inputs() sut.test_forward_reshape(keras=k)
test_forward_multi_outputs() sut.test_forward_crop(keras=k)
test_forward_reuse_layers() sut.test_forward_multi_inputs(keras=k)
test_forward_rnn() sut.test_forward_multi_outputs(keras=k)
test_forward_vgg16() sut.test_forward_reuse_layers(keras=k)
test_forward_xception() sut.test_forward_rnn(keras=k)
test_forward_resnet50() sut.test_forward_vgg16(keras=k)
test_forward_mobilenet() sut.test_forward_xception(keras=k)
sut.test_forward_resnet50(keras=k)
sut.test_forward_mobilenet(keras=k)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment