Commit 3c895464 by Leyuan Wang Committed by Tianqi Chen

vgg16 workload error fixed (#598)

parent 88662130
...@@ -80,6 +80,9 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag): ...@@ -80,6 +80,9 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
if mark % i == 0 and vthread_x > 0: if mark % i == 0 and vthread_x > 0:
num_thread_x = i num_thread_x = i
break break
if num_thread_x * vthread_x > 128:
num_thread_x = 8
vthread_x = 8
num_thread_y = 8 num_thread_y = 8
vthread_y = 2 vthread_y = 2
ifactor = 8 ifactor = 8
...@@ -92,20 +95,20 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag): ...@@ -92,20 +95,20 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
thread_yz = tvm.thread_axis((0, vthread_y), "vthread", name="vy") thread_yz = tvm.thread_axis((0, vthread_y), "vthread", name="vy")
i, oc, h, w = s[Out].op.axis i, oc, h, w = s[Out].op.axis
w = s[Out].fuse(h, w) ow, iw = s[Out].split(w, factor=num_thread_x)
ow, iw = s[Out].split(w, factor=num_thread_x*vthread_x) oh, ih = s[Out].split(h, factor=vthread_x)
ooc, ioc = s[Out].split(oc, factor=num_thread_y*vthread_y) ooc, ioc = s[Out].split(oc, factor=num_thread_y*vthread_y)
oiw, iiw = s[Out].split(iw, nparts=vthread_x)
oioc, iioc = s[Out].split(ioc, nparts=vthread_y) oioc, iioc = s[Out].split(ioc, nparts=vthread_y)
s[Out].reorder(i, ooc, ow, oioc, oiw, iioc, iiw) s[Out].reorder(i, ooc, oh, ow, oioc, ih, iioc, iw)
s[Out].bind(iiw, thread_x) oh = s[Out].fuse(oh, ow)
s[Out].bind(iw, thread_x)
s[Out].bind(iioc, thread_y) s[Out].bind(iioc, thread_y)
s[Out].bind(oiw, thread_xz) s[Out].bind(ih, thread_xz)
s[Out].bind(oioc, thread_yz) s[Out].bind(oioc, thread_yz)
s[Out].bind(ow, block_x) s[Out].bind(oh, block_x)
s[Out].bind(ooc, block_y) s[Out].bind(ooc, block_y)
s[Out_L].compute_at(s[Out], iiw) s[Out_L].compute_at(s[Out], iw)
# schedule Out_L local write # schedule Out_L local write
i, oc, h, w = s[Out_L].op.axis i, oc, h, w = s[Out_L].op.axis
......
...@@ -68,6 +68,7 @@ def test_conv2d_nchw(): ...@@ -68,6 +68,7 @@ def test_conv2d_nchw():
verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1) verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1)
verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0) verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0)
verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1) verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1)
if __name__ == "__main__": if __name__ == "__main__":
test_conv2d_nchw() test_conv2d_nchw()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment