Commit d5b34220 by Leonardo lontra Committed by Yuwei Hu

[Relay][Frontend][keras] added interpolation method of Upsampling2D (#2854)

* [Relay][Frontend][keras] added interpolation method of Upsampling2D.

* added testcase

* small fixes
parent 340678de
...@@ -334,6 +334,14 @@ def _convert_upsample(inexpr, keras_layer, _): ...@@ -334,6 +334,14 @@ def _convert_upsample(inexpr, keras_layer, _):
raise TypeError("Unsupported upsampling type with different axes size : {}" raise TypeError("Unsupported upsampling type with different axes size : {}"
.format(keras_layer.size)) .format(keras_layer.size))
params = {'scale': h} params = {'scale': h}
if hasattr(keras_layer, 'interpolation'):
interpolation = keras_layer.interpolation
if interpolation == 'nearest':
params['method'] = 'NEAREST_NEIGHBOR'
else:
params['method'] = 'BILINEAR'
elif upsample_type == 'UpSampling3D': elif upsample_type == 'UpSampling3D':
h, w, d = keras_layer.size h, w, d = keras_layer.size
if h != w or w != d: if h != w or w != d:
......
...@@ -133,9 +133,9 @@ def test_forward_conv(): ...@@ -133,9 +133,9 @@ def test_forward_conv():
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_upsample(): def test_forward_upsample(interpolation='nearest'):
data = keras.layers.Input(shape=(32,32,3)) data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.UpSampling2D(size=(3,3))(data) x = keras.layers.UpSampling2D(size=(3,3), interpolation=interpolation)(data)
keras_model = keras.models.Model(data, x) keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
...@@ -246,7 +246,8 @@ if __name__ == '__main__': ...@@ -246,7 +246,8 @@ if __name__ == '__main__':
test_forward_dense() test_forward_dense()
test_forward_pool() test_forward_pool()
test_forward_conv() test_forward_conv()
test_forward_upsample() test_forward_upsample(interpolation='nearest')
test_forward_upsample(interpolation='bilinear')
test_forward_reshape() test_forward_reshape()
test_forward_crop() test_forward_crop()
test_forward_multi_inputs() test_forward_multi_inputs()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment