Commit 3d83156c by Yuwei Hu Committed by Tianqi Chen

[Relay][Keras] force const dtype to be float32 (#2376)

* [Relay][Keras] force const dtype to be float32

* fix pylint
parent a7d39d7b
......@@ -28,7 +28,7 @@ def _get_pad_pair(input1d, kernel1d, stride1d):
def _get_elu(inexpr, alpha):
"""A helper method for elu."""
return _op.negative(alpha) * _op.nn.relu(_expr.const(1.) - \
return _op.negative(alpha) * _op.nn.relu(_expr.const(1., dtype='float32') - \
_op.exp(inexpr)) + _op.nn.relu(inexpr)
......@@ -69,7 +69,7 @@ def _convert_activation(inexpr, keras_layer, _):
elif act_type == 'relu':
return _op.nn.relu(inexpr)
elif act_type == 'softplus':
return _op.log(_op.add(_op.exp(inexpr), _expr.const(1.)))
return _op.log(_op.add(_op.exp(inexpr), _expr.const(1., dtype='float32')))
elif act_type == 'elu':
alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1.
alpha = _expr.const(alpha, dtype='float32')
......@@ -86,10 +86,10 @@ def _convert_activation(inexpr, keras_layer, _):
elif act_type == 'relu6':
return _op.clip(inexpr, a_min=0., a_max=6.)
elif act_type == 'softsign':
return inexpr / (_expr.const(1.) + _op.abs(inexpr))
return inexpr / (_expr.const(1., dtype='float32') + _op.abs(inexpr))
elif act_type == 'hard_sigmoid':
transformX = (_expr.const(0.2) * inexpr) + _expr.const(0.5)
return _op.clip(transformX, a_min=0., a_max=1.)
x = (_expr.const(0.2, dtype='float32') * inexpr) + _expr.const(0.5, dtype='float32')
return _op.clip(x, a_min=0., a_max=1.)
else:
raise TypeError("Unsupported activation type : {}".format(act_type))
......@@ -522,7 +522,7 @@ def _convert_gru(inexpr, keras_layer, etab):
recurrent_h = _op.nn.dense(rec_act_r * h_tm1_op, rec_weights[1], units=units)
act_hh = _convert_activation(x_h + recurrent_h, keras_layer, None)
# previous and candidate state mixed by update gate
output = rec_act_z * h_tm1_op + (_expr.const(1.) - rec_act_z) * act_hh
output = rec_act_z * h_tm1_op + (_expr.const(1., dtype='float32') - rec_act_z) * act_hh
out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
output = _op.reshape(output, newshape=out_shape)
return [output, output]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment