Commit f6bb7aba by Tianqi Chen Committed by GitHub

[GPU][TOPI] Fix cross thread reduction schedule (#414)

parent adf39837
......@@ -261,6 +261,7 @@ class ThreadAllreduceBuilder final : public IRMutator {
if (in_warp_seq.size() != 0) {
Stmt warp_body = MergeSeq(in_warp_seq);
seq.emplace_back(IfThenElse::make(in_warp_cond, warp_body));
seq.emplace_back(SyncThread("shared"));
}
return MergeSeq(seq);
}
......
......@@ -34,7 +34,7 @@ def schedule_softmax(outs):
s[expsum].bind(s[expsum].op.axis[0], block_x)
s[expsum].bind(s[expsum].op.reduce_axis[0], thread_x)
s[EF].compute_at(s[expsum], s[expsum].op.reduce_axis[0])
s[expsum].set_store_predicate(thread_x.var.equal(0))
tx, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread)
s[softmax].bind(softmax.op.axis[0], block_x)
s[softmax].bind(tx, thread_x)
......
......@@ -108,8 +108,10 @@ print(s[B].op.body)
xo, xi = s[B].split(s[B].op.axis[0], factor=32)
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.y"))
s[B].bind(s[B].op.reduce_axis[0], tvm.thread_axis("threadIdx.x"))
tx = tvm.thread_axis("threadIdx.x")
s[B].bind(s[B].op.reduce_axis[0], tx)
s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
s[B].set_store_predicate(tx.var.equal(0))
fcuda = tvm.build(s, [A, B], "cuda")
print(fcuda.imported_modules[0].get_source())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment