Commit 73dc5ac3 by Neo Chien Committed by Jared Roesch

Add not operator for the frontend/onnx.py (#3836)

parent 9d880bd3
...@@ -868,6 +868,15 @@ class Equal(Elemwise): ...@@ -868,6 +868,15 @@ class Equal(Elemwise):
""" """
name = 'equal' name = 'equal'
class Not(Elemwise):
""" Operator converter for Not.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
return _op.logical_not(inputs[0])
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -983,7 +992,8 @@ def _get_convert_map(opset): ...@@ -983,7 +992,8 @@ def _get_convert_map(opset):
'Pad': Pad.get_converter(opset), 'Pad': Pad.get_converter(opset),
'Shape': Shape.get_converter(opset), 'Shape': Shape.get_converter(opset),
'Sign': Sign.get_converter(opset), 'Sign': Sign.get_converter(opset),
'Equal': Equal.get_converter(opset) 'Equal': Equal.get_converter(opset),
'Not': Not.get_converter(opset)
} }
......
...@@ -1130,6 +1130,34 @@ def test_sign(): ...@@ -1130,6 +1130,34 @@ def test_sign():
'Sign', 'Sign',
{}) {})
def verify_not(indata, dtype):
x = indata.astype(dtype)
outdata = np.logical_not(x)
node = helper.make_node('Not', inputs=['in'], outputs=['out'],)
graph = helper.make_graph([node],
'not_test',
inputs=[helper.make_tensor_value_info("in", TensorProto.BOOL, list(x.shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))])
model = helper.make_model(graph, producer_name='not_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x], target, ctx, outdata.shape)
tvm.testing.assert_allclose(outdata, tvm_out)
def test_not():
# 2d
verify_not(indata=(np.random.randn(3, 4) > 0), dtype=bool)
# 3d
verify_not(indata=(np.random.randn(3, 4, 5) > 0), dtype=bool)
# 4d
verify_not(indata=(np.random.randn(3, 4, 5, 6) > 0), dtype=bool)
if __name__ == '__main__': if __name__ == '__main__':
test_flatten() test_flatten()
test_reshape() test_reshape()
...@@ -1173,3 +1201,4 @@ if __name__ == '__main__': ...@@ -1173,3 +1201,4 @@ if __name__ == '__main__':
test_inception() test_inception()
test_densenet() test_densenet()
test_sign() test_sign()
test_not()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment