Commit 7ad3c51e by Yuwei HU Committed by Tianqi Chen

[TOPI] improve elemwise schedule (#393)

* [TOPI] improve elemwise schedule

* modify fuse
parent 0560e156
...@@ -23,17 +23,10 @@ def schedule_elemwise(outs): ...@@ -23,17 +23,10 @@ def schedule_elemwise(outs):
x = outs[0] x = outs[0]
num_dim = len(x.shape) num_dim = len(x.shape)
block_factor = tvm.ir_pass.Simplify(x.op.output(0).shape[num_dim-1]).value fused = s[x].fuse(*x.op.axis)
if block_factor % 48 == 0: num_thread = 64
block_factor = 48 bx, tx = s[x].split(fused, factor=num_thread)
elif block_factor % 32 == 0:
block_factor = 32
bx, tx = s[x].split(x.op.axis[num_dim-1], factor=block_factor)
for i in range(num_dim-2, 0, -1):
bx = s[x].fuse(bx, x.op.axis[i])
s[x].bind(bx, tvm.thread_axis("blockIdx.x")) s[x].bind(bx, tvm.thread_axis("blockIdx.x"))
s[x].bind(tx, tvm.thread_axis("threadIdx.x")) s[x].bind(tx, tvm.thread_axis("threadIdx.x"))
return s return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment