Unverified Commit 6d32037c by Tianqi Chen Committed by GitHub

[SCHEDULE] Fix code lowering when loop condition depends on outer axis. (#2208)

parent 94acff30
......@@ -321,27 +321,32 @@ Stmt MakeComputeStmt(const ComputeOpNode* self,
source.push_back(stage->op.output(i));
}
MakeReduction(self, source, &init, &provide);
init = op::Substitute(init, n.init_vmap);
init = MergeNest(n.init_nest, init);
init = op::Substitute(init, n.init_vmap);
// common nest
std::vector<std::vector<Stmt> > common(
n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
std::vector<std::vector<Stmt> > reduce(
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
provide = op::Substitute(provide, n.main_vmap);
provide = MergeNest(reduce, provide);
if (debug_keep_trivial_loop) {
return MergeNest(common, provide);
provide = MergeNest(common, provide);
} else {
return MergeNest(common, Block::make(init, provide));
provide = MergeNest(common, Block::make(init, provide));
}
// run substitution in the on the full nest, because loop condition
// could depend on outer loops.
return op::Substitute(provide, n.main_vmap);
} else {
std::vector<Stmt> provides;
for (size_t i = 0; i < self->body.size(); ++i) {
provides.emplace_back(MakeProvide(self, stage->op.output(i)));
}
Stmt provide = op::Substitute(Block::make(provides), n.main_vmap);
return MergeNest(n.main_nest, provide);
Stmt provide = Block::make(provides);
provide = MergeNest(n.main_nest, provide);
// run substitution in the on the full nest, because loop condition
// could depend on outer loops.
return op::Substitute(provide, n.main_vmap);
}
}
......
......@@ -409,7 +409,18 @@ def test_schedule_tensor_compute3():
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_loop_dep_reduce():
X = tvm.placeholder(shape=(10,), name="x")
def f(n):
rv = tvm.reduce_axis((0, n))
return tvm.sum(X[rv], axis=rv)
Y = tvm.compute(X.shape, f, name="y")
s = tvm.create_schedule([Y.op])
f = tvm.build(s, [X, Y])
if __name__ == "__main__":
test_loop_dep_reduce()
test_schedule_middle_cache()
test_inline_multi_reduce()
test_schedule_cache_relayout4()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment