Commit c4927378 by libing4752 Committed by Tianqi Chen

modified schedule_dataflow_rewrite.cc to fix Stale Tensor during Dataflow Rewrite #738 (#747)

* modified schedule_dataflow_rewrite.cc to fix losing tensor problem

* modified schedule_dataflow_rewrite.cc for lint scan

* modified schedule_dataflow_rewrite.cc for lint scan

* using tensor's value_index to index output of stage op
parent d8be197d
...@@ -86,7 +86,9 @@ Tensor Schedule::cache_read(const Tensor& tensor, ...@@ -86,7 +86,9 @@ Tensor Schedule::cache_read(const Tensor& tensor,
return tensor(Array<Expr>(i.begin(), i.end())); return tensor(Array<Expr>(i.begin(), i.end()));
}, os.str()); }, os.str());
std::unordered_map<Tensor, Tensor> vsub; std::unordered_map<Tensor, Tensor> vsub;
vsub[tensor] = cache; Stage s = operator[](tensor->op);
Tensor sugar_tensor = s->op.output(tensor->value_index);
vsub[sugar_tensor] = cache;
std::unordered_map<Tensor, Tensor> vmap; std::unordered_map<Tensor, Tensor> vmap;
for (Operation op : readers) { for (Operation op : readers) {
......
...@@ -182,6 +182,25 @@ def test_schedule_cache(): ...@@ -182,6 +182,25 @@ def test_schedule_cache():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_middle_cache():
m = tvm.var('m')
n = tvm.var('n')
A = tvm.placeholder((m, n), name='A')
B = tvm.placeholder((m, n), name='B')
C = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='C')
D = tvm.compute((m, n), lambda i, j: C(i , j) , name='D')
s = tvm.create_schedule(D.op)
AA = s.cache_read(A, "local", readers=[C])
BB = s.cache_read(B, "local", readers=[C])
CC = s.cache_read(C, "local", readers=[D])
DD = s.cache_write(D, "local")
#s[AA].compute_at(s[CC], CC.op.axis[0])
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_cache_relayout1(): def test_schedule_cache_relayout1():
m = tvm.var('m') m = tvm.var('m')
...@@ -231,6 +250,7 @@ def test_schedule_cache_relayout3(): ...@@ -231,6 +250,7 @@ def test_schedule_cache_relayout3():
if __name__ == "__main__": if __name__ == "__main__":
test_schedule_middle_cache()
test_inline_multi_reduce() test_inline_multi_reduce()
test_schedule_cache_relayout3() test_schedule_cache_relayout3()
test_schedule_cache_relayout2() test_schedule_cache_relayout2()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment