Commit 83bac2d1 by Xingyu Zhou Committed by Yuwei Hu

[Relay][Frontend][Keras] batch_norm op params not handling well (#4310)

* Relay Keras frontent batch_norm op params not handeling well

* add unit test for Relay Frontend Keras batch_norm
parent 2571449e
......@@ -460,6 +460,11 @@ def _convert_batchnorm(inexpr, keras_layer, etab):
moving_var = keras_layer.get_weights()[idx + 1]
params['moving_mean'] = etab.new_const(moving_mean)
params['moving_var'] = etab.new_const(moving_var)
# in case beta or gamma is not defined
params['beta'] = etab.new_const(np.zeros(moving_mean.shape)) if \
'beta' not in params else params['beta']
params['gamma'] = etab.new_const(np.ones(moving_mean.shape)) if \
'gamma' not in params else params['gamma']
result, moving_mean, moving_var = _op.nn.batch_norm(inexpr, **params)
return result
......
......@@ -190,6 +190,36 @@ def test_forward_conv():
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
def test_forward_batch_norm():
data = keras.layers.Input(shape=(32, 32, 3))
batch_norm_funcs = [keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
center=True, scale=False,
beta_initializer='zeros',
gamma_initializer='ones',
moving_mean_initializer='zeros',
moving_variance_initializer='ones'),
keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
center=True, scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
moving_mean_initializer='zeros',
moving_variance_initializer='ones'),
keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
center=False, scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
moving_mean_initializer='zeros',
moving_variance_initializer='ones'),
keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
center=False, scale=False,
beta_initializer='zeros',
gamma_initializer='ones',
moving_mean_initializer='zeros',
moving_variance_initializer='ones')]
for batch_norm_func in batch_norm_funcs:
x = batch_norm_func(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
def test_forward_upsample(interpolation='nearest'):
data = keras.layers.Input(shape=(32, 32, 3))
......@@ -333,6 +363,7 @@ if __name__ == '__main__':
test_forward_sequential()
test_forward_pool()
test_forward_conv()
test_forward_batch_norm()
test_forward_upsample(interpolation='nearest')
test_forward_upsample(interpolation='bilinear')
test_forward_reshape()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment