Commit ab17bf65 by Leyuan Wang Committed by Tianqi Chen

[TOPI] Improve conv2d for resnet18 workload (#427)

* relu activation migrated to topi

* reviews addressed

* relu compute deleted

* conv2d_nchw updated

* resnet18 hand tuned schedule added

* pylint error fixed

* one more workload test for conv2d_nchw

* conv2d schedule subfunctions added for different patterns

* reviews addressed
parent 5ea4072c
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements
#pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments
"""Schedule for conv2d_nchw with auto fusion"""
import tvm
from .. import util
def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
# sheduler params
ofactor = 16
hfactor = 2
ow_size = util.get_const_int(Out.shape[3])
num_thread = ow_size*hfactor
vthread = hfactor
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
def schedule_conv2d_small_batch(outs):
"""Create schedule for tensors or return error if batch size is larager than 1"""
s = tvm.create_schedule([x.op for x in outs])
i, oc, h, w = s[Out].op.axis
ooc, ioc = s[Out].split(oc, factor=ofactor)
oh, ih = s[Out].split(h, factor=hfactor)
s[Out].reorder(ooc, oh, ioc, ih, w)
oc = s[Out].fuse(ooc, oh)
w = s[Out].fuse(w, ih)
def schedule(temp, Filter, Output):
"""Schedule conv2d_nchw"""
block_h = util.get_const_int(Output.shape[3])
block_w = util.get_const_int(temp.shape[1])
if block_h % 48 == 0:
block_h = 48
elif block_h % 32 == 0:
block_h = 32
if block_w % 48 == 0:
block_w = 48
elif block_w % 32 == 0:
block_w = 32
s[Out].bind(w, thread_x)
s[Out].bind(ioc, thread_xz)
s[Out].bind(oc, block_x)
s[temp].compute_inline()
s[Out_L].compute_at(s[Out], w)
temp_S = s.cache_read(temp, "shared", [Output])
Filter_S = s.cache_read(Filter, "shared", [Output])
# schedule Out_L local write
i, oc, h, w = s[Out_L].op.axis
ic, dh, dw = s[Out_L].op.reduce_axis
s[Out_L].reorder(i, oc, h, w, ic, dh, dw)
s[temp_S].compute_at(s[Out_L], ic)
s[Filter_S].compute_at(s[Out_L], w)
if Output.op in s.outputs:
Out = Output
Out_L = s.cache_write(Out, "local")
else:
Out = outs[0].op.output(0)
s[Output].set_scope("local")
Out_L = Output
#schedule temp_S shared mem load
i, ic, h, w = s[temp_S].op.axis
tx, ih = s[temp_S].split(w, nparts=num_thread)
s[temp_S].bind(tx, thread_x)
#schedule Filter_S shared mem load
i, oc, h, w = s[Filter_S].op.axis
fuse_index = s[Filter_S].fuse(w, h)
w = s[Filter_S].fuse(fuse_index, oc)
tx, _ = s[Filter_S].split(w, nparts=num_thread)
s[Filter_S].bind(tx, thread_x)
def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
# sheduler params
num_thread = 8
vthread = 2
opart2 = 4
ofactor = 64
wfactor = 28
ifactor = 8
if flag > 256:
wfactor = 14
sfactor = max(1, ofactor//(opart2*2))
spart = max(1, (wfactor + vthread-1) // vthread)
block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
block_z = tvm.thread_axis("blockIdx.z")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
i, oc, h, w = s[Out].op.axis
ooc, ioc = s[Out].split(oc, factor=ofactor)
ow, iw = s[Out].split(w, factor=wfactor)
ow = s[Out].fuse(ow, h)
oioc, iioc = s[Out].split(ioc, nparts=vthread)
oiw, iiw = s[Out].split(iw, nparts=vthread)
oiioc, iiioc = s[Out].split(iioc, nparts=opart2)
s[Out].reorder(i, ooc, ow, oioc, oiw, oiioc, iiw, iiioc)
s[Out].bind(iiioc, thread_x)
s[Out].bind(iiw, thread_y)
s[Out].bind(oiioc, thread_xz)
s[Out].bind(oiw, thread_yz)
s[Out].bind(oioc, block_x)
s[Out].bind(ow, block_y)
s[Out].bind(ooc, block_z)
s[Out_L].compute_at(s[Out], iiioc)
# schedule Out_L local write
i, oc, h, w = s[Out_L].op.axis
ic, dh, dw = s[Out_L].op.reduce_axis
oic, iic = s[Out_L].split(ic, factor=ifactor)
s[Out_L].reorder(oic, dh, dw, iic, h, w)
s[Out_L].fuse(iic, dw)
dh = s[Out_L].fuse(dh, oic)
s[temp_S].compute_at(s[Out_L], dh)
s[Filter_S].compute_at(s[Out_L], dh)
#schedule temp_S shared mem load
i, ic, h, w = s[temp_S].op.axis
_, iic = s[temp_S].split(ic, factor=sfactor)
_, iw = s[temp_S].split(w, factor=spart)
s[temp_S].bind(iic, thread_x)
s[temp_S].bind(iw, thread_y)
#schedule Filter_S shared mem load
i, oc, h, w = s[Filter_S].op.axis
_, ioc = s[Filter_S].split(oc, factor=sfactor)
_, ii = s[Filter_S].split(i, factor=spart)
s[Filter_S].bind(ioc, thread_x)
s[Filter_S].bind(ii, thread_y)
def conv2d_14_256_256(s, temp_S, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
# sheduler params
vthread_x = util.get_const_int(Out.shape[3])
num_thread_x = 64
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
thread_xz = tvm.thread_axis((0, vthread_x), "vthread", name="vx")
i, oc, h, w = s[Out].op.axis
ooc, ioc = s[Out].split(oc, factor=num_thread_x)
s[Out].reorder(i, ooc, h, w, ioc)
ooc = s[Out].fuse(h, ooc)
s[Out].bind(ioc, thread_x)
s[Out].bind(w, thread_xz)
s[Out].bind(ooc, block_x)
s[Out_L].compute_at(s[Out], ioc)
# schedule Out_L local write
i, oc, h, w = s[Out_L].op.axis
ic, dh, dw = s[Out_L].op.reduce_axis
oic, iic = s[Out_L].split(ic, factor=8)
s[Out_L].reorder(oic, dh, dw, iic, h, w)
s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], oic)
#schedule temp_S shared mem load
i, ic, h, w = s[temp_S].op.axis
s[temp_S].reorder(i, ic, w, h)
ic = s[temp_S].fuse(w, ic)
_, iic = s[temp_S].split(ic, factor=num_thread_x)
s[temp_S].bind(iic, thread_x)
#schedule Filter_S shared mem load
i, oc, h, w = s[Filter_S].op.axis
_, ii = s[Filter_S].split(i, factor=num_thread_x)
s[Filter_S].bind(ii, thread_x)
def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
# sheduler params
num_thread = 8
vthread = 2
......@@ -41,8 +160,12 @@ def schedule_conv2d_small_batch(outs):
ofactor = 64
wfactor = 56
ifactor = 8
if util.get_const_int(Filter.shape[0]) == 64:
opart2 = 8
ifactor = 16
sfactor = max(1, ofactor//(opart2*2))
spart = max(1, (wfactor + vthread-1) // vthread)
block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
block_z = tvm.thread_axis("blockIdx.z")
......@@ -74,10 +197,10 @@ def schedule_conv2d_small_batch(outs):
ic, dh, dw = s[Out_L].op.reduce_axis
oic, iic = s[Out_L].split(ic, factor=ifactor)
s[Out_L].reorder(oic, dh, dw, iic, h, w)
fuse_index = s[Out_L].fuse(dw, dh)
fuse_index = s[Out_L].fuse(fuse_index, oic)
dw = fuse_index
s[temp_S].compute_at(s[Out_L], dw)
s[Filter_S].compute_at(s[Out_L], dw)
......@@ -95,6 +218,47 @@ def schedule_conv2d_small_batch(outs):
s[Filter_S].bind(ioc, thread_x)
s[Filter_S].bind(ii, thread_y)
def schedule_conv2d_small_batch(outs):
"""Create schedule for tensors or return error if batch size is larager than 1"""
s = tvm.create_schedule([x.op for x in outs])
def schedule(temp, Filter, Output):
"""Schedule conv2d_nchw"""
block_h = util.get_const_int(Output.shape[3])
block_w = util.get_const_int(temp.shape[1])
if block_h % 48 == 0:
block_h = 48
elif block_h % 32 == 0:
block_h = 32
if block_w % 48 == 0:
block_w = 48
elif block_w % 32 == 0:
block_w = 32
s[temp].compute_inline()
temp_S = s.cache_read(temp, "shared", [Output])
Filter_S = s.cache_read(Filter, "shared", [Output])
if Output.op in s.outputs:
Out = Output
Out_L = s.cache_write(Out, "local")
else:
Out = outs[0].op.output(0)
s[Output].set_scope("local")
Out_L = Output
flag = util.get_const_int(Filter.shape[0])+util.get_const_int(Filter.shape[1])
if util.get_const_int(Filter.shape[3]) == 7:
conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L)
elif 128 < flag < 512:
conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag)
elif flag >= 512:
conv2d_14_256_256(s, temp_S, Filter_S, Out, Out_L)
else:
conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L)
def traverse(OP):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
......
......@@ -22,7 +22,6 @@ def schedule_elemwise(outs):
tvm.schedule.AutoInlineInjective(s)
x = outs[0]
num_dim = len(x.shape)
fused = s[x].fuse(*x.op.axis)
num_thread = 64
bx, tx = s[x].split(fused, factor=num_thread)
......
......@@ -55,6 +55,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
def test_conv2d_nchw():
verify_conv2d_nchw(1, 3, 224, 64, 7, 3, 2)
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0)
verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment