Commit ab858e3f by Wei Chen Committed by Tianqi Chen

[PYTHON] Improve equality wrapper (#567)

use `object.__eq__`(default object identity comparison) as default
implementation of same_as. This should be OK since `EqualOp` and
`NotEqualOp` are pure Python object, `object.__eq__` is sufficient.
parent 9a2f01ab
...@@ -152,6 +152,9 @@ class EqualOp(NodeGeneric, ExprOp): ...@@ -152,6 +152,9 @@ class EqualOp(NodeGeneric, ExprOp):
b : Expr b : Expr
Right operand. Right operand.
""" """
# This class is not manipulated by C++. So use python's identity check function is sufficient
same_as = object.__eq__
def __init__(self, a, b): def __init__(self, a, b):
self.a = a self.a = a
self.b = b self.b = b
...@@ -181,6 +184,9 @@ class NotEqualOp(NodeGeneric, ExprOp): ...@@ -181,6 +184,9 @@ class NotEqualOp(NodeGeneric, ExprOp):
b : Expr b : Expr
Right operand. Right operand.
""" """
# This class is not manipulated by C++. So use python's identity check function is sufficient
same_as = object.__eq__
def __init__(self, a, b): def __init__(self, a, b):
self.a = a self.a = a
self.b = b self.b = b
......
...@@ -134,6 +134,14 @@ def test_bitwise(): ...@@ -134,6 +134,14 @@ def test_bitwise():
assert str(~x) == 'bitwise_not(x)' assert str(~x) == 'bitwise_not(x)'
def test_equality():
a = tvm.var('a')
b = tvm.var('b')
c = (a == b)
assert not c
d = (c != c)
assert not d
if __name__ == "__main__": if __name__ == "__main__":
test_cast() test_cast()
test_attr() test_attr()
...@@ -148,3 +156,4 @@ if __name__ == "__main__": ...@@ -148,3 +156,4 @@ if __name__ == "__main__":
test_any() test_any()
test_all() test_all()
test_bitwise() test_bitwise()
test_equality()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment