Commit a5eb4451 by Salem Derisavi Committed by Yizhi Liu

Enhanced simplification rules for Div by a positive constant (#2346)

* Enhanced simplification rules for Div by a positive constant

* Fixed my last commit to correctly interpret TVM's division as truncated division

* Fixed implemenation of IntSet::can_prove_non_positive()

* addressed comments by @yzhliu

* addressed comments by @sgrechanik-h

* addressed more comments by @yzhliu
parent 6e5c65ac
......@@ -69,6 +69,10 @@ class IntSet : public NodeRef {
bool can_prove_positive() const;
/*! \return Whether the set is proved to be smaller than 0 */
bool can_prove_negative() const;
/*! \return Whether the set is proved to be smaller than or equal to 0 */
bool can_prove_non_positive() const;
/*! \return Whether the set is proved to be larger than or equal to 0 */
bool can_prove_non_negative() const;
/*! \return The sign of the elements in the integer set */
SignType sign_type() const;
/*!
......
......@@ -481,41 +481,76 @@ class Canonical::Internal : public IRMutator {
}
return value;
}
// Detect if a = x * coeff + y, where y \in [0, coeff), x >= 0
// return true if such detection is successful
// return false if it is not.
// Detect if a = q * coeff + r, where r \in [0, coeff), coeff > 0
// (in Euclidean division)
// returns pair (q, r) if such detection is successful
// returns empty vector otherwise.
// Assumes that coeff is a constant integer
std::vector<ComExpr> TryLinearEquation(const ComExpr& a,
const Expr& coeff) {
Type type = coeff.type();
int64_t value = GetConstIntValue(coeff);
CHECK_NE(value, 0);
if (value < 0) return {};
auto xnode = make_node<ComExprNode>();
auto ynode = make_node<ComExprNode>();
// Given that denominator (value variable) is positive, truncated division
// (i.e., TVM's division semantics) is equivalent to Euclidean division if and only if
// numerator is non-negative or numerator is divisible by denominator (i.e., value)
IntSet numerator_int_set = EvalSet(Sum2Expr(a, type), var_range_);
bool numerator_is_non_neg = numerator_int_set.can_prove_non_negative();
// Try to separate terms of a into ones that can be proven to be
// divisible by coeff and ones that are not
// We will build q and r from divisible and non_divisible respectively
auto divisible = make_node<ComExprNode>();
auto non_divisible = make_node<ComExprNode>();
if (a->base % value == 0) {
xnode->base = a->base;
divisible->base = a->base;
} else {
ynode->base = a->base;
non_divisible->base = a->base;
}
for (const auto& e : a->elem) {
if (e.scale % value == 0) {
xnode->elem.push_back(e);
divisible->elem.push_back(e);
} else {
ynode->elem.push_back(e);
non_divisible->elem.push_back(e);
}
}
Expr yres = Sum2Expr(ComExpr(ynode), type);
IntSet yset = EvalSet(yres, var_range_);
// This relies on the integer division rounds down
// Most cases it is good for integer division.
if (yset.min().type() == type &&
can_prove(yset.min() >= make_zero(type)) &&
yset.max().type() == type &&
can_prove(yset.max() < coeff)) {
xnode->base /= value;
for (auto &e : xnode->elem) {
bool non_divisible_is_simplified = false;
int64_t div_result;
Expr non_divisible_res = Sum2Expr(ComExpr(non_divisible), type);
// if non_divisible part consists of only an integer and numerator is non-negative,
// we can simply divide it by coeff
if (is_const(non_divisible_res)) {
int64_t non_divisible_const = GetConstIntValue(non_divisible_res);
if (numerator_is_non_neg || non_divisible_const == 0) {
non_divisible_is_simplified = true;
// We need to do an Euclidean division here because (a*b + c)/b == a + c/b
// holds true only if division is Euclidean
div_result = HalideIR::Internal::div_imp(non_divisible_const , value);
}
} else {
// If we can prove that non_divisible part lies within [0, coeff), then
// non_divisible itself will be our r
IntSet non_divisible_set = EvalSet(non_divisible_res, var_range_);
if (non_divisible_set.min().type() == type &&
non_divisible_set.max().type() == type) {
if ( (non_divisible_set.is_single_point() &&
can_prove(non_divisible_set.point_value() == 0)) ||
(numerator_is_non_neg &&
can_prove(non_divisible_set.min() >= make_zero(type)) &&
can_prove(non_divisible_set.max() < coeff)) ) {
non_divisible_is_simplified = true;
div_result = 0;
}
}
}
if (non_divisible_is_simplified) {
non_divisible->base -= div_result * value;
divisible->base /= value;
divisible->base += div_result;
for (auto& e : divisible->elem) {
e.scale /= value;
}
return {ComExpr(xnode), ComExpr(ynode)};
return {ComExpr(divisible), ComExpr(non_divisible)};
} else {
return {};
}
......@@ -526,6 +561,12 @@ class Canonical::Internal : public IRMutator {
if (pair.size() == 0) {
int64_t value = GetConstIntValue(v);
auto n = make_node<ComExprNode>();
// FIXME(derisavi) : The following can be done only for Euclidean division/mod.
// Therefore, it's only valid when truncated division/mod is equivalent to Euclidean one,
// that is, if and only if a and v are
// both negative or both positive or a is divisible by v.
// Extend the code to handle cases where the above condition is not satisfied, i.e.,
// a and v are of different signs and a is not divisible by v.
n->base = a->base % value;
for (auto e : a->elem) {
if (e.scale % value == 0) continue;
......
......@@ -84,6 +84,25 @@ bool IntSet::can_prove_negative() const {
return (s_int && is_negative_const(ir::Simplify(s_int->i.max)));
}
bool IntSet::can_prove_non_positive() const {
if (const IntervalSet* s_int = (*this).as<IntervalSet>()) {
auto max = ir::Simplify(s_int->i.max);
return is_zero(max) || is_negative_const(max);
}
return false;
}
bool IntSet::can_prove_non_negative() const {
if (const IntervalSet* s_int = (*this).as<IntervalSet>()) {
// Any reason why we should or should not use can_prove() to implement
// these functions?
auto min = ir::Simplify(s_int->i.min);
return is_zero(min) || is_positive_const(min);
}
return false;
}
SignType IntSet::sign_type() const {
if (can_prove_positive()) {
return kPositive;
......
......@@ -29,6 +29,24 @@ def test_simplify():
# assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / (-1)),
# tvm.ir_pass.CanonicalSimplify(-n))
def test_simplify_div():
x = tvm.var('x')
assert tvm.ir_pass.CanonicalSimplify((16+48*x)/16 - (1 + (x*3))).value == 0
# (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
# (17+48*x)/16 != 1+3*x
r = tvm.ir_pass.CanonicalSimplify((17+48*x)/16)
assert r.b.value == 16
assert tvm.ir_pass.CanonicalSimplify(r.a - (17 + 48*x)).value == 0
# However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified
assert tvm.ir_pass.CanonicalSimplify((17+48*x)/16 - (1 + (x*3)), {x: tvm.Range(0,10)}).value == 0
# Trying expressions that are not simplifiable for any values of the variables
r = tvm.ir_pass.CanonicalSimplify((17+47*x)/16, {x: tvm.Range(0,10)})
assert r.b.value == 16
assert tvm.ir_pass.CanonicalSimplify(r.a - (17+47*x)).value == 0
r = tvm.ir_pass.CanonicalSimplify((8*x - 17)/8, {x : tvm.Range(4,10)})
assert tvm.ir_pass.CanonicalSimplify(r - (x-3)).value == 0
def test_simplify_mod():
"""Not yet working, mock design"""
......@@ -42,8 +60,12 @@ def test_simplify_mod():
stmt = tvm.ir_pass.CanonicalSimplify(body)
diff = tvm.ir_pass.CanonicalSimplify(stmt.body.value.index - (1 + i) % 16)
assert diff.value == 0
# if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16
index = tvm.ir_pass.CanonicalSimplify(
(j + n * 32) % 16, {j: tvm.Range(0, 6)})
assert index != j
index = tvm.ir_pass.CanonicalSimplify(
(j + n * 32) % 16, {j: tvm.Range(0, 6), n: tvm.Range(0, 10)})
assert index == j
def test_simplify_minmax():
......@@ -79,6 +101,7 @@ def test_modular():
assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0
if __name__ == "__main__":
test_simplify_div()
test_simplify_mod()
test_modular()
test_simplify()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment