Commit b25c15de by Leyuan Wang Committed by Tianqi Chen

intel graphics conv2d schedule fixed for input shapes (300*300) and (512 * 512) (#1709)

parent 54f5e74c
......@@ -49,7 +49,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
stride = ast.literal_eval(attrs['strides'])
wkl = _get_workload(data, kernel, stride, padding, data.dtype)
oc_bn = 16
oc_bn = 1
kernel_shape = util.get_const_tuple(kernel.shape)
for oc_bn in range(16, 1, -1):
if kernel_shape[0] % oc_bn == 0:
break
new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs['kernel_layout'] = 'OIHW%do' % (oc_bn)
......@@ -148,9 +152,6 @@ def _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype='float16
out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1)
oshape = (batch, out_channel, out_height, out_width)
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
temp = pad(data, pad_before, pad_after, name="pad_temp")
rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
......@@ -190,6 +191,10 @@ def _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype='float16
if not out_width % block_w == 0:
c_w = (out_width // block_w + 1) * block_w
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down + c_h - block_h, pad_right + c_w - block_w]
temp = pad(data, pad_before, pad_after, name="pad_temp")
cshape = (batch, out_channel // nv, c_h, c_w, nv)
conv = tvm.compute(
......@@ -263,17 +268,8 @@ def _schedule_cl_spatialpack_NCHWc(s, op):
s[conv_L].compute_at(s[conv], vci)
i, oc, h, w, vc = s[conv_L].op.axis
rc, ry, rx = s[conv_L].op.reduce_axis
if in_channel == 2048:
rco, rci = s[conv_L].split(rc, nparts=128)
s[conv_L].unroll(rci)
s[conv_L].reorder(i, oc, rco, rci, ry, rx, vc, h, w)
s[temp_W].compute_at(s[conv_L], rco)
else:
s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w)
s[temp_W].compute_at(s[conv_L], rc)
if kernel.shape[3].value != 7:
s[conv_L].unroll(ry)
s[conv_L].unroll(rx)
s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w)
s[temp_W].compute_at(s[conv_L], rc)
if kernel.shape[3].value != 7:
s[conv_L].unroll(ry)
s[conv_L].unroll(rx)
......@@ -396,9 +392,6 @@ def _decl_cl_spatialpack(data, kernel, stride, padding, layout, out_dtype='float
out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1)
oshape = (batch, out_channel, out_height, out_width)
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
temp = pad(data, pad_before, pad_after, name="pad_temp")
rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
......@@ -432,13 +425,21 @@ def _decl_cl_spatialpack(data, kernel, stride, padding, layout, out_dtype='float
c_h = out_height
c_w = out_width
if not out_width % block_w == 0:
c_w = (out_width // block_w + 1) * block_w
if not out_height % block_h == 0:
c_h = (out_height // block_h + 1) * block_h
if not out_width % block_w == 0:
c_w = (out_width // block_w + 1) * block_w
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down + c_h - block_h, pad_right + c_w - block_w]
temp = pad(data, pad_before, pad_after, name="pad_temp")
nv = 16
if not num_filter % nv == 0:
num_filter = (num_filter // nv + 1) * nv
out_channel = num_filter
cshape = (batch, out_channel // nv, c_h, c_w, nv)
kvshape = (num_filter // nv, channel, kernel_h, kernel_w, nv)
......@@ -520,14 +521,8 @@ def _schedule_cl_spatialpack(s, op):
s[conv_L].compute_at(s[conv], vci)
i, oc, h, w, vc = s[conv_L].op.axis
rc, ry, rx = s[conv_L].op.reduce_axis
if in_channel == 2048:
rco, rci = s[conv_L].split(rc, nparts=128)
s[conv_L].unroll(rci)
s[conv_L].reorder(i, oc, rco, rci, ry, rx, vc, h, w)
s[temp_W].compute_at(s[conv_L], rco)
else:
s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w)
s[temp_W].compute_at(s[conv_L], rc)
s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w)
s[temp_W].compute_at(s[conv_L], rc)
if kernel.shape[3].value != 7:
s[conv_L].unroll(ry)
s[conv_L].unroll(rx)
......
......@@ -161,6 +161,10 @@ def test_conv2d_nchw():
verify_conv2d_nchw(1, 2048, 8, 384, 1, 1, 0)
verify_conv2d_nchw(1, 2048, 8, 448, 1, 1, 0)
verify_conv2d_nchw(1, 2048, 8, 192, 1, 1, 0)
verify_conv2d_nchw(1, 1024, 19, 84, 3, 1, 1)
verify_conv2d_nchw(1, 2048, 10, 126, 3, 1, 1)
verify_conv2d_nchw(1, 512, 5, 126, 3, 1, 1)
verify_conv2d_nchw(1, 256, 3, 126, 3, 1, 1)
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment