Commit 94359722 by Leyuan Wang Committed by Tianqi Chen

conv2d adjusted to fix different workloads (#511)

parent 2f4a5ad9
#pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments
#pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches
"""Schedule for conv2d_nchw with auto fusion"""
import tvm
from .. import util
......@@ -81,20 +81,20 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
thread_yz = tvm.thread_axis((0, vthread_y), "vthread", name="vy")
i, oc, h, w = s[Out].op.axis
w = s[Out].fuse(h, w)
ow, iw = s[Out].split(w, factor=num_thread_x*vthread_x)
oh, ih = s[Out].split(h, nparts=vthread_x)
w = s[Out].fuse(ih, w)
ooc, ioc = s[Out].split(oc, factor=num_thread_y*vthread_y)
oiw, iiw = s[Out].split(iw, nparts=vthread_x)
ow, iw = s[Out].split(w, factor=num_thread_x)
oioc, iioc = s[Out].split(ioc, nparts=vthread_y)
s[Out].reorder(i, ooc, ow, oioc, oiw, iioc, iiw)
s[Out].bind(iiw, thread_x)
s[Out].reorder(i, ooc, oh, oioc, ow, iioc, iw)
s[Out].bind(iw, thread_x)
s[Out].bind(iioc, thread_y)
s[Out].bind(oiw, thread_xz)
s[Out].bind(ow, thread_xz)
s[Out].bind(oioc, thread_yz)
s[Out].bind(ow, block_x)
s[Out].bind(oh, block_x)
s[Out].bind(ooc, block_y)
s[Out_L].compute_at(s[Out], iiw)
s[Out_L].compute_at(s[Out], iw)
# schedule Out_L local write
i, oc, h, w = s[Out_L].op.axis
......@@ -216,7 +216,7 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
if util.get_const_int(Filter.shape[1]) == 256:
if util.get_const_int(Filter.shape[0]) + util.get_const_int(Filter.shape[1]) <= 768:
# scheduler params
vthread_x = util.get_const_int(Out.shape[3])
num_thread_x = 64
......@@ -391,7 +391,7 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L):
s[Filter_S].bind(ii, thread_y)
def schedule_conv2d_small_batch(outs):
"""Create schedule for tensors or return error if batch size is larager than 1"""
"""Create schedule for tensors or return error if batch size is larger than 1"""
s = tvm.create_schedule([x.op for x in outs])
def schedule(temp, Filter, Output):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment