Commit 8c5078c9 by Leyuan Wang Committed by Tianqi Chen

Fixed bugs for conv2d (#1465)

parent b0ef376a
......@@ -13,14 +13,16 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
hfactor = 2
if flag >= 96:
hfactor = 4
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
ow_size = util.get_const_int(Out.shape[3])
num_thread = ow_size * hfactor
num_thread = min(max_threads, ow_size * hfactor)
vthread = ofactor
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
i, oc, h, w = s[Out].op.axis
if ow_size * hfactor == num_thread:
ooc, ioc = s[Out].split(oc, factor=vthread)
oh, ih = s[Out].split(h, factor=hfactor)
s[Out].reorder(ooc, oh, ioc, ih, w)
......@@ -30,6 +32,10 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
s[Out].bind(w, thread_x)
s[Out].bind(ioc, thread_xz)
s[Out].bind(oc, block_x)
else:
ow, w = s[Out].split(w, factor=num_thread)
s[Out].bind(w, thread_x)
s[Out].bind(ow, block_x)
s[Out_L].compute_at(s[Out], w)
......@@ -40,7 +46,7 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
s[temp_S].compute_at(s[Out_L], ic)
s[Filter_S].compute_at(s[Out_L], w)
num_thread1 = tvm.target.current_target(allow_none=False).max_num_threads
num_thread1 = max_threads
thread_xx = tvm.thread_axis((0, num_thread1), "threadIdx.x")
block_xx = tvm.thread_axis("blockIdx.x")
......@@ -59,6 +65,7 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
h = s[temp_S].fuse(h, ow)
_, tx = s[temp_S].split(h, factor=num_thread)
s[temp_S].bind(tx, thread_x)
if num_thread < max_threads:
s[temp_S].vectorize(iw)
#schedule Filter_S shared mem load
......@@ -250,12 +257,13 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
if util.get_const_int(Filter.shape[0]) + util.get_const_int(Filter.shape[1]) <= 768:
# scheduler params
vthread_x = util.get_const_int(Out.shape[3])
num_thread_x = 64
ofactor = 8
if util.get_const_int(Filter.shape[3]) == 1:
if util.get_const_int(Filter.shape[3]) == 1 and vthread_x * 5 <= max_threads:
ofactor = 64
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
......@@ -295,9 +303,9 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
else:
# scheduler params
vthread_x = util.get_const_int(Out.shape[2])
vthread_x = min(8, util.get_const_int(Out.shape[2]))
num_thread_x = 16
num_thread_y = util.get_const_int(Out.shape[3])
num_thread_y = min(max_threads // num_thread_x, util.get_const_int(Out.shape[3]))
ofactor = 8
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
......@@ -305,11 +313,13 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
thread_xz = tvm.thread_axis((0, vthread_x), "vthread", name="vx")
i, oc, h, w = s[Out].op.axis
ow, iw = s[Out].split(w, factor=num_thread_y)
oh, ih = s[Out].split(h, factor=vthread_x)
ooc, ioc = s[Out].split(oc, factor=num_thread_x)
s[Out].reorder(i, ooc, h, w, ioc)
s[Out].reorder(i, ooc, oh, ih, ow, iw, ioc)
s[Out].bind(ioc, thread_x)
s[Out].bind(w, thread_y)
s[Out].bind(h, thread_xz)
s[Out].bind(iw, thread_y)
s[Out].bind(ih, thread_xz)
s[Out].bind(ooc, block_x)
s[Out_L].compute_at(s[Out], ioc)
......@@ -323,7 +333,7 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], oic)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
num_thread = max_threads
thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_xx = tvm.thread_axis("blockIdx.x")
......
# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return
# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return, too-many-arguments, too-many-locals, too-many-statements, no-member, too-many-branches
"""conv2d schedule on Intel Graphics"""
from __future__ import absolute_import as _abs
......@@ -57,7 +57,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
@conv2d_NCHWc.register(["intel_graphics"])
def _decl_conv2d(data, kernel, num_filter, kernel_size, stride, padding, out_dtype='float32'):
def _decl_conv2d(data, kernel, num_filter, kernel_size, stride, padding, layout,\
out_layout, out_dtype='float32'):
"""Conv2D operator for Intel Graphics backend.
Parameters
......@@ -96,7 +97,7 @@ def _decl_conv2d(data, kernel, num_filter, kernel_size, stride, padding, out_dty
return _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype)
@generic.schedule_conv2d_NCHWc.register(["intel_graphics"])
def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, outs):
def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, layout, out_layout, outs):
"""Schedule for conv2d_nchw for Intel Graphics
Parameters
......
......@@ -74,7 +74,7 @@ def test_conv2d_nchw():
verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1)
verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0)
verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
# ResNet 50 workloads
# ResNet50 workloads
verify_conv2d_nchw(1, 64, 56, 256, 1, 1, 0)
verify_conv2d_nchw(1, 256, 56, 64, 1, 1, 0)
verify_conv2d_nchw(1, 256, 56, 128, 1, 2, 0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment