Commit 769544ad by Leyuan Wang Committed by Tianqi Chen

conv2d schedule fall back warning fixed (#450)

parent 220fa040
......@@ -5,7 +5,7 @@ from .. import util
def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
# sheduler params
# scheduler params
ofactor = 16
hfactor = 2
ow_size = util.get_const_int(Out.shape[3])
......@@ -102,7 +102,7 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):
s[Filter_S].bind(ioc, thread_y)
s[Filter_S].bind(ii, thread_x)
else:
# sheduler params
# scheduler params
num_thread = 8
vthread = 2
opart2 = 4
......@@ -145,11 +145,9 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):
ic, dh, dw = s[Out_L].op.reduce_axis
oic, iic = s[Out_L].split(ic, factor=ifactor)
s[Out_L].reorder(oic, dh, dw, iic, h, w)
s[Out_L].fuse(iic, dw)
dh = s[Out_L].fuse(dh, oic)
s[temp_S].compute_at(s[Out_L], dh)
s[Filter_S].compute_at(s[Out_L], dh)
oic = s[Out_L].fuse(dh, oic)
s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], oic)
#schedule temp_S shared mem load
i, ic, h, w = s[temp_S].op.axis
......@@ -170,7 +168,6 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):
ic, dh, dw = s[Out_L].op.reduce_axis
oic, iic = s[Out_L].split(ic, factor=ifactor)
s[Out_L].reorder(oic, dh, dw, iic, h, w)
# dh = s[Out_L].fuse(dh, oic)
s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], dw)
......@@ -192,7 +189,7 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):
def conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
if util.get_const_int(Filter.shape[1]) == 256:
# sheduler params
# scheduler params
vthread_x = util.get_const_int(Out.shape[3])
num_thread_x = 64
ofactor = 8
......@@ -235,7 +232,7 @@ def conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L):
s[Filter_S].storage_align(s[Filter_S].op.axis[0], 2, 1)
else:
# sheduler params
# scheduler params
vthread_x = util.get_const_int(Out.shape[2])
num_thread_x = 16
num_thread_y = util.get_const_int(Out.shape[3])
......@@ -283,7 +280,7 @@ def conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L):
def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
# sheduler params
# scheduler params
num_thread = 8
vthread = 2
opart2 = 4
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment