Unverified Commit c97e41b0 by Samuel Committed by GitHub

[FRONTEND][KERAS]Max_pool3d and Averagepool3d operator support (#5085)

* [KERAS]Pool3d support added

* Keras pool3d testcase added
parent 4683c3f5
...@@ -510,6 +510,43 @@ def _convert_pooling(inexpr, keras_layer, etab): ...@@ -510,6 +510,43 @@ def _convert_pooling(inexpr, keras_layer, etab):
raise tvm.error.OpNotImplemented( raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend Keras.'.format(keras_layer)) 'Operator {} is not supported for frontend Keras.'.format(keras_layer))
def _convert_pooling3d(inexpr, keras_layer, etab):
_check_data_format(keras_layer)
pool_type = type(keras_layer).__name__
if pool_type not in ['MaxPooling3D', 'AveragePooling3D']:
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend Keras.'.format(keras_layer))
pool_d1, pool_d2, pool_d3 = keras_layer.pool_size
stride_d1, stride_d2, stride_d3 = keras_layer.strides
params = {'pool_size': [pool_d1, pool_d2, pool_d3],
'strides': [stride_d1, stride_d2, stride_d3],
'padding': [0, 0, 0],
'layout': etab.data_layout}
if keras_layer.padding == 'valid':
pass
elif keras_layer.padding == 'same':
in_d1 = keras_layer.input_shape[1]
in_d2 = keras_layer.input_shape[2]
in_d3 = keras_layer.input_shape[3]
pad_d1 = _get_pad_pair(in_d1, pool_d1, stride_d1)
pad_d2 = _get_pad_pair(in_d2, pool_d2, stride_d2)
pad_d3 = _get_pad_pair(in_d3, pool_d3, stride_d3)
params['padding'] = [pad_d1[0], pad_d2[0], pad_d3[0], pad_d1[1], pad_d2[1], pad_d3[1]]
else:
raise tvm.error.OpAttributeUnImplemented(
'Padding with {} is not supported in operator Pooling3D.'.format(keras_layer.padding))
out = _op.transpose(inexpr, axes=(0, 4, 1, 2, 3))
params['layout'] = "NCDHW"
if pool_type == 'MaxPooling3D':
out = _op.nn.max_pool3d(out, **params)
elif pool_type == 'AveragePooling3D':
out = _op.nn.avg_pool3d(out, **params)
return _op.transpose(out, axes=(0, 2, 3, 4, 1))
def _convert_upsample(inexpr, keras_layer, etab): def _convert_upsample(inexpr, keras_layer, etab):
_check_data_format(keras_layer) _check_data_format(keras_layer)
...@@ -817,8 +854,8 @@ _convert_map = { ...@@ -817,8 +854,8 @@ _convert_map = {
'Conv3D' : _convert_convolution3d, 'Conv3D' : _convert_convolution3d,
# 'Conv3DTranspose' : _convert_convolution3d, # 'Conv3DTranspose' : _convert_convolution3d,
# 'SeparableConv3D' : _convert_convolution3d, # 'SeparableConv3D' : _convert_convolution3d,
# 'MaxPooling3D' : _convert_pooling3d, 'MaxPooling3D' : _convert_pooling3d,
# 'AveragePooling3D' : _convert_pooling3d, 'AveragePooling3D' : _convert_pooling3d,
# 'GlobalMaxPooling3D' : _convert_pooling3d, # 'GlobalMaxPooling3D' : _convert_pooling3d,
# 'GlobalAveragePooling3D' : _convert_pooling3d, # 'GlobalAveragePooling3D' : _convert_pooling3d,
# 'UpSampling3D' : _convert_upsample3d, # 'UpSampling3D' : _convert_upsample3d,
......
...@@ -421,6 +421,28 @@ class TestKeras: ...@@ -421,6 +421,28 @@ class TestKeras:
keras_model = keras.models.Model(data, x) keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, layout='NDHWC') verify_keras_frontend(keras_model, layout='NDHWC')
def test_forward_pool3d(self, keras):
data = keras.layers.Input(shape=(32, 32, 32, 1))
pool_funcs = [# maxpool
keras.layers.MaxPooling3D(pool_size=(2, 2, 2),
strides=(1, 1, 1),
padding='same'),
keras.layers.MaxPooling3D(pool_size=(3, 3, 3),
strides=(2, 2, 2),
padding='valid'),
# avgpool
keras.layers.AveragePooling3D(pool_size=(3, 3, 3),
strides=(2, 2, 2),
padding='same'),
keras.layers.AveragePooling3D(pool_size=(2, 2, 2),
strides=(1, 1, 1),
padding='valid'),
]
for pool_func in pool_funcs:
x = pool_func(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, layout='NDHWC')
if __name__ == '__main__': if __name__ == '__main__':
for k in [keras, tf_keras]: for k in [keras, tf_keras]:
sut = TestKeras() sut = TestKeras()
...@@ -449,3 +471,4 @@ if __name__ == '__main__': ...@@ -449,3 +471,4 @@ if __name__ == '__main__':
sut.test_forward_mobilenet(keras=k) sut.test_forward_mobilenet(keras=k)
sut.test_forward_mobilenet(keras=k, layout='NHWC') sut.test_forward_mobilenet(keras=k, layout='NHWC')
sut.test_forward_conv3d(keras=k) sut.test_forward_conv3d(keras=k)
sut.test_forward_pool3d(keras=k)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment