Commit 674feba0 by Jon Soifer Committed by Thierry Moreau

[Relay][Frontend][ONNX] Add Sign and Equal operators to ONNX frontend (#3760)

* [Relay][Frontend][ONNX] Add Sign and Equal operators to ONNX frontend

* Dummy change to retrigger integration test
parent 7eb1f353
...@@ -850,6 +850,18 @@ class ConstantFill(OnnxOpConverter): ...@@ -850,6 +850,18 @@ class ConstantFill(OnnxOpConverter):
shape = shape + attr.pop('extra_shape') shape = shape + attr.pop('extra_shape')
return _op.full(inputs[0], shape) return _op.full(inputs[0], shape)
class Sign(OnnxOpConverter):
""" Operator converter for Sign.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
return _op.sign(inputs[0])
class Equal(Elemwise):
""" Operator converter for Equal.
"""
name = 'equal'
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -964,6 +976,8 @@ def _get_convert_map(opset): ...@@ -964,6 +976,8 @@ def _get_convert_map(opset):
'Unsqueeze': Unsqueeze.get_converter(opset), 'Unsqueeze': Unsqueeze.get_converter(opset),
'Pad': Pad.get_converter(opset), 'Pad': Pad.get_converter(opset),
'Shape': Shape.get_converter(opset), 'Shape': Shape.get_converter(opset),
'Sign': Sign.get_converter(opset),
'Equal': Equal.get_converter(opset)
} }
......
...@@ -962,6 +962,7 @@ def test_binary_ops(): ...@@ -962,6 +962,7 @@ def test_binary_ops():
verify_binary_ops("Sum", x, y, x + y, broadcast=None) verify_binary_ops("Sum", x, y, x + y, broadcast=None)
verify_binary_ops("Greater", x, y, x > y, broadcast=True) verify_binary_ops("Greater", x, y, x > y, broadcast=True)
verify_binary_ops("Less", x, y, x < y, broadcast=True) verify_binary_ops("Less", x, y, x < y, broadcast=True)
verify_binary_ops("Equal", x, y, x == y, broadcast=True)
def test_single_ops(): def test_single_ops():
in_shape = (1, 2, 3, 3) in_shape = (1, 2, 3, 3)
...@@ -1116,6 +1117,15 @@ def test_inception(): ...@@ -1116,6 +1117,15 @@ def test_inception():
# def test_shufflenetv2(): # def test_shufflenetv2():
# check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224)) # check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224))
def test_sign():
def Sign_x(x):
return np.sign(x)
_test_onnx_op_elementwise((3, 4, 5, 6),
Sign_x,
{},
'float32',
'Sign',
{})
if __name__ == '__main__': if __name__ == '__main__':
test_flatten() test_flatten()
...@@ -1159,3 +1169,4 @@ if __name__ == '__main__': ...@@ -1159,3 +1169,4 @@ if __name__ == '__main__':
test_resnet() test_resnet()
test_inception() test_inception()
test_densenet() test_densenet()
test_sign()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment