Commit 7d620be4 by Leyuan Wang Committed by Tianqi Chen

Conv2d scheduler tweaked for super resolution perf (#652)

* scheduler tweaked for super resolution perf

* lint error fixed

* lint error fixed

* conv2d_transpose schedule error fixed
parent a2aa154c
#pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches #pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches, line-too-long
"""Schedule for conv2d_nchw with auto fusion""" """Schedule for conv2d_nchw with auto fusion"""
import tvm import tvm
from .. import util from .. import util
from .. import tag from .. import tag
from .. import generic from .. import generic
def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
"""Schedule conv2d for specific feature_in_out_filter pattern""" """Schedule conv2d for specific feature_in_out_filter pattern"""
# scheduler params # scheduler params
ofactor = 16 ofactor = 16
hfactor = 2 hfactor = 2
if flag >= 96:
hfactor = 4
ow_size = util.get_const_int(Out.shape[3]) ow_size = util.get_const_int(Out.shape[3])
num_thread = ow_size * hfactor num_thread = ow_size * hfactor
vthread = ofactor vthread = ofactor
...@@ -22,7 +24,8 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): ...@@ -22,7 +24,8 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L):
oh, ih = s[Out].split(h, factor=hfactor) oh, ih = s[Out].split(h, factor=hfactor)
s[Out].reorder(ooc, oh, ioc, ih, w) s[Out].reorder(ooc, oh, ioc, ih, w)
oc = s[Out].fuse(ooc, oh) oc = s[Out].fuse(ooc, oh)
w = s[Out].fuse(w, ih) ow, _ = s[Out].split(w, nparts=ow_size)
w = s[Out].fuse(ow, ih)
s[Out].bind(w, thread_x) s[Out].bind(w, thread_x)
s[Out].bind(ioc, thread_xz) s[Out].bind(ioc, thread_xz)
s[Out].bind(oc, block_x) s[Out].bind(oc, block_x)
...@@ -360,7 +363,11 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L): ...@@ -360,7 +363,11 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L):
if util.get_const_int(Filter.shape[0]) == 64: if util.get_const_int(Filter.shape[0]) == 64:
opart2 = 8 opart2 = 8
ifactor = 16 ifactor = 16
sfactor = max(1, ofactor // (opart2*2)) if util.get_const_int(Out.shape[2]) == 224:
num_thread = 4
wfactor = 112
ifactor = 4
sfactor = max(1, ofactor // (opart2*vthread))
spart = max(1, (wfactor + vthread-1) // vthread) spart = max(1, (wfactor + vthread-1) // vthread)
block_x = tvm.thread_axis("blockIdx.x") block_x = tvm.thread_axis("blockIdx.x")
...@@ -368,7 +375,7 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L): ...@@ -368,7 +375,7 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L):
block_z = tvm.thread_axis("blockIdx.z") block_z = tvm.thread_axis("blockIdx.z")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, wfactor // vthread), "threadIdx.y") thread_y = tvm.thread_axis((0, wfactor // vthread), "threadIdx.y")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx") thread_xz = tvm.thread_axis((0, opart2), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy") thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
i, oc, h, w = s[Out].op.axis i, oc, h, w = s[Out].op.axis
...@@ -394,10 +401,10 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L): ...@@ -394,10 +401,10 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L):
ic, dh, dw = s[Out_L].op.reduce_axis ic, dh, dw = s[Out_L].op.reduce_axis
oic, iic = s[Out_L].split(ic, factor=ifactor) oic, iic = s[Out_L].split(ic, factor=ifactor)
s[Out_L].reorder(oic, dh, dw, iic, h, w) s[Out_L].reorder(oic, dh, dw, iic, h, w)
fuse_index = s[Out_L].fuse(dw, dh) fuse_index = s[Out_L].fuse(dw, dh)
fuse_index = s[Out_L].fuse(fuse_index, oic) fuse_index = s[Out_L].fuse(fuse_index, oic)
dw = fuse_index dw = fuse_index
s[temp_S].compute_at(s[Out_L], dw) s[temp_S].compute_at(s[Out_L], dw)
s[Filter_S].compute_at(s[Out_L], dw) s[Filter_S].compute_at(s[Out_L], dw)
...@@ -421,16 +428,6 @@ def schedule_conv2d_small_batch(outs): ...@@ -421,16 +428,6 @@ def schedule_conv2d_small_batch(outs):
def schedule(temp, Filter, Output): def schedule(temp, Filter, Output):
"""Schedule conv2d_nchw""" """Schedule conv2d_nchw"""
block_h = util.get_const_int(Output.shape[3])
block_w = util.get_const_int(temp.shape[1])
if block_h % 48 == 0:
block_h = 48
elif block_h % 32 == 0:
block_h = 32
if block_w % 48 == 0:
block_w = 48
elif block_w % 32 == 0:
block_w = 32
flag = util.get_const_int(Filter.shape[0])+util.get_const_int(Filter.shape[1]) flag = util.get_const_int(Filter.shape[0])+util.get_const_int(Filter.shape[1])
...@@ -450,7 +447,7 @@ def schedule_conv2d_small_batch(outs): ...@@ -450,7 +447,7 @@ def schedule_conv2d_small_batch(outs):
s[temp_G].reorder(i, oic, h, w, iic) s[temp_G].reorder(i, oic, h, w, iic)
temp_R = s.cache_write(temp_G, "global") temp_R = s.cache_write(temp_G, "global")
temp_S = s.cache_read(temp_R, "shared", [temp_G]) temp_S = s.cache_read(temp_R, "shared", [temp_G])
elif util.get_const_int(Filter.shape[3]) == 7: elif util.get_const_int(Filter.shape[3]) == 7 or (util.get_const_int(Output.shape[2] == 224) and flag < 128):
temp_G = s.cache_read(temp, "global", [Output]) temp_G = s.cache_read(temp, "global", [Output])
s[temp_G].compute_inline() s[temp_G].compute_inline()
i, ic, h, w = s[temp_G].op.axis i, ic, h, w = s[temp_G].op.axis
...@@ -472,8 +469,8 @@ def schedule_conv2d_small_batch(outs): ...@@ -472,8 +469,8 @@ def schedule_conv2d_small_batch(outs):
s[Output].set_scope("local") s[Output].set_scope("local")
Out_L = Output Out_L = Output
if util.get_const_int(Filter.shape[3]) == 7: if util.get_const_int(Filter.shape[3]) == 7 or (util.get_const_int(Output.shape[2] == 224) and flag < 128):
conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L) conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag)
elif 128 < flag < 512: elif 128 < flag < 512:
conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag) conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag)
elif flag >= 512: elif flag >= 512:
......
#pylint: disable=invalid-name #pylint: disable=invalid-name, line-too-long
"""Schedule for conv2d_transpose_nchw with auto fusion""" """Schedule for conv2d_transpose_nchw with auto fusion"""
import tvm import tvm
from .. import util from .. import util
...@@ -42,7 +42,7 @@ def schedule_conv2d_transpose_small_batch(outs): ...@@ -42,7 +42,7 @@ def schedule_conv2d_transpose_small_batch(outs):
s[temp_G].reorder(i, oic, h, w, iic) s[temp_G].reorder(i, oic, h, w, iic)
temp_R = s.cache_write(temp_G, "global") temp_R = s.cache_write(temp_G, "global")
temp_S = s.cache_read(temp_R, "shared", [temp_G]) temp_S = s.cache_read(temp_R, "shared", [temp_G])
elif util.get_const_int(Filter.shape[3]) == 7: elif util.get_const_int(Filter.shape[3]) == 7 or (util.get_const_int(Output.shape[2] == 224) and flag < 128):
temp_G = s.cache_read(temp, "global", [Output]) temp_G = s.cache_read(temp, "global", [Output])
s[temp_G].compute_inline() s[temp_G].compute_inline()
i, ic, h, w = s[temp_G].op.axis i, ic, h, w = s[temp_G].op.axis
...@@ -64,8 +64,8 @@ def schedule_conv2d_transpose_small_batch(outs): ...@@ -64,8 +64,8 @@ def schedule_conv2d_transpose_small_batch(outs):
s[Output].set_scope("local") s[Output].set_scope("local")
Out_L = Output Out_L = Output
if util.get_const_int(Filter.shape[3]) == 7: if util.get_const_int(Filter.shape[3]) == 7 or (util.get_const_int(Output.shape[2] == 224) and flag < 128):
conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L) conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag)
elif 128 < flag < 512: elif 128 < flag < 512:
conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag) conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag)
elif flag >= 512: elif flag >= 512:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment