Commit f01cc0e6 by Sergei Grechanik Committed by Tianqi Chen

[TVM] Eagerer const folding for logic ops (#1907)

parent b4946e77
...@@ -310,20 +310,26 @@ Expr operator!=(Expr a, Expr b) { ...@@ -310,20 +310,26 @@ Expr operator!=(Expr a, Expr b) {
Expr operator&&(Expr a, Expr b) { Expr operator&&(Expr a, Expr b) {
using ir::UIntImm; using ir::UIntImm;
if (a.type().is_bool() && b.type().is_bool()) {
const UIntImm* pa = a.as<UIntImm>(); const UIntImm* pa = a.as<UIntImm>();
const UIntImm* pb = b.as<UIntImm>(); const UIntImm* pb = b.as<UIntImm>();
if (pa && pb) { if (pa && pa->value) return b;
return UIntImm::make(UInt(1), pa->value && pb->value); if (pa && !pa->value) return a;
if (pb && pb->value) return a;
if (pb && !pb->value) return b;
} }
return ir::And::make(a, b); return ir::And::make(a, b);
} }
Expr operator||(Expr a, Expr b) { Expr operator||(Expr a, Expr b) {
using ir::UIntImm; using ir::UIntImm;
if (a.type().is_bool() && b.type().is_bool()) {
const UIntImm* pa = a.as<UIntImm>(); const UIntImm* pa = a.as<UIntImm>();
const UIntImm* pb = b.as<UIntImm>(); const UIntImm* pb = b.as<UIntImm>();
if (pa && pb) { if (pa && pa->value) return a;
return UIntImm::make(UInt(1), pa->value || pb->value); if (pa && !pa->value) return b;
if (pb && pb->value) return b;
if (pb && !pb->value) return a;
} }
return ir::Or::make(a, b); return ir::Or::make(a, b);
} }
......
...@@ -30,6 +30,44 @@ def test_const_fold2(): ...@@ -30,6 +30,44 @@ def test_const_fold2():
assert (1 * x).same_as(x) assert (1 * x).same_as(x)
assert isinstance((1 / x), tvm.expr.Div) assert isinstance((1 / x), tvm.expr.Div)
def test_const_fold3():
def check_throws(f):
try:
f()
except tvm.TVMError:
pass
else:
raise AssertionError("Should have raised an exception but didn't.")
# Test that using ints with logic operations is forbidden
x = tvm.var("x")
for val in [0, 1]:
for func in [tvm.all, tvm.any]:
check_throws(lambda: func(tvm.const(val, 'uint1'), x))
check_throws(lambda: func(x, tvm.const(val, 'uint1')))
# Test const folding when both arguments are const
for tvm_func, py_func in [(tvm.all, lambda a, b: a and b), (tvm.any, lambda a, b: a or b)]:
for v1 in [0, 1]:
for v2 in [0, 1]:
assert tvm.ir_pass.Equal(tvm_func(tvm.const(v1, 'uint1'), tvm.const(v2, 'uint1')),
tvm.const(py_func(v1, v2), 'uint1'))
x = tvm.var("x", 'uint1')
true = tvm.const(1, 'uint1')
false = tvm.const(0, 'uint1')
assert tvm.all(x, true).same_as(x)
assert tvm.all(true, x).same_as(x)
assert tvm.any(x, false).same_as(x)
assert tvm.any(false, x).same_as(x)
assert tvm.all(x, false).same_as(false)
assert tvm.all(false, x).same_as(false)
assert tvm.any(x, true).same_as(true)
assert tvm.any(true, x).same_as(true)
if __name__ == "__main__": if __name__ == "__main__":
test_const_fold() test_const_fold()
test_const_fold2() test_const_fold2()
test_const_fold3()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment