Unverified Commit 989b4819 by Samuel Committed by GitHub

[PYTORCH]celu, gelu, selu activations (#5263)

parent d2de35eb
......@@ -216,15 +216,44 @@ def _prelu():
def _leaky_relu():
def _impl(inputs, input_types):
data = inputs[0]
alpha = int(inputs[1])
alpha = float(inputs[1])
return _op.nn.leaky_relu(data, alpha)
return _impl
def _elu():
def _impl(inputs, input_types):
data = inputs[0]
alpha = _expr.const(int(inputs[1]), dtype='float32')
return alpha * _op.nn.relu(alpha - _op.exp(data)) + _op.nn.relu(data)
alpha = _expr.const(float(inputs[1]))
return alpha * _op.nn.relu(_expr.const(1.0) - _op.exp(data)) + _op.nn.relu(data)
return _impl
def _celu():
def _impl(inputs, input_types):
data = inputs[0]
alpha = _expr.const(float(inputs[1]))
return alpha * _op.nn.relu(_expr.const(1.0) - _op.exp(data / alpha)) + _op.nn.relu(data)
return _impl
def _gelu():
def _impl(inputs, input_types):
import math
data = inputs[0]
def _pow3(x):
return x * x * x
return _expr.const(0.5) * data * (_expr.const(1.0) +
_op.tanh(_expr.const(math.sqrt(2.0 / math.pi)) *
(data + _expr.const(0.044715) * _pow3(data))))
return _impl
def _selu():
def _impl(inputs, input_types):
data = inputs[0]
# https://pytorch.org/docs/stable/nn.html#selu
alpha = _expr.const(-1.6732632423543772848170429916717)
gamma = _expr.const(1.0507009873554804934193349852946)
return gamma * (alpha * _op.nn.relu(_expr.const(1.0)
- _op.exp(data)) + _op.nn.relu(data))
return _impl
def _log_sigmoid():
......@@ -1066,6 +1095,9 @@ _convert_map = {
"aten::prelu" : _prelu(),
"aten::leaky_relu" : _leaky_relu(),
"aten::elu" : _elu(),
"aten::celu" : _celu(),
"aten::gelu" : _gelu(),
"aten::selu" : _selu(),
"aten::log_sigmoid" : _log_sigmoid(),
"aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(),
"aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(),
......
......@@ -353,16 +353,43 @@ def test_forward_prelu():
def test_forward_leakyrelu():
torch.set_grad_enabled(False)
input_shape = [10, 10]
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.LeakyReLU().eval(), input_data=input_data)
verify_model(torch.nn.LeakyReLU(negative_slope=0.05).eval(), input_data=input_data)
verify_model(torch.nn.LeakyReLU(negative_slope=1.0).eval(), input_data=input_data)
verify_model(torch.nn.LeakyReLU(negative_slope=1.25).eval(), input_data=input_data)
def test_forward_elu():
torch.set_grad_enabled(False)
input_shape = [10, 10]
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.ELU().eval(), input_data=input_data)
verify_model(torch.nn.ELU(alpha=0.3).eval(), input_data=input_data)
verify_model(torch.nn.ELU(alpha=1.0).eval(), input_data=input_data)
verify_model(torch.nn.ELU(alpha=1.3).eval(), input_data=input_data)
def test_forward_celu():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.CELU().eval(), input_data=input_data)
verify_model(torch.nn.CELU(alpha=0.3).eval(), input_data=input_data)
verify_model(torch.nn.CELU(alpha=1.0).eval(), input_data=input_data)
verify_model(torch.nn.CELU(alpha=1.3).eval(), input_data=input_data)
def test_forward_gelu():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.GELU().eval(), input_data=input_data)
def test_forward_selu():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.SELU().eval(), input_data=input_data)
def test_forward_log_sigmoid():
torch.set_grad_enabled(False)
input_shape = [10, 10]
......@@ -1131,6 +1158,9 @@ if __name__ == "__main__":
test_forward_prelu()
test_forward_leakyrelu()
test_forward_elu()
test_forward_celu()
test_forward_gelu()
test_forward_selu()
test_forward_log_sigmoid()
test_forward_adaptiveavgpool()
test_forward_maxpool2d()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment