Commit 267f0294 by kun-zh Committed by Tianqi Chen

Fix an issue in ReplaceDataFlow for issue 1043 (#1062)

parent d39ac773
......@@ -57,13 +57,21 @@ Expr InjectPredicate(const Array<Expr>& predicates,
// Replace data flow appears in all stages given the tensor change.
// Also update vmap if subsequent dataflow need to be replaced.
// Need to keep an update to the date transitive closure property on the vmap by a reverse map.
void ReplaceDataFlow(const Array<Stage>& stages,
std::unordered_map<Tensor, Tensor>* vmap) {
std::unordered_map<Tensor, Tensor>* vmap,
std::unordered_map<Tensor, Tensor>* rvmap) {
for (Stage s : stages) {
Operation op = s->op->ReplaceInputs(s->op, *vmap);
if (!op.same_as(s->op)) {
for (int i = 0; i < op->num_outputs(); ++i) {
auto it = rvmap->find(s->op.output(i));
if (it != rvmap->end()) {
(*vmap)[it->second] = op.output(i);
} else {
(*vmap)[s->op.output(i)] = op.output(i);
(*rvmap)[op.output(i)] = s->op.output(i);
}
}
s->op = op;
}
......@@ -91,6 +99,7 @@ Tensor Schedule::cache_read(const Tensor& tensor,
vsub[sugar_tensor] = cache;
std::unordered_map<Tensor, Tensor> vmap;
std::unordered_map<Tensor, Tensor> rvmap;
for (Operation op : readers) {
Stage s = operator[](op);
Operation repl_op = s->op->ReplaceInputs(s->op, vsub);
......@@ -98,9 +107,10 @@ Tensor Schedule::cache_read(const Tensor& tensor,
<< "Cannot find " << tensor
<< " in the inputs of " << s->op;
vmap[s->op.output(0)] = repl_op.output(0);
rvmap[repl_op.output(0)] = s->op.output(0);
s->op = repl_op;
}
ReplaceDataFlow((*this)->stages, &vmap);
ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
ArrayNode* stages = (*this)->stages.CopyOnWrite();
Stage op_stage = operator[](tensor->op);
size_t pos = FindNodeRef(stages, op_stage);
......@@ -197,8 +207,10 @@ Tensor CacheWriteWithReLayout(Schedule sch,
{cache_tensor(args)});
// The replace of the dataflow
std::unordered_map<Tensor, Tensor> vmap;
std::unordered_map<Tensor, Tensor> rvmap;
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
ReplaceDataFlow(sch->stages, &vmap);
rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
ReplaceDataFlow(sch->stages, &vmap, &rvmap);
// mutate orig stage
orig_stage->op = orig_new_op;
orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
......@@ -583,10 +595,12 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
}, reduce_stage->op->name + ".repl");
std::unordered_map<Tensor, Tensor> vmap;
std::unordered_map<Tensor, Tensor> rvmap;
for (int idx = 0; idx < size; ++idx) {
vmap[old_tensors[idx]] = repl_tensors[idx];
rvmap[repl_tensors[idx]] = old_tensors[idx];
}
ReplaceDataFlow((*this)->stages, &vmap);
ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
// revamp the reduction stage.
reduce_stage->op = repl_tensors[0]->op;
reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
......
......@@ -442,6 +442,19 @@ def test_reuse_small_buffer():
tvm.ir_pass.PostOrderVisit(body, verify)
assert num_alloc[0] == 1
def test_replace_dataflow():
shape = (255,)
A = tvm.placeholder(shape, name = "A")
B = tvm.compute(shape, lambda i: A[i] + A[i], name = "B")
C = tvm.compute(shape, lambda i: A[i] + B[i], name = "C")
D = tvm.compute(shape, lambda i: A[i] + C[i], name = "D")
E = tvm.compute(shape, lambda i: A[i] + D[i], name = "E")
s = tvm.create_schedule(E.op)
s.cache_read(A, "local", [B, C, D, E])
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
if __name__ == "__main__":
test_alloc_seq()
......@@ -456,3 +469,4 @@ if __name__ == "__main__":
test_alloc_seq_type()
test_alloc_seq_type2()
test_reuse_small_buffer()
test_replace_dataflow()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment