Commit 13970eba by Leyuan Wang Committed by Tianqi Chen

conv2d data re-layout fix out of threads bug (#514)

* conv2d layout change bug fixed

* remove debug msg

* misaligned error fixed
parent 7d42f9f3
......@@ -10,19 +10,18 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L):
ofactor = 16
hfactor = 2
ow_size = util.get_const_int(Out.shape[3])
num_thread = ow_size*hfactor
vthread = hfactor
num_thread = ow_size * hfactor
vthread = ofactor
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
i, oc, h, w = s[Out].op.axis
ooc, ioc = s[Out].split(oc, factor=ofactor)
ooc, ioc = s[Out].split(oc, factor=vthread)
oh, ih = s[Out].split(h, factor=hfactor)
s[Out].reorder(ooc, oh, ioc, ih, w)
oc = s[Out].fuse(ooc, oh)
w = s[Out].fuse(w, ih)
s[Out].bind(w, thread_x)
s[Out].bind(ioc, thread_xz)
s[Out].bind(oc, block_x)
......@@ -261,9 +260,9 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
else:
# scheduler params
vthread_x = util.get_const_int(Out.shape[2])
vthread_x = min(8, util.get_const_int(Out.shape[2]))
num_thread_x = 16
num_thread_y = util.get_const_int(Out.shape[3])
num_thread_y = min(8, util.get_const_int(Out.shape[3]))
ofactor = 8
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
......@@ -272,10 +271,12 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
i, oc, h, w = s[Out].op.axis
ooc, ioc = s[Out].split(oc, factor=num_thread_x)
s[Out].reorder(i, ooc, h, w, ioc)
oh, ih = s[Out].split(h, factor=vthread_x)
ow, iw = s[Out].split(w, factor=num_thread_y)
s[Out].reorder(i, ooc, oh, ih, ow, iw, ioc)
s[Out].bind(ioc, thread_x)
s[Out].bind(w, thread_y)
s[Out].bind(h, thread_xz)
s[Out].bind(iw, thread_y)
s[Out].bind(ih, thread_xz)
s[Out].bind(ooc, block_x)
s[Out_L].compute_at(s[Out], ioc)
......@@ -289,21 +290,19 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], oic)
rfactor = util.get_const_int(Filter.shape[1])
thread_xx = tvm.thread_axis((0, rfactor), "threadIdx.x")
num_thread = 512
thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_xx = tvm.thread_axis("blockIdx.x")
i, ic, h, w = s[temp].op.axis
ic = s[temp].fuse(ic, h, w)
oic, iic = s[temp].split(ic, factor=rfactor)
s[temp].bind(iic, thread_xx)
s[temp].bind(oic, block_xx)
i, h, w, oic, iic = s[temp_R].op.axis
ic = s[temp_R].fuse(oic, iic)
s[temp_R].bind(ic, thread_xx)
h = s[temp_R].fuse(h, w)
s[temp_R].bind(h, block_xx)
i = s[temp].fuse(*s[temp].op.axis)
bx, tx = s[temp].split(i, factor=num_thread)
s[temp].bind(tx, thread_xx)
s[temp].bind(bx, block_xx)
i = s[temp_R].fuse(*s[temp_R].op.axis)
bx, tx = s[temp_R].split(i, factor=num_thread)
s[temp_R].bind(tx, thread_xx)
s[temp_R].bind(bx, block_xx)
#schedule temp_S shared mem load
i, h, w, oc, ic = s[temp_S].op.axis
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment