Commit 42dc24a3 by Wuwei Lin Committed by Tianqi Chen

[TOPI][CUDA] batched int8 conv2d (#1961)

parent 62a94c76
......@@ -9,7 +9,7 @@ from .tensor_intrin import dp4a
from ..nn.conv2d import conv2d_NCHWc_int8_prepacked
from ..nn.pad import pad
from ..nn.util import get_pad_tuple
from ..util import get_const_tuple, get_const_int, traverse_inline
from ..util import get_const_tuple, traverse_inline
def _conv2d_NCHWc_int8_arg_to_workload(data, kernel, stride, padding, out_dtype):
......@@ -183,7 +183,7 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed):
_schedule_injective(packed_data.op, s)
_schedule_injective(packed_kernel.op, s)
else:
kernel = packed_data
kernel = packed_kernel
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
......@@ -191,7 +191,6 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed):
if pad_data != packed_data:
s[pad_data].compute_inline()
batch = get_const_int(packed_data.shape[0])
if isinstance(stride, int):
stride_h = stride_w = stride
else:
......@@ -210,34 +209,51 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed):
# tile and bind spatial axes
n, f, y, x, c = s[output].op.axis
cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
# this is the scope to attach global config inside this kernel
kernel_scope, n = s[output].split(n, nparts=1)
bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
# this is the scope to attach global config inside this kernel
kernel_scope, n = s[output].split(n, nparts=1)
max_block_z = 128
if batch > max_block_z:
_, n = s[output].split(n, factor=max_block_z)
s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
fused_byx = s[output].fuse(by, bx)
s[output].bind(n, tvm.thread_axis("blockIdx.z"))
s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi)
s[output].bind(bn, tvm.thread_axis("blockIdx.z"))
s[output].bind(bf, tvm.thread_axis("blockIdx.y"))
s[output].bind(fused_byx, tvm.thread_axis("blockIdx.x"))
s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x"))
s[output].bind(vn, tvm.thread_axis("vthread"))
s[output].bind(vf, tvm.thread_axis("vthread"))
s[output].bind(vy, tvm.thread_axis("vthread"))
s[output].bind(vx, tvm.thread_axis("vthread"))
s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf
if cfg["fuse_yx"].val:
s[output].bind(tn, tvm.thread_axis("threadIdx.z"))
s[output].bind(tf, tvm.thread_axis("threadIdx.y"))
tyx = s[output].fuse(ty, tx)
s[output].bind(tyx, tvm.thread_axis("threadIdx.x"))
s[conv].compute_at(s[output], tyx)
# number of threads
n_tz = cfg["tile_n"].size[2]
n_ty = cfg["tile_f"].size[2]
n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
else:
s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
s[conv].compute_at(s[output], tx)
# number of threads
n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
n_ty = cfg["tile_y"].size[2]
n_tx = cfg["tile_x"].size[2]
# tile and bind reduction axes
n, f, y, x, c = s[conv].op.axis
......@@ -272,9 +288,9 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed):
fused = s[load].fuse(n, f, y, x, oc_chunk)
s[load].vectorize(c)
fused, tx = s[load].split(fused, factor=cfg["tile_x"].size[2])
fused, ty = s[load].split(fused, factor=cfg["tile_y"].size[2])
fused, tz = s[load].split(fused, factor=cfg["tile_f"].size[2])
fused, tx = s[load].split(fused, factor=n_tx)
fused, ty = s[load].split(fused, factor=n_ty)
fused, tz = s[load].split(fused, factor=n_tz)
s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
......
......@@ -172,6 +172,10 @@ def test_conv2d_nchw():
verify_conv2d_NCHWc_int8(1, 2048, 8, 192, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 1024, 19, 84, 3, 1, 1)
# batch > 1
verify_conv2d_NCHWc_int8(7, 32, 149, 32, 3, 1, 0)
verify_conv2d_NCHWc_int8(8, 32, 149, 32, 3, 1, 0)
verify_conv2d_NCHWc_int8(32, 32, 149, 32, 3, 1, 0)
if __name__ == "__main__":
test_conv2d_nchw()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment