Commit d56c777a by Lianmin Zheng Committed by Tianqi Chen

support to keep trivial loops with extent of 1 (#877)

parent b21aee7d
...@@ -117,11 +117,13 @@ class OperationNode : public FunctionBaseNode { ...@@ -117,11 +117,13 @@ class OperationNode : public FunctionBaseNode {
* \brief Build the statement that provide the output tensors. * \brief Build the statement that provide the output tensors.
* \param stage The schedule stage of the op. * \param stage The schedule stage of the op.
* \param dom_map The domain map of all iteration domains. * \param dom_map The domain map of all iteration domains.
* \param del_trivial_loop Whether eliminate trivial loop with extent of 1
* \return A statement that add production and wraps consumer. * \return A statement that add production and wraps consumer.
*/ */
virtual Stmt BuildProvide( virtual Stmt BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const = 0; const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const = 0;
static constexpr const char* _type_key = "Operation"; static constexpr const char* _type_key = "Operation";
...@@ -160,7 +162,8 @@ class PlaceholderOpNode : public OperationNode { ...@@ -160,7 +162,8 @@ class PlaceholderOpNode : public OperationNode {
const Stmt& body) const final; const Stmt& body) const final;
Stmt BuildProvide( Stmt BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const final; const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
...@@ -211,7 +214,8 @@ class ComputeOpNode : public OperationNode { ...@@ -211,7 +214,8 @@ class ComputeOpNode : public OperationNode {
const Stmt& body) const final; const Stmt& body) const final;
Stmt BuildProvide( Stmt BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const final; const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
...@@ -282,7 +286,8 @@ class ScanOpNode : public OperationNode { ...@@ -282,7 +286,8 @@ class ScanOpNode : public OperationNode {
const Stmt& body) const final; const Stmt& body) const final;
Stmt BuildProvide( Stmt BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const final; const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
...@@ -345,7 +350,8 @@ class ExternOpNode : public OperationNode { ...@@ -345,7 +350,8 @@ class ExternOpNode : public OperationNode {
const Stmt& body) const final; const Stmt& body) const final;
Stmt BuildProvide( Stmt BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const final; const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
......
...@@ -29,9 +29,10 @@ Map<IterVar, Range> InferBound(const Schedule& sch); ...@@ -29,9 +29,10 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
* *
* \param s The schedule to be realized * \param s The schedule to be realized
* \param dom_map The domain of each iter vars. * \param dom_map The domain of each iter vars.
* \param del_trivial_loop Whether delete trivial loops with extent of 1
* \return the result Stmt * \return the result Stmt
*/ */
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map); Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool del_trivial_loop);
/*! /*!
* \brief To automatically inline the element-wise operations. * \brief To automatically inline the element-wise operations.
......
...@@ -24,6 +24,14 @@ TVM_REGISTER_API("schedule.AutoInlineInjective") ...@@ -24,6 +24,14 @@ TVM_REGISTER_API("schedule.AutoInlineInjective")
AutoInlineInjective(args[0]); AutoInlineInjective(args[0]);
}); });
TVM_REGISTER_API("schedule.ScheduleOps")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 2)
*ret = ScheduleOps(args[0], args[1], true);
else
*ret = ScheduleOps(args[0], args[1], args[2]);
});
#define REGISTER_SCHEDULE_PASS1(PassName) \ #define REGISTER_SCHEDULE_PASS1(PassName) \
TVM_REGISTER_API("schedule."#PassName) \ TVM_REGISTER_API("schedule."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
...@@ -43,7 +51,6 @@ REGISTER_SCHEDULE_PASS2(PostDFSOrder); ...@@ -43,7 +51,6 @@ REGISTER_SCHEDULE_PASS2(PostDFSOrder);
REGISTER_SCHEDULE_PASS1(CreateAttachPath); REGISTER_SCHEDULE_PASS1(CreateAttachPath);
REGISTER_SCHEDULE_PASS1(ScanGetBody); REGISTER_SCHEDULE_PASS1(ScanGetBody);
REGISTER_SCHEDULE_PASS1(ScanFixPointAnalysis); REGISTER_SCHEDULE_PASS1(ScanFixPointAnalysis);
REGISTER_SCHEDULE_PASS2(ScheduleOps);
} // namespace schedule } // namespace schedule
} // namespace tvm } // namespace tvm
...@@ -211,7 +211,7 @@ Stmt BuildStmt(Schedule sch, ...@@ -211,7 +211,7 @@ Stmt BuildStmt(Schedule sch,
// Phase 0 // Phase 0
auto bounds = schedule::InferBound(sch); auto bounds = schedule::InferBound(sch);
auto stmt = schedule::ScheduleOps(sch, bounds); auto stmt = schedule::ScheduleOps(sch, bounds, true);
stmt = ir::InjectPrefetch(stmt); stmt = ir::InjectPrefetch(stmt);
// Phase 1 // Phase 1
......
...@@ -305,9 +305,10 @@ Stmt MakeProvide(const ComputeOpNode* op, ...@@ -305,9 +305,10 @@ Stmt MakeProvide(const ComputeOpNode* op,
Stmt MakeComputeStmt(const ComputeOpNode* self, Stmt MakeComputeStmt(const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) { const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) {
// grab the nest structure // grab the nest structure
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map); ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, del_trivial_loop);
// Normal loop structure // Normal loop structure
n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates)); n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
...@@ -387,28 +388,30 @@ ComputeType DetectComputeType(const ComputeOpNode* self, ...@@ -387,28 +388,30 @@ ComputeType DetectComputeType(const ComputeOpNode* self,
// implement the provide utility. // implement the provide utility.
Stmt ComputeOpNode::BuildProvide( Stmt ComputeOpNode::BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const { const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this); CHECK_EQ(stage->op.operator->(), this);
ComputeType ctype = DetectComputeType(this, stage); ComputeType ctype = DetectComputeType(this, stage);
if (ctype == ComputeType::kCrossThreadReduction) { if (ctype == ComputeType::kCrossThreadReduction) {
// specially handle cross thread reduction. // specially handle cross thread reduction.
return MakeCrossThreadReduction(this, stage, dom_map); return MakeCrossThreadReduction(this, stage, dom_map, del_trivial_loop);
} else if (ctype == ComputeType::kTensorize) { } else if (ctype == ComputeType::kTensorize) {
return MakeTensorize(this, stage, dom_map); return MakeTensorize(this, stage, dom_map, del_trivial_loop);
} else { } else {
return MakeComputeStmt(this, stage, dom_map); return MakeComputeStmt(this, stage, dom_map, del_trivial_loop);
} }
} }
ComputeLoopNest ComputeLoopNest::make( ComputeLoopNest ComputeLoopNest::make(
const ComputeOpNode* self, const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) { const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) {
CHECK_EQ(stage->op.operator->(), self); CHECK_EQ(stage->op.operator->(), self);
ComputeLoopNest ret; ComputeLoopNest ret;
// make main loop nest // make main loop nest
ret.main_nest = op::MakeLoopNest( ret.main_nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap); stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap, del_trivial_loop);
ret.main_predicates = schedule::MakeBoundCheck( ret.main_predicates = schedule::MakeBoundCheck(
stage, dom_map, ret.main_vmap, false, stage, dom_map, ret.main_vmap, false,
std::unordered_set<IterVar>()); std::unordered_set<IterVar>());
...@@ -450,7 +453,7 @@ ComputeLoopNest ComputeLoopNest::make( ...@@ -450,7 +453,7 @@ ComputeLoopNest ComputeLoopNest::make(
} }
ret.init_nest = op::MakeLoopNest( ret.init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true, stage, dom_map, begin_loop, true,
skip_iter, &(ret.init_vmap)); skip_iter, &(ret.init_vmap), del_trivial_loop);
ret.init_predicates = schedule::MakeBoundCheck( ret.init_predicates = schedule::MakeBoundCheck(
stage, dom_map, ret.init_vmap, true, skip_iter); stage, dom_map, ret.init_vmap, true, skip_iter);
for (auto& e : ret.init_predicates) { for (auto& e : ret.init_predicates) {
......
...@@ -37,12 +37,14 @@ struct ComputeLoopNest { ...@@ -37,12 +37,14 @@ struct ComputeLoopNest {
* \param self The pointer to compute op. * \param self The pointer to compute op.
* \param stage The scxhedule stage. * \param stage The scxhedule stage.
* \param dom_map The domain map. * \param dom_map The domain map.
* \param del_trivial_loop Whether eliminate trivial loops with extent of 1
* \return The constructed loop nest * \return The constructed loop nest
*/ */
static ComputeLoopNest make( static ComputeLoopNest make(
const ComputeOpNode* self, const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map); const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop);
}; };
/*! /*!
...@@ -50,23 +52,27 @@ struct ComputeLoopNest { ...@@ -50,23 +52,27 @@ struct ComputeLoopNest {
* \param self The pointer to ComputeOpNode * \param self The pointer to ComputeOpNode
* \param stage The schedule stage. * \param stage The schedule stage.
* \param dom_map The domain map. * \param dom_map The domain map.
* \param del_trivial_loop Wheter eliminate trivial loops with extent of 1
* \return The created statement. * \return The created statement.
*/ */
Stmt MakeCrossThreadReduction( Stmt MakeCrossThreadReduction(
const ComputeOpNode* self, const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map); const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop);
/*! /*!
* \brief Build body of compute for tensorization. * \brief Build body of compute for tensorization.
* \param self The pointer to ComputeOpNode * \param self The pointer to ComputeOpNode
* \param stage The schedule stage. * \param stage The schedule stage.
* \param dom_map The domain map. * \param dom_map The domain map.
* \param del_trivial_loop Wheter eliminate trivial loops with extent of 1
* \return The created statement. * \return The created statement.
*/ */
Stmt MakeTensorize(const ComputeOpNode* self, Stmt MakeTensorize(const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map); const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop);
} // namespace tvm } // namespace tvm
#endif // TVM_OP_COMPUTE_OP_H_ #endif // TVM_OP_COMPUTE_OP_H_
...@@ -13,14 +13,15 @@ using namespace ir; ...@@ -13,14 +13,15 @@ using namespace ir;
Stmt MakeCrossThreadReduction( Stmt MakeCrossThreadReduction(
const ComputeOpNode* self, const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) { const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) {
Array<Expr> args; Array<Expr> args;
for (IterVar iv : self->axis) { for (IterVar iv : self->axis) {
args.push_back(iv->var); args.push_back(iv->var);
} }
std::unordered_map<IterVar, Expr> value_map; std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest( auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map); stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map, del_trivial_loop);
auto conds = schedule::MakeBoundCheck( auto conds = schedule::MakeBoundCheck(
stage, dom_map, value_map, false, stage, dom_map, value_map, false,
std::unordered_set<IterVar>()); std::unordered_set<IterVar>());
......
...@@ -128,7 +128,8 @@ Stmt ExternOpNode::BuildRealize( ...@@ -128,7 +128,8 @@ Stmt ExternOpNode::BuildRealize(
Stmt ExternOpNode::BuildProvide( Stmt ExternOpNode::BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const { const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this); CHECK_EQ(stage->op.operator->(), this);
Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body); Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) { auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
......
...@@ -23,7 +23,8 @@ MakeLoopNest(const Stage& stage, ...@@ -23,7 +23,8 @@ MakeLoopNest(const Stage& stage,
size_t begin_iter_pos, size_t begin_iter_pos,
bool new_loop_var, bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter, const std::unordered_set<IterVar>& skip_iter,
std::unordered_map<IterVar, Expr>* p_value_map) { std::unordered_map<IterVar, Expr>* p_value_map,
bool del_trivial_loop) {
auto leaf_iter_vars = stage->leaf_iter_vars; auto leaf_iter_vars = stage->leaf_iter_vars;
Stmt no_op = Evaluate::make(0); Stmt no_op = Evaluate::make(0);
// create the loop nest // create the loop nest
...@@ -75,7 +76,7 @@ MakeLoopNest(const Stage& stage, ...@@ -75,7 +76,7 @@ MakeLoopNest(const Stage& stage,
AttrStmt::make(iv, ir::attr::pragma_scope, p, no_op)); AttrStmt::make(iv, ir::attr::pragma_scope, p, no_op));
} }
} }
if (is_one(dom->extent)) { if (del_trivial_loop && is_one(dom->extent)) {
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
LetStmt::make(var, dom->min, no_op)); LetStmt::make(var, dom->min, no_op));
value_map[iv] = dom->min; value_map[iv] = dom->min;
...@@ -130,7 +131,7 @@ MakeLoopNest(const Stage& stage, ...@@ -130,7 +131,7 @@ MakeLoopNest(const Stage& stage,
// annotate the extent of the IterVar // annotate the extent of the IterVar
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
AttrStmt::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op)); AttrStmt::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
if (is_one(dom->extent)) { if (del_trivial_loop && is_one(dom->extent)) {
value_map[iv] = dom->min; value_map[iv] = dom->min;
} else { } else {
value_map[iv] = var; value_map[iv] = var;
......
...@@ -29,6 +29,7 @@ using ir::MergeNest; ...@@ -29,6 +29,7 @@ using ir::MergeNest;
* \param new_loop_var Whether create new loop variable. * \param new_loop_var Whether create new loop variable.
* \param skip_iter Whether skip certain iteration. * \param skip_iter Whether skip certain iteration.
* \param p_value_map The result value of each IterVar. * \param p_value_map The result value of each IterVar.
* \param del_trivial_loop Whether eliminate trivial loops with extent of 1
*/ */
std::vector<std::vector<Stmt> > std::vector<std::vector<Stmt> >
MakeLoopNest(const Stage& stage, MakeLoopNest(const Stage& stage,
...@@ -36,7 +37,8 @@ MakeLoopNest(const Stage& stage, ...@@ -36,7 +37,8 @@ MakeLoopNest(const Stage& stage,
size_t begin_iter_pos, size_t begin_iter_pos,
bool new_loop_var, bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter, const std::unordered_set<IterVar>& skip_iter,
std::unordered_map<IterVar, Expr>* p_value_map); std::unordered_map<IterVar, Expr>* p_value_map,
bool del_trivial_loop);
/*! /*!
* \brief Create a nest of if checking the predicates. * \brief Create a nest of if checking the predicates.
......
...@@ -78,7 +78,8 @@ Stmt PlaceholderOpNode::BuildRealize( ...@@ -78,7 +78,8 @@ Stmt PlaceholderOpNode::BuildRealize(
Stmt PlaceholderOpNode::BuildProvide( Stmt PlaceholderOpNode::BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const { const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const {
return Stmt(); return Stmt();
} }
} // namespace tvm } // namespace tvm
...@@ -252,7 +252,8 @@ Stmt ScanOpNode::BuildRealize( ...@@ -252,7 +252,8 @@ Stmt ScanOpNode::BuildRealize(
Stmt ScanOpNode::BuildProvide( Stmt ScanOpNode::BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const { const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this); CHECK_EQ(stage->op.operator->(), this);
Stmt provide = AttrStmt::make( Stmt provide = AttrStmt::make(
stage->op, attr::scan_update_scope, this->scan_axis->var, stage->op, attr::scan_update_scope, this->scan_axis->var,
...@@ -270,7 +271,7 @@ Stmt ScanOpNode::BuildProvide( ...@@ -270,7 +271,7 @@ Stmt ScanOpNode::BuildProvide(
std::unordered_map<IterVar, Expr> vmap; std::unordered_map<IterVar, Expr> vmap;
std::unordered_set<IterVar> empty; std::unordered_set<IterVar> empty;
auto nest = op::MakeLoopNest( auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, empty, &vmap); stage, dom_map, 0, false, empty, &vmap, del_trivial_loop);
nest[begin_scan].push_back(init); nest[begin_scan].push_back(init);
nest.push_back( nest.push_back(
op::MakeIfNest( op::MakeIfNest(
......
...@@ -369,14 +369,15 @@ Stmt TransformUpdate(const Stage& stage, ...@@ -369,14 +369,15 @@ Stmt TransformUpdate(const Stage& stage,
Stmt MakeTensorize(const ComputeOpNode* self, Stmt MakeTensorize(const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) { const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop) {
std::unordered_map<IterVar, Range> out_dom; std::unordered_map<IterVar, Range> out_dom;
std::unordered_map<Tensor, Array<Range> > in_region; std::unordered_map<Tensor, Array<Range> > in_region;
size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region); size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region);
TensorIntrin intrin = stage->iter_var_attrs.at( TensorIntrin intrin = stage->iter_var_attrs.at(
stage->leaf_iter_vars[tloc])->tensor_intrin; stage->leaf_iter_vars[tloc])->tensor_intrin;
CHECK(intrin.defined()); CHECK(intrin.defined());
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map); ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, del_trivial_loop);
VerifyTensorizeLoopNest(self, stage, n, tloc); VerifyTensorizeLoopNest(self, stage, n, tloc);
VerifyTensorizeBody(self, stage, out_dom, in_region, intrin); VerifyTensorizeBody(self, stage, out_dom, in_region, intrin);
// Start bind data. // Start bind data.
......
...@@ -22,8 +22,9 @@ using namespace ir; ...@@ -22,8 +22,9 @@ using namespace ir;
Stmt MakePipeline(const Stage& s, Stmt MakePipeline(const Stage& s,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
Stmt consumer) { Stmt consumer,
Stmt producer = s->op->BuildProvide(s, dom_map); bool del_trivial_loop) {
Stmt producer = s->op->BuildProvide(s, dom_map, del_trivial_loop);
if (producer.defined()) { if (producer.defined()) {
producer = ProducerConsumer::make(s->op, true, producer); producer = ProducerConsumer::make(s->op, true, producer);
} }
...@@ -68,7 +69,7 @@ class InjectAttach : public IRMutator { ...@@ -68,7 +69,7 @@ class InjectAttach : public IRMutator {
found_attach = true; found_attach = true;
stmt = AttrStmt::make( stmt = AttrStmt::make(
op->node, op->attr_key, op->value, op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body)); MakePipeline(stage_, dom_map_, op->body, true));
} }
} }
return stmt; return stmt;
...@@ -107,7 +108,7 @@ class InjectScanStep : public IRMutator { ...@@ -107,7 +108,7 @@ class InjectScanStep : public IRMutator {
found_attach = true; found_attach = true;
stmt = AttrStmt::make( stmt = AttrStmt::make(
op->node, op->attr_key, op->value, op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body)); MakePipeline(stage_, dom_map_, op->body, true));
} }
} }
return stmt; return stmt;
...@@ -324,7 +325,7 @@ class SchedulePostProc : public IRMutator { ...@@ -324,7 +325,7 @@ class SchedulePostProc : public IRMutator {
}; };
Stmt ScheduleOps( Stmt ScheduleOps(
Schedule sch, Map<IterVar, Range> dom_map_) { Schedule sch, Map<IterVar, Range> dom_map_, bool del_trivial_loop) {
Stmt body = Stmt(); Stmt body = Stmt();
std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_); std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
// scan init and scan updates // scan init and scan updates
...@@ -374,7 +375,7 @@ Stmt ScheduleOps( ...@@ -374,7 +375,7 @@ Stmt ScheduleOps(
// do nothing // do nothing
} else if (attach_spec->attach_type == kGroupRoot) { } else if (attach_spec->attach_type == kGroupRoot) {
CHECK(!s->group.defined()); CHECK(!s->group.defined());
body = MakePipeline(s, dom_map, body); body = MakePipeline(s, dom_map, body, del_trivial_loop);
} else { } else {
CHECK_EQ(attach_spec->attach_type, kScope); CHECK_EQ(attach_spec->attach_type, kScope);
CHECK(body.defined()); CHECK(body.defined());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment