Commit fde9b570 by Wei Chen Committed by Tianqi Chen

Add same_as to NodeBase (#550)

* Add same_as to NodeBase

1. Most class inherited from NodeBase(Schedule, Stage, etc) still have
the convenience of using '==' for object identity. And this is the right
behavior for non-Expr classes.
2. subclasses of ExprOp now create EQ expression when '==' is used.

`__nonzero__` and `__bool__` in EQ and NE is a comprise that in some cases
object identity semantics is still useful, like in unit test. For instance:
````
assert a == b
````

"a == b" will create EQ expression, assert then calls `__nonzero__` of the
result expression. `Expr.__nonzero__` throws exception since it prohibits
evaluating IR expression.

More complex case like:
````
assert a in b # b is dict
````

it will call `__eq__` on a and all keys of b, then `__bool__` on the result
expression. This could not easily be done by same_as.

* Retain __hash__ from NodeBase in Python3
parent ed783689
......@@ -40,9 +40,7 @@ class NodeBase(_NodeBase):
return _api_internal._raw_ptr(self)
def __eq__(self, other):
if not isinstance(other, NodeBase):
return False
return self.__hash__() == other.__hash__()
return self.same_as(other)
def __ne__(self, other):
return not self.__eq__(other)
......@@ -67,6 +65,12 @@ class NodeBase(_NodeBase):
else:
self.handle = None
def same_as(self, other):
"""check object identity equality"""
if not isinstance(other, NodeBase):
return False
return self.__hash__() == other.__hash__()
def register_node(type_key=None):
"""register node type
......
......@@ -138,9 +138,11 @@ class ExprOp(object):
return _make.static_cast(dtype, self)
class Expr(NodeBase, ExprOp):
class Expr(ExprOp, NodeBase):
"""Base class of all tvm Expressions"""
pass
# In Python3, We have to explicity tell interpreter to retain __hash__ if we overide __eq__
# https://docs.python.org/3.1/reference/datamodel.html#object.__hash__
__hash__ = NodeBase.__hash__
class ConstExpr(Expr):
pass
......@@ -213,11 +215,19 @@ class Max(BinaryOpExpr):
@register_node
class EQ(CmpExpr):
pass
def __nonzero__(self):
return self.a.same_as(self.b)
def __bool__(self):
return self.__nonzero__()
@register_node
class NE(CmpExpr):
pass
def __nonzero__(self):
return not self.a.same_as(self.b)
def __bool__(self):
return self.__nonzero__()
@register_node
class LT(CmpExpr):
......
......@@ -90,7 +90,7 @@ class IRBuilder(object):
n = tvm.var("n")
A = ib.allocate("float32", n, name="A")
with ib.for_range(0, n, name="i") as i:
with ib.if_scope((i % 2).equal(0)):
with ib.if_scope((i % 2) == 0):
A[i] = A[i] + 1
# The result stmt.
stmt = ib.get()
......
......@@ -25,7 +25,7 @@ def test_if():
n = tvm.var("n")
A = ib.pointer("float32", name="A")
with ib.for_range(0, n, name="i") as i:
with ib.if_scope((i % 2).equal(0)):
with ib.if_scope((i % 2) == 0):
A[i] = A[i] + 1
with ib.else_scope():
A[0] = A[i] + 2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment