Commit c68945c5 by MORITA Kazutaka Committed by Tianqi Chen

[FRONTEND][Keras] fix reshape (#493)

parent 39cc9c12
...@@ -345,11 +345,12 @@ def _convert_concat(insym, keras_layer, _): ...@@ -345,11 +345,12 @@ def _convert_concat(insym, keras_layer, _):
def _convert_reshape(insym, keras_layer, _): def _convert_reshape(insym, keras_layer, _):
shape = keras_layer.shape if hasattr(keras_layer, 'shape') \ _check_data_format(keras_layer)
else keras_layer.target_shape if hasattr(keras_layer, 'target_shape') \ ch = keras_layer.input_shape[-1]
else None assert ch == keras_layer.target_shape[-1], \
if shape is None: "Only supports last dimension in target shape being equal to " \
raise TypeError("No shape attribute in reshape layer: {}".format(keras_layer)) "the channel number of input tensor."
shape = (-1, ch) + keras_layer.target_shape[:-1]
return _sym.reshape(insym, shape=shape) return _sym.reshape(insym, shape=shape)
......
...@@ -134,6 +134,14 @@ def test_forward_relu6(): ...@@ -134,6 +134,14 @@ def test_forward_relu6():
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_reshape():
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Reshape(target_shape=(32,32,3))(data)
x = keras.layers.GlobalAveragePooling2D()(x)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
def test_forward_vgg16(): def test_forward_vgg16():
keras_model = keras.applications.vgg16.VGG16(include_top=True, weights=None, keras_model = keras.applications.vgg16.VGG16(include_top=True, weights=None,
input_shape=(224,224,3), classes=1000) input_shape=(224,224,3), classes=1000)
...@@ -162,6 +170,7 @@ if __name__ == '__main__': ...@@ -162,6 +170,7 @@ if __name__ == '__main__':
test_forward_separable_conv() test_forward_separable_conv()
test_forward_upsample() test_forward_upsample()
test_forward_relu6() test_forward_relu6()
test_forward_reshape()
test_forward_vgg16() test_forward_vgg16()
test_forward_xception() test_forward_xception()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment