Unverified Commit 046e4ff0 by Tianqi Chen Committed by GitHub

[ARITH] RewriteSimplifier: min/max, logical, select (#2768)

parent 6c60b8d3
......@@ -75,8 +75,29 @@ class RewriteSimplifier::Impl : public IRMutator {
Expr Mutate_(const Mul* op, const Expr& self) final;
Expr Mutate_(const Div* op, const Expr& self) final;
Expr Mutate_(const Mod* op, const Expr& self) final;
Expr Mutate_(const Min* op, const Expr& self) final;
Expr Mutate_(const Max* op, const Expr& self) final;
Expr Mutate_(const EQ* op, const Expr& self) final;
Expr Mutate_(const NE* op, const Expr& self) final;
Expr Mutate_(const LT* op, const Expr& self) final;
Expr Mutate_(const LE* op, const Expr& self) final;
Expr Mutate_(const GT* op, const Expr& self) final;
Expr Mutate_(const GE* op, const Expr& self) final;
Expr Mutate_(const And* op, const Expr& self) final;
Expr Mutate_(const Or* op, const Expr& self) final;
Expr Mutate_(const Not* op, const Expr& self) final;
Expr Mutate_(const Select* op, const Expr& self) final;
Expr Mutate_(const Ramp* op, const Expr& self) final;
private:
/*! \brief internal structure for comparison. */
enum CompareResult {
kUnknown,
kEQ,
kGT,
kLT,
kNE
};
// reference to the main analyzer
Analyzer* parent_;
// counter to record recursive rewrite depth.
......@@ -92,12 +113,36 @@ class RewriteSimplifier::Impl : public IRMutator {
// Whether x == val
bool CanProveEqual(const Expr& x, int64_t val) {
// TODO(tqchen) refer back to super-analyzer.
Expr res = Mutate(x);
if (const auto* ptr = res.as<ir::IntImm>()) {
return ptr->value == val;
return TryCompare(x, val) == kEQ;
}
// try to prove x equals val
CompareResult TryCompare(const Expr& x, int64_t val) {
Expr diff = Mutate(x);
if (const auto* ptr = diff.as<IntImm>()) {
if (ptr->value == val) {
return kEQ;
} else if (ptr->value > val) {
return kGT;
} else if (ptr->value < val) {
return kLT;
}
}
if (val == 0) {
ModularSet dmod = parent_->modular_set(diff);
if (dmod->base != 0) {
return kNE;
}
return false;
}
ConstIntBound dbound = parent_->const_int_bound(diff);
if (dbound->min_value > val) {
return kGT;
}
if (dbound->max_value < val) {
return kLT;
}
return kUnknown;
}
// Recursive rewrite x
// we limit maximum depth of recursive rewrite allowed to
// avoid infinite loop
......@@ -557,7 +602,7 @@ Mutate_(const Mod* op, const Expr& self) {
// Pattern var to match any expression
PVar<Expr> x, y, z, b1;
// Pattern var match IntImm
PVar<Integer> c1, c2, c3;
PVar<Integer> c1, c2;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
......@@ -626,6 +671,608 @@ Mutate_(const Mod* op, const Expr& self) {
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const Min* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Min>();
Expr const_res = TryConstFold<Min>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y, z, s1, s2;
// Pattern var match IntImm
PVar<Integer> c1, c2;
PVar<int> lanes;
// vector rule
if (op->type.lanes() != 1) {
TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(min(x, y), lanes));
TVM_TRY_REWRITE(min(min(x, broadcast(y, lanes)), broadcast(z, lanes)),
min(x, broadcast(min(y, z), lanes)));
}
if (IsIndexType(op->type)) {
TVM_TRY_REWRITE(min(x, x), x);
// constant int bound
ConstIntBound a_bound = parent_->const_int_bound(op->a);
ConstIntBound b_bound = parent_->const_int_bound(op->b);
if (a_bound->max_value <= b_bound->min_value) {
return op->a;
}
if (b_bound->max_value <= a_bound->min_value) {
return op->b;
}
// constant comparison
if (min(x + c1, x + c2).Match(ret)) {
if (c1.Eval()->value < c2.Eval()->value) {
return (x + c1).Eval();
} else {
return (x + c2).Eval();
}
}
if (min(x + c1, x).Match(ret) ||
min(x, x + c1).Match(ret)) {
if (c1.Eval()->value < 0) {
return (x + c1).Eval();
} else {
return x.Eval();
}
}
if (min(c1 - x, c2 - x).Match(ret)) {
if (c1.Eval()->value < c2.Eval()->value) {
return (c1 - x).Eval();
} else {
return (c2 - x).Eval();
}
}
// Divide up rounding
TVM_TRY_REWRITE_IF(min(((x + c1) / c2) * c2, x), x,
c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(((x + c1) / c2) * c2, max(x, c2)), max(x, c2),
c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value &&
CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF(min(x, ((x + c1) / c2) * c2), x,
c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(max(x, c2), ((x + c1) / c2) * c2), max(x, c2),
c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value &&
CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE(min(max(x, y), min(x, y)), min(x, y));
TVM_TRY_REWRITE(min(max(x, y), min(y, x)), min(x, y));
TVM_TRY_REWRITE(min(min(x, y), max(x, y)), min(x, y));
TVM_TRY_REWRITE(min(min(x, y), max(y, x)), min(x, y));
TVM_TRY_REWRITE(min(max(x, y), x), x);
TVM_TRY_REWRITE(min(max(x, y), y), y);
TVM_TRY_REWRITE(min(min(x, y), x), min(x, y));
TVM_TRY_REWRITE(min(min(x, y), y), min(x, y));
TVM_TRY_REWRITE(min(x, max(x, y)), x);
TVM_TRY_REWRITE(min(y, max(x, y)), y);
TVM_TRY_REWRITE(min(x, min(x, y)), min(x, y));
TVM_TRY_REWRITE(min(y, min(x, y)), min(x, y));
TVM_TRY_REWRITE(min(min(min(x, y), z), y), min(min(x, y), z));
TVM_TRY_REWRITE(min(min(min(min(x, y), z), s1), y), min(min(min(x, y), z), s1));
TVM_TRY_REWRITE(min(min(min(min(min(x, y), z), s1), s2), y),
min(min(min(min(x, y), z), s1), s2));
TVM_TRY_REWRITE(min(max(x, y), max(x, z)), max(min(y, z), x));
TVM_TRY_REWRITE(min(max(x, y), max(z, x)), max(min(y, z), x));
TVM_TRY_REWRITE(min(max(y, x), max(x, z)), max(min(y, z), x));
TVM_TRY_REWRITE(min(max(y, x), max(z, x)), max(min(y, z), x));
TVM_TRY_REWRITE(min(min(x, y), min(x, z)), min(min(y, z), x));
TVM_TRY_REWRITE(min(min(x, y), min(z, x)), min(min(y, z), x));
TVM_TRY_REWRITE(min(min(y, x), min(x, z)), min(min(y, z), x));
TVM_TRY_REWRITE(min(min(y, x), min(z, x)), min(min(y, z), x));
TVM_TRY_REWRITE(min(y + x, z + x), min(y, z) + x);
TVM_TRY_REWRITE(min(y + x, x + z), min(y, z) + x);
TVM_TRY_REWRITE(min(x + y, x + z), min(y, z) + x);
TVM_TRY_REWRITE(min(x + y, z + x), min(y, z) + x);
// sub distribution
TVM_TRY_REWRITE(min(y - x, z - x), min(y, z) - x);
TVM_TRY_REWRITE(min(x - y, x - z), x - max(y, z));
// constant folding rule.
TVM_TRY_REWRITE(min(min(x, c1), c2), min(x, min(c1, c2)));
// scaling rule
if (min(x / c1, y / c1).Match(ret)) {
if (c1.Eval()->value > 0) {
return (min(x, y) / c1).Eval();
} else {
return (max(x, y) / c1).Eval();
}
}
if (min(x * c1, y * c1).Match(ret)) {
if (c1.Eval()->value > 0) {
return (min(x, y) * c1).Eval();
} else {
return (max(x, y) * c1).Eval();
}
}
if (min(x * c1, c2).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
if (c2val % c1val == 0) {
if (c2val / c1val >= 0) {
return (min(x, c2val / c1val) * c1val).Eval();
} else {
return (max(x, c2val / c1val) * c1val).Eval();
}
}
}
// canonicalization
TVM_TRY_RECURSIVE_REWRITE(min(min(x, c1), y), min(min(x, y), c1));
TVM_TRY_RECURSIVE_REWRITE(min(c1 - x, c2), c1 - max(x, c2 - c1));
}
// condition rules.
TVM_TRY_REWRITE(min(select(x, y, z), select(x, s1, s2)),
select(x, min(y, s1), min(z, s2)));
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const Max* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Max>();
Expr const_res = TryConstFold<Max>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y, z, s1, s2;
// Pattern var match IntImm
PVar<Integer> c1, c2;
PVar<int> lanes;
// vector rule
if (op->type.lanes() != 1) {
TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(max(x, y), lanes));
TVM_TRY_REWRITE(max(max(x, broadcast(y, lanes)), broadcast(z, lanes)),
max(x, broadcast(max(y, z), lanes)));
}
if (IsIndexType(op->type)) {
TVM_TRY_REWRITE(max(x, x), x);
// constant int bound
ConstIntBound a_bound = parent_->const_int_bound(op->a);
ConstIntBound b_bound = parent_->const_int_bound(op->b);
if (a_bound->min_value >= b_bound->max_value) {
return op->a;
}
if (b_bound->min_value >= a_bound->max_value) {
return op->b;
}
// constant comparison
if (max(x + c1, x + c2).Match(ret)) {
if (c1.Eval()->value > c2.Eval()->value) {
return (x + c1).Eval();
} else {
return (x + c2).Eval();
}
}
if (max(x + c1, x).Match(ret) ||
max(x, x + c1).Match(ret)) {
if (c1.Eval()->value > 0) {
return (x + c1).Eval();
} else {
return x.Eval();
}
}
if (max(c1 - x, c2 - x).Match(ret)) {
if (c1.Eval()->value > c2.Eval()->value) {
return (c1 - x).Eval();
} else {
return (c2 - x).Eval();
}
}
// Divide up rounding
TVM_TRY_REWRITE_IF(max(((x + c1) / c2) * c2, x), ((x + c1) / c2) * c2,
c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(max(x, ((x + c1) / c2) * c2), ((x + c1) / c2) * c2,
c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE(max(min(x, y), max(x, y)), max(x, y));
TVM_TRY_REWRITE(max(min(x, y), max(y, x)), max(x, y));
TVM_TRY_REWRITE(max(max(x, y), min(x, y)), max(x, y));
TVM_TRY_REWRITE(max(max(x, y), min(y, x)), max(x, y));
TVM_TRY_REWRITE(max(min(x, y), x), x);
TVM_TRY_REWRITE(max(min(x, y), y), y);
TVM_TRY_REWRITE(max(max(x, y), x), max(x, y));
TVM_TRY_REWRITE(max(max(x, y), y), max(x, y));
TVM_TRY_REWRITE(max(x, min(x, y)), x);
TVM_TRY_REWRITE(max(y, min(x, y)), y);
TVM_TRY_REWRITE(max(x, max(x, y)), max(x, y));
TVM_TRY_REWRITE(max(y, max(x, y)), max(x, y));
TVM_TRY_REWRITE(max(max(max(x, y), z), y), max(max(x, y), z));
TVM_TRY_REWRITE(max(max(max(max(x, y), z), s1), y), max(max(max(x, y), z), s1));
TVM_TRY_REWRITE(max(max(max(max(max(x, y), z), s1), s2), y),
max(max(max(max(x, y), z), s1), s2));
// max/max cancelation
TVM_TRY_REWRITE(max(max(x, y), max(x, z)), max(max(y, z), x));
TVM_TRY_REWRITE(max(max(x, y), max(z, x)), max(max(y, z), x));
TVM_TRY_REWRITE(max(max(y, x), max(x, z)), max(max(y, z), x));
TVM_TRY_REWRITE(max(max(y, x), max(z, x)), max(max(y, z), x));
// max/min distribution
TVM_TRY_REWRITE(max(min(x, y), min(x, z)), min(max(y, z), x));
TVM_TRY_REWRITE(max(min(x, y), min(z, x)), min(max(y, z), x));
TVM_TRY_REWRITE(max(min(y, x), min(x, z)), min(max(y, z), x));
TVM_TRY_REWRITE(max(min(y, x), min(z, x)), min(max(y, z), x));
// add distribution
TVM_TRY_REWRITE(max(y + x, z + x), max(y, z) + x);
TVM_TRY_REWRITE(max(y + x, x + z), max(y, z) + x);
TVM_TRY_REWRITE(max(x + y, x + z), max(y, z) + x);
TVM_TRY_REWRITE(max(x + y, z + x), max(y, z) + x);
// sub distribution
TVM_TRY_REWRITE(max(y - x, z - x), max(y, z) - x);
TVM_TRY_REWRITE(max(x - y, x - z), x - min(y, z));
// constant folding rule.
TVM_TRY_REWRITE(max(max(x, c1), c2), max(x, max(c1, c2)));
// scaling rule
if (max(x / c1, y / c1).Match(ret)) {
if (c1.Eval()->value > 0) {
return (max(x, y) / c1).Eval();
} else {
return (min(x, y) / c1).Eval();
}
}
if (max(x * c1, y * c1).Match(ret)) {
if (c1.Eval()->value > 0) {
return (max(x, y) * c1).Eval();
} else {
return (min(x, y) * c1).Eval();
}
}
if (max(x * c1, c2).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
if (c2val % c1val == 0) {
if (c2val / c1val >= 0) {
return (max(x, c2val / c1val) * c1val).Eval();
} else {
return (min(x, c2val / c1val) * c1val).Eval();
}
}
}
// canonicalization
TVM_TRY_RECURSIVE_REWRITE(max(max(x, c1), y), max(max(x, y), c1));
TVM_TRY_RECURSIVE_REWRITE(max(c1 - x, c2), c1 - min(x, c2 - c1));
}
// condition rules.
TVM_TRY_REWRITE(max(select(x, y, z), select(x, s1, s2)),
select(x, max(y, s1), max(z, s2)));
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const EQ* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<EQ>();
Expr const_res = TryConstFold<EQ>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y;
// Pattern var match IntImm
PVar<Integer> c1;
PVar<int> lanes;
// vector rule
if (op->type.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes),
broadcast(x == y, lanes));
}
if (IsIndexType(op->a.type())) {
CompareResult result = TryCompare(op->a - op->b, 0);
if (result != kUnknown) {
if (result == kEQ) {
return make_const(op->type, true);
} else {
return make_const(op->type, false);
}
}
TVM_TRY_REWRITE(x - c1 == 0, x == c1);
TVM_TRY_REWRITE(c1 - x == 0, x == c1);
TVM_TRY_REWRITE(x + c1 == 0, x == 0 - c1);
TVM_TRY_REWRITE(x * y == 0, x == 0 || y == 0);
}
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const NE* op, const Expr& self) {
return Mutate(Not::make(op->a == op->b));
}
Expr RewriteSimplifier::Impl::
Mutate_(const LE* op, const Expr& self) {
return Mutate(Not::make(op->b < op->a));
}
Expr RewriteSimplifier::Impl::
Mutate_(const GT* op, const Expr& self) {
return Mutate(op->b < op->a);
}
Expr RewriteSimplifier::Impl::
Mutate_(const GE* op, const Expr& self) {
return Mutate(Not::make(op->a < op->b));
}
Expr RewriteSimplifier::Impl::
Mutate_(const LT* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<LT>();
Expr const_res = TryConstFold<LT>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y, z, s1, s2;
// Pattern var match IntImm
PVar<Integer> c1, c2;
PVar<int> lanes;
// vector rule
if (op->type.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes),
broadcast(x < y, lanes));
TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes),
broadcast(x < y, lanes));
}
if (IsIndexType(op->a.type())) {
CompareResult result = TryCompare(op->a - op->b, 0);
if (result == kLT) {
return make_const(op->type, true);
}
if (result == kEQ || result == kGT) {
return make_const(op->type, false);
}
TVM_TRY_REWRITE(x + y < x + z, y < z);
TVM_TRY_REWRITE(x + y < z + x, y < z);
TVM_TRY_REWRITE(y + x < x + z, y < z);
TVM_TRY_REWRITE(y + x < z + x, y < z);
TVM_TRY_REWRITE(y - x < z - x, y < z);
TVM_TRY_REWRITE(x - y < x - z, z < y);
TVM_TRY_REWRITE(x < x + z, 0 < z);
TVM_TRY_REWRITE(x < z + x, 0 < z);
TVM_TRY_REWRITE(x < x - z, z < 0);
TVM_TRY_REWRITE(c1 < x + c2, c1 - c2 < x);
TVM_TRY_REWRITE(c1 < c2 - x, x < c2 - c1);
TVM_TRY_REWRITE_IF(x * c1 < y * c1, x < y,
c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x,
c1.Eval()->value < 0);
// require c1 > 0 to work for any div mode
TVM_TRY_REWRITE_IF(x * c2 < c1, x < (c1 - 1) / c2 + 1,
c1.Eval()->value > 0 &&
c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2,
c1.Eval()->value > 0 &&
c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(c1 < x * c2, c1 / c2 < x,
c1.Eval()->value >= 0 &&
c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(c1 < x / c2, (c1 + 1) * c2 - 1 < x,
c1.Eval()->value >= 0 &&
c2.Eval()->value > 0);
// division related simplificationx
// invariance for any div mod: x - (x / c1) * c1 == x % c1
TVM_TRY_REWRITE_IF((x / c1) * c1 < x, 0 < x % c1,
c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF((x / c1) * c1 < x + y, 0 < x % c1 + y,
c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF((x / c1) * c1 < x - y, y < x % c1,
c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(((x + c2)/ c1) * c1 < x,
c2 < (x + c2) % c1,
c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(((x + c2)/ c1) * c1 < x + y,
c2 < (x + c2) % c1 + y,
c1.Eval()->value > 0);
TVM_TRY_REWRITE_IF(((x + c2)/ c1) * c1 < x - y,
y < (x + c2) % c1 + (0 - c2),
c1.Eval()->value > 0);
// canonicalization rule
TVM_TRY_RECURSIVE_REWRITE(min(x, y) < z, x < z || y < z);
TVM_TRY_RECURSIVE_REWRITE(max(x, y) < z, x < z && y < z);
TVM_TRY_RECURSIVE_REWRITE(z < min(x, y), z < x && z < y);
TVM_TRY_RECURSIVE_REWRITE(z < max(x, y), z < x || z < y);
TVM_TRY_REWRITE(x - c1 < 0, x < c1);
TVM_TRY_REWRITE(x + c1 < c2, x < c2 - c1);
}
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const Not* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Not>();
Expr const_res = TryConstFold<Not>(op->a);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y;
PVar<int> lanes;
if (op->type.lanes() != 1) {
TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes));
}
TVM_TRY_REWRITE(!(!x), x);
TVM_TRY_REWRITE(!(x <= y), y < x);
TVM_TRY_REWRITE(!(x >= y), x < y);
TVM_TRY_REWRITE(!(x < y), y <= x);
TVM_TRY_REWRITE(!(x > y), x <= y);
TVM_TRY_REWRITE(!(x == y), x != y);
TVM_TRY_REWRITE(!(x != y), x == y);
TVM_TRY_RECURSIVE_REWRITE(!(x || y), (!x) && (!y));
TVM_TRY_RECURSIVE_REWRITE(!(x && y), (!x) || (!y));
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const And* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<And>();
Expr const_res = TryConstFold<And>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y;
// Pattern var match IntImm
PVar<Integer> c1, c2;
PVar<int> lanes;
if (op->type.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes),
broadcast(x && y, lanes));
}
auto cfalse = PConst<Expr>(make_const(op->type, false));
TVM_TRY_REWRITE(x == y && x != y, cfalse);
TVM_TRY_REWRITE(x != y && x == y, cfalse);
TVM_TRY_REWRITE(x && !x, cfalse);
TVM_TRY_REWRITE(x <= y && y < x, cfalse);
TVM_TRY_REWRITE(y < x && y <= x, cfalse);
TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse,
c2.Eval()->value + 1 >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse,
c2.Eval()->value + 1 >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse,
c2.Eval()->value >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse,
c2.Eval()->value >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse,
c2.Eval()->value >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse,
c2.Eval()->value >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse,
c2.Eval()->value > c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse,
c2.Eval()->value > c1.Eval()->value);
TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2);
TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2);
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const Or* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Or>();
Expr const_res = TryConstFold<Or>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y;
// Pattern var match IntImm
PVar<Integer> c1, c2;
PVar<int> lanes;
if (op->type.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes),
broadcast(x || y, lanes));
}
auto ctrue = PConst<Expr>(make_const(op->type, true));
TVM_TRY_REWRITE(x == y || x != y, ctrue);
TVM_TRY_REWRITE(x != y || x == y, ctrue);
TVM_TRY_REWRITE(x || !x, ctrue);
TVM_TRY_REWRITE(x <= y || y < x, ctrue);
TVM_TRY_REWRITE(y < x || y <= x, ctrue);
TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue,
c2.Eval()->value < c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue,
c2.Eval()->value < c1.Eval()->value);
TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue,
c2.Eval()->value <= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue,
c2.Eval()->value <= c1.Eval()->value);
TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue,
c2.Eval()->value <= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue,
c2.Eval()->value <= c1.Eval()->value);
TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue,
c2.Eval()->value <= c1.Eval()->value + 1);
TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue,
c2.Eval()->value <= c1.Eval()->value + 1);
TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2);
TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2);
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const Ramp* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Ramp>();
if (is_zero(op->stride)) {
return Broadcast::make(op->base, op->lanes);
}
return ret;
}
Expr RewriteSimplifier::Impl::
Mutate_(const Select* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Select>();
if (is_zero(op->condition)) {
return op->false_value;
}
if (is_one(op->condition)) {
return op->true_value;
}
// Pattern var to match any expression
PVar<Expr> x, y;
TVM_TRY_REWRITE(select(x, y, y), y);
return ret;
}
Expr RewriteSimplifier::operator()(const Expr& expr) {
return impl_->PostOrderSimplify(expr);
......
......@@ -6,8 +6,7 @@ class RewriteChecker:
def verify(self, data, expected):
res = self.analyzer.rewrite_simplify(data)
assert tvm.ir_pass.Equal(res, expected), "data={}, res={}, expected={}".format(
data, res, expected)
assert tvm.ir_pass.Equal(res, expected), "data={}, res={}, expected={}".format(data, res, expected)
def test_vector_simplify():
......@@ -62,20 +61,57 @@ def test_vector_simplify():
ck.verify(tvm.expr.Ramp(x * 8 + 1, 15, 4) % 8,
tvm.expr.Ramp(1, 15, 4) % 8)
# Min/Max rules
vx = tvm.var("vx", dtype="int32x2")
vc = tvm.var("vc", dtype="uint1")
ck.verify(tvm.min(y.astype("int32x2"), x.astype("int32x2")),
tvm.min(y, x).astype("int32x2"))
ck.verify(tvm.min(tvm.min(vx, y.astype("int32x2")), x.astype("int32x2")),
tvm.min(vx, tvm.min(y, x).astype("int32x2")))
ck.verify(tvm.max(y.astype("int32x2"), x.astype("int32x2")),
tvm.max(y, x).astype("int32x2"))
ck.verify(tvm.max(tvm.max(vx, y.astype("int32x2")), x.astype("int32x2")),
tvm.max(vx, tvm.max(y, x).astype("int32x2")))
## Logical rules
ck.verify(y.astype("int32x2").equal(x.astype("int32x2")),
(y.equal(x)).astype("uint1x2"))
ck.verify(tvm.expr.NE(y.astype("int32x2"), (x.astype("int32x2"))),
(tvm.expr.NE(y, x)).astype("uint1x2"))
ck.verify(y.astype("int32x2") > x.astype("int32x2"),
(x < y).astype("uint1x2"))
ck.verify(y.astype("int32x2") >= x.astype("int32x2"),
(x <= y).astype("uint1x2"))
ck.verify(y.astype("int32x2") < x.astype("int32x2"),
(y < x).astype("uint1x2"))
ck.verify(y.astype("int32x2") <= x.astype("int32x2"),
(y <= x).astype("uint1x2"))
ck.verify(tvm.expr.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
(tvm.expr.And(y <= x, vc)).astype("uint1x2"))
ck.verify(tvm.expr.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")),
(tvm.expr.Or(y <= x, vc)).astype("uint1x2"))
def test_select_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
# Add rules
ck.verify(tvm.expr.Select(x > 0, y, 0) + tvm.expr.Select(x > 0, 1, z),
tvm.expr.Select(x > 0, y + 1, z))
ck.verify(tvm.expr.Select(x > 0, y, 1) - tvm.expr.Select(x > 0, 1, z),
tvm.expr.Select(x > 0, y + (-1), 1 - z))
ck.verify(tvm.expr.Select(x > 0, y, z) - y,
tvm.expr.Select(x > 0, 0, z - y))
ck.verify(tvm.expr.Select(x > 0, y, z) - z,
tvm.expr.Select(x > 0, y - z, 0))
ck.verify(tvm.expr.Select(x < 0, y, 0) + tvm.expr.Select(x < 0, 1, z),
tvm.expr.Select(x < 0, y + 1, z))
ck.verify(tvm.expr.Select(x < 0, y, 1) - tvm.expr.Select(x < 0, 1, z),
tvm.expr.Select(x < 0, y + (-1), 1 - z))
ck.verify(tvm.expr.Select(x < 0, y, z) - y,
tvm.expr.Select(x < 0, 0, z - y))
ck.verify(tvm.expr.Select(x < 0, y, z) - z,
tvm.expr.Select(x < 0, y - z, 0))
ck.verify(tvm.min(tvm.expr.Select(x < 0, y, 0), tvm.expr.Select(x < 0, 1, z)),
tvm.expr.Select(x < 0, tvm.min(y, 1), tvm.min(0, z)))
ck.verify(tvm.max(tvm.expr.Select(x < 0, y, 0), tvm.expr.Select(x < 0, 1, z)),
tvm.expr.Select(x < 0, tvm.max(y, 1), tvm.max(0, z)))
ck.verify(tvm.expr.Select(x * 3 + 1 != 0, y, z), y)
ck.verify(tvm.expr.Select(x * 3 + 1 == 0, y, z), z)
ck.verify(tvm.expr.Select(x > 0, y + 1, y + 1), y + 1)
def test_add_index_simplify():
......@@ -242,11 +278,231 @@ def test_mod_index_simplify():
ck.verify((x* 10 + 1 + y * 2 + 2) % 2, 1)
def test_min_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
# const int bound
ck.verify(tvm.min(x % 2, y % 2 + 10), x % 2)
ck.verify(tvm.min(x + 1, x + 10), x + 1)
ck.verify(tvm.min(x + 111, x + 10), x + 10)
ck.verify(tvm.min(x + 1, x), x)
ck.verify(tvm.min(x, x + 2), x)
ck.verify(tvm.min(1 - x, 2 - x), 1 - x)
ck.verify(tvm.min(3 - x, 2 - x), 2 - x)
ck.verify(tvm.min((x + 3) / 4 * 4, x), x)
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000))
ck.verify(tvm.min((x + 3) / 4 * 4, tvm.max(x, 4)), tvm.max(x, 4))
ck.verify(tvm.min(x, (x + 3) / 4 * 4), x)
ck.verify(tvm.min(tvm.max(x, 4), (x + 3) / 4 * 4), tvm.max(x, 4))
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True)
ck.verify(tvm.min(tvm.max(x, y), tvm.min(x, y)), tvm.min(x, y))
ck.verify(tvm.min(tvm.max(x, y), tvm.min(y, x)), tvm.min(x, y))
ck.verify(tvm.min(tvm.max(x, y), x), x)
ck.verify(tvm.min(tvm.max(y, x), x), x)
ck.verify(tvm.min(tvm.min(x, y), x), tvm.min(x, y))
ck.verify(tvm.min(tvm.min(x, y), y), tvm.min(x, y))
ck.verify(tvm.min(x, tvm.max(x, y)), x)
ck.verify(tvm.min(x, tvm.max(y, x)), x)
ck.verify(tvm.min(x, tvm.min(x, y)), tvm.min(x, y))
ck.verify(tvm.min(y, tvm.min(x, y)), tvm.min(x, y))
ck.verify(tvm.min(tvm.min(tvm.min(x, y), z), y),
tvm.min(tvm.min(x, y), z))
ck.verify(tvm.min(tvm.min(tvm.min(tvm.min(x, y), z), x * 2), y),
tvm.min(tvm.min(tvm.min(x, y), z), x * 2))
ck.verify(tvm.min(tvm.min(tvm.min(tvm.min(tvm.min(x, y), z), x * 2), z * 2), y),
tvm.min(tvm.min(tvm.min(tvm.min(x, y), z), x * 2), z * 2))
ck.verify(tvm.min(tvm.max(x, y), tvm.max(x, z)), tvm.max(tvm.min(y, z), x))
ck.verify(tvm.min(tvm.max(x, y), tvm.max(z, x)), tvm.max(tvm.min(y, z), x))
ck.verify(tvm.min(tvm.max(y, x), tvm.max(x, z)), tvm.max(tvm.min(y, z), x))
ck.verify(tvm.min(tvm.max(y, x), tvm.max(z, x)), tvm.max(tvm.min(y, z), x))
ck.verify(tvm.min(y + x, z + x), tvm.min(y, z) + x)
ck.verify(tvm.min(y + x, x + z), tvm.min(y, z) + x)
ck.verify(tvm.min(x + y, z + x), tvm.min(y, z) + x)
ck.verify(tvm.min(x + y, x + z), tvm.min(y, z) + x)
ck.verify(tvm.min(x - y, x - z), x - tvm.max(y, z))
ck.verify(tvm.min(y - x, z - x), tvm.min(y, z) - x)
ck.verify(tvm.min(tvm.min(x, 1), 10), tvm.min(x, 1))
ck.verify(tvm.min(tvm.min(x, 11), 10), tvm.min(x, 10))
ck.verify(tvm.min(x / 10, y / 10), tvm.min(x, y) / 10)
ck.verify(tvm.min(x / (-10), y / (-10)), tvm.max(x, y) / (-10))
ck.verify(tvm.min(x * 3, 9), tvm.min(x, 3) * 3)
def test_max_index_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
# const int bound
ck.verify(tvm.max(x % 2, y % 2 + 10), y % 2 + 10)
ck.verify(tvm.max(x + 1, x + 10), x + 10)
ck.verify(tvm.max(x + 111, x + 10), x + 111)
ck.verify(tvm.max(x + 1, x), x + 1)
ck.verify(tvm.max(x, x + 2), x + 2)
ck.verify(tvm.max(1 - x, 2 - x), 2 - x)
ck.verify(tvm.max(3 - x, 2 - x), 3 - x)
ck.verify(tvm.max((x + 3) / 4 * 4, x), (x + 3) / 4 * 4)
ck.verify(tvm.max(tvm.min(x, y), tvm.max(x, y)), tvm.max(x, y))
ck.verify(tvm.max(tvm.min(x, y), tvm.max(y, x)), tvm.max(x, y))
ck.verify(tvm.max(tvm.min(x, y), x), x)
ck.verify(tvm.max(tvm.min(y, x), x), x)
ck.verify(tvm.max(tvm.max(x, y), x), tvm.max(x, y))
ck.verify(tvm.max(tvm.max(x, y), y), tvm.max(x, y))
ck.verify(tvm.max(x, tvm.min(x, y)), x)
ck.verify(tvm.max(x, tvm.min(y, x)), x)
ck.verify(tvm.max(x, tvm.max(x, y)), tvm.max(x, y))
ck.verify(tvm.max(y, tvm.max(x, y)), tvm.max(x, y))
ck.verify(tvm.max(tvm.max(tvm.max(x, y), z), y),
tvm.max(tvm.max(x, y), z))
ck.verify(tvm.max(tvm.max(tvm.max(tvm.max(x, y), z), x * 2), y),
tvm.max(tvm.max(tvm.max(x, y), z), x * 2))
ck.verify(tvm.max(tvm.max(tvm.max(tvm.max(tvm.max(x, y), z), x * 2), z * 2), y),
tvm.max(tvm.max(tvm.max(tvm.max(x, y), z), x * 2), z * 2))
ck.verify(tvm.max(tvm.min(x, y), tvm.min(x, z)), tvm.min(tvm.max(y, z), x))
ck.verify(tvm.max(tvm.min(x, y), tvm.min(z, x)), tvm.min(tvm.max(y, z), x))
ck.verify(tvm.max(tvm.min(y, x), tvm.min(x, z)), tvm.min(tvm.max(y, z), x))
ck.verify(tvm.max(tvm.min(y, x), tvm.min(z, x)), tvm.min(tvm.max(y, z), x))
ck.verify(tvm.max(y + x, z + x), tvm.max(y, z) + x)
ck.verify(tvm.max(y + x, x + z), tvm.max(y, z) + x)
ck.verify(tvm.max(x + y, z + x), tvm.max(y, z) + x)
ck.verify(tvm.max(x + y, x + z), tvm.max(y, z) + x)
ck.verify(tvm.max(x - y, x - z), x - tvm.min(y, z))
ck.verify(tvm.max(y - x, z - x), tvm.max(y, z) - x)
ck.verify(tvm.max(tvm.max(x, 1), 10), tvm.max(x, 10))
ck.verify(tvm.max(tvm.max(x, 11), 10), tvm.max(x, 11))
ck.verify(tvm.max(x / 10, y / 10), tvm.max(x, y) / 10)
ck.verify(tvm.max(x / (-10), y / (-10)), tvm.min(x, y) / (-10))
ck.verify(tvm.max(x * 3, 9), tvm.max(x, 3) * 3)
def test_cmp_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
# const int bound
ck.verify((x % 2 + 10).equal(0), tvm.const(0, "bool"))
ck.verify(tvm.expr.NE(x % 2 + 10, 0), tvm.const(1, "bool"))
ck.verify(x % 2 + 10 > 1, tvm.const(1, "bool"))
ck.verify(x % 2 + 10 <= 1, tvm.const(0, "bool"))
ck.verify(x * 3 + 10 == 0, tvm.const(0, "bool"))
ck.verify(x * 3 + 10 != 0, tvm.const(1, "bool"))
# canonicalization
ck.verify((x - 10).equal(0), x.equal(10))
ck.verify((10 - x).equal(0), x.equal(10))
ck.verify((x * y).equal(0), tvm.expr.Or(x.equal(0), y.equal(0)))
# cmp bound
ck.verify(x + y < x + z, y < z)
ck.verify(x + y < z + x, y < z)
ck.verify(y + x < x + z, y < z)
ck.verify(y + x < z + x, y < z)
ck.verify(y - x < z - x, y < z)
ck.verify(x - y < x - z, z < y)
ck.verify(x < z + x, tvm.expr.LT(0, z))
ck.verify(x < x + z, tvm.expr.LT(0, z))
ck.verify(100 < x + 1, tvm.expr.LT(99, x))
ck.verify(1 < 100 - x, tvm.expr.LT(x, 99))
ck.verify(x * 3 < y * 3, x < y)
ck.verify(x * (-3) < y * (-3), y < x)
ck.verify(x * 3 >= y * 3, y <= x)
ck.verify(x * 4 >= 2, tvm.expr.LE(1, x))
ck.verify(x * 2 >= 50, tvm.expr.LE(25, x))
ck.verify(x / 2 < 3, x < 6)
ck.verify(x * 4 <= 2, x <= 0)
ck.verify(3 < x / 2, tvm.expr.LT(7, x))
ck.verify(x / 4 * 4 < x, tvm.expr.LT(0, x % 4))
ck.verify(x / 4 * 4 >= x, tvm.expr.LE(x % 4, 0))
ck.verify(x / 4 * 4 < x + y, tvm.expr.LT(0, x % 4 + y))
ck.verify(x / 4 * 4 < x - y, tvm.expr.LT(y, x % 4))
ck.verify((x + 2) / 4 * 4 >= x, tvm.expr.LE((x + 2) % 4, 2))
ck.verify((x + 2) / 4 * 4 >= x + y, tvm.expr.LE((x + 2) % 4 + y, 2))
ck.verify((x + 2) / 4 * 4 >= x - y, tvm.expr.LE((x + 2) % 4 + (-2), y))
ck.verify(tvm.min(x, 11) < 10, x < 10)
ck.verify(tvm.min(x, 8) < 10, tvm.const(1, "bool"))
ck.verify(tvm.max(8, x) > 10, tvm.expr.LT(10, x))
ck.verify(x + 1 < tvm.max(8, x), x < 7)
def test_logical_simplify():
ck = RewriteChecker()
x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z")
ck.verify(tvm.expr.And(tvm.expr.EQ(x, y), tvm.expr.NE(x, y)),
tvm.const(False, "bool"))
ck.verify(tvm.expr.And(tvm.expr.NE(x, y), tvm.expr.EQ(x, y)),
tvm.const(False, "bool"))
ck.verify(tvm.expr.And(x > 1, tvm.expr.Not(x > 1)), tvm.const(False, "bool"))
ck.verify(tvm.expr.And(x <= y, y < x), tvm.const(False, "bool"))
ck.verify(tvm.expr.And(y < x, y <= x), tvm.const(False, "bool"))
ck.verify(tvm.expr.And(x < 1, 0 < x), tvm.const(False, "bool"))
ck.verify(tvm.expr.And(x < 0, 1 < x), tvm.const(False, "bool"))
ck.verify(tvm.expr.And(x < 1, 1 <= x), tvm.const(False, "bool"))
ck.verify(tvm.expr.And(x <= 1, 1 < x), tvm.const(False, "bool"))
ck.verify(tvm.expr.And(1 <= x, x < 1), tvm.const(False, "bool"))
ck.verify(tvm.expr.And(1 < x, x <= 1), tvm.const(False, "bool"))
ck.verify(tvm.expr.And(x <= 1, 2 <= x), tvm.const(False, "bool"))
ck.verify(tvm.expr.And(2 <= x, x <= 1), tvm.const(False, "bool"))
ck.verify(tvm.expr.And(x == 1, x != 2), x == 1)
ck.verify(tvm.expr.Or(tvm.expr.EQ(x, y), tvm.expr.NE(x, y)),
tvm.const(True, "bool"))
ck.verify(tvm.expr.Or(tvm.expr.NE(x, y), tvm.expr.EQ(x, y)),
tvm.const(True, "bool"))
ck.verify(tvm.expr.Or(x > y, tvm.expr.Not(x < y)), tvm.const(True, "bool"))
ck.verify(tvm.expr.Or(x <= y, y < x), tvm.const(True, "bool"))
ck.verify(tvm.expr.Or(y < x, y <= x), tvm.const(True, "bool"))
ck.verify(tvm.expr.Or(x < 1, 0 < x), tvm.const(True, "bool"))
ck.verify(tvm.expr.Or(0 < x, x < 1), tvm.const(True, "bool"))
ck.verify(tvm.expr.Or(x < 1, 1 <= x), tvm.const(True, "bool"))
ck.verify(tvm.expr.Or(x <= 1, 1 < x), tvm.const(True, "bool"))
ck.verify(tvm.expr.Or(1 <= x, x < 1), tvm.const(True, "bool"))
ck.verify(tvm.expr.Or(1 < x, x <= 1), tvm.const(True, "bool"))
ck.verify(tvm.expr.Or(x <= 1, 2 <= x), tvm.const(True, "bool"))
ck.verify(tvm.expr.Or(2 <= x, x <= 1), tvm.const(True, "bool"))
ck.verify(tvm.expr.Or(x != 1, x == 2), x != 1)
if __name__ == "__main__":
test_mod_index_simplify()
test_cmp_simplify()
test_vector_simplify()
test_add_index_simplify()
test_sub_index_simplify()
test_mul_index_simplify()
test_div_index_simplify()
test_max_index_simplify()
test_min_index_simplify()
test_mod_index_simplify()
test_select_simplify()
test_logical_simplify()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment