Commit 2ebf1bd1 by Alexander Pivovarov Committed by Zhi

Add more cases to keras _convert_reshape (#3846)

parent ec7790e3
......@@ -490,11 +490,26 @@ def _convert_concat(inexpr, keras_layer, _):
def _convert_reshape(inexpr, keras_layer, _):
_check_data_format(keras_layer)
ch = keras_layer.input_shape[-1]
assert ch == keras_layer.target_shape[-1], \
inshape = keras_layer.input_shape # includes batch
tshape = keras_layer.target_shape # no batch
if len(inshape) == 3 and len(tshape) == 1:
# (?, a, b) -> (-1, ab)
shape = (-1, tshape[0])
elif len(inshape) in [2, 3] and len(tshape) == 2:
# (?, cc) -> (-1, c, c)
# (?, a, b) -> (-1, c, c)
assert tshape[0] == tshape[1], \
"Only supports square target shapes, but got {}".format(tshape)
shape = (-1, ) + tshape
else:
# (?, h, w, c) -> (-1, c, H, W)
# (?, h, w, c) -> (-1, c, hw)
# (?, hw, c) -> (-1, c, h, w)
ch = inshape[-1]
assert ch == tshape[-1], \
"Only supports last dimension in target shape being equal to " \
"the channel number of input tensor."
shape = (-1, ch) + keras_layer.target_shape[:-1]
shape = (-1, ch) + tshape[:-1]
return _op.reshape(inexpr, newshape=shape)
......
......@@ -193,10 +193,36 @@ def test_forward_upsample(interpolation='nearest'):
def test_forward_reshape():
# input_shape len is 3, target_shape len is 3
data = keras.layers.Input(shape=(32, 32, 3))
x = keras.layers.Reshape(target_shape=(32, 32, 3))(data)
x = keras.layers.Reshape(target_shape=(16, 64, 3))(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
# input_shape len is 3, target_shape len is 2
data = keras.layers.Input(shape=(32, 8, 3))
x = keras.layers.Reshape(target_shape=(256, 3))(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
# input_shape len is 2, target_shape len is 3
data = keras.layers.Input(shape=(256, 3))
x = keras.layers.Reshape(target_shape=(8, 32, 3))(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
# input_shape len is 2, target_shape len is 1
data = keras.layers.Input(shape=(2, 8))
x = keras.layers.Reshape(target_shape=(16,))(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)
# input_shape len is 1, target_shape len is 2
data = keras.layers.Input(shape=(16,))
x = keras.layers.Reshape(target_shape=(4, 4))(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)
# input_shape len is 2, target_shape len is 2
data = keras.layers.Input(shape=(2, 8))
x = keras.layers.Reshape(target_shape=(4, 4))(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)
def test_forward_crop():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment