Commit 94359722 by Leyuan Wang Committed by Tianqi Chen

conv2d adjusted to fix different workloads (#511)

parent 2f4a5ad9
#pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments #pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches
"""Schedule for conv2d_nchw with auto fusion""" """Schedule for conv2d_nchw with auto fusion"""
import tvm import tvm
from .. import util from .. import util
...@@ -81,20 +81,20 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag): ...@@ -81,20 +81,20 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
thread_yz = tvm.thread_axis((0, vthread_y), "vthread", name="vy") thread_yz = tvm.thread_axis((0, vthread_y), "vthread", name="vy")
i, oc, h, w = s[Out].op.axis i, oc, h, w = s[Out].op.axis
w = s[Out].fuse(h, w) oh, ih = s[Out].split(h, nparts=vthread_x)
ow, iw = s[Out].split(w, factor=num_thread_x*vthread_x) w = s[Out].fuse(ih, w)
ooc, ioc = s[Out].split(oc, factor=num_thread_y*vthread_y) ooc, ioc = s[Out].split(oc, factor=num_thread_y*vthread_y)
oiw, iiw = s[Out].split(iw, nparts=vthread_x) ow, iw = s[Out].split(w, factor=num_thread_x)
oioc, iioc = s[Out].split(ioc, nparts=vthread_y) oioc, iioc = s[Out].split(ioc, nparts=vthread_y)
s[Out].reorder(i, ooc, ow, oioc, oiw, iioc, iiw) s[Out].reorder(i, ooc, oh, oioc, ow, iioc, iw)
s[Out].bind(iiw, thread_x) s[Out].bind(iw, thread_x)
s[Out].bind(iioc, thread_y) s[Out].bind(iioc, thread_y)
s[Out].bind(oiw, thread_xz) s[Out].bind(ow, thread_xz)
s[Out].bind(oioc, thread_yz) s[Out].bind(oioc, thread_yz)
s[Out].bind(ow, block_x) s[Out].bind(oh, block_x)
s[Out].bind(ooc, block_y) s[Out].bind(ooc, block_y)
s[Out_L].compute_at(s[Out], iiw) s[Out_L].compute_at(s[Out], iw)
# schedule Out_L local write # schedule Out_L local write
i, oc, h, w = s[Out_L].op.axis i, oc, h, w = s[Out_L].op.axis
...@@ -216,7 +216,7 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag): ...@@ -216,7 +216,7 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern""" """Schedule conv2d for specific feature_in_out_filter pattern"""
if util.get_const_int(Filter.shape[1]) == 256: if util.get_const_int(Filter.shape[0]) + util.get_const_int(Filter.shape[1]) <= 768:
# scheduler params # scheduler params
vthread_x = util.get_const_int(Out.shape[3]) vthread_x = util.get_const_int(Out.shape[3])
num_thread_x = 64 num_thread_x = 64
...@@ -391,7 +391,7 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L): ...@@ -391,7 +391,7 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L):
s[Filter_S].bind(ii, thread_y) s[Filter_S].bind(ii, thread_y)
def schedule_conv2d_small_batch(outs): def schedule_conv2d_small_batch(outs):
"""Create schedule for tensors or return error if batch size is larager than 1""" """Create schedule for tensors or return error if batch size is larger than 1"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
def schedule(temp, Filter, Output): def schedule(temp, Filter, Output):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment