Commit 769544ad by Leyuan Wang Committed by Tianqi Chen

conv2d schedule fall back warning fixed (#450)

parent 220fa040
...@@ -5,7 +5,7 @@ from .. import util ...@@ -5,7 +5,7 @@ from .. import util
def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L): def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern""" """Schedule conv2d for specific feature_in_out_filter pattern"""
# sheduler params # scheduler params
ofactor = 16 ofactor = 16
hfactor = 2 hfactor = 2
ow_size = util.get_const_int(Out.shape[3]) ow_size = util.get_const_int(Out.shape[3])
...@@ -102,7 +102,7 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag): ...@@ -102,7 +102,7 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):
s[Filter_S].bind(ioc, thread_y) s[Filter_S].bind(ioc, thread_y)
s[Filter_S].bind(ii, thread_x) s[Filter_S].bind(ii, thread_x)
else: else:
# sheduler params # scheduler params
num_thread = 8 num_thread = 8
vthread = 2 vthread = 2
opart2 = 4 opart2 = 4
...@@ -145,11 +145,9 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag): ...@@ -145,11 +145,9 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):
ic, dh, dw = s[Out_L].op.reduce_axis ic, dh, dw = s[Out_L].op.reduce_axis
oic, iic = s[Out_L].split(ic, factor=ifactor) oic, iic = s[Out_L].split(ic, factor=ifactor)
s[Out_L].reorder(oic, dh, dw, iic, h, w) s[Out_L].reorder(oic, dh, dw, iic, h, w)
s[Out_L].fuse(iic, dw) oic = s[Out_L].fuse(dh, oic)
dh = s[Out_L].fuse(dh, oic) s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], oic)
s[temp_S].compute_at(s[Out_L], dh)
s[Filter_S].compute_at(s[Out_L], dh)
#schedule temp_S shared mem load #schedule temp_S shared mem load
i, ic, h, w = s[temp_S].op.axis i, ic, h, w = s[temp_S].op.axis
...@@ -170,7 +168,6 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag): ...@@ -170,7 +168,6 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):
ic, dh, dw = s[Out_L].op.reduce_axis ic, dh, dw = s[Out_L].op.reduce_axis
oic, iic = s[Out_L].split(ic, factor=ifactor) oic, iic = s[Out_L].split(ic, factor=ifactor)
s[Out_L].reorder(oic, dh, dw, iic, h, w) s[Out_L].reorder(oic, dh, dw, iic, h, w)
# dh = s[Out_L].fuse(dh, oic)
s[temp_S].compute_at(s[Out_L], oic) s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], dw) s[Filter_S].compute_at(s[Out_L], dw)
...@@ -192,7 +189,7 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag): ...@@ -192,7 +189,7 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):
def conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L): def conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern""" """Schedule conv2d for specific feature_in_out_filter pattern"""
if util.get_const_int(Filter.shape[1]) == 256: if util.get_const_int(Filter.shape[1]) == 256:
# sheduler params # scheduler params
vthread_x = util.get_const_int(Out.shape[3]) vthread_x = util.get_const_int(Out.shape[3])
num_thread_x = 64 num_thread_x = 64
ofactor = 8 ofactor = 8
...@@ -235,7 +232,7 @@ def conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L): ...@@ -235,7 +232,7 @@ def conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L):
s[Filter_S].storage_align(s[Filter_S].op.axis[0], 2, 1) s[Filter_S].storage_align(s[Filter_S].op.axis[0], 2, 1)
else: else:
# sheduler params # scheduler params
vthread_x = util.get_const_int(Out.shape[2]) vthread_x = util.get_const_int(Out.shape[2])
num_thread_x = 16 num_thread_x = 16
num_thread_y = util.get_const_int(Out.shape[3]) num_thread_y = util.get_const_int(Out.shape[3])
...@@ -283,7 +280,7 @@ def conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L): ...@@ -283,7 +280,7 @@ def conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L):
def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L): def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern""" """Schedule conv2d for specific feature_in_out_filter pattern"""
# sheduler params # scheduler params
num_thread = 8 num_thread = 8
vthread = 2 vthread = 2
opart2 = 4 opart2 = 4
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment