Commit d90c1e45 by MORITA Kazutaka Committed by Tianqi Chen

[NNVM][KERAS] Add cropping support (#1636)

parent b11f2a04
......@@ -311,6 +311,21 @@ def _convert_upsample(insym, keras_layer, _):
return _sym.upsampling(insym, **params)
def _convert_cropping(insym, keras_layer, _):
_check_data_format(keras_layer)
crop_type = type(keras_layer).__name__
if crop_type == "Cropping1D":
raise NotImplementedError("Cropping1D not implemented")
elif crop_type == "Cropping2D":
(_, in_h, in_w, _) = keras_layer.input_shape
((crop_t, crop_b), (crop_l, crop_r)) = keras_layer.cropping
else:
raise TypeError("Unrecognized cropping type : {}".format(crop_type))
int32_max = np.iinfo(np.int32).max
return _sym.strided_slice(insym, begin=[0, 0, crop_t, crop_l],
end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r])
def _convert_batchnorm(insym, keras_layer, symtab):
params = {'scale': False,
'center': False,
......@@ -409,6 +424,7 @@ _convert_map = {
'Multiply' : _convert_merge,
'ZeroPadding2D' : _convert_padding,
'UpSampling2D' : _convert_upsample,
'Cropping2D' : _convert_cropping,
# 'ZeroPadding1D' : _convert_padding,
# 'AveragePooling1D' : _convert_pooling,
......@@ -416,7 +432,6 @@ _convert_map = {
# 'GlobalAveragePooling1D' : _convert_pooling,
# 'GlobalMaxPooling1D' : _convert_pooling,
# 'Cropping1D' : _convert_cropping,
# 'Cropping2D' : _convert_cropping,
# 'UpSampling1D' : _convert_upsample,
# 'UpSampling3D' : _convert_upsample,
# 'Conv1D' : _convert_convolution1d,
......
......@@ -110,6 +110,20 @@ def test_forward_reshape():
verify_keras_frontend(keras_model)
def test_forward_crop():
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Cropping2D(cropping=((1, 1), (1, 1)))(data)
x = keras.layers.Cropping2D(cropping=(1, 1))(x)
x = keras.layers.Cropping2D(cropping=1)(x)
x = keras.layers.Cropping2D(cropping=((0, 1), (1, 0)))(x)
x = keras.layers.Cropping2D(cropping=(1, 0))(x)
x = keras.layers.Cropping2D(cropping=0)(x)
x = keras.layers.Add()([x, x])
x = keras.layers.GlobalAveragePooling2D()(x)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
def test_forward_vgg16():
keras_model = keras.applications.vgg16.VGG16(include_top=True, weights=None,
input_shape=(224,224,3), classes=1000)
......@@ -196,6 +210,7 @@ if __name__ == '__main__':
test_forward_separable_conv()
test_forward_upsample()
test_forward_reshape()
test_forward_crop()
test_forward_vgg16()
test_forward_xception()
test_forward_resnet50()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment