Unverified Commit 6d32037c by Tianqi Chen Committed by GitHub

[SCHEDULE] Fix code lowering when loop condition depends on outer axis. (#2208)

parent 94acff30
...@@ -321,27 +321,32 @@ Stmt MakeComputeStmt(const ComputeOpNode* self, ...@@ -321,27 +321,32 @@ Stmt MakeComputeStmt(const ComputeOpNode* self,
source.push_back(stage->op.output(i)); source.push_back(stage->op.output(i));
} }
MakeReduction(self, source, &init, &provide); MakeReduction(self, source, &init, &provide);
init = op::Substitute(init, n.init_vmap);
init = MergeNest(n.init_nest, init); init = MergeNest(n.init_nest, init);
init = op::Substitute(init, n.init_vmap);
// common nest // common nest
std::vector<std::vector<Stmt> > common( std::vector<std::vector<Stmt> > common(
n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
std::vector<std::vector<Stmt> > reduce( std::vector<std::vector<Stmt> > reduce(
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end()); n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
provide = op::Substitute(provide, n.main_vmap);
provide = MergeNest(reduce, provide); provide = MergeNest(reduce, provide);
if (debug_keep_trivial_loop) { if (debug_keep_trivial_loop) {
return MergeNest(common, provide); provide = MergeNest(common, provide);
} else { } else {
return MergeNest(common, Block::make(init, provide)); provide = MergeNest(common, Block::make(init, provide));
} }
// run substitution in the on the full nest, because loop condition
// could depend on outer loops.
return op::Substitute(provide, n.main_vmap);
} else { } else {
std::vector<Stmt> provides; std::vector<Stmt> provides;
for (size_t i = 0; i < self->body.size(); ++i) { for (size_t i = 0; i < self->body.size(); ++i) {
provides.emplace_back(MakeProvide(self, stage->op.output(i))); provides.emplace_back(MakeProvide(self, stage->op.output(i)));
} }
Stmt provide = op::Substitute(Block::make(provides), n.main_vmap); Stmt provide = Block::make(provides);
return MergeNest(n.main_nest, provide); provide = MergeNest(n.main_nest, provide);
// run substitution in the on the full nest, because loop condition
// could depend on outer loops.
return op::Substitute(provide, n.main_vmap);
} }
} }
......
...@@ -409,7 +409,18 @@ def test_schedule_tensor_compute3(): ...@@ -409,7 +409,18 @@ def test_schedule_tensor_compute3():
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_loop_dep_reduce():
X = tvm.placeholder(shape=(10,), name="x")
def f(n):
rv = tvm.reduce_axis((0, n))
return tvm.sum(X[rv], axis=rv)
Y = tvm.compute(X.shape, f, name="y")
s = tvm.create_schedule([Y.op])
f = tvm.build(s, [X, Y])
if __name__ == "__main__": if __name__ == "__main__":
test_loop_dep_reduce()
test_schedule_middle_cache() test_schedule_middle_cache()
test_inline_multi_reduce() test_inline_multi_reduce()
test_schedule_cache_relayout4() test_schedule_cache_relayout4()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment