Commit f6bb7aba by Tianqi Chen Committed by GitHub

[GPU][TOPI] Fix cross thread reduction schedule (#414)

parent adf39837
...@@ -261,6 +261,7 @@ class ThreadAllreduceBuilder final : public IRMutator { ...@@ -261,6 +261,7 @@ class ThreadAllreduceBuilder final : public IRMutator {
if (in_warp_seq.size() != 0) { if (in_warp_seq.size() != 0) {
Stmt warp_body = MergeSeq(in_warp_seq); Stmt warp_body = MergeSeq(in_warp_seq);
seq.emplace_back(IfThenElse::make(in_warp_cond, warp_body)); seq.emplace_back(IfThenElse::make(in_warp_cond, warp_body));
seq.emplace_back(SyncThread("shared"));
} }
return MergeSeq(seq); return MergeSeq(seq);
} }
......
...@@ -34,7 +34,7 @@ def schedule_softmax(outs): ...@@ -34,7 +34,7 @@ def schedule_softmax(outs):
s[expsum].bind(s[expsum].op.axis[0], block_x) s[expsum].bind(s[expsum].op.axis[0], block_x)
s[expsum].bind(s[expsum].op.reduce_axis[0], thread_x) s[expsum].bind(s[expsum].op.reduce_axis[0], thread_x)
s[EF].compute_at(s[expsum], s[expsum].op.reduce_axis[0]) s[EF].compute_at(s[expsum], s[expsum].op.reduce_axis[0])
s[expsum].set_store_predicate(thread_x.var.equal(0))
tx, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread) tx, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread)
s[softmax].bind(softmax.op.axis[0], block_x) s[softmax].bind(softmax.op.axis[0], block_x)
s[softmax].bind(tx, thread_x) s[softmax].bind(tx, thread_x)
......
...@@ -108,8 +108,10 @@ print(s[B].op.body) ...@@ -108,8 +108,10 @@ print(s[B].op.body)
xo, xi = s[B].split(s[B].op.axis[0], factor=32) xo, xi = s[B].split(s[B].op.axis[0], factor=32)
s[B].bind(xo, tvm.thread_axis("blockIdx.x")) s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.y")) s[B].bind(xi, tvm.thread_axis("threadIdx.y"))
s[B].bind(s[B].op.reduce_axis[0], tvm.thread_axis("threadIdx.x")) tx = tvm.thread_axis("threadIdx.x")
s[B].bind(s[B].op.reduce_axis[0], tx)
s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
s[B].set_store_predicate(tx.var.equal(0))
fcuda = tvm.build(s, [A, B], "cuda") fcuda = tvm.build(s, [A, B], "cuda")
print(fcuda.imported_modules[0].get_source()) print(fcuda.imported_modules[0].get_source())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment