Commit a7f0f818 by Leyuan Wang Committed by Tianqi Chen

Packing and data layout change added to conv2d_nchw (#479)

* conv2d layout change and packing added for the last workload

* packing added for other workloads

* conv2d added packing for first workload

* fix pylint error
parent 75d53777
......@@ -4,7 +4,7 @@ import tvm
from .. import util
from .. import tag
def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L):
def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
# scheduler params
ofactor = 16
......@@ -36,10 +36,26 @@ def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L):
s[temp_S].compute_at(s[Out_L], ic)
s[Filter_S].compute_at(s[Out_L], w)
num_thread1 = 512
thread_xx = tvm.thread_axis((0, num_thread1), "threadIdx.x")
block_xx = tvm.thread_axis("blockIdx.x")
i = s[temp].fuse(*s[temp].op.axis)
bx, tx = s[temp].split(i, factor=num_thread1)
s[temp].bind(tx, thread_xx)
s[temp].bind(bx, block_xx)
i = s[temp_R].fuse(*s[temp_R].op.axis)
bx, tx = s[temp_R].split(i, factor=num_thread1)
s[temp_R].bind(tx, thread_xx)
s[temp_R].bind(bx, block_xx)
#schedule temp_S shared mem load
i, ic, h, w = s[temp_S].op.axis
tx, ih = s[temp_S].split(w, nparts=num_thread)
i, ic, h, ow, iw = s[temp_S].op.axis
h = s[temp_S].fuse(h, ow)
_, tx = s[temp_S].split(h, factor=num_thread)
s[temp_S].bind(tx, thread_x)
s[temp_S].vectorize(iw)
#schedule Filter_S shared mem load
i, oc, h, w = s[Filter_S].op.axis
......@@ -48,7 +64,7 @@ def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L):
tx, _ = s[Filter_S].split(w, nparts=num_thread)
s[Filter_S].bind(tx, thread_x)
def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):
def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
if util.get_const_int(Filter_S.shape[0]) == util.get_const_int(Filter_S.shape[1]):
num_thread_x = 8
......@@ -89,13 +105,28 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):
s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], dw)
num_thread = 512
thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_xx = tvm.thread_axis("blockIdx.x")
i = s[temp].fuse(*s[temp].op.axis)
bx, tx = s[temp].split(i, factor=num_thread)
s[temp].bind(tx, thread_xx)
s[temp].bind(bx, block_xx)
i = s[temp_R].fuse(*s[temp_R].op.axis)
bx, tx = s[temp_R].split(i, factor=num_thread)
s[temp_R].bind(tx, thread_xx)
s[temp_R].bind(bx, block_xx)
#schedule temp_S shared mem load
i, ic, h, w = s[temp_S].op.axis
_, iic = s[temp_S].split(ic, factor=num_thread_y)
w = s[temp_S].fuse(h, w)
_, iw = s[temp_S].split(w, factor=num_thread_x)
s[temp_S].bind(iic, thread_y)
s[temp_S].bind(iw, thread_x)
i, oic, h, w, iic = s[temp_S].op.axis
oic = s[temp_S].fuse(oic, h, w)
ooic, ioic = s[temp_S].split(oic, factor=num_thread_x)
_, iooic = s[temp_S].split(ooic, factor=num_thread_y)
s[temp_S].bind(ioic, thread_x)
s[temp_S].bind(iooic, thread_y)
s[temp_S].vectorize(iic)
i, oc, h, w = s[Filter_S].op.axis
_, ioc = s[Filter_S].split(oc, factor=num_thread_y)
......@@ -104,7 +135,6 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):
s[Filter_S].bind(ii, thread_x)
else:
# scheduler params
num_thread = 8
vthread = 2
opart2 = 4
ofactor = 64
......@@ -112,13 +142,13 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):
ifactor = 8
if flag > 256:
wfactor = 14
sfactor = max(1, ofactor//(opart2*2))
spart = max(1, (wfactor + vthread-1) // vthread)
num_thread_x = max(1, ofactor//(opart2*2))
num_thread_y = max(1, (wfactor + vthread-1) // vthread)
block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
block_z = tvm.thread_axis("blockIdx.z")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
......@@ -140,54 +170,51 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):
s[Out_L].compute_at(s[Out], iiioc)
# schedule Out_L local write
i, oc, h, w = s[Out_L].op.axis
ic, dh, dw = s[Out_L].op.reduce_axis
oic, iic = s[Out_L].split(ic, factor=ifactor)
s[Out_L].reorder(oic, dh, dw, iic, h, w)
if util.get_const_int(Filter_S.shape[1]) == 128:
# schedule Out_L local write
i, oc, h, w = s[Out_L].op.axis
ic, dh, dw = s[Out_L].op.reduce_axis
oic, iic = s[Out_L].split(ic, factor=ifactor)
s[Out_L].reorder(oic, dh, dw, iic, h, w)
oic = s[Out_L].fuse(dh, oic)
s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], oic)
#schedule temp_S shared mem load
i, ic, h, w = s[temp_S].op.axis
_, iic = s[temp_S].split(ic, factor=sfactor)
_, iw = s[temp_S].split(w, factor=spart)
s[temp_S].bind(iic, thread_x)
s[temp_S].bind(iw, thread_y)
#schedule Filter_S shared mem load
i, oc, h, w = s[Filter_S].op.axis
_, ioc = s[Filter_S].split(oc, factor=sfactor)
_, ii = s[Filter_S].split(i, factor=spart)
s[Filter_S].bind(ioc, thread_x)
s[Filter_S].bind(ii, thread_y)
num_thread = 512
else:
# schedule Out_L local write
i, oc, h, w = s[Out_L].op.axis
ic, dh, dw = s[Out_L].op.reduce_axis
oic, iic = s[Out_L].split(ic, factor=ifactor)
s[Out_L].reorder(oic, dh, dw, iic, h, w)
s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], dw)
num_thread = 456
thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_xx = tvm.thread_axis("blockIdx.x")
#schedule temp_S shared mem load
i, ic, h, w = s[temp_S].op.axis
_, iic = s[temp_S].split(ic, factor=sfactor)
_, iw = s[temp_S].split(w, factor=spart)
s[temp_S].bind(iic, thread_x)
s[temp_S].bind(iw, thread_y)
#schedule Filter_S shared mem load
i, oc, h, w = s[Filter_S].op.axis
_, ioc = s[Filter_S].split(oc, factor=sfactor)
_, ii = s[Filter_S].split(i, factor=spart)
s[Filter_S].bind(ioc, thread_x)
s[Filter_S].bind(ii, thread_y)
def conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L):
i = s[temp].fuse(*s[temp].op.axis)
bx, tx = s[temp].split(i, factor=num_thread)
s[temp].bind(tx, thread_xx)
s[temp].bind(bx, block_xx)
i = s[temp_R].fuse(*s[temp_R].op.axis)
bx, tx = s[temp_R].split(i, factor=num_thread)
s[temp_R].bind(tx, thread_xx)
s[temp_R].bind(bx, block_xx)
#schedule temp_S shared mem load
i, oic, h, w, iic = s[temp_S].op.axis
oic = s[temp_S].fuse(oic, h, w)
ooic, ioic = s[temp_S].split(oic, factor=num_thread_x)
_, iooic = s[temp_S].split(ooic, factor=num_thread_y)
s[temp_S].bind(ioic, thread_x)
s[temp_S].bind(iooic, thread_y)
s[temp_S].vectorize(iic)
#schedule Filter_S shared mem load
i, oc, h, w = s[Filter_S].op.axis
_, ioc = s[Filter_S].split(oc, factor=num_thread_x)
_, ii = s[Filter_S].split(i, factor=num_thread_y)
s[Filter_S].bind(ioc, thread_x)
s[Filter_S].bind(ii, thread_y)
def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
if util.get_const_int(Filter.shape[1]) == 256:
# scheduler params
......@@ -262,13 +289,30 @@ def conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L):
s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], oic)
rfactor = util.get_const_int(Filter.shape[1])
thread_xx = tvm.thread_axis((0, rfactor), "threadIdx.x")
block_xx = tvm.thread_axis("blockIdx.x")
i, ic, h, w = s[temp].op.axis
ic = s[temp].fuse(ic, h, w)
oic, iic = s[temp].split(ic, factor=rfactor)
s[temp].bind(iic, thread_xx)
s[temp].bind(oic, block_xx)
i, h, w, oic, iic = s[temp_R].op.axis
ic = s[temp_R].fuse(oic, iic)
s[temp_R].bind(ic, thread_xx)
h = s[temp_R].fuse(h, w)
s[temp_R].bind(h, block_xx)
#schedule temp_S shared mem load
i, ic, h, w = s[temp_S].op.axis
ic = s[temp_S].fuse(w, h, ic)
oic, iic = s[temp_S].split(ic, factor=num_thread_x)
i, h, w, oc, ic = s[temp_S].op.axis
icc = s[temp_S].fuse(oc, w, h)
oic, iic = s[temp_S].split(icc, factor=num_thread_x)
_, ioic = s[temp_S].split(oic, factor=num_thread_y)
s[temp_S].bind(iic, thread_x)
s[temp_S].bind(ioic, thread_y)
s[temp_S].vectorize(ic)
#schedule Filter_S shared mem load
i, oc, h, w = s[Filter_S].op.axis
......@@ -363,9 +407,36 @@ def schedule_conv2d_small_batch(outs):
elif block_w % 32 == 0:
block_w = 32
s[temp].compute_inline()
flag = util.get_const_int(Filter.shape[0])+util.get_const_int(Filter.shape[1])
if flag > 768:
temp_G = s.cache_read(temp, "global", [Output])
s[temp_G].compute_inline()
i, ic, h, w = s[temp_G].op.axis
oic, iic = s[temp_G].split(ic, factor=4)
s[temp_G].reorder(i, h, w, oic, iic)
temp_R = s.cache_write(temp_G, "global")
temp_S = s.cache_read(temp_R, "shared", [temp_G])
elif 128 < flag < 512:
temp_G = s.cache_read(temp, "global", [Output])
s[temp_G].compute_inline()
i, ic, h, w = s[temp_G].op.axis
oic, iic = s[temp_G].split(ic, factor=4)
s[temp_G].reorder(i, oic, h, w, iic)
temp_R = s.cache_write(temp_G, "global")
temp_S = s.cache_read(temp_R, "shared", [temp_G])
elif util.get_const_int(Filter.shape[3]) == 7:
temp_G = s.cache_read(temp, "global", [Output])
s[temp_G].compute_inline()
i, ic, h, w = s[temp_G].op.axis
s[temp_G].split(w, factor=4)
temp_R = s.cache_write(temp_G, "global")
temp_S = s.cache_read(temp_R, "shared", [temp_G])
else:
s[temp].compute_inline()
temp_S = s.cache_read(temp, "shared", [Output])
temp_R = temp_S
temp_S = s.cache_read(temp, "shared", [Output])
Filter_S = s.cache_read(Filter, "shared", [Output])
if Output.op in s.outputs:
......@@ -376,14 +447,12 @@ def schedule_conv2d_small_batch(outs):
s[Output].set_scope("local")
Out_L = Output
flag = util.get_const_int(Filter.shape[0])+util.get_const_int(Filter.shape[1])
if util.get_const_int(Filter.shape[3]) == 7:
conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L)
conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L)
elif 128 < flag < 512:
conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag)
conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag)
elif flag >= 512:
conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L)
conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L)
else:
conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment