Commit f13421cf by Lianmin Zheng Committed by Tianqi Chen

fix keeping trivial loop (#982)

parent df1b4f64
......@@ -57,8 +57,10 @@ class InjectAttach : public IRMutator {
public:
InjectAttach(const Stage& stage,
const Stage& attach_spec,
const std::unordered_map<IterVar, Range>& dom_map)
: stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map) {}
const std::unordered_map<IterVar, Range>& dom_map,
bool del_trivial_loop)
: stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map),
del_trivial_loop_(del_trivial_loop) {}
Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined());
......@@ -74,7 +76,7 @@ class InjectAttach : public IRMutator {
found_attach = true;
stmt = AttrStmt::make(
op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body, true));
MakePipeline(stage_, dom_map_, op->body, del_trivial_loop_));
}
}
return stmt;
......@@ -89,6 +91,8 @@ class InjectAttach : public IRMutator {
const Stage& attach_spec_;
// domain map
const std::unordered_map<IterVar, Range>& dom_map_;
// whether delete trivial loops with extent of 1
bool del_trivial_loop_;
};
// inject the operator's realization on the stmt.
......@@ -97,9 +101,10 @@ class InjectScanStep : public IRMutator {
InjectScanStep(const Stage& stage,
const Operation& scan_op,
const std::unordered_map<IterVar, Range>& dom_map,
bool is_init)
bool is_init,
bool del_trivial_loop)
: stage_(stage), scan_op_(scan_op),
dom_map_(dom_map), is_init_(is_init) {}
dom_map_(dom_map), is_init_(is_init), del_trivial_loop_(del_trivial_loop) {}
Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined());
......@@ -113,7 +118,7 @@ class InjectScanStep : public IRMutator {
found_attach = true;
stmt = AttrStmt::make(
op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body, true));
MakePipeline(stage_, dom_map_, op->body, del_trivial_loop_));
}
}
return stmt;
......@@ -130,6 +135,8 @@ class InjectScanStep : public IRMutator {
const std::unordered_map<IterVar, Range>& dom_map_;
// whether it is init.
bool is_init_;
// whether delete trivial loops with extent of 1
bool del_trivial_loop_;
};
// Postprocessing of schedule op
......@@ -365,14 +372,14 @@ Stmt ScheduleOps(
if (scan_init.count(s->op)) {
CHECK(body.defined());
InjectScanStep mu(s, scan_init.at(s->op), dom_map, true);
InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, del_trivial_loop);
body = mu.Mutate(body);
CHECK(mu.found_attach)
<< "did not find attachment point for scan.init";
} else if (attach_spec->attach_type == kScanUpdate) {
// Handle scan update
CHECK(body.defined());
InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false);
InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, del_trivial_loop);
body = mu.Mutate(body);
CHECK(mu.found_attach)
<< "did not find attachment point for scan.update";
......@@ -384,7 +391,7 @@ Stmt ScheduleOps(
} else {
CHECK_EQ(attach_spec->attach_type, kScope);
CHECK(body.defined());
InjectAttach mutator(s, attach_spec, dom_map);
InjectAttach mutator(s, attach_spec, dom_map, del_trivial_loop);
body = mutator.Mutate(body);
CHECK(mutator.found_attach)
<< "did not find attachment point for " << s << " in "
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment