Unverified Commit 6c81d784 by Tianqi Chen Committed by GitHub

[ARITH] Canonicalize comparison to move constant to one side (#3467)

parent 79e071c9
...@@ -1187,6 +1187,12 @@ Mutate_(const LT* op, const Expr& self) { ...@@ -1187,6 +1187,12 @@ Mutate_(const LT* op, const Expr& self) {
TVM_TRY_RECURSIVE_REWRITE(z < min(x, y), z < x && z < y); TVM_TRY_RECURSIVE_REWRITE(z < min(x, y), z < x && z < y);
TVM_TRY_RECURSIVE_REWRITE(z < max(x, y), z < x || z < y); TVM_TRY_RECURSIVE_REWRITE(z < max(x, y), z < x || z < y);
TVM_TRY_RECURSIVE_REWRITE(x < c1 - y, x + y < c1);
TVM_TRY_RECURSIVE_REWRITE(x < c1 + y, x - y < c1);
TVM_TRY_RECURSIVE_REWRITE(c1 - y < x, c1 < x + y);
TVM_TRY_RECURSIVE_REWRITE(c1 + y < x, c1 < x - y);
TVM_TRY_REWRITE(x - c1 < 0, x < c1); TVM_TRY_REWRITE(x - c1 < 0, x < c1);
TVM_TRY_REWRITE(x + c1 < c2, x < c2 - c1); TVM_TRY_REWRITE(x + c1 < c2, x < c2 - c1);
} }
......
...@@ -166,12 +166,18 @@ def test_simplify_if_then_else(): ...@@ -166,12 +166,18 @@ def test_simplify_if_then_else():
tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528), tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528),
(((((x*4) + y) - 466036) % 24528) -24512) % 16, (((((x*4) + y) - 466036) % 24528) -24512) % 16,
x), y) x), y)
res2 = tvm.if_then_else((x * 4) >= 466036 - y,
tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528),
(((((x*4) + y) - 466036) % 24528) -24512) % 16,
x), y)
expected = tvm.if_then_else( expected = tvm.if_then_else(
tvm.expr.LE(466036, (x * 4 + y)), tvm.expr.LE(466036, (x * 4 + y)),
tvm.if_then_else(tvm.expr.LE(24512, ((((x*4) + y) - 4) % 24528)), tvm.if_then_else(tvm.expr.LE(24512, ((((x*4) + y) - 4) % 24528)),
(((x*4) + y) - 4) % 16, (((x*4) + y) - 4) % 16,
x), y) x), y)
ck.verify(res, expected) ck.verify(res, expected)
ck.verify(res2, expected)
# can only simplify if condition # can only simplify if condition
res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 100) % 3, (x + 100) % 3) res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 100) % 3, (x + 100) % 3)
expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 1) % 3, (x + 100) % 3) expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 1) % 3, (x + 100) % 3)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment