Commit d5b34220 by Leonardo lontra Committed by Yuwei Hu

[Relay][Frontend][keras] added interpolation method of Upsampling2D (#2854)

* [Relay][Frontend][keras] added interpolation method of Upsampling2D.

* added testcase

* small fixes
parent 340678de
......@@ -334,6 +334,14 @@ def _convert_upsample(inexpr, keras_layer, _):
raise TypeError("Unsupported upsampling type with different axes size : {}"
.format(keras_layer.size))
params = {'scale': h}
if hasattr(keras_layer, 'interpolation'):
interpolation = keras_layer.interpolation
if interpolation == 'nearest':
params['method'] = 'NEAREST_NEIGHBOR'
else:
params['method'] = 'BILINEAR'
elif upsample_type == 'UpSampling3D':
h, w, d = keras_layer.size
if h != w or w != d:
......
......@@ -133,9 +133,9 @@ def test_forward_conv():
verify_keras_frontend(keras_model)
def test_forward_upsample():
def test_forward_upsample(interpolation='nearest'):
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.UpSampling2D(size=(3,3))(data)
x = keras.layers.UpSampling2D(size=(3,3), interpolation=interpolation)(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
......@@ -246,7 +246,8 @@ if __name__ == '__main__':
test_forward_dense()
test_forward_pool()
test_forward_conv()
test_forward_upsample()
test_forward_upsample(interpolation='nearest')
test_forward_upsample(interpolation='bilinear')
test_forward_reshape()
test_forward_crop()
test_forward_multi_inputs()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment