Commit 13970eba by Leyuan Wang Committed by Tianqi Chen

conv2d data re-layout fix out of threads bug (#514)

* conv2d layout change bug fixed

* remove debug msg

* misaligned error fixed
parent 7d42f9f3
...@@ -10,19 +10,18 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): ...@@ -10,19 +10,18 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L):
ofactor = 16 ofactor = 16
hfactor = 2 hfactor = 2
ow_size = util.get_const_int(Out.shape[3]) ow_size = util.get_const_int(Out.shape[3])
num_thread = ow_size*hfactor num_thread = ow_size * hfactor
vthread = hfactor vthread = ofactor
block_x = tvm.thread_axis("blockIdx.x") block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx") thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
i, oc, h, w = s[Out].op.axis i, oc, h, w = s[Out].op.axis
ooc, ioc = s[Out].split(oc, factor=ofactor) ooc, ioc = s[Out].split(oc, factor=vthread)
oh, ih = s[Out].split(h, factor=hfactor) oh, ih = s[Out].split(h, factor=hfactor)
s[Out].reorder(ooc, oh, ioc, ih, w) s[Out].reorder(ooc, oh, ioc, ih, w)
oc = s[Out].fuse(ooc, oh) oc = s[Out].fuse(ooc, oh)
w = s[Out].fuse(w, ih) w = s[Out].fuse(w, ih)
s[Out].bind(w, thread_x) s[Out].bind(w, thread_x)
s[Out].bind(ioc, thread_xz) s[Out].bind(ioc, thread_xz)
s[Out].bind(oc, block_x) s[Out].bind(oc, block_x)
...@@ -261,9 +260,9 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): ...@@ -261,9 +260,9 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
else: else:
# scheduler params # scheduler params
vthread_x = util.get_const_int(Out.shape[2]) vthread_x = min(8, util.get_const_int(Out.shape[2]))
num_thread_x = 16 num_thread_x = 16
num_thread_y = util.get_const_int(Out.shape[3]) num_thread_y = min(8, util.get_const_int(Out.shape[3]))
ofactor = 8 ofactor = 8
block_x = tvm.thread_axis("blockIdx.x") block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x") thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
...@@ -272,10 +271,12 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): ...@@ -272,10 +271,12 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
i, oc, h, w = s[Out].op.axis i, oc, h, w = s[Out].op.axis
ooc, ioc = s[Out].split(oc, factor=num_thread_x) ooc, ioc = s[Out].split(oc, factor=num_thread_x)
s[Out].reorder(i, ooc, h, w, ioc) oh, ih = s[Out].split(h, factor=vthread_x)
ow, iw = s[Out].split(w, factor=num_thread_y)
s[Out].reorder(i, ooc, oh, ih, ow, iw, ioc)
s[Out].bind(ioc, thread_x) s[Out].bind(ioc, thread_x)
s[Out].bind(w, thread_y) s[Out].bind(iw, thread_y)
s[Out].bind(h, thread_xz) s[Out].bind(ih, thread_xz)
s[Out].bind(ooc, block_x) s[Out].bind(ooc, block_x)
s[Out_L].compute_at(s[Out], ioc) s[Out_L].compute_at(s[Out], ioc)
...@@ -289,21 +290,19 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): ...@@ -289,21 +290,19 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
s[temp_S].compute_at(s[Out_L], oic) s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], oic) s[Filter_S].compute_at(s[Out_L], oic)
rfactor = util.get_const_int(Filter.shape[1]) num_thread = 512
thread_xx = tvm.thread_axis((0, rfactor), "threadIdx.x") thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_xx = tvm.thread_axis("blockIdx.x") block_xx = tvm.thread_axis("blockIdx.x")
i, ic, h, w = s[temp].op.axis i = s[temp].fuse(*s[temp].op.axis)
ic = s[temp].fuse(ic, h, w) bx, tx = s[temp].split(i, factor=num_thread)
oic, iic = s[temp].split(ic, factor=rfactor) s[temp].bind(tx, thread_xx)
s[temp].bind(iic, thread_xx) s[temp].bind(bx, block_xx)
s[temp].bind(oic, block_xx)
i = s[temp_R].fuse(*s[temp_R].op.axis)
i, h, w, oic, iic = s[temp_R].op.axis bx, tx = s[temp_R].split(i, factor=num_thread)
ic = s[temp_R].fuse(oic, iic) s[temp_R].bind(tx, thread_xx)
s[temp_R].bind(ic, thread_xx) s[temp_R].bind(bx, block_xx)
h = s[temp_R].fuse(h, w)
s[temp_R].bind(h, block_xx)
#schedule temp_S shared mem load #schedule temp_S shared mem load
i, h, w, oc, ic = s[temp_S].op.axis i, h, w, oc, ic = s[temp_S].op.axis
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment