Commit cfdc5119 by Lianmin Zheng Committed by Tianqi Chen

delete init part when keeping trivial loop (#1031)

parent ca9ec009
...@@ -117,13 +117,13 @@ class OperationNode : public FunctionBaseNode { ...@@ -117,13 +117,13 @@ class OperationNode : public FunctionBaseNode {
* \brief Build the statement that provide the output tensors. * \brief Build the statement that provide the output tensors.
* \param stage The schedule stage of the op. * \param stage The schedule stage of the op.
* \param dom_map The domain map of all iteration domains. * \param dom_map The domain map of all iteration domains.
* \param del_trivial_loop Whether eliminate trivial loop with extent of 1 * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
* \return A statement that add production and wraps consumer. * \return A statement that add production and wraps consumer.
*/ */
virtual Stmt BuildProvide( virtual Stmt BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const = 0; bool debug_keep_trivial_loop) const = 0;
static constexpr const char* _type_key = "Operation"; static constexpr const char* _type_key = "Operation";
...@@ -163,7 +163,7 @@ class PlaceholderOpNode : public OperationNode { ...@@ -163,7 +163,7 @@ class PlaceholderOpNode : public OperationNode {
Stmt BuildProvide( Stmt BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const final; bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
...@@ -215,7 +215,7 @@ class ComputeOpNode : public OperationNode { ...@@ -215,7 +215,7 @@ class ComputeOpNode : public OperationNode {
Stmt BuildProvide( Stmt BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const final; bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
...@@ -287,7 +287,7 @@ class ScanOpNode : public OperationNode { ...@@ -287,7 +287,7 @@ class ScanOpNode : public OperationNode {
Stmt BuildProvide( Stmt BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const final; bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
...@@ -351,7 +351,7 @@ class ExternOpNode : public OperationNode { ...@@ -351,7 +351,7 @@ class ExternOpNode : public OperationNode {
Stmt BuildProvide( Stmt BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const final; bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
......
...@@ -29,10 +29,13 @@ Map<IterVar, Range> InferBound(const Schedule& sch); ...@@ -29,10 +29,13 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
* *
* \param s The schedule to be realized * \param s The schedule to be realized
* \param dom_map The domain of each iter vars. * \param dom_map The domain of each iter vars.
* \param del_trivial_loop Whether delete trivial loops with extent of 1 * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 during lowering.
* This is a debug feature for dataflow/axis analysis.
* Note: If this is true, The lowered IR may be incorrect,
* because we will also delete the init part of reduction
* \return the result Stmt * \return the result Stmt
*/ */
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool del_trivial_loop); Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);
/*! /*!
* \brief To automatically inline the element-wise operations. * \brief To automatically inline the element-wise operations.
......
...@@ -27,7 +27,7 @@ TVM_REGISTER_API("schedule.AutoInlineInjective") ...@@ -27,7 +27,7 @@ TVM_REGISTER_API("schedule.AutoInlineInjective")
TVM_REGISTER_API("schedule.ScheduleOps") TVM_REGISTER_API("schedule.ScheduleOps")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 2) if (args.size() == 2)
*ret = ScheduleOps(args[0], args[1], true); *ret = ScheduleOps(args[0], args[1], false);
else else
*ret = ScheduleOps(args[0], args[1], args[2]); *ret = ScheduleOps(args[0], args[1], args[2]);
}); });
......
...@@ -349,7 +349,7 @@ Stmt BuildStmt(Schedule sch, ...@@ -349,7 +349,7 @@ Stmt BuildStmt(Schedule sch,
// Phase 0 // Phase 0
auto bounds = schedule::InferBound(sch); auto bounds = schedule::InferBound(sch);
auto stmt = schedule::ScheduleOps(sch, bounds, true); auto stmt = schedule::ScheduleOps(sch, bounds, false);
stmt = ir::InjectPrefetch(stmt); stmt = ir::InjectPrefetch(stmt);
// Phase 1 // Phase 1
......
...@@ -296,9 +296,9 @@ Stmt MakeProvide(const ComputeOpNode* op, ...@@ -296,9 +296,9 @@ Stmt MakeProvide(const ComputeOpNode* op,
Stmt MakeComputeStmt(const ComputeOpNode* self, Stmt MakeComputeStmt(const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) { bool debug_keep_trivial_loop) {
// grab the nest structure // grab the nest structure
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, del_trivial_loop); ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
// Normal loop structure // Normal loop structure
n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates)); n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
...@@ -319,7 +319,11 @@ Stmt MakeComputeStmt(const ComputeOpNode* self, ...@@ -319,7 +319,11 @@ Stmt MakeComputeStmt(const ComputeOpNode* self,
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end()); n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
provide = op::Substitute(provide, n.main_vmap); provide = op::Substitute(provide, n.main_vmap);
provide = MergeNest(reduce, provide); provide = MergeNest(reduce, provide);
if (debug_keep_trivial_loop) {
return MergeNest(common, provide);
} else {
return MergeNest(common, Block::make(init, provide)); return MergeNest(common, Block::make(init, provide));
}
} else { } else {
std::vector<Stmt> provides; std::vector<Stmt> provides;
for (size_t i = 0; i < self->body.size(); ++i) { for (size_t i = 0; i < self->body.size(); ++i) {
...@@ -379,16 +383,16 @@ ComputeType DetectComputeType(const ComputeOpNode* self, ...@@ -379,16 +383,16 @@ ComputeType DetectComputeType(const ComputeOpNode* self,
Stmt ComputeOpNode::BuildProvide( Stmt ComputeOpNode::BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const { bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this); CHECK_EQ(stage->op.operator->(), this);
ComputeType ctype = DetectComputeType(this, stage); ComputeType ctype = DetectComputeType(this, stage);
if (ctype == ComputeType::kCrossThreadReduction) { if (ctype == ComputeType::kCrossThreadReduction) {
// specially handle cross thread reduction. // specially handle cross thread reduction.
return MakeCrossThreadReduction(this, stage, dom_map, del_trivial_loop); return MakeCrossThreadReduction(this, stage, dom_map, debug_keep_trivial_loop);
} else if (ctype == ComputeType::kTensorize) { } else if (ctype == ComputeType::kTensorize) {
return MakeTensorize(this, stage, dom_map, del_trivial_loop); return MakeTensorize(this, stage, dom_map, debug_keep_trivial_loop);
} else { } else {
return MakeComputeStmt(this, stage, dom_map, del_trivial_loop); return MakeComputeStmt(this, stage, dom_map, debug_keep_trivial_loop);
} }
} }
...@@ -396,12 +400,13 @@ ComputeLoopNest ComputeLoopNest::make( ...@@ -396,12 +400,13 @@ ComputeLoopNest ComputeLoopNest::make(
const ComputeOpNode* self, const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) { bool debug_keep_trivial_loop) {
CHECK_EQ(stage->op.operator->(), self); CHECK_EQ(stage->op.operator->(), self);
ComputeLoopNest ret; ComputeLoopNest ret;
// make main loop nest // make main loop nest
ret.main_nest = op::MakeLoopNest( ret.main_nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap, del_trivial_loop); stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap,
debug_keep_trivial_loop);
ret.main_predicates = schedule::MakeBoundCheck( ret.main_predicates = schedule::MakeBoundCheck(
stage, dom_map, ret.main_vmap, false, stage, dom_map, ret.main_vmap, false,
std::unordered_set<IterVar>()); std::unordered_set<IterVar>());
...@@ -443,7 +448,7 @@ ComputeLoopNest ComputeLoopNest::make( ...@@ -443,7 +448,7 @@ ComputeLoopNest ComputeLoopNest::make(
} }
ret.init_nest = op::MakeLoopNest( ret.init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true, stage, dom_map, begin_loop, true,
skip_iter, &(ret.init_vmap), del_trivial_loop); skip_iter, &(ret.init_vmap), debug_keep_trivial_loop);
ret.init_predicates = schedule::MakeBoundCheck( ret.init_predicates = schedule::MakeBoundCheck(
stage, dom_map, ret.init_vmap, true, skip_iter); stage, dom_map, ret.init_vmap, true, skip_iter);
for (auto& e : ret.init_predicates) { for (auto& e : ret.init_predicates) {
......
...@@ -37,14 +37,14 @@ struct ComputeLoopNest { ...@@ -37,14 +37,14 @@ struct ComputeLoopNest {
* \param self The pointer to compute op. * \param self The pointer to compute op.
* \param stage The scxhedule stage. * \param stage The scxhedule stage.
* \param dom_map The domain map. * \param dom_map The domain map.
* \param del_trivial_loop Whether eliminate trivial loops with extent of 1 * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
* \return The constructed loop nest * \return The constructed loop nest
*/ */
static ComputeLoopNest make( static ComputeLoopNest make(
const ComputeOpNode* self, const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop); bool debug_keep_trivial_loop);
}; };
/*! /*!
...@@ -52,27 +52,27 @@ struct ComputeLoopNest { ...@@ -52,27 +52,27 @@ struct ComputeLoopNest {
* \param self The pointer to ComputeOpNode * \param self The pointer to ComputeOpNode
* \param stage The schedule stage. * \param stage The schedule stage.
* \param dom_map The domain map. * \param dom_map The domain map.
* \param del_trivial_loop Wheter eliminate trivial loops with extent of 1 * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
* \return The created statement. * \return The created statement.
*/ */
Stmt MakeCrossThreadReduction( Stmt MakeCrossThreadReduction(
const ComputeOpNode* self, const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop); bool debug_keep_trivial_loop);
/*! /*!
* \brief Build body of compute for tensorization. * \brief Build body of compute for tensorization.
* \param self The pointer to ComputeOpNode * \param self The pointer to ComputeOpNode
* \param stage The schedule stage. * \param stage The schedule stage.
* \param dom_map The domain map. * \param dom_map The domain map.
* \param del_trivial_loop Wheter eliminate trivial loops with extent of 1 * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
* \return The created statement. * \return The created statement.
*/ */
Stmt MakeTensorize(const ComputeOpNode* self, Stmt MakeTensorize(const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop); bool debug_keep_trivial_loop);
} // namespace tvm } // namespace tvm
#endif // TVM_OP_COMPUTE_OP_H_ #endif // TVM_OP_COMPUTE_OP_H_
...@@ -14,14 +14,14 @@ Stmt MakeCrossThreadReduction( ...@@ -14,14 +14,14 @@ Stmt MakeCrossThreadReduction(
const ComputeOpNode* self, const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) { bool debug_keep_trivial_loop) {
Array<Expr> args; Array<Expr> args;
for (IterVar iv : self->axis) { for (IterVar iv : self->axis) {
args.push_back(iv->var); args.push_back(iv->var);
} }
std::unordered_map<IterVar, Expr> value_map; std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest( auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map, del_trivial_loop); stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map, debug_keep_trivial_loop);
auto conds = schedule::MakeBoundCheck( auto conds = schedule::MakeBoundCheck(
stage, dom_map, value_map, false, stage, dom_map, value_map, false,
std::unordered_set<IterVar>()); std::unordered_set<IterVar>());
......
...@@ -129,7 +129,7 @@ Stmt ExternOpNode::BuildRealize( ...@@ -129,7 +129,7 @@ Stmt ExternOpNode::BuildRealize(
Stmt ExternOpNode::BuildProvide( Stmt ExternOpNode::BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const { bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this); CHECK_EQ(stage->op.operator->(), this);
Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body); Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) { auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
......
...@@ -24,7 +24,7 @@ MakeLoopNest(const Stage& stage, ...@@ -24,7 +24,7 @@ MakeLoopNest(const Stage& stage,
bool new_loop_var, bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter, const std::unordered_set<IterVar>& skip_iter,
std::unordered_map<IterVar, Expr>* p_value_map, std::unordered_map<IterVar, Expr>* p_value_map,
bool del_trivial_loop) { bool debug_keep_trivial_loop) {
auto leaf_iter_vars = stage->leaf_iter_vars; auto leaf_iter_vars = stage->leaf_iter_vars;
Stmt no_op = Evaluate::make(0); Stmt no_op = Evaluate::make(0);
// create the loop nest // create the loop nest
...@@ -76,7 +76,7 @@ MakeLoopNest(const Stage& stage, ...@@ -76,7 +76,7 @@ MakeLoopNest(const Stage& stage,
AttrStmt::make(iv, ir::attr::pragma_scope, p, no_op)); AttrStmt::make(iv, ir::attr::pragma_scope, p, no_op));
} }
} }
if (del_trivial_loop && is_one(dom->extent)) { if (!debug_keep_trivial_loop && is_one(dom->extent)) {
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
LetStmt::make(var, dom->min, no_op)); LetStmt::make(var, dom->min, no_op));
value_map[iv] = dom->min; value_map[iv] = dom->min;
...@@ -131,7 +131,7 @@ MakeLoopNest(const Stage& stage, ...@@ -131,7 +131,7 @@ MakeLoopNest(const Stage& stage,
// annotate the extent of the IterVar // annotate the extent of the IterVar
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
AttrStmt::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op)); AttrStmt::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
if (del_trivial_loop && is_one(dom->extent)) { if (!debug_keep_trivial_loop && is_one(dom->extent)) {
value_map[iv] = dom->min; value_map[iv] = dom->min;
} else { } else {
value_map[iv] = var; value_map[iv] = var;
......
...@@ -29,7 +29,7 @@ using ir::MergeNest; ...@@ -29,7 +29,7 @@ using ir::MergeNest;
* \param new_loop_var Whether create new loop variable. * \param new_loop_var Whether create new loop variable.
* \param skip_iter Whether skip certain iteration. * \param skip_iter Whether skip certain iteration.
* \param p_value_map The result value of each IterVar. * \param p_value_map The result value of each IterVar.
* \param del_trivial_loop Whether eliminate trivial loops with extent of 1 * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
*/ */
std::vector<std::vector<Stmt> > std::vector<std::vector<Stmt> >
MakeLoopNest(const Stage& stage, MakeLoopNest(const Stage& stage,
...@@ -38,7 +38,7 @@ MakeLoopNest(const Stage& stage, ...@@ -38,7 +38,7 @@ MakeLoopNest(const Stage& stage,
bool new_loop_var, bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter, const std::unordered_set<IterVar>& skip_iter,
std::unordered_map<IterVar, Expr>* p_value_map, std::unordered_map<IterVar, Expr>* p_value_map,
bool del_trivial_loop); bool debug_keep_trivial_loop);
/*! /*!
* \brief Create a nest of if checking the predicates. * \brief Create a nest of if checking the predicates.
......
...@@ -79,7 +79,7 @@ Stmt PlaceholderOpNode::BuildRealize( ...@@ -79,7 +79,7 @@ Stmt PlaceholderOpNode::BuildRealize(
Stmt PlaceholderOpNode::BuildProvide( Stmt PlaceholderOpNode::BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const { bool debug_keep_trivial_loop) const {
return Stmt(); return Stmt();
} }
} // namespace tvm } // namespace tvm
...@@ -253,7 +253,7 @@ Stmt ScanOpNode::BuildRealize( ...@@ -253,7 +253,7 @@ Stmt ScanOpNode::BuildRealize(
Stmt ScanOpNode::BuildProvide( Stmt ScanOpNode::BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const { bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this); CHECK_EQ(stage->op.operator->(), this);
Stmt provide = AttrStmt::make( Stmt provide = AttrStmt::make(
stage->op, attr::scan_update_scope, this->scan_axis->var, stage->op, attr::scan_update_scope, this->scan_axis->var,
...@@ -271,7 +271,7 @@ Stmt ScanOpNode::BuildProvide( ...@@ -271,7 +271,7 @@ Stmt ScanOpNode::BuildProvide(
std::unordered_map<IterVar, Expr> vmap; std::unordered_map<IterVar, Expr> vmap;
std::unordered_set<IterVar> empty; std::unordered_set<IterVar> empty;
auto nest = op::MakeLoopNest( auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, empty, &vmap, del_trivial_loop); stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop);
nest[begin_scan].push_back(init); nest[begin_scan].push_back(init);
nest.push_back( nest.push_back(
op::MakeIfNest( op::MakeIfNest(
......
...@@ -370,14 +370,14 @@ Stmt TransformUpdate(const Stage& stage, ...@@ -370,14 +370,14 @@ Stmt TransformUpdate(const Stage& stage,
Stmt MakeTensorize(const ComputeOpNode* self, Stmt MakeTensorize(const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) { bool debug_keep_trivial_loop) {
std::unordered_map<IterVar, Range> out_dom; std::unordered_map<IterVar, Range> out_dom;
std::unordered_map<Tensor, Array<Range> > in_region; std::unordered_map<Tensor, Array<Range> > in_region;
size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region); size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region);
TensorIntrin intrin = stage->iter_var_attrs.at( TensorIntrin intrin = stage->iter_var_attrs.at(
stage->leaf_iter_vars[tloc])->tensor_intrin; stage->leaf_iter_vars[tloc])->tensor_intrin;
CHECK(intrin.defined()); CHECK(intrin.defined());
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, del_trivial_loop); ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
VerifyTensorizeLoopNest(self, stage, n, tloc); VerifyTensorizeLoopNest(self, stage, n, tloc);
VerifyTensorizeBody(self, stage, out_dom, in_region, intrin); VerifyTensorizeBody(self, stage, out_dom, in_region, intrin);
// Start bind data. // Start bind data.
......
...@@ -23,8 +23,8 @@ using namespace ir; ...@@ -23,8 +23,8 @@ using namespace ir;
Stmt MakePipeline(const Stage& s, Stmt MakePipeline(const Stage& s,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
Stmt consumer, Stmt consumer,
bool del_trivial_loop) { bool debug_keep_trivial_loop) {
Stmt producer = s->op->BuildProvide(s, dom_map, del_trivial_loop); Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
if (producer.defined()) { if (producer.defined()) {
producer = ProducerConsumer::make(s->op, true, producer); producer = ProducerConsumer::make(s->op, true, producer);
} }
...@@ -58,9 +58,9 @@ class InjectAttach : public IRMutator { ...@@ -58,9 +58,9 @@ class InjectAttach : public IRMutator {
InjectAttach(const Stage& stage, InjectAttach(const Stage& stage,
const Stage& attach_spec, const Stage& attach_spec,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) bool debug_keep_trivial_loop)
: stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map), : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map),
del_trivial_loop_(del_trivial_loop) {} debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
Stmt Mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined()); CHECK(stmt.defined());
...@@ -76,7 +76,7 @@ class InjectAttach : public IRMutator { ...@@ -76,7 +76,7 @@ class InjectAttach : public IRMutator {
found_attach = true; found_attach = true;
stmt = AttrStmt::make( stmt = AttrStmt::make(
op->node, op->attr_key, op->value, op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body, del_trivial_loop_)); MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
} }
} }
return stmt; return stmt;
...@@ -91,8 +91,9 @@ class InjectAttach : public IRMutator { ...@@ -91,8 +91,9 @@ class InjectAttach : public IRMutator {
const Stage& attach_spec_; const Stage& attach_spec_;
// domain map // domain map
const std::unordered_map<IterVar, Range>& dom_map_; const std::unordered_map<IterVar, Range>& dom_map_;
// whether delete trivial loops with extent of 1 // Whether keep trivial loops with extent of 1 during lowering.
bool del_trivial_loop_; // This is a debug feature for dataflow/axis analysis
bool debug_keep_trivial_loop_;
}; };
// inject the operator's realization on the stmt. // inject the operator's realization on the stmt.
...@@ -102,9 +103,9 @@ class InjectScanStep : public IRMutator { ...@@ -102,9 +103,9 @@ class InjectScanStep : public IRMutator {
const Operation& scan_op, const Operation& scan_op,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool is_init, bool is_init,
bool del_trivial_loop) bool debug_keep_trivial_loop)
: stage_(stage), scan_op_(scan_op), : stage_(stage), scan_op_(scan_op),
dom_map_(dom_map), is_init_(is_init), del_trivial_loop_(del_trivial_loop) {} dom_map_(dom_map), is_init_(is_init), debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
Stmt Mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined()); CHECK(stmt.defined());
...@@ -118,7 +119,7 @@ class InjectScanStep : public IRMutator { ...@@ -118,7 +119,7 @@ class InjectScanStep : public IRMutator {
found_attach = true; found_attach = true;
stmt = AttrStmt::make( stmt = AttrStmt::make(
op->node, op->attr_key, op->value, op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body, del_trivial_loop_)); MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
} }
} }
return stmt; return stmt;
...@@ -135,8 +136,9 @@ class InjectScanStep : public IRMutator { ...@@ -135,8 +136,9 @@ class InjectScanStep : public IRMutator {
const std::unordered_map<IterVar, Range>& dom_map_; const std::unordered_map<IterVar, Range>& dom_map_;
// whether it is init. // whether it is init.
bool is_init_; bool is_init_;
// whether delete trivial loops with extent of 1 // Whether keep trivial loops with extent of 1 during lowering.
bool del_trivial_loop_; // This is a debug feature for dataflow/axis analysis
bool debug_keep_trivial_loop_;
}; };
// Postprocessing of schedule op // Postprocessing of schedule op
...@@ -337,7 +339,7 @@ class SchedulePostProc : public IRMutator { ...@@ -337,7 +339,7 @@ class SchedulePostProc : public IRMutator {
}; };
Stmt ScheduleOps( Stmt ScheduleOps(
Schedule sch, Map<IterVar, Range> dom_map_, bool del_trivial_loop) { Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) {
Stmt body = Stmt(); Stmt body = Stmt();
std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_); std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
// scan init and scan updates // scan init and scan updates
...@@ -372,14 +374,14 @@ Stmt ScheduleOps( ...@@ -372,14 +374,14 @@ Stmt ScheduleOps(
if (scan_init.count(s->op)) { if (scan_init.count(s->op)) {
CHECK(body.defined()); CHECK(body.defined());
InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, del_trivial_loop); InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop);
body = mu.Mutate(body); body = mu.Mutate(body);
CHECK(mu.found_attach) CHECK(mu.found_attach)
<< "did not find attachment point for scan.init"; << "did not find attachment point for scan.init";
} else if (attach_spec->attach_type == kScanUpdate) { } else if (attach_spec->attach_type == kScanUpdate) {
// Handle scan update // Handle scan update
CHECK(body.defined()); CHECK(body.defined());
InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, del_trivial_loop); InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop);
body = mu.Mutate(body); body = mu.Mutate(body);
CHECK(mu.found_attach) CHECK(mu.found_attach)
<< "did not find attachment point for scan.update"; << "did not find attachment point for scan.update";
...@@ -387,11 +389,11 @@ Stmt ScheduleOps( ...@@ -387,11 +389,11 @@ Stmt ScheduleOps(
// do nothing // do nothing
} else if (attach_spec->attach_type == kGroupRoot) { } else if (attach_spec->attach_type == kGroupRoot) {
CHECK(!s->group.defined()); CHECK(!s->group.defined());
body = MakePipeline(s, dom_map, body, del_trivial_loop); body = MakePipeline(s, dom_map, body, debug_keep_trivial_loop);
} else { } else {
CHECK_EQ(attach_spec->attach_type, kScope); CHECK_EQ(attach_spec->attach_type, kScope);
CHECK(body.defined()); CHECK(body.defined());
InjectAttach mutator(s, attach_spec, dom_map, del_trivial_loop); InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop);
body = mutator.Mutate(body); body = mutator.Mutate(body);
CHECK(mutator.found_attach) CHECK(mutator.found_attach)
<< "did not find attachment point for " << s << " in " << "did not find attachment point for " << s << " in "
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment