Commit 3c895464 by Leyuan Wang Committed by Tianqi Chen

vgg16 workload error fixed (#598)

parent 88662130
......@@ -80,6 +80,9 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
if mark % i == 0 and vthread_x > 0:
num_thread_x = i
break
if num_thread_x * vthread_x > 128:
num_thread_x = 8
vthread_x = 8
num_thread_y = 8
vthread_y = 2
ifactor = 8
......@@ -92,20 +95,20 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
thread_yz = tvm.thread_axis((0, vthread_y), "vthread", name="vy")
i, oc, h, w = s[Out].op.axis
w = s[Out].fuse(h, w)
ow, iw = s[Out].split(w, factor=num_thread_x*vthread_x)
ow, iw = s[Out].split(w, factor=num_thread_x)
oh, ih = s[Out].split(h, factor=vthread_x)
ooc, ioc = s[Out].split(oc, factor=num_thread_y*vthread_y)
oiw, iiw = s[Out].split(iw, nparts=vthread_x)
oioc, iioc = s[Out].split(ioc, nparts=vthread_y)
s[Out].reorder(i, ooc, ow, oioc, oiw, iioc, iiw)
s[Out].bind(iiw, thread_x)
s[Out].reorder(i, ooc, oh, ow, oioc, ih, iioc, iw)
oh = s[Out].fuse(oh, ow)
s[Out].bind(iw, thread_x)
s[Out].bind(iioc, thread_y)
s[Out].bind(oiw, thread_xz)
s[Out].bind(ih, thread_xz)
s[Out].bind(oioc, thread_yz)
s[Out].bind(ow, block_x)
s[Out].bind(oh, block_x)
s[Out].bind(ooc, block_y)
s[Out_L].compute_at(s[Out], iiw)
s[Out_L].compute_at(s[Out], iw)
# schedule Out_L local write
i, oc, h, w = s[Out_L].op.axis
......
......@@ -68,6 +68,7 @@ def test_conv2d_nchw():
verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1)
verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0)
verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1)
if __name__ == "__main__":
test_conv2d_nchw()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment