Unverified Commit b796c13c by Samuel Committed by GitHub

[KERAS]Upsample3d & ZeroPadding3d op (#5125)

* [KERAS]upsampling3d and zeropadding3d op

* [KERAS]upsampling3d and zeropadding3d test case

* Review comments updated
parent 3c2aa1aa
......@@ -569,14 +569,6 @@ def _convert_upsample(inexpr, keras_layer, etab):
params['method'] = 'nearest_neighbor'
else:
params['method'] = 'bilinear'
elif upsample_type == 'UpSampling3D':
h, w, d = keras_layer.size
if h != w or w != d:
raise tvm.error.OpAttributeInvalid(
'Height, width, and depth must all be equal for operator Upsample.')
params['scale_h'] = h
params['scale_w'] = h
else:
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend Keras.'.format(upsample_type))
......@@ -585,6 +577,18 @@ def _convert_upsample(inexpr, keras_layer, etab):
return out
def _convert_upsample3d(inexpr, keras_layer, etab):
_check_data_format(keras_layer)
params = {}
d, h, w = keras_layer.size
params['scale_d'] = d
params['scale_h'] = h
params['scale_w'] = w
params['layout'] = etab.data_layout
out = _op.nn.upsampling3d(inexpr, **params)
return out
def _convert_cropping(inexpr, keras_layer, _):
_check_data_format(keras_layer)
crop_type = type(keras_layer).__name__
......@@ -663,6 +667,36 @@ def _convert_padding(inexpr, keras_layer, etab):
return _op.nn.pad(data=inexpr, pad_width=((0, 0), (0, 0), (top, bottom), (left, right)))
return _op.nn.pad(data=inexpr, pad_width=((0, 0), (top, bottom), (left, right), (0, 0)))
def _convert_padding3d(inexpr, keras_layer, etab):
_check_data_format(keras_layer)
padding = keras_layer.padding
d_pad = h_pad = w_pad = [0, 0]
# padding can be 'int' or 'tuple of 3 ints' or 'tuple of 3 tuples of 2 ints' or 'tuple
# of 3 tuples of 2 ints different values'. In all these scenarios keras will send 3
# tuples of 2 ints.
if isinstance(padding, tuple) and isinstance(padding[0], tuple):
d_pad = padding[0]
h_pad = padding[1]
w_pad = padding[2]
else:
msg = 'Value {} in attribute "padding" of operator ZeroPadding3D is ' \
'not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(str(padding)))
if etab.data_layout == 'NCDHW':
out = _op.nn.pad(data=inexpr, pad_width=((0, 0), (0, 0),
(d_pad[0], d_pad[1]),
(h_pad[0], h_pad[1]),
(w_pad[0], w_pad[1])))
else:
out = _op.nn.pad(data=inexpr, pad_width=((0, 0),
(d_pad[0], d_pad[1]),
(h_pad[0], h_pad[1]),
(w_pad[0], w_pad[1]),
(0, 0)))
return out
def _convert_concat(inexpr, keras_layer, etab):
_check_data_format(keras_layer)
......@@ -858,7 +892,8 @@ _convert_map = {
'AveragePooling3D' : _convert_pooling3d,
# 'GlobalMaxPooling3D' : _convert_pooling3d,
# 'GlobalAveragePooling3D' : _convert_pooling3d,
# 'UpSampling3D' : _convert_upsample3d,
'UpSampling3D' : _convert_upsample3d,
'ZeroPadding3D' : _convert_padding3d,
'SimpleRNN' : _convert_simple_rnn,
'LSTM' : _convert_lstm,
......
......@@ -443,6 +443,28 @@ class TestKeras:
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, layout='NDHWC')
def test_forward_upsample3d(self, keras):
data = keras.layers.Input(shape=(32, 32, 32, 3))
x = keras.layers.UpSampling3D(size=(2, 3, 4))(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, layout='NDHWC')
def test_forward_zero_padding3d(self, keras):
data = keras.layers.Input(shape=(32, 32, 32, 3))
pad_funcs = [# Integer
keras.layers.ZeroPadding3D(padding=2),
# tuple of 3 ints
keras.layers.ZeroPadding3D(padding=(1, 2, 3)),
# tuple of 3 tuples of 2 ints
keras.layers.ZeroPadding3D(padding=((1,1), (2,2), (2,2))),
# tuple of 3 tuples of 2 ints different values
keras.layers.ZeroPadding3D(padding=((1,2), (2,3), (3,2))),
]
for pad_func in pad_funcs:
x = pad_func(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, layout='NDHWC')
if __name__ == '__main__':
for k in [keras, tf_keras]:
sut = TestKeras()
......@@ -472,3 +494,6 @@ if __name__ == '__main__':
sut.test_forward_mobilenet(keras=k, layout='NHWC')
sut.test_forward_conv3d(keras=k)
sut.test_forward_pool3d(keras=k)
sut.test_forward_upsample3d(keras=k)
sut.test_forward_zero_padding3d(keras=k)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment