Commit 46657ed1 by Leyuan Wang Committed by Tianqi Chen

Conv2d modified for better performance (#516)

* conv2d tweaked for better end-to-end performance

* syntax changed
parent 13970eba
...@@ -66,9 +66,20 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): ...@@ -66,9 +66,20 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L):
def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag): def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
"""Schedule conv2d for specific feature_in_out_filter pattern""" """Schedule conv2d for specific feature_in_out_filter pattern"""
if util.get_const_int(Filter_S.shape[0]) == util.get_const_int(Filter_S.shape[1]): if util.get_const_int(Filter_S.shape[0]) == util.get_const_int(Filter_S.shape[1]):
num_thread_x = 8 mark = util.get_const_int(Out.shape[2]) * util.get_const_int(Out.shape[3])
num_thread_x = 0
if mark % 8 == 0 and mark % 7 == 0:
num_thread_x = 8
vthread_x = 7
else:
for i in range(5, mark):
if mark % i == 0 and num_thread_x == 0:
vthread_x = i
mark = mark // i
if mark % i == 0 and vthread_x > 0:
num_thread_x = i
break
num_thread_y = 8 num_thread_y = 8
vthread_x = 7
vthread_y = 2 vthread_y = 2
ifactor = 8 ifactor = 8
...@@ -80,20 +91,20 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag): ...@@ -80,20 +91,20 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
thread_yz = tvm.thread_axis((0, vthread_y), "vthread", name="vy") thread_yz = tvm.thread_axis((0, vthread_y), "vthread", name="vy")
i, oc, h, w = s[Out].op.axis i, oc, h, w = s[Out].op.axis
oh, ih = s[Out].split(h, nparts=vthread_x) w = s[Out].fuse(h, w)
w = s[Out].fuse(ih, w) ow, iw = s[Out].split(w, factor=num_thread_x*vthread_x)
ooc, ioc = s[Out].split(oc, factor=num_thread_y*vthread_y) ooc, ioc = s[Out].split(oc, factor=num_thread_y*vthread_y)
ow, iw = s[Out].split(w, factor=num_thread_x) oiw, iiw = s[Out].split(iw, nparts=vthread_x)
oioc, iioc = s[Out].split(ioc, nparts=vthread_y) oioc, iioc = s[Out].split(ioc, nparts=vthread_y)
s[Out].reorder(i, ooc, oh, oioc, ow, iioc, iw) s[Out].reorder(i, ooc, ow, oioc, oiw, iioc, iiw)
s[Out].bind(iw, thread_x) s[Out].bind(iiw, thread_x)
s[Out].bind(iioc, thread_y) s[Out].bind(iioc, thread_y)
s[Out].bind(ow, thread_xz) s[Out].bind(oiw, thread_xz)
s[Out].bind(oioc, thread_yz) s[Out].bind(oioc, thread_yz)
s[Out].bind(oh, block_x) s[Out].bind(ow, block_x)
s[Out].bind(ooc, block_y) s[Out].bind(ooc, block_y)
s[Out_L].compute_at(s[Out], iw) s[Out_L].compute_at(s[Out], iiw)
# schedule Out_L local write # schedule Out_L local write
i, oc, h, w = s[Out_L].op.axis i, oc, h, w = s[Out_L].op.axis
...@@ -260,9 +271,9 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): ...@@ -260,9 +271,9 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
else: else:
# scheduler params # scheduler params
vthread_x = min(8, util.get_const_int(Out.shape[2])) vthread_x = util.get_const_int(Out.shape[2])
num_thread_x = 16 num_thread_x = 16
num_thread_y = min(8, util.get_const_int(Out.shape[3])) num_thread_y = util.get_const_int(Out.shape[3])
ofactor = 8 ofactor = 8
block_x = tvm.thread_axis("blockIdx.x") block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x") thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
...@@ -271,12 +282,10 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): ...@@ -271,12 +282,10 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
i, oc, h, w = s[Out].op.axis i, oc, h, w = s[Out].op.axis
ooc, ioc = s[Out].split(oc, factor=num_thread_x) ooc, ioc = s[Out].split(oc, factor=num_thread_x)
oh, ih = s[Out].split(h, factor=vthread_x) s[Out].reorder(i, ooc, h, w, ioc)
ow, iw = s[Out].split(w, factor=num_thread_y)
s[Out].reorder(i, ooc, oh, ih, ow, iw, ioc)
s[Out].bind(ioc, thread_x) s[Out].bind(ioc, thread_x)
s[Out].bind(iw, thread_y) s[Out].bind(w, thread_y)
s[Out].bind(ih, thread_xz) s[Out].bind(h, thread_xz)
s[Out].bind(ooc, block_x) s[Out].bind(ooc, block_x)
s[Out_L].compute_at(s[Out], ioc) s[Out_L].compute_at(s[Out], ioc)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment