Commit fde9b570 by Wei Chen Committed by Tianqi Chen

Add same_as to NodeBase (#550)

* Add same_as to NodeBase

1. Most class inherited from NodeBase(Schedule, Stage, etc) still have
the convenience of using '==' for object identity. And this is the right
behavior for non-Expr classes.
2. subclasses of ExprOp now create EQ expression when '==' is used.

`__nonzero__` and `__bool__` in EQ and NE is a comprise that in some cases
object identity semantics is still useful, like in unit test. For instance:
````
assert a == b
````

"a == b" will create EQ expression, assert then calls `__nonzero__` of the
result expression. `Expr.__nonzero__` throws exception since it prohibits
evaluating IR expression.

More complex case like:
````
assert a in b # b is dict
````

it will call `__eq__` on a and all keys of b, then `__bool__` on the result
expression. This could not easily be done by same_as.

* Retain __hash__ from NodeBase in Python3
parent ed783689
...@@ -40,9 +40,7 @@ class NodeBase(_NodeBase): ...@@ -40,9 +40,7 @@ class NodeBase(_NodeBase):
return _api_internal._raw_ptr(self) return _api_internal._raw_ptr(self)
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, NodeBase): return self.same_as(other)
return False
return self.__hash__() == other.__hash__()
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
...@@ -67,6 +65,12 @@ class NodeBase(_NodeBase): ...@@ -67,6 +65,12 @@ class NodeBase(_NodeBase):
else: else:
self.handle = None self.handle = None
def same_as(self, other):
"""check object identity equality"""
if not isinstance(other, NodeBase):
return False
return self.__hash__() == other.__hash__()
def register_node(type_key=None): def register_node(type_key=None):
"""register node type """register node type
......
...@@ -138,9 +138,11 @@ class ExprOp(object): ...@@ -138,9 +138,11 @@ class ExprOp(object):
return _make.static_cast(dtype, self) return _make.static_cast(dtype, self)
class Expr(NodeBase, ExprOp): class Expr(ExprOp, NodeBase):
"""Base class of all tvm Expressions""" """Base class of all tvm Expressions"""
pass # In Python3, We have to explicity tell interpreter to retain __hash__ if we overide __eq__
# https://docs.python.org/3.1/reference/datamodel.html#object.__hash__
__hash__ = NodeBase.__hash__
class ConstExpr(Expr): class ConstExpr(Expr):
pass pass
...@@ -213,11 +215,19 @@ class Max(BinaryOpExpr): ...@@ -213,11 +215,19 @@ class Max(BinaryOpExpr):
@register_node @register_node
class EQ(CmpExpr): class EQ(CmpExpr):
pass def __nonzero__(self):
return self.a.same_as(self.b)
def __bool__(self):
return self.__nonzero__()
@register_node @register_node
class NE(CmpExpr): class NE(CmpExpr):
pass def __nonzero__(self):
return not self.a.same_as(self.b)
def __bool__(self):
return self.__nonzero__()
@register_node @register_node
class LT(CmpExpr): class LT(CmpExpr):
......
...@@ -90,7 +90,7 @@ class IRBuilder(object): ...@@ -90,7 +90,7 @@ class IRBuilder(object):
n = tvm.var("n") n = tvm.var("n")
A = ib.allocate("float32", n, name="A") A = ib.allocate("float32", n, name="A")
with ib.for_range(0, n, name="i") as i: with ib.for_range(0, n, name="i") as i:
with ib.if_scope((i % 2).equal(0)): with ib.if_scope((i % 2) == 0):
A[i] = A[i] + 1 A[i] = A[i] + 1
# The result stmt. # The result stmt.
stmt = ib.get() stmt = ib.get()
......
...@@ -25,7 +25,7 @@ def test_if(): ...@@ -25,7 +25,7 @@ def test_if():
n = tvm.var("n") n = tvm.var("n")
A = ib.pointer("float32", name="A") A = ib.pointer("float32", name="A")
with ib.for_range(0, n, name="i") as i: with ib.for_range(0, n, name="i") as i:
with ib.if_scope((i % 2).equal(0)): with ib.if_scope((i % 2) == 0):
A[i] = A[i] + 1 A[i] = A[i] + 1
with ib.else_scope(): with ib.else_scope():
A[0] = A[i] + 2 A[0] = A[i] + 2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment