Commit a610edee by Sergei Grechanik Committed by Tianqi Chen

[ARITH] RewriteSimplifier: improved cmp simplification (#2851)

parent ce3decb1
......@@ -96,6 +96,8 @@ class RewriteSimplifier::Impl : public IRMutator {
kEQ,
kGT,
kLT,
kGE,
kLE,
kNE
};
// reference to the main analyzer
......@@ -140,6 +142,12 @@ class RewriteSimplifier::Impl : public IRMutator {
if (dbound->max_value < val) {
return kLT;
}
if (dbound->min_value >= val) {
return kGE;
}
if (dbound->max_value <= val) {
return kLE;
}
return kUnknown;
}
......@@ -994,12 +1002,10 @@ Mutate_(const EQ* op, const Expr& self) {
if (IsIndexType(op->a.type())) {
CompareResult result = TryCompare(op->a - op->b, 0);
if (result != kUnknown) {
if (result == kEQ) {
return make_const(op->type, true);
} else {
return make_const(op->type, false);
}
if (result == kEQ) {
return make_const(op->type, true);
} else if (result == kNE || result == kGT || result == kLT) {
return make_const(op->type, false);
}
TVM_TRY_REWRITE(x - c1 == 0, x == c1);
TVM_TRY_REWRITE(c1 - x == 0, x == c1);
......@@ -1055,7 +1061,7 @@ Mutate_(const LT* op, const Expr& self) {
if (result == kLT) {
return make_const(op->type, true);
}
if (result == kEQ || result == kGT) {
if (result == kEQ || result == kGT || result == kGE) {
return make_const(op->type, false);
}
......
......@@ -450,6 +450,21 @@ def test_cmp_simplify():
ck.verify(tvm.max(8, x) > 10, tvm.expr.LT(10, x))
ck.verify(x + 1 < tvm.max(8, x), x < 7)
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10), override=True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(-10, 0), override=True)
ck.analyzer.update(z, tvm.arith.ConstIntBound(-5, 5), override=True)
ck.verify(x < 11, tvm.const(1, "bool"))
ck.verify(x <= 10, tvm.const(1, "bool"))
ck.verify(z <= 5, tvm.const(1, "bool"))
ck.verify(x + y <= 10, tvm.const(1, "bool"))
ck.verify(x + y >= -10, tvm.const(1, "bool"))
ck.verify(z - 5 <= y + 10, tvm.const(1, "bool"))
ck.verify(tvm.all(x > -1, z <= x + 5), tvm.const(1, "bool"))
ck.verify(x*y <= 0, tvm.const(1, "bool"))
ck.verify((x + 1)*(y - 1) < 0, tvm.const(1, "bool"))
ck.verify(y*y >= 0, tvm.const(1, "bool"))
def test_logical_simplify():
ck = RewriteChecker()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment