Commit 472c3146 by Leyuan Wang Committed by Tianqi Chen

[Bugfix][TOPI] conv2d_transpose bugfix (#3138)

* deconv tests

* deconv bug fixed for certain cases tests added
parent 17b60b90
...@@ -174,7 +174,6 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs): ...@@ -174,7 +174,6 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
bf = s[output].fuse(n, bf)
s[output].bind(bf, tvm.thread_axis("blockIdx.z")) s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
s[output].bind(by, tvm.thread_axis("blockIdx.y")) s[output].bind(by, tvm.thread_axis("blockIdx.y"))
s[output].bind(bx, tvm.thread_axis("blockIdx.x")) s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
...@@ -184,7 +183,7 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs): ...@@ -184,7 +183,7 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
s[output].bind(tf, tvm.thread_axis("threadIdx.z")) s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y")) s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x")) s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi) s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
s[OL].compute_at(s[output], tx) s[OL].compute_at(s[output], tx)
# tile reduction axes # tile reduction axes
...@@ -193,13 +192,13 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs): ...@@ -193,13 +192,13 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc) rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
s[OL].reorder(rco, rcm, ry, rx, rci, n, f, y, x) s[OL].reorder(rco, rcm, ry, rx, rci, n, f, y, x)
s[AA].compute_at(s[OL], rcm) s[AA].compute_at(s[OL], rx)
s[WW].compute_at(s[OL], rcm) s[WW].compute_at(s[OL], rx)
# cooperative fetching # cooperative fetching
for load in [AA, WW]: for load in [AA, WW]:
n, f, y, x = s[load].op.axis n, f, y, x = s[load].op.axis
fused = s[load].fuse(n, f, y, x) fused = s[load].fuse(f, y, x)
tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2]) tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
......
...@@ -74,6 +74,7 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, ...@@ -74,6 +74,7 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
def test_conv2d_transpose_nchw(): def test_conv2d_transpose_nchw():
verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 1, 0) verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 1, 0)
verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 2, 1) verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 2, 1)
verify_conv2d_transpose_nchw(1, 3, 224, 32, 2, 2, 0)
verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 1, 0) verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 1, 0)
verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 2, 1) verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 2, 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment