Commit bbaee69b by Mr You Committed by Tianqi Chen

Update schedule_dataflow_rewrite.cc (#2934)

parent 7afbca56
......@@ -603,8 +603,8 @@ void InjectInline(ScheduleNode* sch) {
if (!op.same_as(s->op)) {
for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
repl[s->op.output(idx)] = op.output(idx);
s->op = op;
}
s->op = op;
}
} else {
Operation op = s->op->ReplaceInputs(s->op, repl);
......
......@@ -459,6 +459,26 @@ def test_reduction_and_dummy_fuse_split():
f(*args)
assert np.all(args[0].asnumpy() == n)
def test_schedule_compute_inline():
shape = [10, 1024]
A = tvm.placeholder(shape, name="A")
B = tvm.placeholder(shape, name="B")
C = tvm.compute(shape, lambda *index:A(*index)+ B(*index), name = "C")
def _compute(*index) :
return C(*index) , C(*index) * B(*index)
F,E = tvm.compute(shape, _compute, name = "F")
s = tvm.create_schedule([F.op, E.op])
AL = s.cache_read(A, "local", [C])
BL = s.cache_read(B, "local", [C,E])
CL = s.cache_write(C, "local")
FL, EL = s.cache_write([F, E], "local")
s[C].compute_inline()
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
if __name__ == "__main__":
test_loop_dep_reduce()
test_loop_dep_reduce_cache_write()
......@@ -483,3 +503,4 @@ if __name__ == "__main__":
test_schedule_tensor_compute2()
test_schedule_tensor_compute3()
test_reduction_and_dummy_fuse_split()
test_schedule_compute_inline()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment