Commit dc30880b by Leandro Nunes Committed by Tianqi Chen

[FRONTEND][Keras] Add support for tf.Keras networks in Relay Keras frontend (#4630)

* Make Relay Keras frontend support networks created using
   Tensorflow (1.13) Keras implementation (tf.Keras)
 * Modify Keras frontend tests to run from a class rather than a
   function based script
 * Adjust Keras frontend tests to run with both 'Keras' and 'tf.Keras'
 * Change "TestKeras.test_forward_merge" to validate instances by
   class name rather than instance type
parent 3b67e8a8
......@@ -660,6 +660,9 @@ _convert_map = {
'Concatenate' : _convert_concat,
'BatchNormalization' : _convert_batchnorm,
# Specific tf.Keras terminology for batch normalization
'BatchNormalizationV1' : _convert_batchnorm,
'Add' : _convert_merge,
'Subtract' : _convert_merge,
'Multiply' : _convert_merge,
......@@ -742,7 +745,7 @@ def from_keras(model, shape=None):
Parameters
----------
model : keras.engine.training.Model
model : keras.engine.training.Model or tensorflow.keras.models.Model
The keras model to be converted.
shape: dict of str to int list/tuple
......@@ -756,25 +759,42 @@ def from_keras(model, shape=None):
params : dict of str to tvm.NDArray
The parameter dict to be used by Relay.
"""
def _check_model_is_tf_keras():
return type(model).__module__.startswith("tensorflow.python.keras")
def _convert_input_layer(keras_layer):
input_name = keras_layer.name
input_shape = shape[input_name] if shape is not None and input_name in shape else None
etab.set_expr(input_name, new_var(input_name, shape=input_shape))
is_tf_keras = _check_model_is_tf_keras()
if not is_tf_keras:
# Importing from Keras
try:
import keras
except ImportError:
raise ImportError('Keras must be installed')
assert isinstance(model, keras.engine.training.Model)
raise ImportError("Keras must be installed")
if keras.backend.backend() != 'tensorflow':
raise ValueError("Keras frontend currently supports tensorflow backend only.")
if keras.backend.image_data_format() != 'channels_last':
raise ValueError("Keras frontend currently supports data_format = channels_last only.")
_check_unsupported_layers(model)
expected_model_class = keras.engine.training.Model
input_layer_class = keras.engine.InputLayer
else:
# Importing from Tensorflow Keras (tf.keras)
try:
from tensorflow import keras as tf_keras
except ImportError:
raise ImportError("Tensorflow must be installed")
expected_model_class = tf_keras.models.Model
input_layer_class = tf_keras.layers.InputLayer
def _convert_input_layer(keras_layer):
input_name = keras_layer.name
input_shape = shape[input_name] if shape is not None and input_name in shape else None
etab.set_expr(input_name, new_var(input_name, shape=input_shape))
assert isinstance(model, expected_model_class)
etab = ExprTable()
for keras_layer in model.layers:
if isinstance(keras_layer, keras.engine.InputLayer):
if isinstance(keras_layer, input_layer_class):
_convert_input_layer(keras_layer)
else:
inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \
......@@ -784,10 +804,13 @@ def from_keras(model, shape=None):
raise TypeError("Unknown layer type or unsupported Keras version : {}"
.format(keras_layer))
for node_idx, node in enumerate(inbound_nodes):
# If some nodes in imported model is not relevant to the current model,
# skip such layers. model._network_nodes contains keys of all nodes relevant
# to the current model.
if not model._node_key(keras_layer, node_idx) in model._network_nodes:
# If some nodes in imported model are not relevant to the current model,
# skip such layers.
# - In Keras, model._network_nodes contains keys of all nodes relevant to the
# current model;
# - In tf.Keras, this is already done as part of tensorflow.keras.network.get_config
if not is_tf_keras and \
not model._node_key(keras_layer, node_idx) in model._network_nodes:
continue
inexpr = []
# Since Keras allows creating multiple layers from the same name instance,
......@@ -797,7 +820,7 @@ def from_keras(model, shape=None):
# they are named uniquely to input_1, input_2, input_3... by default.
zip_node = zip(node.node_indices, node.tensor_indices, node.inbound_layers)
for n_idx, t_idx, inbound_layer in zip_node:
if isinstance(inbound_layer, keras.engine.InputLayer):
if isinstance(inbound_layer, input_layer_class):
expr_name = inbound_layer.name
_convert_input_layer(inbound_layer)
else:
......
......@@ -23,12 +23,35 @@ import keras
# prevent Keras from using up all gpu memory
import tensorflow as tf
from tensorflow import keras as tf_keras
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.5
set_session(tf.Session(config=config))
def pytest_generate_tests(metafunc):
# This function generates the list of tests for pytest, based
# on scenatios that will change the parameters in which the
# tests use to run.
# https://docs.pytest.org/en/latest/example/parametrize.html
idlist = []
argvalues = []
for scenario in metafunc.cls.scenarios:
idlist.append(scenario[0])
items = scenario[1].items()
argnames = [x[0] for x in items]
argvalues.append([x[1] for x in items])
metafunc.parametrize(argnames, argvalues, ids=idlist, scope="class")
# Scenarios:
# - classic keras, using keras from "import keras"
# - tensorflow keras, using keras from "from tensorflow import keras as tf_keras"
using_classic_keras = ("keras", {"keras": keras})
using_tensorflow_keras = ("tf_keras", {"keras": tf_keras})
def verify_keras_frontend(keras_model, need_transpose=True):
# Keras frontend currently supports tensorflow backend only.
assert(keras.backend.backend() == 'tensorflow')
......@@ -72,7 +95,10 @@ def verify_keras_frontend(keras_model, need_transpose=True):
tvm.testing.assert_allclose(kout, tout, rtol=1e-5, atol=1e-5)
def test_forward_merge():
class TestKeras:
scenarios = [using_classic_keras, using_tensorflow_keras]
def test_forward_merge(self, keras):
data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
y = keras.layers.Conv2D(8, (3, 3), padding="same")(x)
......@@ -84,14 +110,15 @@ def test_forward_merge():
keras.layers.Average(),
keras.layers.Concatenate()]
for merge_func in merge_funcs:
if isinstance(merge_func, (keras.layers.merge.Subtract, keras.layers.merge.Dot)):
class_name = type(merge_func).__name__
if class_name in ('Subtract', 'Dot'):
out = merge_func([x, y])
else:
out = merge_func([x, y, z])
keras_model = keras.models.Model(data, out)
verify_keras_frontend(keras_model)
def test_forward_merge_dot():
def test_forward_merge_dot(self, keras):
data1 = keras.layers.Input(shape=(2, 2))
data2 = keras.layers.Input(shape=(2, 2))
merge_funcs = [keras.layers.Dot(axes=[1, 2]),
......@@ -105,7 +132,7 @@ def test_forward_merge_dot():
keras_model = keras.models.Model([data1, data2], out)
verify_keras_frontend(keras_model)
def test_forward_activations():
def test_forward_activations(self, keras):
data = keras.layers.Input(shape=(32, 32, 3))
act_funcs = [keras.layers.Activation('softmax'),
keras.layers.Softmax(),
......@@ -138,7 +165,7 @@ def test_forward_activations():
verify_keras_frontend(keras_model)
def test_forward_dense():
def test_forward_dense(self, keras):
data = keras.layers.Input(shape=(32, 32, 1))
x = keras.layers.Flatten()(data)
x = keras.layers.Dropout(0.5)(x)
......@@ -146,13 +173,13 @@ def test_forward_dense():
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
def test_forward_permute():
def test_forward_permute(self, keras):
data = keras.layers.Input(shape=(2, 3, 4))
x = keras.layers.Permute([2, 3, 1])(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)
def test_forward_sequential():
def test_forward_sequential(self, keras):
keras_model = keras.models.Sequential([
keras.layers.Dense(16, input_dim=32, activation='relu'),
keras.layers.Dropout(0.5),
......@@ -163,7 +190,7 @@ def test_forward_sequential():
verify_keras_frontend(keras_model)
def test_forward_pool():
def test_forward_pool(self, keras):
data = keras.layers.Input(shape=(32, 32, 1))
# maxpool
x = keras.layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same')(data)
......@@ -175,7 +202,7 @@ def test_forward_pool():
verify_keras_frontend(keras_model)
def test_forward_conv():
def test_forward_conv(self, keras):
data = keras.layers.Input(shape=(32, 32, 3))
conv_funcs = [keras.layers.Conv2D(filters=10, kernel_size=(3, 3),
strides=(2, 2), padding='same'),
......@@ -190,7 +217,7 @@ def test_forward_conv():
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
def test_forward_batch_norm():
def test_forward_batch_norm(self, keras):
data = keras.layers.Input(shape=(32, 32, 3))
batch_norm_funcs = [keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
center=True, scale=False,
......@@ -221,14 +248,14 @@ def test_forward_batch_norm():
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
def test_forward_upsample(interpolation='nearest'):
def test_forward_upsample(self, keras, interpolation='nearest'):
data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.UpSampling2D(size=(3, 3), interpolation=interpolation)(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
def test_forward_reshape():
def test_forward_reshape(self, keras):
# input_shape len is 3, target_shape len is 3
data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Reshape(target_shape=(16, 64, 3))(data)
......@@ -261,7 +288,7 @@ def test_forward_reshape():
verify_keras_frontend(keras_model, need_transpose=False)
def test_forward_crop():
def test_forward_crop(self, keras):
data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Cropping2D(cropping=((1, 1), (1, 1)))(data)
x = keras.layers.Cropping2D(cropping=(1, 1))(x)
......@@ -274,7 +301,7 @@ def test_forward_crop():
verify_keras_frontend(keras_model)
def test_forward_multi_inputs():
def test_forward_multi_inputs(self, keras):
data1 = keras.layers.Input(shape=(32, 32, 3))
data2 = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data1)
......@@ -285,7 +312,7 @@ def test_forward_multi_inputs():
verify_keras_frontend(keras_model)
def test_forward_multi_outputs():
def test_forward_multi_outputs(self, keras):
data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
x = keras.layers.GlobalAveragePooling2D()(x)
......@@ -295,7 +322,7 @@ def test_forward_multi_outputs():
verify_keras_frontend(keras_model)
def test_forward_reuse_layers():
def test_forward_reuse_layers(self, keras):
# reuse conv2d
data = keras.layers.Input(shape=(32, 32, 3))
conv2d = keras.layers.Conv2D(8, (3, 3), padding="same")
......@@ -316,7 +343,7 @@ def test_forward_reuse_layers():
verify_keras_frontend(keras_model)
def test_forward_rnn():
def test_forward_rnn(self,keras):
data = keras.layers.Input(shape=(1, 32))
rnn_funcs = [keras.layers.LSTM(units=16, return_state=False,
recurrent_activation='sigmoid', activation='tanh'),
......@@ -330,49 +357,51 @@ def test_forward_rnn():
verify_keras_frontend(keras_model, need_transpose=False)
def test_forward_vgg16():
def test_forward_vgg16(self, keras):
keras_model = keras.applications.VGG16(include_top=True, weights='imagenet',
input_shape=(224, 224, 3), classes=1000)
verify_keras_frontend(keras_model)
def test_forward_xception():
def test_forward_xception(self, keras):
keras_model = keras.applications.Xception(include_top=True, weights='imagenet',
input_shape=(299, 299, 3), classes=1000)
verify_keras_frontend(keras_model)
def test_forward_resnet50():
def test_forward_resnet50(self, keras):
keras_model = keras.applications.ResNet50(include_top=True, weights='imagenet',
input_shape=(224, 224, 3), classes=1000)
verify_keras_frontend(keras_model)
def test_forward_mobilenet():
def test_forward_mobilenet(self, keras):
keras_model = keras.applications.MobileNet(include_top=True, weights='imagenet',
input_shape=(224, 224, 3), classes=1000)
verify_keras_frontend(keras_model)
if __name__ == '__main__':
test_forward_merge()
test_forward_merge_dot()
test_forward_activations()
test_forward_dense()
test_forward_permute()
test_forward_sequential()
test_forward_pool()
test_forward_conv()
test_forward_batch_norm()
test_forward_upsample(interpolation='nearest')
test_forward_upsample(interpolation='bilinear')
test_forward_reshape()
test_forward_crop()
test_forward_multi_inputs()
test_forward_multi_outputs()
test_forward_reuse_layers()
test_forward_rnn()
test_forward_vgg16()
test_forward_xception()
test_forward_resnet50()
test_forward_mobilenet()
for k in [keras, tf_keras]:
sut = TestKeras()
sut.test_forward_merge_dot(keras=k)
sut.test_forward_merge(keras=k)
sut.test_forward_activations(keras=k)
sut.test_forward_dense(keras=k)
sut.test_forward_permute(keras=k)
sut.test_forward_sequential(keras=k)
sut.test_forward_pool(keras=k)
sut.test_forward_conv(keras=k)
sut.test_forward_batch_norm(keras=k)
sut.test_forward_upsample(keras=k, interpolation='nearest')
sut.test_forward_upsample(keras=k, interpolation='bilinear')
sut.test_forward_reshape(keras=k)
sut.test_forward_crop(keras=k)
sut.test_forward_multi_inputs(keras=k)
sut.test_forward_multi_outputs(keras=k)
sut.test_forward_reuse_layers(keras=k)
sut.test_forward_rnn(keras=k)
sut.test_forward_vgg16(keras=k)
sut.test_forward_xception(keras=k)
sut.test_forward_resnet50(keras=k)
sut.test_forward_mobilenet(keras=k)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment