Commit 58b2395d by MORITA Kazutaka Committed by Tianqi Chen

[NNVM][KERAS] Fixed padding in pooling (#1635)

parent d90c1e45
......@@ -269,14 +269,12 @@ def _convert_pooling(insym, keras_layer, symtab):
'padding': [0, 0]}
if keras_layer.padding == 'valid':
pass
# we insert a separate pad operator
elif keras_layer.padding == 'same':
in_h = keras_layer.input_shape[1]
in_w = keras_layer.input_shape[2]
pad_t, pad_b = _get_pad_pair(in_h, pool_h, stride_h)
pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w)
insym = _sym.pad(data=insym, pad_width=(
(0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r)))
params['padding'] = [pad_t, pad_l, pad_b, pad_r]
else:
raise TypeError("Unsupported padding type : {}".format(keras_layer.padding))
if pool_type == 'MaxPooling2D':
......
......@@ -38,7 +38,7 @@ def verify_keras_frontend(keras_model):
out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
return out.asnumpy()
xs = [np.random.uniform(size=shape) for shape in in_shapes]
xs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes]
keras_out = get_keras_output(xs)
for target, ctx in ctx_list():
tvm_out = get_tvm_output([x.transpose([0,3,1,2]) for x in xs], target, ctx)
......@@ -74,6 +74,18 @@ def test_forward_dense():
verify_keras_frontend(keras_model)
def test_forward_pool():
data = keras.layers.Input(shape=(2,2,1))
# maxpool
x = keras.layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same')(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
# avgpool
y = keras.layers.AveragePooling2D((3, 3), strides=(1, 1), padding='same')(data)
keras_model = keras.models.Model(data, y)
verify_keras_frontend(keras_model)
def test_forward_transpose_conv():
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Conv2D(filters=10, kernel_size=(3,3), strides=(2,2), padding='same')(data)
......@@ -206,6 +218,7 @@ if __name__ == '__main__':
test_forward_elemwise_add()
test_forward_activations()
test_forward_dense()
test_forward_pool()
test_forward_transpose_conv()
test_forward_separable_conv()
test_forward_upsample()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment