Unverified Commit 22db299b by Samuel Committed by GitHub

[PYTORCH]Unary Ops (#5378)

parent c3511c5e
...@@ -132,12 +132,16 @@ def _elemwise(name): ...@@ -132,12 +132,16 @@ def _elemwise(name):
return get_relay_op(name)(data0, data1) return get_relay_op(name)(data0, data1)
return _impl return _impl
def _abs():
def _unary(name):
def _impl(inputs, input_types): def _impl(inputs, input_types):
data = inputs[0] input_type = input_types[0]
return _op.abs(data) data = _convert_elemwise_input(inputs[0], input_type)
return get_relay_op(name)(data)
return _impl return _impl
def _arange(): def _arange():
def _impl(inputs, input_types): def _impl(inputs, input_types):
if len(inputs) == 5: if len(inputs) == 5:
...@@ -1254,26 +1258,6 @@ def _pad(): ...@@ -1254,26 +1258,6 @@ def _pad():
return _op.nn.pad(data, pad_width, pad_value) return _op.nn.pad(data, pad_width, pad_value)
return _impl return _impl
def _sqrt():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.sqrt(data)
return _impl
def _rsqrt():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.rsqrt(data)
return _impl
def _ceil():
def _impl(inputs, input_types):
data = inputs[0]
return _op.ceil(data)
return _impl
def _clamp(): def _clamp():
def _impl(inputs, input_types): def _impl(inputs, input_types):
...@@ -1284,20 +1268,6 @@ def _clamp(): ...@@ -1284,20 +1268,6 @@ def _clamp():
return _impl return _impl
def _floor():
def _impl(inputs, input_types):
data = inputs[0]
return _op.floor(data)
return _impl
def _round():
def _impl(inputs, input_types):
data = inputs[0]
return _op.round(data)
return _impl
def _to(): def _to():
def _impl(inputs, input_types): def _impl(inputs, input_types):
data = inputs[0] data = inputs[0]
...@@ -1375,17 +1345,6 @@ def _expand_as(): ...@@ -1375,17 +1345,6 @@ def _expand_as():
return inputs[0] return inputs[0]
return _impl return _impl
def _neg():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.negative(data)
return _impl
def _tanh():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.tanh(data)
return _impl
def _Bool(): def _Bool():
def _impl(inputs, input_types): def _impl(inputs, input_types):
...@@ -1467,18 +1426,6 @@ def _logical_xor(): ...@@ -1467,18 +1426,6 @@ def _logical_xor():
return _impl return _impl
def _isfinite():
def _impl(inputs, input_types):
return _op.isfinite(inputs[0])
return _impl
def _isnan():
def _impl(inputs, input_types):
return _op.isnan(inputs[0])
return _impl
def _list_getitem(prelude): def _list_getitem(prelude):
def _impl(inputs, input_types): def _impl(inputs, input_types):
return prelude.nth(inputs[0], _wrap_const(inputs[1])) return prelude.nth(inputs[0], _wrap_const(inputs[1]))
...@@ -1601,7 +1548,6 @@ def _get_convert_map(prelude): ...@@ -1601,7 +1548,6 @@ def _get_convert_map(prelude):
"aten::mul" : _elemwise("multiply"), "aten::mul" : _elemwise("multiply"),
"aten::mul_" : _elemwise("multiply"), "aten::mul_" : _elemwise("multiply"),
"aten::pow" : _elemwise("power"), "aten::pow" : _elemwise("power"),
"aten::abs" : _abs(),
"aten::arange" : _arange(), "aten::arange" : _arange(),
"aten::div" : _elemwise("divide"), "aten::div" : _elemwise("divide"),
"aten::div_" : _elemwise("divide"), "aten::div_" : _elemwise("divide"),
...@@ -1683,12 +1629,26 @@ def _get_convert_map(prelude): ...@@ -1683,12 +1629,26 @@ def _get_convert_map(prelude):
"aten::argmax" : _reduce("argmax"), "aten::argmax" : _reduce("argmax"),
"aten::std" : _std(), "aten::std" : _std(),
"aten::var" : _variance(), "aten::var" : _variance(),
"aten::sqrt" : _sqrt(), "aten::abs" : _unary("abs"),
"aten::rsqrt" : _rsqrt(), "aten::neg" : _unary("negative"),
"aten::ceil" : _ceil(), "aten::cos" : _unary("cos"),
"aten::sin" : _unary("sin"),
"aten::tan" : _unary("tan"),
"aten::tanh" : _unary("tanh"),
"aten::atan" : _unary("atan"),
"aten::log" : _unary("log"),
"aten::exp" : _unary("exp"),
"aten::erf" : _unary("erf"),
"aten::trunc" : _unary("trunc"),
"aten::sign" : _unary("sign"),
"aten::sqrt" : _unary("sqrt"),
"aten::rsqrt" : _unary("rsqrt"),
"aten::ceil" : _unary("ceil"),
"aten::floor" : _unary("floor"),
"aten::round" : _unary("round"),
"aten::isfinite" : _unary("isfinite"),
"aten::isnan" : _unary("isnan"),
"aten::clamp" : _clamp(), "aten::clamp" : _clamp(),
"aten::floor" : _floor(),
"aten::round" : _round(),
"aten::detach" : _identity(), "aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"), "aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"), "aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
...@@ -1703,12 +1663,8 @@ def _get_convert_map(prelude): ...@@ -1703,12 +1663,8 @@ def _get_convert_map(prelude):
"aten::logical_xor" : _logical_xor(), "aten::logical_xor" : _logical_xor(),
"aten::bitwise_not" : _bitwise_not(), "aten::bitwise_not" : _bitwise_not(),
"aten::bitwise_xor" : _bitwise_xor(), "aten::bitwise_xor" : _bitwise_xor(),
"aten::isfinite" : _isfinite(),
"aten::isnan" : _isnan(),
"aten::Bool" : _Bool(), "aten::Bool" : _Bool(),
"aten::Float" : _Float(), "aten::Float" : _Float(),
"aten::neg" : _neg(),
"aten::tanh" : _tanh(),
"aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(), "aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(),
"aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(), "aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(),
"aten::mm" : _matmul(), "aten::mm" : _matmul(),
......
...@@ -1497,30 +1497,6 @@ def test_forward_isinf(): ...@@ -1497,30 +1497,6 @@ def test_forward_isinf():
verify_model(IsInf1().float().eval(), input_data=input_data) verify_model(IsInf1().float().eval(), input_data=input_data)
def test_forward_rsqrt():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
class Rsqrt1(Module):
def forward(self, *args):
return torch.rsqrt(args[0])
input_data = torch.rand(input_shape).float()
verify_model(Rsqrt1().float().eval(), input_data=input_data)
def test_forward_ceil():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
class Ceil1(Module):
def forward(self, *args):
return torch.ceil(args[0])
input_data = torch.rand(input_shape).float()
verify_model(Ceil1().float().eval(), input_data=input_data)
def test_forward_clamp(): def test_forward_clamp():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10] input_shape = [1, 3, 10, 10]
...@@ -1543,30 +1519,6 @@ def test_forward_clamp(): ...@@ -1543,30 +1519,6 @@ def test_forward_clamp():
verify_model(Clamp3().float().eval(), input_data=input_data) verify_model(Clamp3().float().eval(), input_data=input_data)
def test_forward_floor():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
class Floor1(Module):
def forward(self, *args):
return torch.floor(args[0])
input_data = torch.rand(input_shape).float()
verify_model(Floor1().float().eval(), input_data=input_data)
def test_forward_round():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
class Round1(Module):
def forward(self, *args):
return torch.round(args[0])
input_data = torch.rand(input_shape).float()
verify_model(Round1().float().eval(), input_data=input_data)
def test_forward_ones(): def test_forward_ones():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -1849,6 +1801,93 @@ def test_forward_logical_xor(): ...@@ -1849,6 +1801,93 @@ def test_forward_logical_xor():
verify_model(LogicalXor2().float().eval(), input_data=[lhs]) verify_model(LogicalXor2().float().eval(), input_data=[lhs])
def test_forward_unary():
torch.set_grad_enabled(False)
class Sqrt1(Module):
def forward(self, *args):
return torch.sqrt(args[0])
class RSqrt1(Module):
def forward(self, *args):
return torch.rsqrt(args[0])
class Ceil1(Module):
def forward(self, *args):
return torch.ceil(args[0])
class Floor1(Module):
def forward(self, *args):
return torch.floor(args[0])
class Round1(Module):
def forward(self, *args):
return torch.round(args[0])
class Cos1(Module):
def forward(self, *args):
return torch.cos(args[0])
class Sin1(Module):
def forward(self, *args):
return torch.sin(args[0])
class Tan1(Module):
def forward(self, *args):
return torch.tan(args[0])
class Tanh1(Module):
def forward(self, *args):
return torch.tanh(args[0])
class ATanh1(Module):
def forward(self, *args):
return torch.atan(args[0])
class Log1(Module):
def forward(self, *args):
return torch.log(args[0])
class Exp1(Module):
def forward(self, *args):
return torch.exp(args[0])
class Erf1(Module):
def forward(self, *args):
return torch.erf(args[0])
class Trunc1(Module):
def forward(self, *args):
return torch.trunc(args[0])
class Sign1(Module):
def forward(self, *args):
return torch.sign(args[0])
class Neg1(Module):
def forward(self, *args):
return torch.neg(args[0])
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(Sqrt1().float().eval(), input_data=input_data)
verify_model(RSqrt1().float().eval(), input_data=input_data)
verify_model(Ceil1().float().eval(), input_data=input_data)
verify_model(Floor1().float().eval(), input_data=input_data)
verify_model(Round1().float().eval(), input_data=input_data)
verify_model(Cos1().float().eval(), input_data=input_data)
verify_model(Sin1().float().eval(), input_data=input_data)
verify_model(Tan1().float().eval(), input_data=input_data)
verify_model(Tanh1().float().eval(), input_data=input_data)
verify_model(ATanh1().float().eval(), input_data=input_data)
verify_model(Log1().float().eval(), input_data=input_data)
verify_model(Exp1().float().eval(), input_data=input_data)
verify_model(Erf1().float().eval(), input_data=input_data)
verify_model(Trunc1().float().eval(), input_data=input_data)
verify_model(Sign1().float().eval(), input_data=input_data)
verify_model(Neg1().float().eval(), input_data=input_data)
if __name__ == "__main__": if __name__ == "__main__":
# Single operator tests # Single operator tests
test_forward_add() test_forward_add()
...@@ -1907,12 +1946,8 @@ if __name__ == "__main__": ...@@ -1907,12 +1946,8 @@ if __name__ == "__main__":
test_forward_mean() test_forward_mean()
test_forward_expand() test_forward_expand()
test_forward_pow() test_forward_pow()
test_forward_abs() test_forward_unary()
test_forward_rsqrt()
test_forward_ceil()
test_forward_clamp() test_forward_clamp()
test_forward_floor()
test_forward_round()
test_forward_logical_not() test_forward_logical_not()
test_forward_bitwise_not() test_forward_bitwise_not()
test_forward_bitwise_xor() test_forward_bitwise_xor()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment