Commit d58b733a by Kimish Patel Committed by ziheng

[FIX] Fix for a specific case when loop partitioning with indivisble (#4243)

factors and resulting nested loop is broken.
This is due to the fact that we are creating zero extent loops which
are fixed afterwards. However unroll pass breaks due to the zero extent
loop.
parent 2c5c4da6
......@@ -513,17 +513,19 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool pre_stmt_recurse = true;
if (middle_interval_i->HasLowerBound()) {
body_begin = ir::Simplify(middle_interval.min());
Expr cond = (body_begin - min >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the pre doubt loop";
body_begin = Max::make(body_begin, min);
// stop recursing on this interval if we can't prove it has non-negative length
pre_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
pre_stmt = MakeFor(node, body_begin - min, pre_body);
if (!analyzer_.CanProve(body_begin == min)) {
Expr cond = (body_begin - min >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the pre doubt loop";
body_begin = Max::make(body_begin, min);
// stop recursing on this interval if we can't prove it has non-negative length
pre_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
pre_stmt = MakeFor(node, body_begin - min, pre_body);
}
}
} else {
body_begin = min;
......@@ -536,19 +538,21 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool post_stmt_recurse = true;
if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
// require the extent to be non-negative
Expr cond = (max - post_doubt_begin + 1 >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the post doubt loop";
post_doubt_begin = Min::make(post_doubt_begin, max+1);
// stop recursing on this interval if we can't prove it has non-negative length
post_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt post_body =
Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
if (!analyzer_.CanProve(middle_interval.max() == max)) {
// require the extent to be non-negative
Expr cond = (max - post_doubt_begin + 1 >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the post doubt loop";
post_doubt_begin = Min::make(post_doubt_begin, max+1);
// stop recursing on this interval if we can't prove it has non-negative length
post_stmt_recurse = false;
}
if (!partition_thread_scope) {
Stmt post_body =
Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
}
}
} else {
post_doubt_begin = max + 1;
......
......@@ -365,6 +365,27 @@ def test_conv_tiling():
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def test_multilevel_splitting_with_indivisble_factors():
import topi
A = tvm.placeholder((130,), dtype="float32")
B = topi.nn.relu(A)
s = tvm.create_schedule(B.op)
(y,) = s[B].op.axis
(yo, yi) = s[B].split(y, factor=8)
(yoo, yoi) = s[B].split(yo, factor=16)
s[B].reorder(yoo, yoi, yi)
s[B].unroll(yi)
## But this does the right thing.
with tvm.build_config(partition_const_loop=True):
lowered_body = tvm.lower(s, [A, B]).body
def visit_stmt(op):
return(isinstance(op, tvm.expr.Max))
num_max = collect_visit(lowered_body, visit_stmt)
assert num_max.count(True) == 10
def test_double_splitting_with_indivisible_factors():
m = 48
dtype="float32"
......@@ -443,4 +464,5 @@ if __name__ == "__main__":
test_cce_loop_3()
test_conv_tiling()
test_double_splitting_with_indivisible_factors()
test_multilevel_splitting_with_indivisble_factors()
test_simple_rfactor()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment