Commit edf09673 by Wuwei Lin Committed by Tianqi Chen

[TOPI] Add dp4a intrinsic to CUDA (#1707)

parent 49fb6e85
"""Tensor intrinsics on CUDA."""
#pylint: disable=invalid-name
import tvm
def dp4a(x_scope='local', y_scope='local', z_scope='local'):
"""
Int8 dot product reduced by every 4 elements using __dp4a
Parameters
----------
x_scope : str, optional
The storage scope of buffer for lhs
y_scope : str, optional
The storage scope of buffer for rhs
z_scope : str, optional
The storage scope of buffer for result
Returns
-------
intrin : TensorIntrin
The dp4a TensorIntrin that can be used in tensorizing schedule.
"""
n = 4 # dp4a requires operands packed by 4
x = tvm.placeholder((n,), name='x', dtype='int8')
y = tvm.placeholder((n,), name='y', dtype='int8')
k = tvm.reduce_axis((0, n), name='rc')
z = tvm.compute((1,), lambda i: tvm.sum(
x[k].astype('int32') * y[k].astype('int32'), axis=[k]))
def _intrin_func(ins, outs):
def _instr(index):
xx, yy = ins
zz = outs[0]
if index == 1:
return zz.vstore(0, 0)
ib = tvm.ir_builder.create()
vec_x = xx.vload(0, dtype='int8x4')
vec_y = yy.vload(0, dtype='int8x4')
prev_z = 0 if index == 0 else zz.vload(0)
new_z = tvm.call_pure_extern('int32', '__dp4a', vec_x, vec_y, prev_z)
ib.emit(zz.vstore(0, new_z))
return ib.get()
return _instr(0), _instr(1), _instr(2) # body, reset, update
with tvm.build_config(data_alignment=4, offset_factor=1) as cfg:
scopes = {x: x_scope, y: y_scope, z: z_scope}
binds = {t: tvm.decl_buffer(t.shape, t.dtype, t.op.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor,
scope=scopes[t]) for t in [x, y, z]}
return tvm.decl_tensor_intrin(z.op, _intrin_func, binds=binds)
......@@ -4,44 +4,12 @@ import sys
import numpy as np
import tvm
from tvm import autotvm
from topi.cuda.tensor_intrin import dp4a
DO_TUNING = True
PRETUNED_INDEX = 75333
def intrin_dot():
n = 4 # dp4a requires operands packed by 4
x = tvm.placeholder((n,), name='x', dtype='int8')
y = tvm.placeholder((n,), name='y', dtype='int8')
k = tvm.reduce_axis((0, n), name='k')
z = tvm.compute(
(1,), lambda _: tvm.sum(
x[k].astype('int32') * y[k].astype('int32'), axis=k))
def intrin_func(ins, outs):
xx, yy = ins
zz = outs[0]
ib = tvm.ir_builder.create()
dp4a = zz.vstore(0, tvm.call_pure_extern('int32', '__dp4a',
xx.vload(0, dtype='int8x4'),
yy.vload(0, dtype='int8x4'),
zz.vload(0)))
ib.emit(dp4a)
body = ib.get()
return body, zz.vstore(0, 0), body
with tvm.build_config(data_alignment=4, offset_factor=1) as cfg:
binds = {t: tvm.decl_buffer(t.shape, t.dtype, t.op.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor,
scope='local') for t in [x, y, z]}
return tvm.decl_tensor_intrin(z.op, intrin_func, binds=binds)
dot = intrin_dot()
intrin_dp4a = dp4a('local', 'local', 'local')
@autotvm.template
def gemm_int8(n, m, l):
......@@ -70,7 +38,7 @@ def gemm_int8(n, m, l):
ko, kt, ki = cfg['tile_k'].apply(s, CC, k)
s[CC].tensorize(ki, dot)
s[CC].tensorize(ki, intrin_dp4a)
block_x = tvm.thread_axis('blockIdx.x')
block_y = tvm.thread_axis('blockIdx.y')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment