Commit d8bd4762 by Jian Weng Committed by Tianqi Chen

[Hybrid Script] Support logical and/or; support 0 < a < 5 clause (#2264)

parent cb70da1b
......@@ -8,6 +8,8 @@ from .util import make_nop, halide_imm_types, is_docstring, _internal_assert
from .intrin import LOOP_INTRIN, MATH_INTRIN
from .var_decl import determine_variable_usage
from ..api import thread_axis
from ..api import all as _all
from ..api import any as _any
from .. import expr as _expr
from .. import make as _make
from .. import intrin
......@@ -47,6 +49,8 @@ class HybridParser(ast.NodeVisitor):
ast.LtE : operator.le,
ast.Eq : operator.eq,
ast.NotEq : operator.ne,
ast.And : _all,
ast.Or : _any,
}
......@@ -282,11 +286,31 @@ class HybridParser(ast.NodeVisitor):
def visit_Compare(self, node):
lhs = self.visit(node.left)
_internal_assert(len(node.ops) == 1, "Only one compare op is supported!")
_internal_assert(len(node.comparators) == 1, "Only one comparator is supported!")
rhs = self.visit(node.comparators[0])
return HybridParser._binop_maker[type(node.ops[0])](lhs, rhs)
_internal_assert(len(node.ops) == len(node.comparators),
"#compare ops != #comparators")
ops = [self.visit(node.left)]
ops += [self.visit(i) for i in node.comparators]
res = []
for i in range(len(node.ops)):
lhs = ops[i]
rhs = ops[i + 1]
res.append(HybridParser._binop_maker[type(node.ops[i])](lhs, rhs))
return _all(*res)
def visit_BoolOp(self, node):
n = len(node.values)
if n == 1:
_internal_assert(isinstance(node.op, ast.Not), \
"Unary is supposed to be not!")
return operator.not_(self.visit(node.values[0]))
elif n == 2:
_internal_assert(isinstance(node.op, (ast.And, ast.Or)), \
"Binary is supposed to be and/or!")
values = [self.visit(i) for i in node.values]
return HybridParser._binop_maker[type(node.op)](*values)
else:
raise ValueError("This Bool Op is not supported yet!")
def visit_UnaryOp(self, node):
......
......@@ -237,6 +237,30 @@ def test_if():
run_and_check(if_then_else, [a])
@script
def if_triple_condition(a):
b = output_tensor((10, ), 'int32')
for i in range(10):
if 0 <= i < 5:
b[i] = a[i]
else:
b[i] = a[i] + 1
return b
run_and_check(if_triple_condition, [a])
@script
def if_and(a):
b = output_tensor((10, ), 'int32')
for i in range(10):
if i >= 0 and i < 5:
b[i] = a[i]
else:
b[i] = a[i] + 1
return b
run_and_check(if_and, [a])
def test_bind():
if not tvm.gpu(0).exist:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment