Commit 83bac2d1 by Xingyu Zhou Committed by Yuwei Hu

[Relay][Frontend][Keras] batch_norm op params not handling well (#4310)

* Relay Keras frontent batch_norm op params not handeling well

* add unit test for Relay Frontend Keras batch_norm
parent 2571449e
...@@ -460,6 +460,11 @@ def _convert_batchnorm(inexpr, keras_layer, etab): ...@@ -460,6 +460,11 @@ def _convert_batchnorm(inexpr, keras_layer, etab):
moving_var = keras_layer.get_weights()[idx + 1] moving_var = keras_layer.get_weights()[idx + 1]
params['moving_mean'] = etab.new_const(moving_mean) params['moving_mean'] = etab.new_const(moving_mean)
params['moving_var'] = etab.new_const(moving_var) params['moving_var'] = etab.new_const(moving_var)
# in case beta or gamma is not defined
params['beta'] = etab.new_const(np.zeros(moving_mean.shape)) if \
'beta' not in params else params['beta']
params['gamma'] = etab.new_const(np.ones(moving_mean.shape)) if \
'gamma' not in params else params['gamma']
result, moving_mean, moving_var = _op.nn.batch_norm(inexpr, **params) result, moving_mean, moving_var = _op.nn.batch_norm(inexpr, **params)
return result return result
......
...@@ -190,6 +190,36 @@ def test_forward_conv(): ...@@ -190,6 +190,36 @@ def test_forward_conv():
keras_model = keras.models.Model(data, x) keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model)
def test_forward_batch_norm():
data = keras.layers.Input(shape=(32, 32, 3))
batch_norm_funcs = [keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
center=True, scale=False,
beta_initializer='zeros',
gamma_initializer='ones',
moving_mean_initializer='zeros',
moving_variance_initializer='ones'),
keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
center=True, scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
moving_mean_initializer='zeros',
moving_variance_initializer='ones'),
keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
center=False, scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
moving_mean_initializer='zeros',
moving_variance_initializer='ones'),
keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
center=False, scale=False,
beta_initializer='zeros',
gamma_initializer='ones',
moving_mean_initializer='zeros',
moving_variance_initializer='ones')]
for batch_norm_func in batch_norm_funcs:
x = batch_norm_func(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
def test_forward_upsample(interpolation='nearest'): def test_forward_upsample(interpolation='nearest'):
data = keras.layers.Input(shape=(32, 32, 3)) data = keras.layers.Input(shape=(32, 32, 3))
...@@ -333,6 +363,7 @@ if __name__ == '__main__': ...@@ -333,6 +363,7 @@ if __name__ == '__main__':
test_forward_sequential() test_forward_sequential()
test_forward_pool() test_forward_pool()
test_forward_conv() test_forward_conv()
test_forward_batch_norm()
test_forward_upsample(interpolation='nearest') test_forward_upsample(interpolation='nearest')
test_forward_upsample(interpolation='bilinear') test_forward_upsample(interpolation='bilinear')
test_forward_reshape() test_forward_reshape()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment