Commit 6268e183 by Tianqi Chen Committed by GitHub

[SCHEDULE] Fix the scan schedule with rewriting (#80)

parent 54593ca1
......@@ -237,6 +237,13 @@ class SchedulePostProc : public IRMutator {
void Init(const Schedule& sch) {
for (Stage s : sch->stages) {
// This must be checked for all ops, including scan.
if (!s->op.same_as(s->origin_op)) {
Tensor target = s->origin_op.output(0);
AddReplace(s->op.output(0), target,
target, s->origin_op);
}
// Specially add replacements for scan op.
if (s->op.as<ScanOpNode>()) {
const ScanOpNode* scan = s->op.as<ScanOpNode>();
for (size_t i = 0; i < scan->update.size(); ++i) {
......@@ -245,10 +252,6 @@ class SchedulePostProc : public IRMutator {
AddReplace(scan->update[i], t);
AddReplace(scan->state_placeholder[i], t);
}
} else if (!s->op.same_as(s->origin_op)) {
Tensor target = s->origin_op.output(0);
AddReplace(s->op.output(0), target,
target, s->origin_op);
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment