Commit 9d6dbe34 by xqdan Committed by Tianqi Chen

[SCHEDULE]Improve bound deduce for loop partition (#743) (#755)

* [SCHEDULE]enable partition const loop with build flag (#719)

    * enable partition loop with build flag

    * add a testcase, and modify LoopPartition related cases

*     * add document for split_const_loop

* [IRbuild]Support automatically Name Loop Variable in IRBuilder (#719)

    * add idx_num in class

* using typical index [i, j, k] first, then i_suffix

* keep inputs names

* fix lint

* improve comment of name

* fix lint

* [SCHEDULE]Improve bound deduce for loop partition (#743)

    * add divided checking when deducing

    * related testcase

* fix

* * transform LE and GE first
* remove is_equal
* modify testcase for edge cases checking

* * fix comment

* * fix lint

* * apply transformation form LT -> LE, GT -> GE

* * fix lint

* simplify code and testcase

* add negative co-efficient case

* More complicated cases

* add testcase

* simplify testcase

* comment case for now

* fix testcase
parent 48240ef6
...@@ -128,13 +128,25 @@ class BoundDeducer: public IRVisitor { ...@@ -128,13 +128,25 @@ class BoundDeducer: public IRVisitor {
} }
// always use relax bound // always use relax bound
result = result / operand + (is_greater ? 1 : -1); bool divided = can_prove(result % operand == 0);
result = result / operand;
// since system will round down when not divided
// eg. 2/4 -> 0; -2/4 -> -1
// no need fix for !is_greater:
// eg. a <= 2/4 -> a <= 0
// eg. a <= 0/4 -> a <= 0
// so just fix for not divided and is_greater
// eg. a >= 2/4 -> a >= 0 + 1
// eg. a >= 0/4 -> a >= 0
if (is_greater && !divided) {
result += 1;
}
Visit(left ? op->a : op->b); Visit(left ? op->a : op->b);
} }
Expr result; Expr result;
bool is_greater{true}; bool is_greater{true};
bool is_equal{true};
bool success{true}; bool success{true};
private: private:
...@@ -178,22 +190,20 @@ void BoundDeducer::Init() { ...@@ -178,22 +190,20 @@ void BoundDeducer::Init() {
void BoundDeducer::Transform() { void BoundDeducer::Transform() {
if (const LT* op = expr_.as<LT>()) { if (const LT* op = expr_.as<LT>()) {
is_greater = false; is_greater = false;
is_equal = false;
expr_ = op->a; expr_ = op->a;
result = op->b; // a < b -> a <= b - 1
result = op->b - 1;
} else if (const LE* op = expr_.as<LE>()) { } else if (const LE* op = expr_.as<LE>()) {
is_greater = false; is_greater = false;
is_equal = true;
expr_ = op->a; expr_ = op->a;
result = op->b; result = op->b;
} else if (const GT* op = expr_.as<GT>()) { } else if (const GT* op = expr_.as<GT>()) {
is_greater = true; is_greater = true;
is_equal = false;
expr_ = op->a; expr_ = op->a;
result = op->b; // a > b -> a >= b + 1
result = op->b + 1;
} else if (const GE* op = expr_.as<GE>()) { } else if (const GE* op = expr_.as<GE>()) {
is_greater = true; is_greater = true;
is_equal = true;
expr_ = op->a; expr_ = op->a;
result = op->b; result = op->b;
} else { } else {
...@@ -237,9 +247,9 @@ IntSet DeduceBound(Expr v, Expr e, ...@@ -237,9 +247,9 @@ IntSet DeduceBound(Expr v, Expr e,
if (!d.success) return IntSet::nothing(); if (!d.success) return IntSet::nothing();
Expr min = Interval::neg_inf, max = Interval::pos_inf; Expr min = Interval::neg_inf, max = Interval::pos_inf;
if (d.is_greater) { if (d.is_greater) {
min = d.is_equal ? d.result : d.result + 1; min = d.result;
} else { } else {
max = d.is_equal ? d.result : d.result - 1; max = d.result;
} }
return IntSet::interval(min, max); return IntSet::interval(min, max);
} }
......
...@@ -25,12 +25,17 @@ def test_deduce(): ...@@ -25,12 +25,17 @@ def test_deduce():
e0 = (-b)*a+c-d e0 = (-b)*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = (d-c)/(-b)+(-1) ans0 = ((d - c) /(b*-1))
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
e0 = d*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = ((0-c)/d + 1)
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
e1 = (a*4+b < c) e1 = (a*4+b < c)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
ans1 = (c-b)/4+(-2) ans1 = (((c - b) + -1)/4)
assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
e2 = (tvm.max(5, a * 4) < 0) e2 = (tvm.max(5, a * 4) < 0)
...@@ -59,14 +64,77 @@ def test_check(): ...@@ -59,14 +64,77 @@ def test_check():
# multiple compare operators # multiple compare operators
res2 = tvm.arith.DeduceBound(a, (a+b>3)>c , {b: b_s, c: c_s}, {}) res2 = tvm.arith.DeduceBound(a, (a+b>3)>c , {b: b_s, c: c_s}, {})
assert res1.is_nothing() assert res2.is_nothing()
# multiple target variable # multiple target variable
res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {}) res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {})
assert res1.is_nothing() assert res2.is_nothing()
def test_deduce_basic():
def test_basic(a1, a2, coff):
a = tvm.var('a')
b = tvm.var('b')
b_s = tvm.arith.intset_interval(a1, a2)
e0 = b + a*coff + 3
res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1
res1 = tvm.arith.DeduceBound(a, e0>17, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1
res1 = tvm.arith.DeduceBound(a, e0<=17, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1
res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s})
[x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1
test_basic(0, 4, 4)
test_basic(1, 5, 4)
test_basic(2, 6, 4)
test_basic(0, 4, -4)
test_basic(1, 5, -4)
test_basic(2, 6, -4)
def test_deduce_complex():
def test_complex(a1, a2, coff):
a = tvm.var('a')
b = tvm.var('b')
b_s = tvm.arith.intset_interval(a1, a2)
e0 = (b*3 + a* coff) * 4
res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s})
[t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1
res1 = tvm.arith.DeduceBound(a, e0<=63, {b: b_s}, {b: b_s})
[t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1
res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s})
[t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1
res1 = tvm.arith.DeduceBound(a, e0>=63, {b: b_s}, {b: b_s})
[t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1
test_complex(0, 4, 4)
test_complex(0, 4, -4)
test_complex(2, 6, 4)
test_complex(0, 4, -4)
test_complex(1, 5, -4)
test_complex(2, 6, -4)
if __name__ == "__main__": if __name__ == "__main__":
test_basic() test_basic()
test_vector() test_vector()
test_deduce() test_deduce()
test_check() test_check()
test_deduce_basic()
test_deduce_complex()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment