Commit a5eb4451 by Salem Derisavi Committed by Yizhi Liu

Enhanced simplification rules for Div by a positive constant (#2346)

* Enhanced simplification rules for Div by a positive constant

* Fixed my last commit to correctly interpret TVM's division as truncated division

* Fixed implemenation of IntSet::can_prove_non_positive()

* addressed comments by @yzhliu

* addressed comments by @sgrechanik-h

* addressed more comments by @yzhliu
parent 6e5c65ac
...@@ -69,6 +69,10 @@ class IntSet : public NodeRef { ...@@ -69,6 +69,10 @@ class IntSet : public NodeRef {
bool can_prove_positive() const; bool can_prove_positive() const;
/*! \return Whether the set is proved to be smaller than 0 */ /*! \return Whether the set is proved to be smaller than 0 */
bool can_prove_negative() const; bool can_prove_negative() const;
/*! \return Whether the set is proved to be smaller than or equal to 0 */
bool can_prove_non_positive() const;
/*! \return Whether the set is proved to be larger than or equal to 0 */
bool can_prove_non_negative() const;
/*! \return The sign of the elements in the integer set */ /*! \return The sign of the elements in the integer set */
SignType sign_type() const; SignType sign_type() const;
/*! /*!
......
...@@ -481,41 +481,76 @@ class Canonical::Internal : public IRMutator { ...@@ -481,41 +481,76 @@ class Canonical::Internal : public IRMutator {
} }
return value; return value;
} }
// Detect if a = x * coeff + y, where y \in [0, coeff), x >= 0 // Detect if a = q * coeff + r, where r \in [0, coeff), coeff > 0
// return true if such detection is successful // (in Euclidean division)
// return false if it is not. // returns pair (q, r) if such detection is successful
// returns empty vector otherwise.
// Assumes that coeff is a constant integer
std::vector<ComExpr> TryLinearEquation(const ComExpr& a, std::vector<ComExpr> TryLinearEquation(const ComExpr& a,
const Expr& coeff) { const Expr& coeff) {
Type type = coeff.type(); Type type = coeff.type();
int64_t value = GetConstIntValue(coeff); int64_t value = GetConstIntValue(coeff);
CHECK_NE(value, 0);
if (value < 0) return {}; if (value < 0) return {};
auto xnode = make_node<ComExprNode>(); // Given that denominator (value variable) is positive, truncated division
auto ynode = make_node<ComExprNode>(); // (i.e., TVM's division semantics) is equivalent to Euclidean division if and only if
// numerator is non-negative or numerator is divisible by denominator (i.e., value)
IntSet numerator_int_set = EvalSet(Sum2Expr(a, type), var_range_);
bool numerator_is_non_neg = numerator_int_set.can_prove_non_negative();
// Try to separate terms of a into ones that can be proven to be
// divisible by coeff and ones that are not
// We will build q and r from divisible and non_divisible respectively
auto divisible = make_node<ComExprNode>();
auto non_divisible = make_node<ComExprNode>();
if (a->base % value == 0) { if (a->base % value == 0) {
xnode->base = a->base; divisible->base = a->base;
} else { } else {
ynode->base = a->base; non_divisible->base = a->base;
} }
for (const auto& e : a->elem) { for (const auto& e : a->elem) {
if (e.scale % value == 0) { if (e.scale % value == 0) {
xnode->elem.push_back(e); divisible->elem.push_back(e);
} else { } else {
ynode->elem.push_back(e); non_divisible->elem.push_back(e);
} }
} }
Expr yres = Sum2Expr(ComExpr(ynode), type); bool non_divisible_is_simplified = false;
IntSet yset = EvalSet(yres, var_range_); int64_t div_result;
// This relies on the integer division rounds down Expr non_divisible_res = Sum2Expr(ComExpr(non_divisible), type);
// Most cases it is good for integer division. // if non_divisible part consists of only an integer and numerator is non-negative,
if (yset.min().type() == type && // we can simply divide it by coeff
can_prove(yset.min() >= make_zero(type)) && if (is_const(non_divisible_res)) {
yset.max().type() == type && int64_t non_divisible_const = GetConstIntValue(non_divisible_res);
can_prove(yset.max() < coeff)) { if (numerator_is_non_neg || non_divisible_const == 0) {
xnode->base /= value; non_divisible_is_simplified = true;
for (auto &e : xnode->elem) { // We need to do an Euclidean division here because (a*b + c)/b == a + c/b
// holds true only if division is Euclidean
div_result = HalideIR::Internal::div_imp(non_divisible_const , value);
}
} else {
// If we can prove that non_divisible part lies within [0, coeff), then
// non_divisible itself will be our r
IntSet non_divisible_set = EvalSet(non_divisible_res, var_range_);
if (non_divisible_set.min().type() == type &&
non_divisible_set.max().type() == type) {
if ( (non_divisible_set.is_single_point() &&
can_prove(non_divisible_set.point_value() == 0)) ||
(numerator_is_non_neg &&
can_prove(non_divisible_set.min() >= make_zero(type)) &&
can_prove(non_divisible_set.max() < coeff)) ) {
non_divisible_is_simplified = true;
div_result = 0;
}
}
}
if (non_divisible_is_simplified) {
non_divisible->base -= div_result * value;
divisible->base /= value;
divisible->base += div_result;
for (auto& e : divisible->elem) {
e.scale /= value; e.scale /= value;
} }
return {ComExpr(xnode), ComExpr(ynode)}; return {ComExpr(divisible), ComExpr(non_divisible)};
} else { } else {
return {}; return {};
} }
...@@ -526,6 +561,12 @@ class Canonical::Internal : public IRMutator { ...@@ -526,6 +561,12 @@ class Canonical::Internal : public IRMutator {
if (pair.size() == 0) { if (pair.size() == 0) {
int64_t value = GetConstIntValue(v); int64_t value = GetConstIntValue(v);
auto n = make_node<ComExprNode>(); auto n = make_node<ComExprNode>();
// FIXME(derisavi) : The following can be done only for Euclidean division/mod.
// Therefore, it's only valid when truncated division/mod is equivalent to Euclidean one,
// that is, if and only if a and v are
// both negative or both positive or a is divisible by v.
// Extend the code to handle cases where the above condition is not satisfied, i.e.,
// a and v are of different signs and a is not divisible by v.
n->base = a->base % value; n->base = a->base % value;
for (auto e : a->elem) { for (auto e : a->elem) {
if (e.scale % value == 0) continue; if (e.scale % value == 0) continue;
......
...@@ -84,6 +84,25 @@ bool IntSet::can_prove_negative() const { ...@@ -84,6 +84,25 @@ bool IntSet::can_prove_negative() const {
return (s_int && is_negative_const(ir::Simplify(s_int->i.max))); return (s_int && is_negative_const(ir::Simplify(s_int->i.max)));
} }
bool IntSet::can_prove_non_positive() const {
if (const IntervalSet* s_int = (*this).as<IntervalSet>()) {
auto max = ir::Simplify(s_int->i.max);
return is_zero(max) || is_negative_const(max);
}
return false;
}
bool IntSet::can_prove_non_negative() const {
if (const IntervalSet* s_int = (*this).as<IntervalSet>()) {
// Any reason why we should or should not use can_prove() to implement
// these functions?
auto min = ir::Simplify(s_int->i.min);
return is_zero(min) || is_positive_const(min);
}
return false;
}
SignType IntSet::sign_type() const { SignType IntSet::sign_type() const {
if (can_prove_positive()) { if (can_prove_positive()) {
return kPositive; return kPositive;
......
...@@ -29,6 +29,24 @@ def test_simplify(): ...@@ -29,6 +29,24 @@ def test_simplify():
# assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / (-1)), # assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / (-1)),
# tvm.ir_pass.CanonicalSimplify(-n)) # tvm.ir_pass.CanonicalSimplify(-n))
def test_simplify_div():
x = tvm.var('x')
assert tvm.ir_pass.CanonicalSimplify((16+48*x)/16 - (1 + (x*3))).value == 0
# (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
# (17+48*x)/16 != 1+3*x
r = tvm.ir_pass.CanonicalSimplify((17+48*x)/16)
assert r.b.value == 16
assert tvm.ir_pass.CanonicalSimplify(r.a - (17 + 48*x)).value == 0
# However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified
assert tvm.ir_pass.CanonicalSimplify((17+48*x)/16 - (1 + (x*3)), {x: tvm.Range(0,10)}).value == 0
# Trying expressions that are not simplifiable for any values of the variables
r = tvm.ir_pass.CanonicalSimplify((17+47*x)/16, {x: tvm.Range(0,10)})
assert r.b.value == 16
assert tvm.ir_pass.CanonicalSimplify(r.a - (17+47*x)).value == 0
r = tvm.ir_pass.CanonicalSimplify((8*x - 17)/8, {x : tvm.Range(4,10)})
assert tvm.ir_pass.CanonicalSimplify(r - (x-3)).value == 0
def test_simplify_mod(): def test_simplify_mod():
"""Not yet working, mock design""" """Not yet working, mock design"""
...@@ -42,8 +60,12 @@ def test_simplify_mod(): ...@@ -42,8 +60,12 @@ def test_simplify_mod():
stmt = tvm.ir_pass.CanonicalSimplify(body) stmt = tvm.ir_pass.CanonicalSimplify(body)
diff = tvm.ir_pass.CanonicalSimplify(stmt.body.value.index - (1 + i) % 16) diff = tvm.ir_pass.CanonicalSimplify(stmt.body.value.index - (1 + i) % 16)
assert diff.value == 0 assert diff.value == 0
# if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16
index = tvm.ir_pass.CanonicalSimplify( index = tvm.ir_pass.CanonicalSimplify(
(j + n * 32) % 16, {j: tvm.Range(0, 6)}) (j + n * 32) % 16, {j: tvm.Range(0, 6)})
assert index != j
index = tvm.ir_pass.CanonicalSimplify(
(j + n * 32) % 16, {j: tvm.Range(0, 6), n: tvm.Range(0, 10)})
assert index == j assert index == j
def test_simplify_minmax(): def test_simplify_minmax():
...@@ -79,6 +101,7 @@ def test_modular(): ...@@ -79,6 +101,7 @@ def test_modular():
assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0 assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0
if __name__ == "__main__": if __name__ == "__main__":
test_simplify_div()
test_simplify_mod() test_simplify_mod()
test_modular() test_modular()
test_simplify() test_simplify()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment