Unverified Commit 8703d9fb by Tianqi Chen Committed by GitHub

[ARITH] Bugfix min/max const canonicalize rule (#3386)

parent 56397826
Subproject commit fbe142b267a8edd1f1188fa2140d88f7ae308661
Subproject commit 3943914eed66470bd010df581e29e4dca4f7df6f
......@@ -813,7 +813,9 @@ Mutate_(const Min* op, const Expr& self) {
// canonicalization
TVM_TRY_RECURSIVE_REWRITE(min(min(x, c1), y), min(min(x, y), c1));
TVM_TRY_RECURSIVE_REWRITE(min(c1 - x, c2), c1 - max(x, c2 - c1));
TVM_TRY_RECURSIVE_REWRITE_IF(
min(c1 - x, c2), c1 - max(x, c1 - c2),
c2.Eval()->value != 0);
}
// condition rules.
......@@ -961,7 +963,8 @@ Mutate_(const Max* op, const Expr& self) {
// canonicalization
TVM_TRY_RECURSIVE_REWRITE(max(max(x, c1), y), max(max(x, y), c1));
TVM_TRY_RECURSIVE_REWRITE(max(c1 - x, c2), c1 - min(x, c2 - c1));
TVM_TRY_RECURSIVE_REWRITE_IF(
max(c1 - x, c2), c1 - min(x, c1 - c2), c2.Eval()->value != 0);
}
// condition rules.
......
......@@ -392,6 +392,7 @@ def test_min_index_simplify():
ck.verify(tvm.min(x / 10, y / 10), tvm.min(x, y) / 10)
ck.verify(tvm.min(x / (-10), y / (-10)), tvm.max(x, y) / (-10))
ck.verify(tvm.min(x * 3, 9), tvm.min(x, 3) * 3)
ck.verify(tvm.min(3 - x, 2), 3 - tvm.max(x, 1))
def test_max_index_simplify():
......@@ -448,6 +449,7 @@ def test_max_index_simplify():
ck.verify(tvm.max(x / 10, y / 10), tvm.max(x, y) / 10)
ck.verify(tvm.max(x / (-10), y / (-10)), tvm.min(x, y) / (-10))
ck.verify(tvm.max(x * 3, 9), tvm.max(x, 3) * 3)
ck.verify(tvm.max(3 - x, 1), 3 - tvm.min(x, 2))
def test_cmp_simplify():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment