Commit a610edee by Sergei Grechanik Committed by Tianqi Chen

[ARITH] RewriteSimplifier: improved cmp simplification (#2851)

parent ce3decb1
...@@ -96,6 +96,8 @@ class RewriteSimplifier::Impl : public IRMutator { ...@@ -96,6 +96,8 @@ class RewriteSimplifier::Impl : public IRMutator {
kEQ, kEQ,
kGT, kGT,
kLT, kLT,
kGE,
kLE,
kNE kNE
}; };
// reference to the main analyzer // reference to the main analyzer
...@@ -140,6 +142,12 @@ class RewriteSimplifier::Impl : public IRMutator { ...@@ -140,6 +142,12 @@ class RewriteSimplifier::Impl : public IRMutator {
if (dbound->max_value < val) { if (dbound->max_value < val) {
return kLT; return kLT;
} }
if (dbound->min_value >= val) {
return kGE;
}
if (dbound->max_value <= val) {
return kLE;
}
return kUnknown; return kUnknown;
} }
...@@ -994,12 +1002,10 @@ Mutate_(const EQ* op, const Expr& self) { ...@@ -994,12 +1002,10 @@ Mutate_(const EQ* op, const Expr& self) {
if (IsIndexType(op->a.type())) { if (IsIndexType(op->a.type())) {
CompareResult result = TryCompare(op->a - op->b, 0); CompareResult result = TryCompare(op->a - op->b, 0);
if (result != kUnknown) { if (result == kEQ) {
if (result == kEQ) { return make_const(op->type, true);
return make_const(op->type, true); } else if (result == kNE || result == kGT || result == kLT) {
} else { return make_const(op->type, false);
return make_const(op->type, false);
}
} }
TVM_TRY_REWRITE(x - c1 == 0, x == c1); TVM_TRY_REWRITE(x - c1 == 0, x == c1);
TVM_TRY_REWRITE(c1 - x == 0, x == c1); TVM_TRY_REWRITE(c1 - x == 0, x == c1);
...@@ -1055,7 +1061,7 @@ Mutate_(const LT* op, const Expr& self) { ...@@ -1055,7 +1061,7 @@ Mutate_(const LT* op, const Expr& self) {
if (result == kLT) { if (result == kLT) {
return make_const(op->type, true); return make_const(op->type, true);
} }
if (result == kEQ || result == kGT) { if (result == kEQ || result == kGT || result == kGE) {
return make_const(op->type, false); return make_const(op->type, false);
} }
......
...@@ -450,6 +450,21 @@ def test_cmp_simplify(): ...@@ -450,6 +450,21 @@ def test_cmp_simplify():
ck.verify(tvm.max(8, x) > 10, tvm.expr.LT(10, x)) ck.verify(tvm.max(8, x) > 10, tvm.expr.LT(10, x))
ck.verify(x + 1 < tvm.max(8, x), x < 7) ck.verify(x + 1 < tvm.max(8, x), x < 7)
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10), override=True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(-10, 0), override=True)
ck.analyzer.update(z, tvm.arith.ConstIntBound(-5, 5), override=True)
ck.verify(x < 11, tvm.const(1, "bool"))
ck.verify(x <= 10, tvm.const(1, "bool"))
ck.verify(z <= 5, tvm.const(1, "bool"))
ck.verify(x + y <= 10, tvm.const(1, "bool"))
ck.verify(x + y >= -10, tvm.const(1, "bool"))
ck.verify(z - 5 <= y + 10, tvm.const(1, "bool"))
ck.verify(tvm.all(x > -1, z <= x + 5), tvm.const(1, "bool"))
ck.verify(x*y <= 0, tvm.const(1, "bool"))
ck.verify((x + 1)*(y - 1) < 0, tvm.const(1, "bool"))
ck.verify(y*y >= 0, tvm.const(1, "bool"))
def test_logical_simplify(): def test_logical_simplify():
ck = RewriteChecker() ck = RewriteChecker()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment