Commit 69aefaa3 by Haichen Shen Committed by Tianqi Chen

[LANG] Add all and any in the python API (#196)

* [LANG] Add all and any in the python API

* compatible with python3
parent 7cc92ace
......@@ -114,6 +114,53 @@ def var(name="tindex", dtype=int32):
return _api_internal._Var(name, dtype)
def any(*args):
"""Create a new experssion of the union of all conditions in the arguments
Parameters
----------
args : list
List of symbolic boolean expressions
Returns
-------
expr: Expr
Expression
"""
if not args:
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _make.Or(args[0], args[1])
for i in range(2, len(args)):
ret = _make.Or(ret, args[i])
return ret
def all(*args):
"""Create a new experssion of the intersection of all conditions in the
arguments
Parameters
----------
args : list
List of symbolic boolean expressions
Returns
-------
expr: Expr
Expression
"""
if not args:
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _make.And(args[0], args[1])
for i in range(2, len(args)):
ret = _make.And(ret, args[i])
return ret
def placeholder(shape, dtype=None, name="placeholder"):
"""Construct an empty tensor object.
......
......@@ -74,6 +74,13 @@ class ExprOp(object):
def __ge__(self, other):
return _make.GE(self, other)
def __nonzero__(self):
raise ValueError("Cannot use and / or / not operator to Expr, hint: " +
"use tvm.all / tvm.any instead")
def __bool__(self):
return self.__nonzero__()
def equal(self, other):
"""Build an equal check expression with other expr.
......
......@@ -50,7 +50,7 @@ def test_check():
assert res1.is_nothing()
# multiple compare operators
res2 = tvm.arith.DeduceBound(a, a+b>3>c , {b: b_s, c: c_s}, {})
res2 = tvm.arith.DeduceBound(a, (a+b>3)>c , {b: b_s, c: c_s}, {})
assert res1.is_nothing()
# multiple target variable
......
......@@ -81,6 +81,51 @@ def test_dtype():
y = tvm.var('y')
assert (x > y).dtype == 'uint1'
def test_any():
x = tvm.var('x')
y = tvm.var('y')
z = tvm.var('z')
try:
t = x or x
assert False
except ValueError:
pass
try:
tvm.any()
assert False
except ValueError:
pass
assert str(tvm.any(x < y)) == '(%s < %s)' % (x.name, y.name)
assert str(tvm.any(x < y, x > z)) == '((%s < %s) || (%s > %s))' % (
x.name, y.name, x.name, z.name)
assert str(tvm.any(x < y, y > z + 1, x < z * 2)) == \
'(((%s < %s) || (%s > (%s + 1))) || (%s < (%s*2)))' % (
x.name, y.name, y.name, z.name, x.name, z.name)
def test_all():
x = tvm.var('x')
y = tvm.var('y')
z = tvm.var('z')
try:
t = x and x
assert False
except ValueError:
pass
try:
tvm.all()
assert False
except ValueError:
pass
assert str(tvm.all(x < y)) == '(%s < %s)' % (x.name, y.name)
assert str(tvm.all(x < y, x > z)) == '((%s < %s) && (%s > %s))' % (
x.name, y.name, x.name, z.name)
assert str(tvm.all(x < y, y > z + 1, x < z * 2)) == \
'(((%s < %s) && (%s > (%s + 1))) && (%s < (%s*2)))' % (
x.name, y.name, y.name, z.name, x.name, z.name)
if __name__ == "__main__":
test_attr()
test_const()
......@@ -92,3 +137,5 @@ if __name__ == "__main__":
test_let()
test_dir()
test_dtype()
test_any()
test_all()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment