Commit d58b733a by Kimish Patel Committed by ziheng

[FIX] Fix for a specific case when loop partitioning with indivisble (#4243)

factors and resulting nested loop is broken.
This is due to the fact that we are creating zero extent loops which
are fixed afterwards. However unroll pass breaks due to the zero extent
loop.
parent 2c5c4da6
...@@ -513,17 +513,19 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -513,17 +513,19 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool pre_stmt_recurse = true; bool pre_stmt_recurse = true;
if (middle_interval_i->HasLowerBound()) { if (middle_interval_i->HasLowerBound()) {
body_begin = ir::Simplify(middle_interval.min()); body_begin = ir::Simplify(middle_interval.min());
Expr cond = (body_begin - min >= 0); if (!analyzer_.CanProve(body_begin == min)) {
if (!analyzer_.CanProve(cond)) { Expr cond = (body_begin - min >= 0);
LOG(WARNING) << "Cannot prove: " << cond if (!analyzer_.CanProve(cond)) {
<< ", when generating the pre doubt loop"; LOG(WARNING) << "Cannot prove: " << cond
body_begin = Max::make(body_begin, min); << ", when generating the pre doubt loop";
// stop recursing on this interval if we can't prove it has non-negative length body_begin = Max::make(body_begin, min);
pre_stmt_recurse = false; // stop recursing on this interval if we can't prove it has non-negative length
} pre_stmt_recurse = false;
if (!partition_thread_scope) { }
Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); if (!partition_thread_scope) {
pre_stmt = MakeFor(node, body_begin - min, pre_body); Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
pre_stmt = MakeFor(node, body_begin - min, pre_body);
}
} }
} else { } else {
body_begin = min; body_begin = min;
...@@ -536,19 +538,21 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -536,19 +538,21 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool post_stmt_recurse = true; bool post_stmt_recurse = true;
if (middle_interval_i->HasUpperBound()) { if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = ir::Simplify(middle_interval.max() + 1); post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
// require the extent to be non-negative if (!analyzer_.CanProve(middle_interval.max() == max)) {
Expr cond = (max - post_doubt_begin + 1 >= 0); // require the extent to be non-negative
if (!analyzer_.CanProve(cond)) { Expr cond = (max - post_doubt_begin + 1 >= 0);
LOG(WARNING) << "Cannot prove: " << cond if (!analyzer_.CanProve(cond)) {
<< ", when generating the post doubt loop"; LOG(WARNING) << "Cannot prove: " << cond
post_doubt_begin = Min::make(post_doubt_begin, max+1); << ", when generating the post doubt loop";
// stop recursing on this interval if we can't prove it has non-negative length post_doubt_begin = Min::make(post_doubt_begin, max+1);
post_stmt_recurse = false; // stop recursing on this interval if we can't prove it has non-negative length
} post_stmt_recurse = false;
if (!partition_thread_scope) { }
Stmt post_body = if (!partition_thread_scope) {
Substitute(body, {{Var{var}, var + post_doubt_begin}}); Stmt post_body =
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
}
} }
} else { } else {
post_doubt_begin = max + 1; post_doubt_begin = max + 1;
......
...@@ -365,6 +365,27 @@ def test_conv_tiling(): ...@@ -365,6 +365,27 @@ def test_conv_tiling():
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse)))) assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def test_multilevel_splitting_with_indivisble_factors():
import topi
A = tvm.placeholder((130,), dtype="float32")
B = topi.nn.relu(A)
s = tvm.create_schedule(B.op)
(y,) = s[B].op.axis
(yo, yi) = s[B].split(y, factor=8)
(yoo, yoi) = s[B].split(yo, factor=16)
s[B].reorder(yoo, yoi, yi)
s[B].unroll(yi)
## But this does the right thing.
with tvm.build_config(partition_const_loop=True):
lowered_body = tvm.lower(s, [A, B]).body
def visit_stmt(op):
return(isinstance(op, tvm.expr.Max))
num_max = collect_visit(lowered_body, visit_stmt)
assert num_max.count(True) == 10
def test_double_splitting_with_indivisible_factors(): def test_double_splitting_with_indivisible_factors():
m = 48 m = 48
dtype="float32" dtype="float32"
...@@ -443,4 +464,5 @@ if __name__ == "__main__": ...@@ -443,4 +464,5 @@ if __name__ == "__main__":
test_cce_loop_3() test_cce_loop_3()
test_conv_tiling() test_conv_tiling()
test_double_splitting_with_indivisible_factors() test_double_splitting_with_indivisible_factors()
test_multilevel_splitting_with_indivisble_factors()
test_simple_rfactor() test_simple_rfactor()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment