Commit b25c15de by Leyuan Wang Committed by Tianqi Chen

intel graphics conv2d schedule fixed for input shapes (300*300) and (512 * 512) (#1709)

parent 54f5e74c
...@@ -49,7 +49,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -49,7 +49,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
stride = ast.literal_eval(attrs['strides']) stride = ast.literal_eval(attrs['strides'])
wkl = _get_workload(data, kernel, stride, padding, data.dtype) wkl = _get_workload(data, kernel, stride, padding, data.dtype)
oc_bn = 16 oc_bn = 1
kernel_shape = util.get_const_tuple(kernel.shape)
for oc_bn in range(16, 1, -1):
if kernel_shape[0] % oc_bn == 0:
break
new_attrs = {k: attrs[k] for k in attrs.keys()} new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs['kernel_layout'] = 'OIHW%do' % (oc_bn) new_attrs['kernel_layout'] = 'OIHW%do' % (oc_bn)
...@@ -148,9 +152,6 @@ def _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype='float16 ...@@ -148,9 +152,6 @@ def _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype='float16
out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1) out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1) out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1)
oshape = (batch, out_channel, out_height, out_width) oshape = (batch, out_channel, out_height, out_width)
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
temp = pad(data, pad_before, pad_after, name="pad_temp")
rc = tvm.reduce_axis((0, in_channel), name='rc') rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry') ry = tvm.reduce_axis((0, kernel_h), name='ry')
...@@ -190,6 +191,10 @@ def _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype='float16 ...@@ -190,6 +191,10 @@ def _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype='float16
if not out_width % block_w == 0: if not out_width % block_w == 0:
c_w = (out_width // block_w + 1) * block_w c_w = (out_width // block_w + 1) * block_w
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down + c_h - block_h, pad_right + c_w - block_w]
temp = pad(data, pad_before, pad_after, name="pad_temp")
cshape = (batch, out_channel // nv, c_h, c_w, nv) cshape = (batch, out_channel // nv, c_h, c_w, nv)
conv = tvm.compute( conv = tvm.compute(
...@@ -263,20 +268,11 @@ def _schedule_cl_spatialpack_NCHWc(s, op): ...@@ -263,20 +268,11 @@ def _schedule_cl_spatialpack_NCHWc(s, op):
s[conv_L].compute_at(s[conv], vci) s[conv_L].compute_at(s[conv], vci)
i, oc, h, w, vc = s[conv_L].op.axis i, oc, h, w, vc = s[conv_L].op.axis
rc, ry, rx = s[conv_L].op.reduce_axis rc, ry, rx = s[conv_L].op.reduce_axis
if in_channel == 2048:
rco, rci = s[conv_L].split(rc, nparts=128)
s[conv_L].unroll(rci)
s[conv_L].reorder(i, oc, rco, rci, ry, rx, vc, h, w)
s[temp_W].compute_at(s[conv_L], rco)
else:
s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w) s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w)
s[temp_W].compute_at(s[conv_L], rc) s[temp_W].compute_at(s[conv_L], rc)
if kernel.shape[3].value != 7: if kernel.shape[3].value != 7:
s[conv_L].unroll(ry) s[conv_L].unroll(ry)
s[conv_L].unroll(rx) s[conv_L].unroll(rx)
if kernel.shape[3].value != 7:
s[conv_L].unroll(ry)
s[conv_L].unroll(rx)
# schedule temp # schedule temp
_, ci, h, w = s[temp].op.axis _, ci, h, w = s[temp].op.axis
...@@ -396,9 +392,6 @@ def _decl_cl_spatialpack(data, kernel, stride, padding, layout, out_dtype='float ...@@ -396,9 +392,6 @@ def _decl_cl_spatialpack(data, kernel, stride, padding, layout, out_dtype='float
out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1) out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1) out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1)
oshape = (batch, out_channel, out_height, out_width) oshape = (batch, out_channel, out_height, out_width)
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
temp = pad(data, pad_before, pad_after, name="pad_temp")
rc = tvm.reduce_axis((0, in_channel), name='rc') rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry') ry = tvm.reduce_axis((0, kernel_h), name='ry')
...@@ -432,13 +425,21 @@ def _decl_cl_spatialpack(data, kernel, stride, padding, layout, out_dtype='float ...@@ -432,13 +425,21 @@ def _decl_cl_spatialpack(data, kernel, stride, padding, layout, out_dtype='float
c_h = out_height c_h = out_height
c_w = out_width c_w = out_width
if not out_width % block_w == 0:
c_w = (out_width // block_w + 1) * block_w
if not out_height % block_h == 0: if not out_height % block_h == 0:
c_h = (out_height // block_h + 1) * block_h c_h = (out_height // block_h + 1) * block_h
if not out_width % block_w == 0: pad_before = [0, 0, pad_top, pad_left]
c_w = (out_width // block_w + 1) * block_w pad_after = [0, 0, pad_down + c_h - block_h, pad_right + c_w - block_w]
temp = pad(data, pad_before, pad_after, name="pad_temp")
nv = 16 nv = 16
if not num_filter % nv == 0:
num_filter = (num_filter // nv + 1) * nv
out_channel = num_filter
cshape = (batch, out_channel // nv, c_h, c_w, nv) cshape = (batch, out_channel // nv, c_h, c_w, nv)
kvshape = (num_filter // nv, channel, kernel_h, kernel_w, nv) kvshape = (num_filter // nv, channel, kernel_h, kernel_w, nv)
...@@ -520,12 +521,6 @@ def _schedule_cl_spatialpack(s, op): ...@@ -520,12 +521,6 @@ def _schedule_cl_spatialpack(s, op):
s[conv_L].compute_at(s[conv], vci) s[conv_L].compute_at(s[conv], vci)
i, oc, h, w, vc = s[conv_L].op.axis i, oc, h, w, vc = s[conv_L].op.axis
rc, ry, rx = s[conv_L].op.reduce_axis rc, ry, rx = s[conv_L].op.reduce_axis
if in_channel == 2048:
rco, rci = s[conv_L].split(rc, nparts=128)
s[conv_L].unroll(rci)
s[conv_L].reorder(i, oc, rco, rci, ry, rx, vc, h, w)
s[temp_W].compute_at(s[conv_L], rco)
else:
s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w) s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w)
s[temp_W].compute_at(s[conv_L], rc) s[temp_W].compute_at(s[conv_L], rc)
if kernel.shape[3].value != 7: if kernel.shape[3].value != 7:
......
...@@ -161,6 +161,10 @@ def test_conv2d_nchw(): ...@@ -161,6 +161,10 @@ def test_conv2d_nchw():
verify_conv2d_nchw(1, 2048, 8, 384, 1, 1, 0) verify_conv2d_nchw(1, 2048, 8, 384, 1, 1, 0)
verify_conv2d_nchw(1, 2048, 8, 448, 1, 1, 0) verify_conv2d_nchw(1, 2048, 8, 448, 1, 1, 0)
verify_conv2d_nchw(1, 2048, 8, 192, 1, 1, 0) verify_conv2d_nchw(1, 2048, 8, 192, 1, 1, 0)
verify_conv2d_nchw(1, 1024, 19, 84, 3, 1, 1)
verify_conv2d_nchw(1, 2048, 10, 126, 3, 1, 1)
verify_conv2d_nchw(1, 512, 5, 126, 3, 1, 1)
verify_conv2d_nchw(1, 256, 3, 126, 3, 1, 1)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment