Commit afc693dc by Leyuan Wang Committed by Tianqi Chen

conv2d perf improved for conv2d_56_64_128, super resolution workloads added (#643)

* conv2d perf improved for conv2d_56_64_128, test name added to differentiate workloads

* fix lint error
parent a908b831
......@@ -95,20 +95,30 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
thread_yz = tvm.thread_axis((0, vthread_y), "vthread", name="vy")
i, oc, h, w = s[Out].op.axis
ow, iw = s[Out].split(w, factor=num_thread_x)
oh, ih = s[Out].split(h, factor=vthread_x)
factor = util.get_const_int(Out.shape[3])
ooc, ioc = s[Out].split(oc, factor=num_thread_y*vthread_y)
oioc, iioc = s[Out].split(ioc, nparts=vthread_y)
s[Out].reorder(i, ooc, oh, ow, oioc, ih, iioc, iw)
oh = s[Out].fuse(oh, ow)
s[Out].bind(iw, thread_x)
s[Out].bind(iioc, thread_y)
s[Out].bind(ih, thread_xz)
s[Out].bind(oioc, thread_yz)
s[Out].bind(oh, block_x)
s[Out].bind(ooc, block_y)
s[Out_L].compute_at(s[Out], iw)
if factor < num_thread_x*vthread_x:
oh, ih = s[Out].split(h, factor=num_thread_x*vthread_x//factor)
w = s[Out].fuse(ih, w)
ow, iw = s[Out].split(w, nparts=vthread_x)
s[Out].reorder(i, ooc, oh, oioc, ow, iioc, iw)
s[Out].bind(iw, thread_x)
s[Out].bind(ow, thread_xz)
s[Out].bind(oh, block_x)
s[Out_L].compute_at(s[Out], iw)
else:
ow, iw = s[Out].split(w, factor=num_thread_x)
oh, ih = s[Out].split(h, factor=vthread_x)
s[Out].reorder(i, ooc, oh, ow, oioc, ih, iioc, iw)
oh = s[Out].fuse(oh, ow)
s[Out].bind(iw, thread_x)
s[Out].bind(ih, thread_xz)
s[Out].bind(oh, block_x)
s[Out_L].compute_at(s[Out], iw)
# schedule Out_L local write
i, oc, h, w = s[Out_L].op.axis
......@@ -350,14 +360,14 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L):
if util.get_const_int(Filter.shape[0]) == 64:
opart2 = 8
ifactor = 16
sfactor = max(1, ofactor//(opart2*2))
sfactor = max(1, ofactor // (opart2*2))
spart = max(1, (wfactor + vthread-1) // vthread)
block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
block_z = tvm.thread_axis("blockIdx.z")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
thread_y = tvm.thread_axis((0, wfactor // vthread), "threadIdx.y")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
......
......@@ -6,7 +6,6 @@ import topi
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
in_height = in_width = in_size
......@@ -42,10 +41,10 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
with tvm.build_config(auto_unroll_max_step=128,
with tvm.build_config(auto_unroll_max_step=1400,
unroll_explicit=(device != "cuda")):
func1 = tvm.build(s1, [A, W, B], device)
func2 = tvm.build(s2, [A, W, C], device)
func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding))
func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding))
func1(a, w, b)
func2(a, w, c)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
......@@ -56,6 +55,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
def test_conv2d_nchw():
# ResNet18 worklaods
verify_conv2d_nchw(1, 3, 224, 64, 7, 3, 2)
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0)
......@@ -68,7 +68,13 @@ def test_conv2d_nchw():
verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1)
verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0)
verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
# Vgg16 workloads
verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1)
# Super resolution workloads
verify_conv2d_nchw(1, 1, 224, 64, 5, 1, 2)
verify_conv2d_nchw(1, 64, 224, 64, 3, 1, 1)
verify_conv2d_nchw(1, 64, 224, 32, 3, 1, 1)
verify_conv2d_nchw(1, 32, 224, 9, 3, 1, 1)
if __name__ == "__main__":
test_conv2d_nchw()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment