Unverified Commit eadc4e38 by Tianqi Chen Committed by GitHub

[ARITH] Bugfix div subtract rewrite rule (#3504)

parent f9788871
...@@ -342,13 +342,16 @@ Mutate_(const Sub* op, const Expr& self) { ...@@ -342,13 +342,16 @@ Mutate_(const Sub* op, const Expr& self) {
c1.Eval()->value != 0 && c1.Eval()->value != 0 &&
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
// Proof in the case of floordiv, need positive condition.
// let x = a * c3 + r
// (x + c1) / c3 - x / c3 => (r + c1) / c3
TVM_TRY_REWRITE_IF((x + c1) / c3 - (x + c2) / c3, TVM_TRY_REWRITE_IF((x + c1) / c3 - (x + c2) / c3,
((x + (c1 % c3)) % c3 + (c1 - c2)) / c3, ((x + ((c2 % c3) + c3) % c3) % c3 + (c1 - c2)) / c3,
CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) && CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) &&
c1.Eval()->value >= c2.Eval()->value && c1.Eval()->value >= c2.Eval()->value &&
c3.Eval()->value > 0); c3.Eval()->value > 0);
TVM_TRY_REWRITE_IF((x + c1) / c3 - x / c3, TVM_TRY_REWRITE_IF((x + c1) / c3 - x / c3,
((x + (c1 % c3)) % c3 + c1) / c3, (x % c3 + c1) / c3,
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(x.Eval(), 0) &&
c1.Eval()->value >= 0 && c1.Eval()->value >= 0 &&
c3.Eval()->value > 0); c3.Eval()->value > 0);
......
...@@ -236,7 +236,9 @@ def test_sub_index_simplify(): ...@@ -236,7 +236,9 @@ def test_sub_index_simplify():
# div pattern # div pattern
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True)
ck.verify(x - (x / 3) * 3, x % 3) ck.verify(x - (x / 3) * 3, x % 3)
ck.verify((x + 5) / 3 - x / 3, (((x + 2) % 3) + 5)/ 3)
ck.verify((x + 5) / 3 - x / 3, ((x % 3) + 5)/ 3)
ck.verify((x + 5) / 3 - (x + 1) / 3, (((x + 1) % 3) + 4)/ 3)
ck.verify(y - (y / (-5)) * (-5), y % 5) ck.verify(y - (y / (-5)) * (-5), y % 5)
ck.verify((y / 3) * 3 - y, 0 - y % 3) ck.verify((y / 3) * 3 - y, 0 - y % 3)
...@@ -258,6 +260,7 @@ def test_sub_index_simplify(): ...@@ -258,6 +260,7 @@ def test_sub_index_simplify():
ck.verify(6 * ((y + z) / 3) - y * 2, (z - (y + z) % 3) * 2) ck.verify(6 * ((y + z) / 3) - y * 2, (z - (y + z) % 3) * 2)
ck.verify(((y - z) / 3) * 6 - 2 * y, (0 - (y - z) % 3 - z) * 2) ck.verify(((y - z) / 3) * 6 - 2 * y, (0 - (y - z) % 3 - z) * 2)
def test_mul_index_simplify(): def test_mul_index_simplify():
ck = RewriteChecker() ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment