Commit fd6ad274 by MORITA Kazutaka Committed by Tianqi Chen

[FRONTEND][Keras] Fix softmax axis (#503)

parent 8534db36
......@@ -40,7 +40,7 @@ def _convert_activation(insym, keras_layer, _):
return _sym.__add_scalar__(_sym.__mul_scalar__(insym, \
scalar=alpha), scalar=beta)
elif act_type == 'softmax':
return _sym.softmax(insym)
return _sym.softmax(insym, axis=1)
elif act_type == 'sigmoid':
return _sym.sigmoid(insym)
elif act_type == 'tanh':
......
......@@ -59,6 +59,15 @@ def test_forward_elemwise_add():
verify_keras_frontend(keras_model)
def test_forward_softmax():
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Activation('softmax')(data)
x = keras.layers.Concatenate()([x, x])
x = keras.layers.GlobalMaxPooling2D()(x)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
def test_forward_softrelu():
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Activation('softplus')(data)
......@@ -145,6 +154,7 @@ def test_forward_resnet50():
if __name__ == '__main__':
test_forward_elemwise_add()
test_forward_softmax()
test_forward_softrelu()
test_forward_leaky_relu()
test_forward_dense()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment