Commit f13421cf by Lianmin Zheng Committed by Tianqi Chen

fix keeping trivial loop (#982)

parent df1b4f64
...@@ -57,8 +57,10 @@ class InjectAttach : public IRMutator { ...@@ -57,8 +57,10 @@ class InjectAttach : public IRMutator {
public: public:
InjectAttach(const Stage& stage, InjectAttach(const Stage& stage,
const Stage& attach_spec, const Stage& attach_spec,
const std::unordered_map<IterVar, Range>& dom_map) const std::unordered_map<IterVar, Range>& dom_map,
: stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map) {} bool del_trivial_loop)
: stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map),
del_trivial_loop_(del_trivial_loop) {}
Stmt Mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined()); CHECK(stmt.defined());
...@@ -74,7 +76,7 @@ class InjectAttach : public IRMutator { ...@@ -74,7 +76,7 @@ class InjectAttach : public IRMutator {
found_attach = true; found_attach = true;
stmt = AttrStmt::make( stmt = AttrStmt::make(
op->node, op->attr_key, op->value, op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body, true)); MakePipeline(stage_, dom_map_, op->body, del_trivial_loop_));
} }
} }
return stmt; return stmt;
...@@ -89,6 +91,8 @@ class InjectAttach : public IRMutator { ...@@ -89,6 +91,8 @@ class InjectAttach : public IRMutator {
const Stage& attach_spec_; const Stage& attach_spec_;
// domain map // domain map
const std::unordered_map<IterVar, Range>& dom_map_; const std::unordered_map<IterVar, Range>& dom_map_;
// whether delete trivial loops with extent of 1
bool del_trivial_loop_;
}; };
// inject the operator's realization on the stmt. // inject the operator's realization on the stmt.
...@@ -97,9 +101,10 @@ class InjectScanStep : public IRMutator { ...@@ -97,9 +101,10 @@ class InjectScanStep : public IRMutator {
InjectScanStep(const Stage& stage, InjectScanStep(const Stage& stage,
const Operation& scan_op, const Operation& scan_op,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool is_init) bool is_init,
bool del_trivial_loop)
: stage_(stage), scan_op_(scan_op), : stage_(stage), scan_op_(scan_op),
dom_map_(dom_map), is_init_(is_init) {} dom_map_(dom_map), is_init_(is_init), del_trivial_loop_(del_trivial_loop) {}
Stmt Mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined()); CHECK(stmt.defined());
...@@ -113,7 +118,7 @@ class InjectScanStep : public IRMutator { ...@@ -113,7 +118,7 @@ class InjectScanStep : public IRMutator {
found_attach = true; found_attach = true;
stmt = AttrStmt::make( stmt = AttrStmt::make(
op->node, op->attr_key, op->value, op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body, true)); MakePipeline(stage_, dom_map_, op->body, del_trivial_loop_));
} }
} }
return stmt; return stmt;
...@@ -130,6 +135,8 @@ class InjectScanStep : public IRMutator { ...@@ -130,6 +135,8 @@ class InjectScanStep : public IRMutator {
const std::unordered_map<IterVar, Range>& dom_map_; const std::unordered_map<IterVar, Range>& dom_map_;
// whether it is init. // whether it is init.
bool is_init_; bool is_init_;
// whether delete trivial loops with extent of 1
bool del_trivial_loop_;
}; };
// Postprocessing of schedule op // Postprocessing of schedule op
...@@ -365,14 +372,14 @@ Stmt ScheduleOps( ...@@ -365,14 +372,14 @@ Stmt ScheduleOps(
if (scan_init.count(s->op)) { if (scan_init.count(s->op)) {
CHECK(body.defined()); CHECK(body.defined());
InjectScanStep mu(s, scan_init.at(s->op), dom_map, true); InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, del_trivial_loop);
body = mu.Mutate(body); body = mu.Mutate(body);
CHECK(mu.found_attach) CHECK(mu.found_attach)
<< "did not find attachment point for scan.init"; << "did not find attachment point for scan.init";
} else if (attach_spec->attach_type == kScanUpdate) { } else if (attach_spec->attach_type == kScanUpdate) {
// Handle scan update // Handle scan update
CHECK(body.defined()); CHECK(body.defined());
InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false); InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, del_trivial_loop);
body = mu.Mutate(body); body = mu.Mutate(body);
CHECK(mu.found_attach) CHECK(mu.found_attach)
<< "did not find attachment point for scan.update"; << "did not find attachment point for scan.update";
...@@ -384,7 +391,7 @@ Stmt ScheduleOps( ...@@ -384,7 +391,7 @@ Stmt ScheduleOps(
} else { } else {
CHECK_EQ(attach_spec->attach_type, kScope); CHECK_EQ(attach_spec->attach_type, kScope);
CHECK(body.defined()); CHECK(body.defined());
InjectAttach mutator(s, attach_spec, dom_map); InjectAttach mutator(s, attach_spec, dom_map, del_trivial_loop);
body = mutator.Mutate(body); body = mutator.Mutate(body);
CHECK(mutator.found_attach) CHECK(mutator.found_attach)
<< "did not find attachment point for " << s << " in " << "did not find attachment point for " << s << " in "
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment