Commit 202570e4 by libing4752 Committed by Tianqi Chen

enhance cache write to support multiple tensors generated by ONE computeOp (#1042)

parent 239227d4
...@@ -301,6 +301,23 @@ class Schedule : public NodeRef { ...@@ -301,6 +301,23 @@ class Schedule : public NodeRef {
* User can further call compute_inline to inline the original layout and keep * User can further call compute_inline to inline the original layout and keep
* the data stored in the transformed layout. * the data stored in the transformed layout.
* *
* \param tensor The tensors to be produced.
* \param scope The scope of the storage.
* \return The created tensor.
*/
EXPORT Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
/*!
* \brief Create a cache write tensor for producing tensor.
* The the tensor will take over body of original tensor op.
*
* This function can be used to do data layout transformation.
* If there is a split/fuse/reorder on the data parallel axis of tensor
* before cache_write is called. The intermediate cache stores
* the data in the layout as the iteration order of leave axis.
* The data will be transformed back to the original layout in the original tensor.
* User can further call compute_inline to inline the original layout and keep
* the data stored in the transformed layout.
*
* \param tensor The tensor to be produced. * \param tensor The tensor to be produced.
* \param scope The scope of the storage. * \param scope The scope of the storage.
* \return The created tensor. * \return The created tensor.
......
...@@ -292,8 +292,8 @@ class Schedule(NodeBase): ...@@ -292,8 +292,8 @@ class Schedule(NodeBase):
Parameters Parameters
---------- ----------
tensor : Tensor tensor : Tensor, list or tuple
The tensor to be feed to. The tensors to be feed to. All the tensors must be produced by one computeOp
scope : str scope : str
The scope of cached The scope of cached
......
...@@ -425,8 +425,13 @@ TVM_REGISTER_API("_ScheduleCacheRead") ...@@ -425,8 +425,13 @@ TVM_REGISTER_API("_ScheduleCacheRead")
TVM_REGISTER_API("_ScheduleCacheWrite") TVM_REGISTER_API("_ScheduleCacheWrite")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule() if (args[1].IsNodeType<Tensor>()) {
.cache_write(args[1], args[2]); *ret = args[0].operator Schedule()
.cache_write(args[1].operator Tensor(), args[2]);
} else {
*ret = args[0].operator Schedule()
.cache_write(args[1].operator Array<Tensor>(), args[2]);
}
}); });
TVM_REGISTER_API("_ScheduleRFactor") TVM_REGISTER_API("_ScheduleRFactor")
......
...@@ -78,6 +78,13 @@ void ReplaceDataFlow(const Array<Stage>& stages, ...@@ -78,6 +78,13 @@ void ReplaceDataFlow(const Array<Stage>& stages,
} }
} }
inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
return (a->combiner.same_as(b->combiner)) &&
(a->source.same_as(b->source)) &&
(a->axis.same_as(b->axis)) &&
(a->condition.same_as(b->condition));
}
Tensor Schedule::cache_read(const Tensor& tensor, Tensor Schedule::cache_read(const Tensor& tensor,
const std::string& scope, const std::string& scope,
const Array<Operation>& readers) { const Array<Operation>& readers) {
...@@ -128,15 +135,15 @@ Tensor Schedule::cache_read(const Tensor& tensor, ...@@ -128,15 +135,15 @@ Tensor Schedule::cache_read(const Tensor& tensor,
return cache; return cache;
} }
// Cache write and relayout the data according to loop pattern // Cache write and relayout the data according to loop pattern
Tensor CacheWriteWithReLayout(Schedule sch, Array<Tensor> CacheWriteWithReLayout(Schedule sch,
const Tensor& tensor, const Array<Tensor>& tensor_array,
const std::string& scope) { const std::string& scope) {
size_t tensor_size = tensor_array.size();
sch->InvalidateCache(); sch->InvalidateCache();
Tensor tensor = tensor_array[0];
Stage orig_stage = sch[tensor->op]; Stage orig_stage = sch[tensor->op];
const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>(); const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
std::unordered_set<IterVar> red_axis; std::unordered_set<IterVar> red_axis;
for (IterVar iv : compute->reduce_axis) { for (IterVar iv : compute->reduce_axis) {
red_axis.insert(iv); red_axis.insert(iv);
...@@ -182,9 +189,34 @@ Tensor CacheWriteWithReLayout(Schedule sch, ...@@ -182,9 +189,34 @@ Tensor CacheWriteWithReLayout(Schedule sch,
vsub[iv->var.get()] = value_map.at(iv); vsub[iv->var.get()] = value_map.at(iv);
} }
} }
Expr body = VarReplacer(vsub).Mutate(compute->body[tensor->value_index]);
body = InjectPredicate(predicates, body); Expr body;
body = VarReplacer(vsub2newvar).Mutate(body); Array<Expr> body_list;
const ir::Reduce* first_reduce = nullptr;
for (auto cbody : compute->body) {
body = VarReplacer(vsub).Mutate(cbody);
body = InjectPredicate(predicates, body);
body = VarReplacer(vsub2newvar).Mutate(body);
// Reduce nodes in ONE computeOp must be the same except value_index
// This is right only if the oringinal body ensures Reduce nodes are the same
if (body->is_type<ir::Reduce>()) {
const ir::Reduce* reduce_body = body.as<ir::Reduce>();
if (first_reduce != nullptr) {
CHECK(ReduceEqual(reduce_body, first_reduce));
body = ir::Reduce::make(first_reduce->combiner,
first_reduce->source,
first_reduce->axis,
first_reduce->condition,
reduce_body->value_index);
} else {
first_reduce = reduce_body;
}
} else {
CHECK(first_reduce == nullptr)
<< "cannot mix reduce and other node in ONE compute bodys";
}
body_list.push_back(body);
}
// The reader args // The reader args
Array<Expr> args; Array<Expr> args;
{ {
...@@ -200,16 +232,25 @@ Tensor CacheWriteWithReLayout(Schedule sch, ...@@ -200,16 +232,25 @@ Tensor CacheWriteWithReLayout(Schedule sch,
} }
} }
Operation cache_op = ComputeOpNode::make( Operation cache_op = ComputeOpNode::make(
compute->name + "." + scope, compute->tag, new_axis, {body}); compute->name + "." + scope, compute->tag, new_axis, body_list);
Tensor cache_tensor = cache_op.output(0); Array<Tensor> cache_tensor_list;
Array<Expr> cache_expr_list;
for (size_t i = 0; i < tensor_size; i++) {
Tensor cache_tensor = cache_op.output(i);
cache_tensor_list.push_back(cache_tensor);
cache_expr_list.push_back(cache_tensor(args));
}
Operation orig_new_op = ComputeOpNode::make( Operation orig_new_op = ComputeOpNode::make(
compute->name, compute->tag, compute->axis, compute->name, compute->tag, compute->axis, cache_expr_list);
{cache_tensor(args)});
// The replace of the dataflow // The replace of the dataflow
std::unordered_map<Tensor, Tensor> vmap; std::unordered_map<Tensor, Tensor> vmap;
std::unordered_map<Tensor, Tensor> rvmap; std::unordered_map<Tensor, Tensor> rvmap;
vmap[orig_stage->op.output(0)] = orig_new_op.output(0); vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
for (size_t i = 0; i < tensor_size; i++) {
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
}
ReplaceDataFlow(sch->stages, &vmap, &rvmap); ReplaceDataFlow(sch->stages, &vmap, &rvmap);
// mutate orig stage // mutate orig stage
orig_stage->op = orig_new_op; orig_stage->op = orig_new_op;
...@@ -230,7 +271,26 @@ Tensor CacheWriteWithReLayout(Schedule sch, ...@@ -230,7 +271,26 @@ Tensor CacheWriteWithReLayout(Schedule sch,
if (cache_stage->group.defined()) { if (cache_stage->group.defined()) {
++cache_stage->group->num_child_stages; ++cache_stage->group->num_child_stages;
} }
return cache_tensor; return cache_tensor_list;
}
Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
const std::string& scope) {
(*this)->InvalidateCache();
CHECK(tensor_array.size() > 0)
<< "size of tensor_array must be greater than 0";
Tensor tensor = tensor_array[0];
Stage orig_stage = operator[](tensor->op);
const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
CHECK(static_cast<size_t>(compute->num_outputs()) == tensor_array.size())
<< "size of input tensor list must be same as number of stage outputs";
for (size_t i = 1; i < tensor_array.size(); i++) {
Stage tmp_stage = operator[](tensor_array[i]->op);
CHECK(orig_stage.same_as(tmp_stage))
<< "Input tensor list must be generated by ONE computeOp";
}
return CacheWriteWithReLayout(*this, tensor_array, scope);
} }
Tensor Schedule::cache_write(const Tensor& tensor, Tensor Schedule::cache_write(const Tensor& tensor,
...@@ -243,7 +303,7 @@ Tensor Schedule::cache_write(const Tensor& tensor, ...@@ -243,7 +303,7 @@ Tensor Schedule::cache_write(const Tensor& tensor,
CHECK_EQ(compute->num_outputs(), 1) CHECK_EQ(compute->num_outputs(), 1)
<< "cache write only support single output ComputeOp"; << "cache write only support single output ComputeOp";
return CacheWriteWithReLayout(*this, tensor, scope); return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
} }
void RebaseNonZeroMinLoop(const Schedule& sch) { void RebaseNonZeroMinLoop(const Schedule& sch) {
...@@ -289,13 +349,6 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { ...@@ -289,13 +349,6 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
} }
} }
inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
return (a->combiner.same_as(b->combiner)) &&
(a->source.same_as(b->source)) &&
(a->axis.same_as(b->axis)) &&
(a->condition.same_as(b->condition));
}
void InjectInline(ScheduleNode* sch) { void InjectInline(ScheduleNode* sch) {
sch->InvalidateCache(); sch->InvalidateCache();
......
...@@ -39,6 +39,49 @@ def test_exp(): ...@@ -39,6 +39,49 @@ def test_exp():
check_device("vulkan") check_device("vulkan")
def test_multiple_cache_write():
# graph
n = tvm.convert(1024)
A0 = tvm.placeholder((n,), name='A0', dtype = "float32")
A1 = tvm.placeholder((n,), name='A1', dtype = "float32")
B0, B1 = tvm.compute((n,),
lambda *i: (A0(*i) + A1(*i), A0(*i) * A1(*i)),
name='B')
C = tvm.compute((n,), lambda *i: B0(*i) + B1(*i),
name='C')
s = tvm.create_schedule(C.op)
# create iter var and assign them tags.
num_thread = 8
B0_cache, B1_cache = s.cache_write([B0, B1], "local")
bx, tx = s[C].split(C.op.axis[0], factor=num_thread)
s[B0].compute_at(s[C], bx)
s[B0_cache].compute_at(s[C], bx)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
# one line to build the function.
def check_device(device, host="stackvm"):
if not tvm.module.enabled(host):
return
ctx = tvm.context(device, 0)
if not ctx.exist:
return
func = tvm.build(s, [A0, A1, C],
device, host,
name="multiple_cache_write")
ctx = tvm.context(device, 0)
# launch the kernel.
n = 1024
a0 = tvm.nd.array(np.random.uniform(size=n).astype(A0.dtype), ctx)
a1 = tvm.nd.array(np.random.uniform(size=n).astype(A1.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
func(a0, a1, c)
np.testing.assert_allclose(
c.asnumpy(), a0.asnumpy() + a1.asnumpy() + (a0.asnumpy() * a1.asnumpy()),
rtol=1e-5)
check_device("cuda", "llvm")
check_device("vulkan")
check_device("opencl")
def test_log_pow_llvm(): def test_log_pow_llvm():
# graph # graph
...@@ -199,6 +242,7 @@ def try_warp_memory(): ...@@ -199,6 +242,7 @@ def try_warp_memory():
if __name__ == "__main__": if __name__ == "__main__":
test_exp() test_exp()
try_warp_memory() try_warp_memory()
test_multiple_cache_write()
test_add() test_add()
test_log_pow_llvm() test_log_pow_llvm()
test_popcount() test_popcount()
...@@ -249,6 +249,20 @@ def test_schedule_cache_relayout3(): ...@@ -249,6 +249,20 @@ def test_schedule_cache_relayout3():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_cache_relayout4():
def _compute(*indice):
return A(*indice) + 1, B(*indice) / 2
m = tvm.var('m')
n = tvm.var('n')
A = tvm.placeholder((m*4, n), name='A')
B = tvm.placeholder((m*4, n), name='B')
C1, C2 = tvm.compute(A.shape, _compute, name='C')
s = tvm.create_schedule([C1.op, C2.op])
C1_cache, C2_cache = s.cache_write([C1, C2], "local")
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_bound_condition(): def test_schedule_bound_condition():
A = tvm.placeholder((64,), name='A', dtype="float32") A = tvm.placeholder((64,), name='A', dtype="float32")
...@@ -265,6 +279,7 @@ def test_schedule_bound_condition(): ...@@ -265,6 +279,7 @@ def test_schedule_bound_condition():
if __name__ == "__main__": if __name__ == "__main__":
test_schedule_middle_cache() test_schedule_middle_cache()
test_inline_multi_reduce() test_inline_multi_reduce()
test_schedule_cache_relayout4()
test_schedule_cache_relayout3() test_schedule_cache_relayout3()
test_schedule_cache_relayout2() test_schedule_cache_relayout2()
test_schedule_cache_relayout1() test_schedule_cache_relayout1()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment