Commit 8c5b4909 by Yuwei Hu Committed by Tianqi Chen

[Keras] fix dropout bug (#399)

parent b7b74228
......@@ -355,7 +355,7 @@ def _convert_reshape(insym, keras_layer, _):
def _default_skip(insym, keras_layer, _): # pylint: disable=unused-argument
"""Layers that can be skipped because they are train time only."""
return
return insym
_convert_map = {
......
......@@ -40,7 +40,6 @@ def verify_keras_frontend(keras_model):
def test_forward_elemwise_add():
print("test_forward_elemwise_add")
r = []
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
......@@ -48,26 +47,16 @@ def test_forward_elemwise_add():
x = keras.layers.Conv2D(8, (3, 3), padding="same")(x)
r.append(x)
x = keras.layers.Conv2D(8, (3, 3), padding="same")(x)
# add two symbols
y = keras.layers.add([keras.layers.add([x, r[0]]), r[1]])
y = keras.layers.GlobalAveragePooling2D()(y)
keras_model = keras.models.Model(data, y)
verify_keras_frontend(keras_model)
# add three symbols
y = keras.layers.add([x, r[0], r[1]])
y = keras.layers.GlobalAveragePooling2D()(y)
keras_model = keras.models.Model(data, y)
verify_keras_frontend(keras_model)
def test_forward_elementwise_add2():
data = keras.layers.Input(shape=(32,32,3))
r = keras.layers.Conv2D(10, (3, 3), padding="same")(data)
x = keras.layers.Conv2D(10, (3, 3), strides=(2, 2), padding="same")(data)
x = keras.layers.UpSampling2D()(x)
x = keras.layers.add([x, r])
x = keras.layers.GlobalAveragePooling2D()(x)
keras_model = keras.models.Model(data, x)
def test_forward_softrelu():
......@@ -92,6 +81,7 @@ def test_forward_dense():
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.MaxPooling2D(pool_size=(2,2))(data)
x = keras.layers.Flatten()(x)
x = keras.layers.Dropout(0.5)(x)
x = keras.layers.Dense(10, activation='relu', kernel_initializer='uniform')(x)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
......@@ -146,7 +136,6 @@ def test_forward_resnet50():
if __name__ == '__main__':
test_forward_elemwise_add()
test_forward_elementwise_add2()
test_forward_softrelu()
test_forward_leaky_relu()
test_forward_dense()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment