Commit ebf52dfb by Wuwei Lin Committed by masahi

[TOPI, CUDA] Improve conv2d_transpose schedule template (#3796)

parent 80fc943f
...@@ -138,6 +138,7 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs): ...@@ -138,6 +138,7 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
##### space definition begin ##### ##### space definition begin #####
n, f, y, x = s[conv].op.axis n, f, y, x = s[conv].op.axis
rc = s[conv].op.reduce_axis[0] rc = s[conv].op.reduce_axis[0]
cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
cfg.define_split("tile_f", cfg.axis(f), num_outputs=4) cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
cfg.define_split("tile_y", cfg.axis(y), num_outputs=4) cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
cfg.define_split("tile_x", cfg.axis(x), num_outputs=4) cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
...@@ -170,21 +171,43 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs): ...@@ -170,21 +171,43 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
# tile and bind spatial axes # tile and bind spatial axes
n, f, y, x = s[output].op.axis n, f, y, x = s[output].op.axis
kernel_scope, n = s[output].split(n, nparts=1) kernel_scope, n = s[output].split(n, nparts=1)
bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
s[output].bind(bf, tvm.thread_axis("blockIdx.z")) s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi)
s[output].bind(by, tvm.thread_axis("blockIdx.y")) s[output].bind(bn, tvm.thread_axis("blockIdx.z"))
s[output].bind(bx, tvm.thread_axis("blockIdx.x")) s[output].bind(bf, tvm.thread_axis("blockIdx.y"))
s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x"))
s[output].bind(vn, tvm.thread_axis("vthread"))
s[output].bind(vf, tvm.thread_axis("vthread")) s[output].bind(vf, tvm.thread_axis("vthread"))
s[output].bind(vy, tvm.thread_axis("vthread")) s[output].bind(vy, tvm.thread_axis("vthread"))
s[output].bind(vx, tvm.thread_axis("vthread")) s[output].bind(vx, tvm.thread_axis("vthread"))
s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y")) cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi) if cfg["fuse_yx"].val:
s[OL].compute_at(s[output], tx) s[output].bind(tn, tvm.thread_axis("threadIdx.z"))
s[output].bind(tf, tvm.thread_axis("threadIdx.y"))
tyx = s[output].fuse(ty, tx)
s[output].bind(s[output].fuse(ty, tx), tvm.thread_axis("threadIdx.x"))
s[OL].compute_at(s[output], tyx)
# number of threads
n_tz = cfg["tile_n"].size[2]
n_ty = cfg["tile_f"].size[2]
n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
else:
s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
s[OL].compute_at(s[output], tx)
# number of threads
n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
n_ty = cfg["tile_y"].size[2]
n_tx = cfg["tile_x"].size[2]
# tile reduction axes # tile reduction axes
n, f, y, x = s[OL].op.axis n, f, y, x = s[OL].op.axis
...@@ -199,9 +222,9 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs): ...@@ -199,9 +222,9 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
for load in [AA, WW]: for load in [AA, WW]:
n, f, y, x = s[load].op.axis n, f, y, x = s[load].op.axis
fused = s[load].fuse(f, y, x) fused = s[load].fuse(f, y, x)
tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2]) tz, fused = s[load].split(fused, nparts=n_tz)
ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) ty, fused = s[load].split(fused, nparts=n_ty)
tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) tx, fused = s[load].split(fused, nparts=n_tx)
s[load].bind(tz, tvm.thread_axis("threadIdx.z")) s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
s[load].bind(ty, tvm.thread_axis("threadIdx.y")) s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
s[load].bind(tx, tvm.thread_axis("threadIdx.x")) s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment