Commit 9e595b42 by Neo Chien Committed by Zhi

ONNX frontend operator support: And (#3878)

parent 224cc243
......@@ -877,6 +877,14 @@ class Not(Elemwise):
return _op.logical_not(inputs[0])
class And(Elemwise):
""" Operator converter for And.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
return _op.logical_and(inputs[0], inputs[1])
# compatible operators that do NOT require any conversion.
_identity_list = []
......@@ -993,7 +1001,8 @@ def _get_convert_map(opset):
'Shape': Shape.get_converter(opset),
'Sign': Sign.get_converter(opset),
'Equal': Equal.get_converter(opset),
'Not': Not.get_converter(opset)
'Not': Not.get_converter(opset),
'And': And.get_converter(opset)
}
......
......@@ -1158,6 +1158,53 @@ def test_not():
verify_not(indata=(np.random.randn(3, 4, 5, 6) > 0), dtype=bool)
def verify_and(indata, dtype):
x = indata[0].astype(dtype)
y = indata[1].astype(dtype)
outdata = np.logical_and(x, y)
node = helper.make_node('And', inputs=['in1', 'in2'], outputs=['out'], )
graph = helper.make_graph([node],
'and_test',
inputs=[helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)),
helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))])
model = helper.make_model(graph, producer_name='and_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x, y], target, ctx, outdata.shape)
tvm.testing.assert_allclose(outdata, tvm_out)
def test_and():
# 2d
x = (np.random.randn(3, 4) > 0)
y = (np.random.randn(3, 4) > 0)
verify_and(indata=[x, y], dtype=bool)
# 3d
x = (np.random.randn(3, 4, 5) > 0)
y = (np.random.randn(3, 4, 5) > 0)
verify_and(indata=[x, y], dtype=bool)
# 4d
x = (np.random.randn(3, 4, 5, 6) > 0)
y = (np.random.randn(3, 4, 5, 6) > 0)
verify_and(indata=[x, y], dtype=bool)
# 3d vs 1d
x = (np.random.randn(3, 4, 5) > 0)
y = (np.random.randn(5) > 0)
verify_and(indata=[x, y], dtype=bool)
# 3d vs 2d
x = (np.random.randn(3, 4, 5) > 0)
y = (np.random.randn(4, 5) > 0)
verify_and(indata=[x, y], dtype=bool)
if __name__ == '__main__':
test_flatten()
test_reshape()
......@@ -1202,3 +1249,4 @@ if __name__ == '__main__':
test_densenet()
test_sign()
test_not()
test_and()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment