Commit 9a2f01ab by Tianqi Chen Committed by GitHub

[PYTHON] Improve equal sugar (#564)

* [PYTHON] Improve equal sugar

* fix comment
parent 60510a47
...@@ -16,7 +16,7 @@ For example, you can use addexp.a to get the left operand of an Add node. ...@@ -16,7 +16,7 @@ For example, you can use addexp.a to get the left operand of an Add node.
""" """
# pylint: disable=missing-docstring # pylint: disable=missing-docstring
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node from ._ffi.node import NodeBase, NodeGeneric, register_node
from . import make as _make from . import make as _make
from . import _api_internal from . import _api_internal
...@@ -89,10 +89,10 @@ class ExprOp(object): ...@@ -89,10 +89,10 @@ class ExprOp(object):
return _make.LE(self, other) return _make.LE(self, other)
def __eq__(self, other): def __eq__(self, other):
return self.equal(other) return EqualOp(self, other)
def __ne__(self, other): def __ne__(self, other):
return _make.NE(self, other) return NotEqualOp(self, other)
def __gt__(self, other): def __gt__(self, other):
return _make.GT(self, other) return _make.GT(self, other)
...@@ -138,12 +138,71 @@ class ExprOp(object): ...@@ -138,12 +138,71 @@ class ExprOp(object):
return _make.static_cast(dtype, self) return _make.static_cast(dtype, self)
class EqualOp(NodeGeneric, ExprOp):
"""Deferred equal operator.
This is used to support sugar that a == b can either
mean NodeBase.same_as or NodeBase.equal.
Parameters
----------
a : Expr
Left operand.
b : Expr
Right operand.
"""
def __init__(self, a, b):
self.a = a
self.b = b
def __nonzero__(self):
return self.a.same_as(self.b)
def __bool__(self):
return self.__nonzero__()
def asnode(self):
"""Convert node."""
return _make.EQ(self.a, self.b)
class NotEqualOp(NodeGeneric, ExprOp):
"""Deferred NE operator.
This is used to support sugar that a != b can either
mean not NodeBase.same_as or make.NE.
Parameters
----------
a : Expr
Left operand.
b : Expr
Right operand.
"""
def __init__(self, a, b):
self.a = a
self.b = b
def __nonzero__(self):
return not self.a.same_as(self.b)
def __bool__(self):
return self.__nonzero__()
def asnode(self):
"""Convert node."""
return _make.NE(self.a, self.b)
class Expr(ExprOp, NodeBase): class Expr(ExprOp, NodeBase):
"""Base class of all tvm Expressions""" """Base class of all tvm Expressions"""
# In Python3, We have to explicity tell interpreter to retain __hash__ if we overide __eq__ # In Python3, We have to explicity tell interpreter to retain __hash__ if we overide __eq__
# https://docs.python.org/3.1/reference/datamodel.html#object.__hash__ # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__
__hash__ = NodeBase.__hash__ __hash__ = NodeBase.__hash__
class ConstExpr(Expr): class ConstExpr(Expr):
pass pass
...@@ -215,19 +274,11 @@ class Max(BinaryOpExpr): ...@@ -215,19 +274,11 @@ class Max(BinaryOpExpr):
@register_node @register_node
class EQ(CmpExpr): class EQ(CmpExpr):
def __nonzero__(self): pass
return self.a.same_as(self.b)
def __bool__(self):
return self.__nonzero__()
@register_node @register_node
class NE(CmpExpr): class NE(CmpExpr):
def __nonzero__(self): pass
return not self.a.same_as(self.b)
def __bool__(self):
return self.__nonzero__()
@register_node @register_node
class LT(CmpExpr): class LT(CmpExpr):
......
...@@ -31,6 +31,7 @@ def test_if(): ...@@ -31,6 +31,7 @@ def test_if():
A[0] = A[i] + 2 A[0] = A[i] + 2
body = ib.get() body = ib.get()
assert A == A
assert isinstance(body, tvm.stmt.For) assert isinstance(body, tvm.stmt.For)
body = body.body body = body.body
assert isinstance(body, tvm.stmt.IfThenElse) assert isinstance(body, tvm.stmt.IfThenElse)
...@@ -42,6 +43,7 @@ def test_prefetch(): ...@@ -42,6 +43,7 @@ def test_prefetch():
A = tvm.placeholder((10, 20), name="A") A = tvm.placeholder((10, 20), name="A")
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
n = tvm.var("n") n = tvm.var("n")
with ib.for_range(0, n, name="i") as i: with ib.for_range(0, n, name="i") as i:
ib.emit( ib.emit(
tvm.make.Prefetch( tvm.make.Prefetch(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment