Unverified Commit 0c38b916 by ANSHUMAN TRIPATHY Committed by GitHub

Duplicate likely nodes added when loop axis split unevenly (#5084)

* [TE][Schedule] Duplicate likely nodes removed

* [1] Test case added

* [2] Lint error fixed

* [3] Review comments handled

* [4] Review comments handled
parent 3aabbd9c
...@@ -556,6 +556,14 @@ void PassUpBoundCheck(const Stage& s, ...@@ -556,6 +556,14 @@ void PassUpBoundCheck(const Stage& s,
} }
} }
bool IsRangeSame(const Range input_1, const Range input_2) {
arith::Analyzer analyzer;
if (input_1.same_as(input_2)) return true;
return (analyzer.CanProve(input_1->min == input_2->min)
&& analyzer.CanProve(input_1->extent == input_2->extent));
}
std::vector<PrimExpr> MakeBoundCheck( std::vector<PrimExpr> MakeBoundCheck(
const Stage& stage, const Stage& stage,
const Map<IterVar, Range>& dom_map, const Map<IterVar, Range>& dom_map,
...@@ -593,7 +601,7 @@ std::vector<PrimExpr> MakeBoundCheck( ...@@ -593,7 +601,7 @@ std::vector<PrimExpr> MakeBoundCheck(
if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue; if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
Range dom = dom_map.at(iv); Range dom = dom_map.at(iv);
CHECK(iv->dom.defined()); CHECK(iv->dom.defined());
if (!skip_ivar_domain && !iv->dom.same_as(dom)) { if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) {
PrimExpr value = value_map.at(iv) - iv->dom->min; PrimExpr value = value_map.at(iv) - iv->dom->min;
IntSet s = EvalSet(value, iset_dmap); IntSet s = EvalSet(value, iset_dmap);
PrimExpr vmin = s.min(); PrimExpr vmin = s.min();
......
...@@ -40,6 +40,19 @@ def test_dependent_output_shape(): ...@@ -40,6 +40,19 @@ def test_dependent_output_shape():
s = te.create_schedule(B.op) s = te.create_schedule(B.op)
mod = tvm.build(s, [A, B, x]) mod = tvm.build(s, [A, B, x])
def test_split_uneven_unique_likely():
a = te.placeholder((16, 16),)
b = te.placeholder((16, 16),)
c = te.compute((16, 16), lambda x, y: a[x, y] + b[x, y])
x, y = c.op.axis
sch = te.create_schedule(c.op)
xo, xi = sch[c].split(x, 5)
stmt = tvm.lower(sch, [a, b, c], simple_mode=True)
assert isinstance(stmt.body.body.body.body, tvm.tir.stmt.IfThenElse)
assert str(stmt.body.body.body.body).count("likely") == 1
if __name__ == "__main__": if __name__ == "__main__":
test_lower_rfactor() test_lower_rfactor()
test_dependent_output_shape() test_dependent_output_shape()
test_split_uneven_unique_likely()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment