Commit dedcf82f by Yong Wu Committed by Yao Wang

[Relay][Keras] Permute, Softmax support (#3618)

parent e7fb2d4d
...@@ -115,6 +115,9 @@ def _convert_activation(inexpr, keras_layer, _): ...@@ -115,6 +115,9 @@ def _convert_activation(inexpr, keras_layer, _):
def _convert_advanced_activation(inexpr, keras_layer, etab): def _convert_advanced_activation(inexpr, keras_layer, etab):
act_type = type(keras_layer).__name__ act_type = type(keras_layer).__name__
if act_type == 'Softmax':
return _op.nn.softmax(inexpr, axis=1)
if act_type == 'ReLU': if act_type == 'ReLU':
if keras_layer.max_value: if keras_layer.max_value:
return _op.clip(inexpr, a_min=0., a_max=float(keras_layer.max_value)) return _op.clip(inexpr, a_min=0., a_max=float(keras_layer.max_value))
...@@ -160,6 +163,8 @@ def _convert_merge(inexpr, keras_layer, _): ...@@ -160,6 +163,8 @@ def _convert_merge(inexpr, keras_layer, _):
'Operator {} is not supported in frontend Keras.'.format(merge_type)) 'Operator {} is not supported in frontend Keras.'.format(merge_type))
return ret return ret
def _convert_permute(inexpr, keras_layer, _):
return _op.transpose(inexpr, axes=(0,) + keras_layer.dims)
def _convert_dense(inexpr, keras_layer, etab): def _convert_dense(inexpr, keras_layer, etab):
weightList = keras_layer.get_weights() weightList = keras_layer.get_weights()
...@@ -574,6 +579,7 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument ...@@ -574,6 +579,7 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
_convert_map = { _convert_map = {
'Dense' : _convert_dense, 'Dense' : _convert_dense,
'Activation' : _convert_activation, 'Activation' : _convert_activation,
'Softmax' : _convert_advanced_activation,
'ReLU' : _convert_advanced_activation, 'ReLU' : _convert_advanced_activation,
'LeakyReLU' : _convert_advanced_activation, 'LeakyReLU' : _convert_advanced_activation,
'PReLU' : _convert_advanced_activation, 'PReLU' : _convert_advanced_activation,
...@@ -620,7 +626,7 @@ _convert_map = { ...@@ -620,7 +626,7 @@ _convert_map = {
'Average' : _convert_merge, 'Average' : _convert_merge,
'Maximum' : _convert_merge, 'Maximum' : _convert_merge,
# 'Dot' : _convert_merge, # 'Dot' : _convert_merge,
# 'Permute' : _convert_permute, 'Permute' : _convert_permute,
# 'Embedding' : _convert_embedding, # 'Embedding' : _convert_embedding,
# 'RepeatVector' : _convert_repeat_vector, # 'RepeatVector' : _convert_repeat_vector,
...@@ -632,11 +638,15 @@ _convert_map = { ...@@ -632,11 +638,15 @@ _convert_map = {
def _check_unsupported_layers(model): def _check_unsupported_layers(model):
missing_ops = set()
for layer in model.layers: for layer in model.layers:
op_name = type(layer).__name__ op_name = type(layer).__name__
if op_name not in _convert_map: if op_name not in _convert_map:
raise tvm.error.OpNotImplemented( missing_ops.add(op_name)
'Operator {} is not supported in frontend Keras.'.format(op_name))
if missing_ops:
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_ops))
def keras_op_to_relay(inexpr, keras_layer, outname, etab): def keras_op_to_relay(inexpr, keras_layer, outname, etab):
......
...@@ -73,7 +73,7 @@ def verify_keras_frontend(keras_model, need_transpose=True): ...@@ -73,7 +73,7 @@ def verify_keras_frontend(keras_model, need_transpose=True):
def test_forward_merge(): def test_forward_merge():
data = keras.layers.Input(shape=(32,32,3)) data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data) x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
y = keras.layers.Conv2D(8, (3, 3), padding="same")(x) y = keras.layers.Conv2D(8, (3, 3), padding="same")(x)
z = keras.layers.Conv2D(8, (3, 3), padding="same")(y) z = keras.layers.Conv2D(8, (3, 3), padding="same")(y)
...@@ -93,7 +93,7 @@ def test_forward_merge(): ...@@ -93,7 +93,7 @@ def test_forward_merge():
def test_forward_activations(): def test_forward_activations():
data = keras.layers.Input(shape=(32,32,3)) data = keras.layers.Input(shape=(32, 32, 3))
act_funcs = [keras.layers.Activation('softmax'), act_funcs = [keras.layers.Activation('softmax'),
keras.layers.Activation('softplus'), keras.layers.Activation('softplus'),
keras.layers.Activation('relu'), keras.layers.Activation('relu'),
...@@ -103,6 +103,7 @@ def test_forward_activations(): ...@@ -103,6 +103,7 @@ def test_forward_activations():
keras.layers.Activation('tanh'), keras.layers.Activation('tanh'),
keras.layers.Activation('linear'), keras.layers.Activation('linear'),
keras.layers.Activation('selu'), keras.layers.Activation('selu'),
keras.layers.Softmax(),
keras.layers.ReLU(), keras.layers.ReLU(),
keras.layers.ReLU(max_value=6.), keras.layers.ReLU(max_value=6.),
keras.layers.LeakyReLU(alpha=0.3), keras.layers.LeakyReLU(alpha=0.3),
...@@ -116,13 +117,18 @@ def test_forward_activations(): ...@@ -116,13 +117,18 @@ def test_forward_activations():
def test_forward_dense(): def test_forward_dense():
data = keras.layers.Input(shape=(32,32,1)) data = keras.layers.Input(shape=(32, 32, 1))
x = keras.layers.Flatten()(data) x = keras.layers.Flatten()(data)
x = keras.layers.Dropout(0.5)(x) x = keras.layers.Dropout(0.5)(x)
x = keras.layers.Dense(10, activation='relu', kernel_initializer='uniform')(x) x = keras.layers.Dense(10, activation='relu', kernel_initializer='uniform')(x)
keras_model = keras.models.Model(data, x) keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_permute():
data = keras.layers.Input(shape=(2, 3, 4))
x = keras.layers.Permute([2, 3, 1])(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)
def test_forward_sequential(): def test_forward_sequential():
keras_model = keras.models.Sequential([ keras_model = keras.models.Sequential([
...@@ -136,7 +142,7 @@ def test_forward_sequential(): ...@@ -136,7 +142,7 @@ def test_forward_sequential():
def test_forward_pool(): def test_forward_pool():
data = keras.layers.Input(shape=(32,32,1)) data = keras.layers.Input(shape=(32, 32, 1))
# maxpool # maxpool
x = keras.layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same')(data) x = keras.layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same')(data)
keras_model = keras.models.Model(data, x) keras_model = keras.models.Model(data, x)
...@@ -148,14 +154,14 @@ def test_forward_pool(): ...@@ -148,14 +154,14 @@ def test_forward_pool():
def test_forward_conv(): def test_forward_conv():
data = keras.layers.Input(shape=(32,32,3)) data = keras.layers.Input(shape=(32, 32, 3))
conv_funcs = [keras.layers.Conv2D(filters=10, kernel_size=(3,3), conv_funcs = [keras.layers.Conv2D(filters=10, kernel_size=(3, 3),
strides=(2,2), padding='same'), strides=(2, 2), padding='same'),
keras.layers.Conv2D(filters=10, kernel_size=(3,3), keras.layers.Conv2D(filters=10, kernel_size=(3, 3),
dilation_rate=(2,2), padding='same'), dilation_rate=(2, 2), padding='same'),
keras.layers.DepthwiseConv2D(kernel_size=(3,3), padding='same'), keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
keras.layers.Conv2DTranspose(filters=10, kernel_size=(3,3), padding='valid'), keras.layers.Conv2DTranspose(filters=10, kernel_size=(3, 3), padding='valid'),
keras.layers.SeparableConv2D(filters=10, kernel_size=(3,3), padding='same')] keras.layers.SeparableConv2D(filters=10, kernel_size=(3, 3), padding='same')]
for conv_func in conv_funcs: for conv_func in conv_funcs:
x = conv_func(data) x = conv_func(data)
keras_model = keras.models.Model(data, x) keras_model = keras.models.Model(data, x)
...@@ -163,21 +169,21 @@ def test_forward_conv(): ...@@ -163,21 +169,21 @@ def test_forward_conv():
def test_forward_upsample(interpolation='nearest'): def test_forward_upsample(interpolation='nearest'):
data = keras.layers.Input(shape=(32,32,3)) data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.UpSampling2D(size=(3,3), interpolation=interpolation)(data) x = keras.layers.UpSampling2D(size=(3, 3), interpolation=interpolation)(data)
keras_model = keras.models.Model(data, x) keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_reshape(): def test_forward_reshape():
data = keras.layers.Input(shape=(32,32,3)) data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Reshape(target_shape=(32,32,3))(data) x = keras.layers.Reshape(target_shape=(32, 32, 3))(data)
keras_model = keras.models.Model(data, x) keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_crop(): def test_forward_crop():
data = keras.layers.Input(shape=(32,32,3)) data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Cropping2D(cropping=((1, 1), (1, 1)))(data) x = keras.layers.Cropping2D(cropping=((1, 1), (1, 1)))(data)
x = keras.layers.Cropping2D(cropping=(1, 1))(x) x = keras.layers.Cropping2D(cropping=(1, 1))(x)
x = keras.layers.Cropping2D(cropping=1)(x) x = keras.layers.Cropping2D(cropping=1)(x)
...@@ -190,8 +196,8 @@ def test_forward_crop(): ...@@ -190,8 +196,8 @@ def test_forward_crop():
def test_forward_multi_inputs(): def test_forward_multi_inputs():
data1 = keras.layers.Input(shape=(32,32,3)) data1 = keras.layers.Input(shape=(32, 32, 3))
data2 = keras.layers.Input(shape=(32,32,3)) data2 = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data1) x = keras.layers.Conv2D(8, (3, 3), padding="same")(data1)
y = keras.layers.Conv2D(8, (3, 3), padding="same")(data2) y = keras.layers.Conv2D(8, (3, 3), padding="same")(data2)
z = keras.layers.Average()([x, y]) z = keras.layers.Average()([x, y])
...@@ -201,7 +207,7 @@ def test_forward_multi_inputs(): ...@@ -201,7 +207,7 @@ def test_forward_multi_inputs():
def test_forward_multi_outputs(): def test_forward_multi_outputs():
data = keras.layers.Input(shape=(32,32,3)) data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data) x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
x = keras.layers.GlobalAveragePooling2D()(x) x = keras.layers.GlobalAveragePooling2D()(x)
y = keras.layers.Conv2D(8, (3, 3), padding="same")(data) y = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
...@@ -212,7 +218,7 @@ def test_forward_multi_outputs(): ...@@ -212,7 +218,7 @@ def test_forward_multi_outputs():
def test_forward_reuse_layers(): def test_forward_reuse_layers():
# reuse conv2d # reuse conv2d
data = keras.layers.Input(shape=(32,32,3)) data = keras.layers.Input(shape=(32, 32, 3))
conv2d = keras.layers.Conv2D(8, (3, 3), padding="same") conv2d = keras.layers.Conv2D(8, (3, 3), padding="same")
x = conv2d(data) x = conv2d(data)
y = conv2d(data) y = conv2d(data)
...@@ -221,7 +227,7 @@ def test_forward_reuse_layers(): ...@@ -221,7 +227,7 @@ def test_forward_reuse_layers():
keras_model = keras.models.Model(data, z) keras_model = keras.models.Model(data, z)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
# reuse add # reuse add
data = keras.layers.Input(shape=(32,32,3)) data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data) x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
add = keras.layers.Add() add = keras.layers.Add()
x = add([x, x]) x = add([x, x])
...@@ -232,7 +238,7 @@ def test_forward_reuse_layers(): ...@@ -232,7 +238,7 @@ def test_forward_reuse_layers():
def test_forward_rnn(): def test_forward_rnn():
data = keras.layers.Input(shape=(1,32)) data = keras.layers.Input(shape=(1, 32))
rnn_funcs = [keras.layers.LSTM(units=16, return_state=False, rnn_funcs = [keras.layers.LSTM(units=16, return_state=False,
recurrent_activation='sigmoid', activation='tanh'), recurrent_activation='sigmoid', activation='tanh'),
keras.layers.SimpleRNN(units=16, return_state=False, keras.layers.SimpleRNN(units=16, return_state=False,
...@@ -247,25 +253,25 @@ def test_forward_rnn(): ...@@ -247,25 +253,25 @@ def test_forward_rnn():
def test_forward_vgg16(): def test_forward_vgg16():
keras_model = keras.applications.VGG16(include_top=True, weights='imagenet', keras_model = keras.applications.VGG16(include_top=True, weights='imagenet',
input_shape=(224,224,3), classes=1000) input_shape=(224, 224, 3), classes=1000)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_xception(): def test_forward_xception():
keras_model = keras.applications.Xception(include_top=True, weights='imagenet', keras_model = keras.applications.Xception(include_top=True, weights='imagenet',
input_shape=(299,299,3), classes=1000) input_shape=(299, 299, 3), classes=1000)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_resnet50(): def test_forward_resnet50():
keras_model = keras.applications.ResNet50(include_top=True, weights='imagenet', keras_model = keras.applications.ResNet50(include_top=True, weights='imagenet',
input_shape=(224,224,3), classes=1000) input_shape=(224, 224, 3), classes=1000)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_mobilenet(): def test_forward_mobilenet():
keras_model = keras.applications.MobileNet(include_top=True, weights='imagenet', keras_model = keras.applications.MobileNet(include_top=True, weights='imagenet',
input_shape=(224,224,3), classes=1000) input_shape=(224, 224, 3), classes=1000)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
...@@ -273,6 +279,7 @@ if __name__ == '__main__': ...@@ -273,6 +279,7 @@ if __name__ == '__main__':
test_forward_merge() test_forward_merge()
test_forward_activations() test_forward_activations()
test_forward_dense() test_forward_dense()
test_forward_permute()
test_forward_sequential() test_forward_sequential()
test_forward_pool() test_forward_pool()
test_forward_conv() test_forward_conv()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment