Commit a7873b0a by Umang Yadav Committed by ziheng

[RELAY/PASS] Fix the extent for the post_stmt in the loop partition (#3734)

parent 59cf5735
...@@ -492,7 +492,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -492,7 +492,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
std::tie(middle_interval, cond_set) = std::tie(middle_interval, cond_set) =
GetIntervalAndCondset(finder.partitions, for_interval, false); GetIntervalAndCondset(finder.partitions, for_interval, false);
if (middle_interval.is_nothing()) if (middle_interval.is_nothing())
// we couldn't find an interval in which the condintions are provably true or false // we couldn't find an interval in which the conditions are provably true or false
// Therefore, we can't partition the loop based on those conds // Therefore, we can't partition the loop based on those conds
return Stmt(); return Stmt();
cond_value = false; cond_value = false;
...@@ -513,46 +513,42 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -513,46 +513,42 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool pre_stmt_recurse = true; bool pre_stmt_recurse = true;
if (middle_interval_i->HasLowerBound()) { if (middle_interval_i->HasLowerBound()) {
body_begin = ir::Simplify(middle_interval.min()); body_begin = ir::Simplify(middle_interval.min());
if (!analyzer_.CanProve(body_begin == min)) { Expr cond = (body_begin - min >= 0);
Expr cond = (body_begin - min >= 0); if (!analyzer_.CanProve(cond)) {
if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond
LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre doubt loop";
<< ", when generating the pre doubt loop"; body_begin = Max::make(body_begin, min);
body_begin = Max::make(body_begin, min); // stop recursing on this interval if we can't prove it has non-negative length
// stop recursing on this interval if we can't prove it has non-negative length pre_stmt_recurse = false;
pre_stmt_recurse = false; }
} if (!partition_thread_scope) {
if (!partition_thread_scope) { Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); pre_stmt = MakeFor(node, body_begin - min, pre_body);
pre_stmt = MakeFor(node, body_begin - min, pre_body);
}
} }
} else { } else {
body_begin = min; body_begin = min;
} }
// Calculating post-subrange and generating code for it. // Calculating post-subrange and generating code for it.
// post-subrange = [post_doubt_begin, max] // post-subrange = [post_doubt_begin, max+1)
Expr post_doubt_begin; Expr post_doubt_begin;
Stmt post_stmt; Stmt post_stmt;
bool post_stmt_recurse = true; bool post_stmt_recurse = true;
if (middle_interval_i->HasUpperBound()) { if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = ir::Simplify(middle_interval.max() + 1); post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
if (!analyzer_.CanProve(middle_interval.max() == max)) { // require the extent to be non-negative
// require the extent to be non-negative Expr cond = (max - post_doubt_begin + 1 >= 0);
Expr cond = (max - post_doubt_begin + 1 >= 0); if (!analyzer_.CanProve(cond)) {
if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond
LOG(WARNING) << "Cannot prove: " << cond << ", when generating the post doubt loop";
<< ", when generating the post doubt loop"; post_doubt_begin = Min::make(post_doubt_begin, max+1);
post_doubt_begin = Min::make(post_doubt_begin, max); // stop recursing on this interval if we can't prove it has non-negative length
// stop recursing on this interval if we can't prove it has non-negative length post_stmt_recurse = false;
post_stmt_recurse = false; }
} if (!partition_thread_scope) {
if (!partition_thread_scope) { Stmt post_body =
Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
Substitute(body, {{Var{var}, var + post_doubt_begin}}); post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
}
} }
} else { } else {
post_doubt_begin = max + 1; post_doubt_begin = max + 1;
......
...@@ -37,6 +37,7 @@ def lower(sch, args): ...@@ -37,6 +37,7 @@ def lower(sch, args):
bounds = tvm.schedule.InferBound(sch) bounds = tvm.schedule.InferBound(sch)
stmt = tvm.schedule.ScheduleOps(sch, bounds) stmt = tvm.schedule.ScheduleOps(sch, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt, True) stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.RemoveNoOp(stmt)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 64, True) stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 64, True)
stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.VectorizeLoop(stmt) stmt = tvm.ir_pass.VectorizeLoop(stmt)
......
...@@ -69,8 +69,16 @@ def test_ewise(): ...@@ -69,8 +69,16 @@ def test_ewise():
foo(a, b) foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
for device in get_all_backend(): check_device('llvm')
check_device(device) check_device('cuda')
check_device('opencl')
check_device('metal')
check_device('rocm')
check_device('vulkan')
check_device('nvptx')
check_device('llvm -device=arm-cpu')
check_device('opencl -device=mali')
check_device('aocl_sw_emu')
def test_isnan( def test_isnan(
low, low,
...@@ -109,8 +117,16 @@ def test_ewise(): ...@@ -109,8 +117,16 @@ def test_ewise():
foo(a, b) foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
for device in get_all_backend(): check_device('llvm')
check_device(device) check_device('cuda')
check_device('opencl')
check_device('metal')
check_device('rocm')
check_device('vulkan')
check_device('nvptx')
check_device('llvm -device=arm-cpu')
check_device('opencl -device=mali')
check_device('aocl_sw_emu')
test_apply(topi.floor, "floor", np.floor, -100, 100) test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100) test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment