Unverified Commit b41f4e55 by Tang, Shizhi Committed by GitHub

[TE] Support mixing normal and cross-thread reduction (#5193)

* Support mixing normal and cross-thread reduction

* minor improvements
parent 75e936e1
......@@ -443,8 +443,6 @@ ComputeType DetectComputeType(const ComputeOpNode* self,
<< "Cannot mix cross thread reduction with Tensorize";
return ComputeType::kTensorize;
}
CHECK(normal_red == 0 || thread_red == 0)
<< "Cannot mix normal reduction with thread reduce";
if (thread_red != 0) {
return ComputeType::kCrossThreadReduction;
} else {
......
......@@ -57,11 +57,64 @@ Stmt MakeCrossThreadReduction(
for (PrimExpr v : conds) {
cond = cond && v;
}
std::vector<std::vector<Stmt>> common, normal_red;
for (size_t i = 0, n = stage->leaf_iter_vars.size(); i < n; ++i) {
IterVar iv = stage->leaf_iter_vars[i];
IterVarAttr attr;
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end()) {
attr = (*it).second;
}
if (iv->iter_type == kCommReduce) {
if (attr.defined() && attr->bind_thread.defined()) {
common.emplace_back(nest[i + 1]);
} else {
normal_red.emplace_back(nest[i + 1]);
}
} else {
common.emplace_back(nest[i + 1]);
}
}
// If we load from and then store into the same res_handles in the thread_allreduce intrinsic,
// something goes wrong, so we use an extra variable here for normal reduction.
std::vector<Var> normal_res_handles;
std::vector<Stmt> normal_init, normal_update;
if (!normal_red.empty()) {
normal_res_handles.reserve(size);
normal_init.reserve(size);
normal_update.resize(size);
const CommReducerNode* combiner = reduces[0]->combiner.as<CommReducerNode>();
CHECK(combiner);
Array<PrimExpr> lhs;
for (size_t i = 0; i < size; ++i) {
DataType t = reduces[i]->dtype;
normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), DataType::Handle());
lhs.push_back(LoadNode::make(t, normal_res_handles[i], 0, const_true(t.lanes())));
}
Array<PrimExpr> init_value = combiner->identity_element;
Array<PrimExpr> update_value = (*combiner)(lhs, reduces[0]->source);
for (size_t i = 0; i < size; ++i) {
DataType t = reduces[i]->dtype;
normal_init.emplace_back(StoreNode::make(
normal_res_handles[i], init_value[i], 0, const_true(t.lanes())));
normal_update.emplace_back(StoreNode::make(
normal_res_handles[i], update_value[i], 0, const_true(t.lanes())));
}
}
Array<PrimExpr> freduce_args;
freduce_args.push_back(make_const(DataType::UInt(32), static_cast<uint32_t>(size)));
for (size_t i = 0; i < size; ++i) {
if (!normal_red.empty()) {
DataType t = reduces[i]->dtype;
freduce_args.push_back(LoadNode::make(
t, normal_res_handles[i], 0, const_true(t.lanes())));
} else {
freduce_args.push_back(reduces[0]->source[i]);
}
}
freduce_args.push_back(cond);
std::vector<Var> res_handles(size);
for (size_t idx = 0; idx < size; ++idx) {
......@@ -94,6 +147,15 @@ Stmt MakeCrossThreadReduction(
tir::attr::reduce_scope,
make_zero(DataType::Handle()),
reduce_body);
if (!normal_red.empty()) {
Stmt init_body = SeqStmt::Flatten(normal_init);
Stmt update_body = SeqStmt::Flatten(normal_update);
update_body = MergeNest(normal_red, update_body);
reduce_body = SeqStmt::Flatten(init_body, update_body, reduce_body);
reduce_body = MergeNest(MakeIfNest(conds), reduce_body);
}
std::vector<Stmt> assigns(size);
for (size_t idx = 0; idx < size; ++idx) {
DataType t = reduces[idx]->dtype;
......@@ -110,9 +172,15 @@ Stmt MakeCrossThreadReduction(
res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
body = AttrStmtNode::make(
res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body);
if (!normal_red.empty()) {
body = AllocateNode::make(
normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
body = AttrStmtNode::make(
normal_res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body);
}
}
body = Substitute(body, value_map);
return MergeNest(nest, body);
return MergeNest(common, body);
}
} // namespace te
} // namespace tvm
......@@ -321,6 +321,33 @@ def test_cuda_reduction():
check_cuda("float32")
check_cuda("float16")
def test_cuda_mix_threaded_and_normal_reduction():
def check_cuda(dtype, m=32, n=32):
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
a = tvm.te.placeholder((m, n), name="a", dtype=dtype)
b = topi.sum(a)
with tvm.target.cuda():
sb = tvm.te.create_schedule(b.op)
i, _ = b.op.reduce_axis
sb[b].bind(i, tvm.te.thread_axis("threadIdx.x"))
ctx = tvm.gpu(0)
func = tvm.build(sb, [a, b], 'cuda')
a_np = np.random.uniform(size=(m, n)).astype(a.dtype)
b_np = np.sum(a_np)
a_nd = tvm.nd.array(a_np, ctx)
b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), ctx)
func(a_nd, b_nd)
tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3)
check_cuda("float32")
check_cuda("float16")
def test_cuda_floordiv_with_vectorization():
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
......@@ -528,6 +555,7 @@ if __name__ == "__main__":
test_rfactor_predicates()
test_cuda_const_float_to_half()
test_cuda_reduction()
test_cuda_mix_threaded_and_normal_reduction()
test_cuda_floordiv_with_vectorization()
test_vectorized_intrin1()
test_vectorized_intrin2()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment