Commit bec08fec by Cody Hao Yu Committed by Yuwei Hu

[TOPI] Add proper scheduling for dense on CUDA (#3923)

* add proper scheduling for dense on CUDA

* add fallback config and fix unit test

* fix corner cases

* refactoring

* fix bias and add testcase

* let fusion happen
parent 1d00c083
......@@ -17,8 +17,10 @@
# pylint: disable=invalid-name, unused-variable
"""Schedule for dense operator"""
from __future__ import absolute_import as _abs
import logging
import tvm
import tvm.autotvm as autotvm
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cublas
from .tensor_intrin import dp4a
from ..nn.dense import dense, dense_default
......@@ -26,6 +28,8 @@ from .. import tag
from .. import generic
from ..util import traverse_inline, get_const_tuple
logger = logging.getLogger('topi')
@autotvm.register_topi_compute(dense, ["cuda", "gpu"], "direct")
def dense_cuda(cfg, data, weight, bias=None, out_dtype=None):
......@@ -85,31 +89,23 @@ def schedule_dense(cfg, outs):
"""
# pylint: disable=unused-argument
target = tvm.target.current_target()
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
if target.target_name == "cuda" and "cublas" in target.libs:
A, B = outs[0].op.input_tensors
b, i = get_const_tuple(A.shape)
o, _ = get_const_tuple(B.shape)
cfg.add_flop(2 * i * b * o)
return generic.schedule_extern(outs)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(Dense):
num_thread = 64
k = Dense.op.reduce_axis[0]
ko, kf = s[Dense].split(k, factor=num_thread)
DenseF = s.rfactor(Dense, kf)
if Dense.op in s.outputs:
Out = Dense
else:
Out = outs[0].op.output(0)
s[Dense].compute_at(s[Out], s[Out].op.axis[1])
s[Out].bind(s[Out].op.axis[0], tvm.thread_axis("blockIdx.y"))
s[Out].bind(s[Out].op.axis[1], tvm.thread_axis("blockIdx.x"))
tx = s[Dense].op.reduce_axis[0]
thread_x = tvm.thread_axis("threadIdx.x")
s[Dense].bind(tx, thread_x)
s[DenseF].compute_at(s[Dense], tx)
s[Dense].set_store_predicate(thread_x.var.equal(0))
s[Out].set_store_predicate(thread_x.var.equal(0))
def _schedule(C):
A, _ = C.op.input_tensors
batch, _ = get_const_tuple(A.shape)
if batch < 32:
return schedule_dense_small_batch(cfg, s, C)
return schedule_dense_large_batch(cfg, s, C)
scheduled_ops = []
......@@ -135,6 +131,130 @@ def schedule_dense(cfg, outs):
return s
def schedule_dense_small_batch(cfg, s, C):
"""Schedule float32/64 dense with small batch size"""
A, _ = C.op.input_tensors
_, in_dim = get_const_tuple(A.shape)
cfg.define_split('tile_k', in_dim, num_outputs=2)
if cfg.is_fallback:
cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64])
_, kf = cfg['tile_k'].apply(s, C, C.op.reduce_axis[0])
CF = s.rfactor(C, kf)
if C.op in s.outputs:
Out = C
else:
Out = s.outputs[0].output(0)
s[C].compute_at(s[Out], s[Out].op.axis[1])
s[Out].bind(s[Out].op.axis[0], tvm.thread_axis("blockIdx.y"))
s[Out].bind(s[Out].op.axis[1], tvm.thread_axis("blockIdx.x"))
tx = s[C].op.reduce_axis[0]
thread_x = tvm.thread_axis("threadIdx.x")
s[C].bind(tx, thread_x)
s[CF].compute_at(s[C], tx)
s[C].set_store_predicate(thread_x.var.equal(0))
s[Out].set_store_predicate(thread_x.var.equal(0))
def schedule_dense_large_batch(cfg, s, C):
"""Schedule float32/64 dense with large batch size"""
A, B = C.op.input_tensors
batch, in_dim = get_const_tuple(A.shape)
out_dim, _ = get_const_tuple(B.shape)
k = C.op.reduce_axis[0]
# create tuning space
try:
block_cand = [64, 128]
vthread_cand = [2**x for x in range(1, 7)]
n_thread_cand = [2**x for x in range(3, 7)]
cfg.define_split('tile_x', batch, num_outputs=4,
filter=lambda x: (x.size[1] in vthread_cand and
x.size[2] in n_thread_cand and
(x.size[1] * x.size[2] * x.size[3]) in block_cand))
cfg.define_split('tile_y', out_dim, num_outputs=4,
filter=lambda x: (x.size[1] in vthread_cand and
x.size[2] in n_thread_cand and
(x.size[1] * x.size[2] * x.size[3]) in block_cand))
cfg.define_split('tile_k', in_dim, num_outputs=3, filter=lambda x: x.size[0] > 2)
except IndexError:
# Index error happens when no entities left after filtering, which was designed
# to prune tuning space for better search efficiency.
logger.debug(
'Tuning space was created without pruning due to unfit shapes')
cfg.define_split('tile_x', batch, num_outputs=4)
cfg.define_split('tile_y', out_dim, num_outputs=4)
cfg.define_split('tile_k', in_dim, num_outputs=3)
if cfg.is_fallback:
if batch > 1:
cfg['tile_x'] = SplitEntity([-1, 2, 16, 2])
else:
cfg['tile_x'] = SplitEntity([1, 1, 1, 1])
if out_dim > 1:
cfg['tile_y'] = SplitEntity([-1, 2, 16, 2])
else:
cfg['tile_y'] = SplitEntity([1, 1, 1, 1])
if in_dim > 8:
cfg['tile_k'] = SplitEntity([-1, 8, 1])
else:
cfg['tile_k'] = SplitEntity([-1, 1, 1])
# Explicit memory access
AA = s.cache_read(A, "shared", [C])
BB = s.cache_read(B, "shared", [C])
AL = s.cache_read(AA, "local", [C])
BL = s.cache_read(BB, "local", [C])
CC = s.cache_write(C, "local")
# Deal with op fusion
if C.op not in s.outputs:
s[C].compute_inline()
C = s.outputs[0].output(0)
# Split and reorder computation
bx, txz, tx, xi = cfg['tile_x'].apply(s, C, C.op.axis[0])
by, tyz, ty, yi = cfg['tile_y'].apply(s, C, C.op.axis[1])
s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
s[CC].compute_at(s[C], tx)
# Binding
s[C].bind(by, tvm.thread_axis("blockIdx.y"))
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tyz, tvm.thread_axis("vthread"))
s[C].bind(txz, tvm.thread_axis("vthread"))
s[C].bind(ty, tvm.thread_axis("threadIdx.y"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
# Split reduction
yo, xo = CC.op.axis
ko, kt, ki = cfg['tile_k'].apply(s, CC, k)
s[CC].reorder(ko, kt, ki, yo, xo)
s[AA].compute_at(s[CC], ko)
s[BB].compute_at(s[CC], ko)
s[CC].unroll(kt)
s[AL].compute_at(s[CC], kt)
s[BL].compute_at(s[CC], kt)
# Schedule for A's shared memory load
num_thread_x = cfg['tile_x'].size[2]
ty, _ = s[AA].split(s[AA].op.axis[0], nparts=num_thread_x)
_, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread_x * 4)
tx, xi = s[AA].split(xi, nparts=num_thread_x)
s[AA].bind(ty, tvm.thread_axis("threadIdx.y"))
s[AA].bind(tx, tvm.thread_axis("threadIdx.x"))
s[AA].double_buffer()
# Schedule for B' shared memory load
num_thread_y = cfg['tile_y'].size[2]
ty, _ = s[BB].split(s[BB].op.axis[0], nparts=num_thread_y)
_, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread_y * 4)
tx, xi = s[BB].split(xi, nparts=num_thread_y)
s[BB].bind(ty, tvm.thread_axis("threadIdx.y"))
s[BB].bind(tx, tvm.thread_axis("threadIdx.x"))
s[BB].double_buffer()
@autotvm.register_topi_compute(dense, ['cuda'], ['int8'])
def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator for int8 on CUDA"""
......
......@@ -117,8 +117,9 @@ def verify_dense_int8(batch, in_dim, out_dim, use_bias=True):
def test_dense():
verify_dense(1, 1024, 1000, use_bias=True)
verify_dense(1, 1024, 1000, use_bias=False)
verify_dense(2, 1024, 1000, use_bias=True)
verify_dense(128, 1024, 1000, use_bias=False)
verify_dense(128, 1024, 1000, use_bias=True)
def test_dense_int8():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment