Commit 9a2f01ab by Tianqi Chen Committed by GitHub

[PYTHON] Improve equal sugar (#564)

* [PYTHON] Improve equal sugar

* fix comment
parent 60510a47
......@@ -16,7 +16,7 @@ For example, you can use addexp.a to get the left operand of an Add node.
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node
from ._ffi.node import NodeBase, NodeGeneric, register_node
from . import make as _make
from . import _api_internal
......@@ -89,10 +89,10 @@ class ExprOp(object):
return _make.LE(self, other)
def __eq__(self, other):
return self.equal(other)
return EqualOp(self, other)
def __ne__(self, other):
return _make.NE(self, other)
return NotEqualOp(self, other)
def __gt__(self, other):
return _make.GT(self, other)
......@@ -138,12 +138,71 @@ class ExprOp(object):
return _make.static_cast(dtype, self)
class EqualOp(NodeGeneric, ExprOp):
"""Deferred equal operator.
This is used to support sugar that a == b can either
mean NodeBase.same_as or NodeBase.equal.
Parameters
----------
a : Expr
Left operand.
b : Expr
Right operand.
"""
def __init__(self, a, b):
self.a = a
self.b = b
def __nonzero__(self):
return self.a.same_as(self.b)
def __bool__(self):
return self.__nonzero__()
def asnode(self):
"""Convert node."""
return _make.EQ(self.a, self.b)
class NotEqualOp(NodeGeneric, ExprOp):
"""Deferred NE operator.
This is used to support sugar that a != b can either
mean not NodeBase.same_as or make.NE.
Parameters
----------
a : Expr
Left operand.
b : Expr
Right operand.
"""
def __init__(self, a, b):
self.a = a
self.b = b
def __nonzero__(self):
return not self.a.same_as(self.b)
def __bool__(self):
return self.__nonzero__()
def asnode(self):
"""Convert node."""
return _make.NE(self.a, self.b)
class Expr(ExprOp, NodeBase):
"""Base class of all tvm Expressions"""
# In Python3, We have to explicity tell interpreter to retain __hash__ if we overide __eq__
# https://docs.python.org/3.1/reference/datamodel.html#object.__hash__
__hash__ = NodeBase.__hash__
class ConstExpr(Expr):
pass
......@@ -215,19 +274,11 @@ class Max(BinaryOpExpr):
@register_node
class EQ(CmpExpr):
def __nonzero__(self):
return self.a.same_as(self.b)
def __bool__(self):
return self.__nonzero__()
pass
@register_node
class NE(CmpExpr):
def __nonzero__(self):
return not self.a.same_as(self.b)
def __bool__(self):
return self.__nonzero__()
pass
@register_node
class LT(CmpExpr):
......
......@@ -31,6 +31,7 @@ def test_if():
A[0] = A[i] + 2
body = ib.get()
assert A == A
assert isinstance(body, tvm.stmt.For)
body = body.body
assert isinstance(body, tvm.stmt.IfThenElse)
......@@ -42,6 +43,7 @@ def test_prefetch():
A = tvm.placeholder((10, 20), name="A")
ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="i") as i:
ib.emit(
tvm.make.Prefetch(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment