Commit dabde40f by Siyuan Feng Committed by Leyuan Wang

[Perf] Enhance cudnn and cublas backend and enable TensorCore (#4353)

* add half and mix precision support to cublas backend

* add TensorCore support in CuDNN

* enhance CuDNN support

* address comments and fix lint

* fix

* add fp16 test
parent fbb2a356
......@@ -39,6 +39,14 @@ namespace runtime {
inline bool TypeMatch(TVMType t, int code, int bits, int lanes = 1) {
return t.code == code && t.bits == bits && t.lanes == lanes;
}
/*!
* \brief Check whether two types are equal .
* \param lhs The left operand.
* \param rhs The right operand.
*/
inline bool TypeEqual(TVMType lhs, TVMType rhs) {
return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
}
} // namespace runtime
} // namespace tvm
// Forward declare the intrinsic id we need
......
......@@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs
from .. import api as _api
from .. import intrin as _intrin
def matmul(lhs, rhs, transa=False, transb=False):
def matmul(lhs, rhs, transa=False, transb=False, dtype=None):
"""Create an extern op that compute matrix mult of A and rhs with cuBLAS
Parameters
......@@ -41,13 +41,14 @@ def matmul(lhs, rhs, transa=False, transb=False):
"""
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
dtype = dtype if dtype is not None else lhs.dtype
return _api.extern(
(n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cublas.matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")
ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
def batch_matmul(lhs, rhs, transa=False, transb=False):
def batch_matmul(lhs, rhs, transa=False, transb=False, dtype=None):
"""Create an extern op that compute batch matrix mult of A and rhs with cuBLAS
Parameters
......@@ -69,8 +70,9 @@ def batch_matmul(lhs, rhs, transa=False, transb=False):
b = lhs.shape[0]
n = lhs.shape[2] if transa else lhs.shape[1]
m = rhs.shape[1] if transb else rhs.shape[2]
dtype = dtype if dtype is not None else lhs.dtype
return _api.extern(
(b, n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cublas.batch_matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")
ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
......@@ -22,7 +22,6 @@ from .. import api as _api
from .. import intrin as _intrin
from .. import get_global_func as _get_global_func
# algos can be read from cudnn.h
_FWD_ALGOS = [
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM",
......@@ -67,6 +66,7 @@ _ALGO_TYPE = [
"bwd_data"
]
def algo_to_index(algo_type, algo_name):
"""Return a index represents the algorithm, which can be used in
calling CuDNN function
......@@ -172,6 +172,7 @@ def conv2d_w_shape(in_channel,
"""
return [out_channel, in_channel, filter_h, filter_w]
def conv2d_output_shape(tensor_format,
pad_h,
pad_w,
......@@ -180,7 +181,9 @@ def conv2d_output_shape(tensor_format,
dilation_h,
dilation_w,
x_shape,
w_shape):
w_shape,
data_dtype,
conv_dtype):
"""Get output shape of 2D convolution
Paramters
......@@ -232,7 +235,9 @@ def conv2d_output_shape(tensor_format,
w_shape[1].value,
w_shape[2].value,
w_shape[3].value,
_get_np_int32_array_handle(oshape))
_get_np_int32_array_handle(oshape),
data_dtype,
conv_dtype)
return list(oshape)
......@@ -245,7 +250,9 @@ def conv2d_find_algo(tensor_format,
dilation_w,
x_shape,
w_shape,
y_shape):
y_shape,
data_dtype,
conv_dtype):
"""Choose the best algo for the given input.
Paramters
......@@ -272,6 +279,10 @@ def conv2d_find_algo(tensor_format,
weight shape
y_shape: list
output shape
data_dtype: str
data type
conv_dtype: str
convolution type
Returns
-------
......@@ -297,7 +308,9 @@ def conv2d_find_algo(tensor_format,
int(y_shape[0]),
int(y_shape[1]),
int(y_shape[2]),
int(y_shape[3]))
int(y_shape[3]),
data_dtype,
conv_dtype)
def conv2d_forward(x,
......@@ -310,7 +323,8 @@ def conv2d_forward(x,
dilation_w=1,
conv_mode=1,
tensor_format=0,
algo=-1):
algo=-1,
conv_dtype=None):
"""Create an extern op that compute 2D convolution with CuDNN
Parameters
......@@ -341,12 +355,16 @@ def conv2d_forward(x,
algo: int
Forward algorithm, get index from ```algo_to_index``` function
if algo == -1, the best algo will be chosen by CUDNN
conv_dtype: str
convolution type
Returns
-------
y: Tensor
The result tensor
"""
conv_dtype = x.dtype if conv_dtype is None else conv_dtype
oshape = conv2d_output_shape(tensor_format,
pad_h,
pad_w,
......@@ -355,18 +373,28 @@ def conv2d_forward(x,
dilation_h,
dilation_w,
list(x.shape),
list(w.shape))
list(w.shape),
x.dtype,
conv_dtype)
if algo == -1:
algo = conv2d_find_algo(tensor_format,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
list(x.shape),
list(w.shape),
oshape)
# For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when
# using INT8 data type, CuDNN will crash down.
# On the other hand, CuDNN only support IMPLICIT_​PRECOMP_GEMM at NHWC format
if tensor_format == 1 and conv_dtype == "int32":
algo = 1
else:
algo = conv2d_find_algo(tensor_format,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
list(x.shape),
list(w.shape),
oshape,
x.dtype,
conv_dtype)
return _api.extern(
oshape, [x, w],
......@@ -383,4 +411,5 @@ def conv2d_forward(x,
dilation_w,
ins[0],
ins[1],
outs[0]), name="y")
outs[0],
conv_dtype), name="y")
......@@ -93,13 +93,13 @@ inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) {
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
op(transb, transa, ColumnCount(B, transb), RowCount(A, transa),
ColumnCount(A, transa), static_cast<float>(alpha),
ColumnCount(A, transa), static_cast<typename TGemmOp::TDatatype>(alpha),
reinterpret_cast<typename TGemmOp::TDatatype *>(
static_cast<char *>(B->data) + B->byte_offset),
ColumnStride(B),
reinterpret_cast<typename TGemmOp::TDatatype *>(
static_cast<char *>(A->data) + A->byte_offset),
ColumnStride(A), static_cast<float>(beta),
ColumnStride(A), static_cast<typename TGemmOp::TDatatype>(beta),
reinterpret_cast<typename TGemmOp::TDatatype *>(
static_cast<char *>(C->data) + C->byte_offset),
ColumnStride(C));
......@@ -170,9 +170,10 @@ inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) {
DType *C_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
static_cast<char *>(C->data) + C->byte_offset);
op(batch_size, transb, transa, ColumnCount3D(B, transb),
RowCount3D(A, transa), ColumnCount3D(A, transa), static_cast<float>(alpha),
RowCount3D(A, transa), ColumnCount3D(A, transa),
static_cast<typename TBatchGemmOp::TDatatype>(alpha),
B_data, B_size, ColumnStride3D(B), A_data, A_size, ColumnStride3D(A),
static_cast<float>(beta), C_data, C_size, ColumnStride3D(C));
static_cast<typename TBatchGemmOp::TDatatype>(beta), C_data, C_size, ColumnStride3D(C));
}
} // namespace contrib
......
......@@ -25,7 +25,7 @@
#define TVM_RUNTIME_CONTRIB_CUBLAS_CUBLAS_UTILS_H_
#include <dmlc/logging.h>
#include <dlpack/dlpack.h>
#include <cublas_v2.h>
namespace tvm {
......@@ -62,7 +62,27 @@ struct CuBlasThreadEntry {
static CuBlasThreadEntry* ThreadLocal();
}; // CuBlasThreadEntry
inline cudaDataType_t GetCudaDataType(DLDataType type) {
if (type.code == kDLInt) {
switch (type.bits) {
case 8: return CUDA_R_8I;
case 32: return CUDA_R_32I;
}
} else if (type.code == kDLUInt) {
switch (type.bits) {
case 8: return CUDA_R_8U;
case 32: return CUDA_R_32U;
}
} else if (type.code == kDLFloat) {
switch (type.bits) {
case 16: return CUDA_R_16F;
case 32: return CUDA_R_32F;
case 64: return CUDA_R_64F;
}
}
LOG(FATAL) << "Unsupported cuda type";
return CUDA_R_16F;
}
} // namespace contrib
} // namespace tvm
......
......@@ -42,9 +42,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
int stride_w = args[6];
int dilation_h = args[7];
int dilation_w = args[8];
DLTensor *x = args[9];
DLTensor *w = args[10];
DLTensor *y = args[11];
DLTensor* x = args[9];
DLTensor* w = args[10];
DLTensor* y = args[11];
std::string conv_dtype = args[12];
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
// Set Mode
entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
......@@ -55,7 +57,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
// Set Ctx
entry_ptr->conv_entry.ctx = x->ctx;
// Set Data Type
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(conv_dtype));
cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
// Set Desc
CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc,
pad_h,
......@@ -68,8 +71,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
entry_ptr->conv_entry.data_type));
// Set Filter
CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc,
entry_ptr->conv_entry.data_type,
CUDNN_TENSOR_NCHW,
data_type,
entry_ptr->conv_entry.tensor_format,
static_cast<int>(w->shape[0]),
static_cast<int>(w->shape[1]),
static_cast<int>(w->shape[2]),
......@@ -77,7 +80,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
// Set Input
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.tensor_format,
entry_ptr->conv_entry.data_type,
data_type,
static_cast<int>(x->shape[0]),
static_cast<int>(x->shape[1]),
static_cast<int>(x->shape[2]),
......@@ -85,11 +88,15 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
// Set Output
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc,
entry_ptr->conv_entry.tensor_format,
entry_ptr->conv_entry.data_type,
data_type,
static_cast<int>(y->shape[0]),
static_cast<int>(y->shape[1]),
static_cast<int>(y->shape[2]),
static_cast<int>(y->shape[3])));
if (cudnnGetVersion() > 7000) {
CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH))
}
// Set workspace
size_t workspace_size = 0;
CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(entry_ptr->handle,
......@@ -135,6 +142,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.output_shape")
int w_dim2 = args[13];
int w_dim3 = args[14];
void *out_shape = args[15];
std::string data_dtype = args[16];
std::string conv_dtype = args[17];
// Set Data Type
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(conv_dtype));
cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(data_dtype));
// Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// conv desc
......@@ -150,15 +162,15 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.output_shape")
// input desc
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.tensor_format,
CUDNN_DATA_FLOAT,
data_type,
x_dim0,
x_dim1,
x_dim2,
x_dim3));
// filter desc
CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc,
CUDNN_DATA_FLOAT,
CUDNN_TENSOR_NCHW,
data_type,
entry_ptr->conv_entry.tensor_format,
w_dim0,
w_dim1,
w_dim2,
......@@ -196,7 +208,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.find_algo")
int y_dim1 = args[16];
int y_dim2 = args[17];
int y_dim3 = args[18];
std::string data_dtype = args[19];
std::string conv_dtype = args[20];
// Set Data Type
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(conv_dtype));
cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(data_dtype));
// Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// conv desc
......@@ -212,15 +229,15 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.find_algo")
// input desc
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.tensor_format,
CUDNN_DATA_FLOAT,
data_type,
x_dim0,
x_dim1,
x_dim2,
x_dim3));
// filter desc
CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc,
CUDNN_DATA_FLOAT,
CUDNN_TENSOR_NCHW,
data_type,
entry_ptr->conv_entry.tensor_format,
w_dim0,
w_dim1,
w_dim2,
......@@ -229,11 +246,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.find_algo")
// output desc
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc,
entry_ptr->conv_entry.tensor_format,
entry_ptr->conv_entry.data_type,
data_type,
y_dim0,
y_dim1,
y_dim2,
y_dim3));
if (cudnnGetVersion() > 7000) {
CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH))
}
int returned_algo_count = 0;
cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT];
......
......@@ -18,13 +18,13 @@ import tvm
import numpy as np
from tvm.contrib import cublas
def test_matmul_add():
def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5):
n = 1024
l = 128
m = 235
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((l, m), name='B')
C = cublas.matmul(A, B)
m = 236
A = tvm.placeholder((n, l), name='A', dtype=in_dtype)
B = tvm.placeholder((l, m), name='B', dtype=in_dtype)
C = cublas.matmul(A, B, dtype=out_dtype)
s = tvm.create_schedule(C.op)
def verify(target="cuda"):
......@@ -36,22 +36,22 @@ def test_matmul_add():
return
ctx = tvm.gpu(0)
f = tvm.build(s, [A, B, C], target)
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
a = tvm.nd.array(np.random.uniform(0, 128, size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(0, 128, size=(l, m)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
f(a, b, c)
tvm.testing.assert_allclose(
c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
c.asnumpy(), np.dot(a.asnumpy().astype(C.dtype), b.asnumpy().astype(C.dtype)), rtol=rtol)
verify()
def test_batch_matmul():
def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5):
j = 16
n = 1024
l = 128
m = 235
A = tvm.placeholder((j, n, l), name='A')
B = tvm.placeholder((j, l, m), name='B')
C = cublas.batch_matmul(A, B)
m = 236
A = tvm.placeholder((j, n, l), name='A', dtype=in_dtype)
B = tvm.placeholder((j, l, m), name='B', dtype=in_dtype)
C = cublas.batch_matmul(A, B, dtype=out_dtype)
s = tvm.create_schedule(C.op)
def verify(target="cuda"):
......@@ -68,10 +68,22 @@ def test_batch_matmul():
c = tvm.nd.array(np.zeros((j, n, m), dtype=C.dtype), ctx)
f(a, b, c)
tvm.testing.assert_allclose(
c.asnumpy(), np.matmul(a.asnumpy(), b.asnumpy()), rtol=1e-5)
c.asnumpy(), np.matmul(a.asnumpy().astype(C.dtype),
b.asnumpy().astype(C.dtype)).astype(C.dtype), rtol=rtol)
verify()
def test_matmul_add():
verify_matmul_add('float', 'float')
verify_matmul_add('float16', 'float')
verify_matmul_add('float16', 'float16', rtol=1e-2)
verify_matmul_add('int8', 'int32')
def test_batch_matmul():
verify_batch_matmul('float', 'float')
verify_batch_matmul('float16', 'float')
verify_batch_matmul('float16', 'float16', rtol=1e-2)
if __name__ == "__main__":
test_matmul_add()
test_batch_matmul()
......@@ -19,8 +19,8 @@ from tvm.contrib import cudnn
import numpy as np
def test_conv2d():
in_channel = 3
def verify_conv2d(data_dtype, conv_dtype, tensor_format=0):
in_channel = 4
out_channel = 32
filter_h = 3
filter_w = 3
......@@ -30,21 +30,25 @@ def test_conv2d():
stride_w = 1
dilation_h = 1
dilation_w = 1
batch = 3
height = 32
weight = 32
xshape = [4, 3, 32, 32]
if not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled...")
return
if not tvm.get_global_func("tvm.contrib.cudnn.conv2d.output_shape", True):
print("skip because cudnn is not enabled...")
return
xshape = [batch, in_channel, height, weight]
wshape = cudnn.conv2d_w_shape(in_channel,
out_channel,
filter_h,
filter_w)
out_channel,
filter_h,
filter_w)
X = tvm.placeholder(xshape, name='X')
W = tvm.placeholder(wshape, name='W')
X = tvm.placeholder(xshape, name='X', dtype=data_dtype)
W = tvm.placeholder(wshape, name='W', dtype=data_dtype)
Y = cudnn.conv2d_forward(X,
W,
stride_h,
......@@ -54,24 +58,31 @@ def test_conv2d():
dilation_h,
dilation_w,
conv_mode=1,
tensor_format=0,
algo=1)
tensor_format=tensor_format,
conv_dtype=conv_dtype,
algo=-1)
yshape = [x.value for x in Y.shape]
s = tvm.create_schedule(Y.op)
s = tvm.create_schedule(Y.op)
def verify():
ctx = tvm.gpu(0)
f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv2d")
x = tvm.nd.array(np.random.uniform(-1, 1, xshape).astype(np.float32),
x = tvm.nd.array(np.random.uniform(-1, 1, xshape).astype(data_dtype),
ctx)
w = tvm.nd.array(np.random.uniform(-1, 1, wshape).astype(np.float32),
w = tvm.nd.array(np.random.uniform(-1, 1, wshape).astype(data_dtype),
ctx)
y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32),
y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(data_dtype),
ctx)
f(x, w, y)
verify()
def test_conv2d():
verify_conv2d("float32", "float32", tensor_format=0)
verify_conv2d("float16", "float32", tensor_format=1)
verify_conv2d("float16", "float16", tensor_format=0)
verify_conv2d("int8", "int32", tensor_format=1)
if __name__ == "__main__":
test_conv2d()
......@@ -89,6 +89,13 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\
((KW - 1) * dilation_w + 1))
if data.dtype == "int8" or kernel.dtype == "int8":
if layout == 'NCHW':
raise ValueError("NCHW layout do not support int8 in cudnn")
dtype = "int32"
else:
dtype = data.dtype
return cudnn.conv2d_forward(data,
kernel,
stride_h,
......@@ -99,7 +106,8 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
dilation_w,
conv_mode=1,
tensor_format=tensor_format,
algo=-1) # let CUDNN choose the best algo
algo=-1, # let CUDNN choose the best algo
conv_dtype=dtype)
if cfg.template_key == 'winograd':
return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
......
......@@ -62,8 +62,7 @@ def dense_cuda(cfg, data, weight, bias=None, out_dtype=None):
out_dim, _ = weight.shape
target = tvm.target.current_target()
if "cublas" in target.libs:
assert out_dtype == data.dtype, "Mixed precision not supported."
matmul = cublas.matmul(data, weight, False, True)
matmul = cublas.matmul(data, weight, False, True, out_dtype)
if bias is not None:
matmul = tvm.compute((batch, out_dim), \
lambda i, j: matmul[i, j] + bias[j], \
......@@ -256,9 +255,19 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator for int8 on CUDA"""
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
target = tvm.target.current_target()
if "cublas" in target.libs:
matmul = cublas.matmul(data, weight, False, True, out_dtype)
if bias is not None:
matmul = tvm.compute((batch, out_dim), \
lambda i, j: matmul[i, j] + bias[j].astype(out_dtype), \
tag=tag.BROADCAST)
return matmul
k = tvm.reduce_axis((0, in_dim), name='k')
matmul = tvm.compute((batch, out_dim),
......@@ -279,7 +288,14 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
@autotvm.register_topi_schedule(generic.schedule_dense, ['cuda', 'gpu'], ['int8'])
def schedule_dense_int8(cfg, outs):
"""Dense schedule for int8 on CUDA"""
s = tvm.create_schedule([x.op for x in outs])
target = tvm.target.current_target()
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
if "cublas" in target.libs:
return generic.schedule_extern(outs)
def _callback(op):
if "dense_int8" in op.tag:
_schedule_dense_int8(cfg, s, op.output(0))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment