Unverified Commit 5b37d4c1 by Samuel Committed by GitHub

[PYTORCH]Abs, Arange, Softplus ops (#5295)

* [PYTHON]Abs, Arange, Softplus ops

* Review comments updated
parent 403929f9
......@@ -57,6 +57,33 @@ def _elemwise(name):
return get_relay_op(name)(data0, data1)
return _impl
def _abs():
def _impl(inputs, input_types):
data = inputs[0]
return _op.abs(data)
return _impl
def _arange():
def _impl(inputs, input_types):
if len(inputs) == 5:
dtype = "float" if "float" in input_types[0:1] else _convert_dtype_value(inputs[1])
start = _create_typed_const(0, dtype)
stop = _create_typed_const(inputs[0], dtype)
step = _create_typed_const(1, dtype)
elif len(inputs) == 7:
dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3])
start = _create_typed_const(inputs[0], dtype)
stop = _create_typed_const(inputs[1], dtype)
step = _create_typed_const(inputs[2], dtype)
else:
msg = "Unknown number of arguments (%d) to parse." % (len(inputs))
raise AssertionError(msg)
return _op.transform.arange(start=start,
stop=stop,
step=step,
dtype=_convert_data_type(dtype))
return _impl
def _squeeze():
def _impl(inputs, input_types):
data = inputs[0]
......@@ -732,6 +759,13 @@ def _sigmoid():
return _op.tensor.sigmoid(data)
return _impl
def _softplus():
def _impl(inputs, input_types):
data = inputs[0]
beta = _expr.const(float(inputs[1]))
return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.)) / beta
return _impl
def _avg_pool2d():
def _impl(inputs, input_types):
data = inputs[0]
......@@ -1044,6 +1078,21 @@ def _Float():
return _impl
# Helper functions for operator implementation
def _convert_dtype_value(val):
convert_torch_dtype_map = {7:"torch.float64",
6:"torch.float32",
5:"torch.float16",
4:"torch.int64",
3:"torch.int32",
2:"torch.int16",
1:"torch.int8",
0:"torch.unit8",
None:"torch.int64"} # Default is torch.int64
if val in convert_torch_dtype_map:
return convert_torch_dtype_map[val]
else:
msg = "Torch data type value %d is not handled yet." % (val)
raise NotImplementedError(msg)
def _convert_data_type(input_type):
if input_type in ["double", "torch.float64"]:
......@@ -1118,6 +1167,8 @@ _convert_map = {
"aten::pow" : _elemwise("power"),
"aten::div" : _elemwise("divide"),
"aten::div_" : _elemwise("divide"),
"aten::abs" : _abs(),
"aten::arange" : _arange(),
"aten::ones" : _ones(),
"aten::zeros" : _zeros(),
"aten::reciprocal" : _reciprocal(),
......@@ -1167,6 +1218,7 @@ _convert_map = {
"aten::clone" : _clone(),
"aten::log_softmax" : _log_softmax(),
"aten::sigmoid" : _sigmoid(),
"aten::softplus" : _softplus(),
"aten::avg_pool2d" : _avg_pool2d(),
"aten::avg_pool3d" : _avg_pool3d(),
"aten::dropout" : _dropout(),
......
......@@ -375,6 +375,54 @@ def test_forward_squeeze():
verify_model(Squeeze1().float().eval(), input_data=input_data)
verify_model(Squeeze2().float().eval(), input_data=input_data)
def test_forward_arange():
torch.set_grad_enabled(False)
class Arange1(Module):
def forward(self, *args):
return torch.arange(5)
class Arange2(Module):
def forward(self, *args):
return torch.arange(2.5)
class Arange3(Module):
def forward(self, *args):
return torch.arange(1, 4)
class Arange4(Module):
def forward(self, *args):
return torch.arange(1, 2.5, 0.5)
class Arange5(Module):
def forward(self, *args):
return torch.arange(1, 2, 1, dtype=torch.int32)
class Arange6(Module):
def forward(self, *args):
return torch.arange(start=1, end=6, step=2)
class Arange7(Module):
def forward(self, *args):
return torch.arange(1, 4, dtype=torch.float32)
class Arange8(Module):
def forward(self, *args):
return torch.arange(1, 2, 1, dtype=torch.int16)
verify_model(Arange1().float().eval())
verify_model(Arange2().float().eval())
verify_model(Arange3().float().eval())
verify_model(Arange4().float().eval())
verify_model(Arange5().float().eval())
verify_model(Arange6().float().eval())
verify_model(Arange7().float().eval())
verify_model(Arange8().float().eval())
def test_forward_abs():
torch.set_grad_enabled(False)
input_shape = [2, 1, 10, 1, 10]
class Abs1(Module):
def forward(self, *args):
return args[0].abs()
input_data = torch.rand(input_shape).float()
verify_model(Abs1().float().eval(), input_data=input_data)
def test_forward_concatenate():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
......@@ -445,6 +493,20 @@ def test_forward_selu():
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.SELU().eval(), input_data=input_data)
def test_forward_softplus():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.Softplus().eval(), input_data=input_data)
verify_model(torch.nn.Softplus(beta=1.5, threshold=20).eval(), input_data=input_data)
verify_model(torch.nn.Softplus(beta=5, threshold=10).eval(), input_data=input_data)
def test_forward_softsign():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.Softsign().eval(), input_data=input_data)
def test_forward_log_sigmoid():
torch.set_grad_enabled(False)
input_shape = [10, 10]
......@@ -1254,6 +1316,8 @@ if __name__ == "__main__":
test_forward_view()
test_forward_select()
test_forward_clone()
test_forward_softplus()
test_forward_softsign()
test_forward_logsoftmax()
test_forward_sigmoid()
test_forward_dense()
......@@ -1264,6 +1328,8 @@ if __name__ == "__main__":
test_forward_mean()
test_forward_expand()
test_forward_pow()
test_forward_abs()
test_forward_arange()
test_forward_chunk()
test_forward_split()
test_upsample()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment