Commit dd5722d3 by Wuwei Lin Committed by Tianqi Chen

Fix conv2d int8 schedule on CUDA (#2074)

parent 644a15c3
...@@ -138,10 +138,6 @@ _dp4a = dp4a('shared', 'shared', 'local') ...@@ -138,10 +138,6 @@ _dp4a = dp4a('shared', 'shared', 'local')
def schedule_conv2d_NCHWc_int8(cfg, s, output): def schedule_conv2d_NCHWc_int8(cfg, s, output):
"""Schedule conv2d int8 NCHWc template""" """Schedule conv2d int8 NCHWc template"""
workload = output.op.attrs["workload"]
stride = workload[3]
conv = output.op.input_tensors[0] conv = output.op.input_tensors[0]
packed_data, packed_kernel = conv.op.input_tensors packed_data, packed_kernel = conv.op.input_tensors
...@@ -166,11 +162,6 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output): ...@@ -166,11 +162,6 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output):
if pad_data != packed_data: if pad_data != packed_data:
s[pad_data].compute_inline() s[pad_data].compute_inline()
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
# create cache stage # create cache stage
AA = s.cache_read(pad_data, 'shared', [conv]) AA = s.cache_read(pad_data, 'shared', [conv])
WW = s.cache_read(packed_kernel, 'shared', [conv]) WW = s.cache_read(packed_kernel, 'shared', [conv])
...@@ -250,18 +241,11 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output): ...@@ -250,18 +241,11 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output):
# cooperative fetching # cooperative fetching
for load in [AA, WW]: for load in [AA, WW]:
if load == AA: c = s[load].op.axis[-1]
n, f, y, x, c = s[load].op.axis c_outer, c = s[load].split(c, factor=4)
if pad_data == packed_data and stride_h == 1 and stride_w == 1:
s[load].vectorize(c)
fused = s[load].fuse(n, f, y, x)
else:
c, _ = s[load].split(c, factor=4)
fused = s[load].fuse(n, f, y, x, c)
else:
n, f, y, x, oc_chunk, c = s[load].op.axis
fused = s[load].fuse(n, f, y, x, oc_chunk)
s[load].vectorize(c) s[load].vectorize(c)
fused = s[load].op.axis[:-1] + [c_outer]
fused = s[load].fuse(*fused)
fused, tx = s[load].split(fused, factor=n_tx) fused, tx = s[load].split(fused, factor=n_tx)
fused, ty = s[load].split(fused, factor=n_ty) fused, ty = s[load].split(fused, factor=n_ty)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment