Commit dabde40f by Siyuan Feng Committed by Leyuan Wang

[Perf] Enhance cudnn and cublas backend and enable TensorCore (#4353)

* add half and mix precision support to cublas backend

* add TensorCore support in CuDNN

* enhance CuDNN support

* address comments and fix lint

* fix

* add fp16 test
parent fbb2a356
...@@ -39,6 +39,14 @@ namespace runtime { ...@@ -39,6 +39,14 @@ namespace runtime {
inline bool TypeMatch(TVMType t, int code, int bits, int lanes = 1) { inline bool TypeMatch(TVMType t, int code, int bits, int lanes = 1) {
return t.code == code && t.bits == bits && t.lanes == lanes; return t.code == code && t.bits == bits && t.lanes == lanes;
} }
/*!
* \brief Check whether two types are equal .
* \param lhs The left operand.
* \param rhs The right operand.
*/
inline bool TypeEqual(TVMType lhs, TVMType rhs) {
return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
// Forward declare the intrinsic id we need // Forward declare the intrinsic id we need
......
...@@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs ...@@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs
from .. import api as _api from .. import api as _api
from .. import intrin as _intrin from .. import intrin as _intrin
def matmul(lhs, rhs, transa=False, transb=False): def matmul(lhs, rhs, transa=False, transb=False, dtype=None):
"""Create an extern op that compute matrix mult of A and rhs with cuBLAS """Create an extern op that compute matrix mult of A and rhs with cuBLAS
Parameters Parameters
...@@ -41,13 +41,14 @@ def matmul(lhs, rhs, transa=False, transb=False): ...@@ -41,13 +41,14 @@ def matmul(lhs, rhs, transa=False, transb=False):
""" """
n = lhs.shape[1] if transa else lhs.shape[0] n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1] m = rhs.shape[0] if transb else rhs.shape[1]
dtype = dtype if dtype is not None else lhs.dtype
return _api.extern( return _api.extern(
(n, m), [lhs, rhs], (n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cublas.matmul", "tvm.contrib.cublas.matmul",
ins[0], ins[1], outs[0], transa, transb), name="C") ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
def batch_matmul(lhs, rhs, transa=False, transb=False): def batch_matmul(lhs, rhs, transa=False, transb=False, dtype=None):
"""Create an extern op that compute batch matrix mult of A and rhs with cuBLAS """Create an extern op that compute batch matrix mult of A and rhs with cuBLAS
Parameters Parameters
...@@ -69,8 +70,9 @@ def batch_matmul(lhs, rhs, transa=False, transb=False): ...@@ -69,8 +70,9 @@ def batch_matmul(lhs, rhs, transa=False, transb=False):
b = lhs.shape[0] b = lhs.shape[0]
n = lhs.shape[2] if transa else lhs.shape[1] n = lhs.shape[2] if transa else lhs.shape[1]
m = rhs.shape[1] if transb else rhs.shape[2] m = rhs.shape[1] if transb else rhs.shape[2]
dtype = dtype if dtype is not None else lhs.dtype
return _api.extern( return _api.extern(
(b, n, m), [lhs, rhs], (b, n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cublas.batch_matmul", "tvm.contrib.cublas.batch_matmul",
ins[0], ins[1], outs[0], transa, transb), name="C") ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
...@@ -22,7 +22,6 @@ from .. import api as _api ...@@ -22,7 +22,6 @@ from .. import api as _api
from .. import intrin as _intrin from .. import intrin as _intrin
from .. import get_global_func as _get_global_func from .. import get_global_func as _get_global_func
# algos can be read from cudnn.h # algos can be read from cudnn.h
_FWD_ALGOS = [ _FWD_ALGOS = [
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM", "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM",
...@@ -67,6 +66,7 @@ _ALGO_TYPE = [ ...@@ -67,6 +66,7 @@ _ALGO_TYPE = [
"bwd_data" "bwd_data"
] ]
def algo_to_index(algo_type, algo_name): def algo_to_index(algo_type, algo_name):
"""Return a index represents the algorithm, which can be used in """Return a index represents the algorithm, which can be used in
calling CuDNN function calling CuDNN function
...@@ -172,6 +172,7 @@ def conv2d_w_shape(in_channel, ...@@ -172,6 +172,7 @@ def conv2d_w_shape(in_channel,
""" """
return [out_channel, in_channel, filter_h, filter_w] return [out_channel, in_channel, filter_h, filter_w]
def conv2d_output_shape(tensor_format, def conv2d_output_shape(tensor_format,
pad_h, pad_h,
pad_w, pad_w,
...@@ -180,7 +181,9 @@ def conv2d_output_shape(tensor_format, ...@@ -180,7 +181,9 @@ def conv2d_output_shape(tensor_format,
dilation_h, dilation_h,
dilation_w, dilation_w,
x_shape, x_shape,
w_shape): w_shape,
data_dtype,
conv_dtype):
"""Get output shape of 2D convolution """Get output shape of 2D convolution
Paramters Paramters
...@@ -232,7 +235,9 @@ def conv2d_output_shape(tensor_format, ...@@ -232,7 +235,9 @@ def conv2d_output_shape(tensor_format,
w_shape[1].value, w_shape[1].value,
w_shape[2].value, w_shape[2].value,
w_shape[3].value, w_shape[3].value,
_get_np_int32_array_handle(oshape)) _get_np_int32_array_handle(oshape),
data_dtype,
conv_dtype)
return list(oshape) return list(oshape)
...@@ -245,7 +250,9 @@ def conv2d_find_algo(tensor_format, ...@@ -245,7 +250,9 @@ def conv2d_find_algo(tensor_format,
dilation_w, dilation_w,
x_shape, x_shape,
w_shape, w_shape,
y_shape): y_shape,
data_dtype,
conv_dtype):
"""Choose the best algo for the given input. """Choose the best algo for the given input.
Paramters Paramters
...@@ -272,6 +279,10 @@ def conv2d_find_algo(tensor_format, ...@@ -272,6 +279,10 @@ def conv2d_find_algo(tensor_format,
weight shape weight shape
y_shape: list y_shape: list
output shape output shape
data_dtype: str
data type
conv_dtype: str
convolution type
Returns Returns
------- -------
...@@ -297,7 +308,9 @@ def conv2d_find_algo(tensor_format, ...@@ -297,7 +308,9 @@ def conv2d_find_algo(tensor_format,
int(y_shape[0]), int(y_shape[0]),
int(y_shape[1]), int(y_shape[1]),
int(y_shape[2]), int(y_shape[2]),
int(y_shape[3])) int(y_shape[3]),
data_dtype,
conv_dtype)
def conv2d_forward(x, def conv2d_forward(x,
...@@ -310,7 +323,8 @@ def conv2d_forward(x, ...@@ -310,7 +323,8 @@ def conv2d_forward(x,
dilation_w=1, dilation_w=1,
conv_mode=1, conv_mode=1,
tensor_format=0, tensor_format=0,
algo=-1): algo=-1,
conv_dtype=None):
"""Create an extern op that compute 2D convolution with CuDNN """Create an extern op that compute 2D convolution with CuDNN
Parameters Parameters
...@@ -341,12 +355,16 @@ def conv2d_forward(x, ...@@ -341,12 +355,16 @@ def conv2d_forward(x,
algo: int algo: int
Forward algorithm, get index from ```algo_to_index``` function Forward algorithm, get index from ```algo_to_index``` function
if algo == -1, the best algo will be chosen by CUDNN if algo == -1, the best algo will be chosen by CUDNN
conv_dtype: str
convolution type
Returns Returns
------- -------
y: Tensor y: Tensor
The result tensor The result tensor
""" """
conv_dtype = x.dtype if conv_dtype is None else conv_dtype
oshape = conv2d_output_shape(tensor_format, oshape = conv2d_output_shape(tensor_format,
pad_h, pad_h,
pad_w, pad_w,
...@@ -355,8 +373,16 @@ def conv2d_forward(x, ...@@ -355,8 +373,16 @@ def conv2d_forward(x,
dilation_h, dilation_h,
dilation_w, dilation_w,
list(x.shape), list(x.shape),
list(w.shape)) list(w.shape),
x.dtype,
conv_dtype)
if algo == -1: if algo == -1:
# For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when
# using INT8 data type, CuDNN will crash down.
# On the other hand, CuDNN only support IMPLICIT_​PRECOMP_GEMM at NHWC format
if tensor_format == 1 and conv_dtype == "int32":
algo = 1
else:
algo = conv2d_find_algo(tensor_format, algo = conv2d_find_algo(tensor_format,
pad_h, pad_h,
pad_w, pad_w,
...@@ -366,7 +392,9 @@ def conv2d_forward(x, ...@@ -366,7 +392,9 @@ def conv2d_forward(x,
dilation_w, dilation_w,
list(x.shape), list(x.shape),
list(w.shape), list(w.shape),
oshape) oshape,
x.dtype,
conv_dtype)
return _api.extern( return _api.extern(
oshape, [x, w], oshape, [x, w],
...@@ -383,4 +411,5 @@ def conv2d_forward(x, ...@@ -383,4 +411,5 @@ def conv2d_forward(x,
dilation_w, dilation_w,
ins[0], ins[0],
ins[1], ins[1],
outs[0]), name="y") outs[0],
conv_dtype), name="y")
...@@ -93,13 +93,13 @@ inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) { ...@@ -93,13 +93,13 @@ inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) {
double alpha = args.size() > 5 ? args[5] : 1.0; double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0; double beta = args.size() > 6 ? args[6] : 0.0;
op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), op(transb, transa, ColumnCount(B, transb), RowCount(A, transa),
ColumnCount(A, transa), static_cast<float>(alpha), ColumnCount(A, transa), static_cast<typename TGemmOp::TDatatype>(alpha),
reinterpret_cast<typename TGemmOp::TDatatype *>( reinterpret_cast<typename TGemmOp::TDatatype *>(
static_cast<char *>(B->data) + B->byte_offset), static_cast<char *>(B->data) + B->byte_offset),
ColumnStride(B), ColumnStride(B),
reinterpret_cast<typename TGemmOp::TDatatype *>( reinterpret_cast<typename TGemmOp::TDatatype *>(
static_cast<char *>(A->data) + A->byte_offset), static_cast<char *>(A->data) + A->byte_offset),
ColumnStride(A), static_cast<float>(beta), ColumnStride(A), static_cast<typename TGemmOp::TDatatype>(beta),
reinterpret_cast<typename TGemmOp::TDatatype *>( reinterpret_cast<typename TGemmOp::TDatatype *>(
static_cast<char *>(C->data) + C->byte_offset), static_cast<char *>(C->data) + C->byte_offset),
ColumnStride(C)); ColumnStride(C));
...@@ -170,9 +170,10 @@ inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) { ...@@ -170,9 +170,10 @@ inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) {
DType *C_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>( DType *C_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
static_cast<char *>(C->data) + C->byte_offset); static_cast<char *>(C->data) + C->byte_offset);
op(batch_size, transb, transa, ColumnCount3D(B, transb), op(batch_size, transb, transa, ColumnCount3D(B, transb),
RowCount3D(A, transa), ColumnCount3D(A, transa), static_cast<float>(alpha), RowCount3D(A, transa), ColumnCount3D(A, transa),
static_cast<typename TBatchGemmOp::TDatatype>(alpha),
B_data, B_size, ColumnStride3D(B), A_data, A_size, ColumnStride3D(A), B_data, B_size, ColumnStride3D(B), A_data, A_size, ColumnStride3D(A),
static_cast<float>(beta), C_data, C_size, ColumnStride3D(C)); static_cast<typename TBatchGemmOp::TDatatype>(beta), C_data, C_size, ColumnStride3D(C));
} }
} // namespace contrib } // namespace contrib
......
...@@ -36,12 +36,40 @@ inline cublasOperation_t BooleanToTranspose(bool item) { ...@@ -36,12 +36,40 @@ inline cublasOperation_t BooleanToTranspose(bool item) {
return item ? CUBLAS_OP_T : CUBLAS_OP_N; return item ? CUBLAS_OP_T : CUBLAS_OP_N;
} }
inline void TryEnableTensorCore(cublasHandle_t hdl) {
// TensorCores are only supported in cublas 9.0 or higher
int version;
CHECK_CUBLAS_ERROR(cublasGetVersion(hdl, &version));
if (version >= 9000)
CHECK_CUBLAS_ERROR(cublasSetMathMode(hdl, CUBLAS_TENSOR_OP_MATH));
}
struct CublasHgemmOp {
typedef half TDatatype;
cublasHandle_t handle;
explicit CublasHgemmOp(cublasHandle_t hdl)
: handle(hdl) {}
void operator()(bool ta, bool tb,
int M, int N, int K,
half alpha, half* A, int lda,
half* B, int ldb,
half beta, half* C, int ldc) {
CHECK_CUBLAS_ERROR(cublasHgemm(handle,
BooleanToTranspose(ta),
BooleanToTranspose(tb),
M, N, K,
&alpha, A, lda,
B, ldb,
&beta, C, ldc));
}
};
struct CublasSgemmOp { struct CublasSgemmOp {
typedef float TDatatype; typedef float TDatatype;
cublasHandle_t handle; cublasHandle_t handle;
explicit CublasSgemmOp(cublasHandle_t hdl) explicit CublasSgemmOp(cublasHandle_t hdl)
: handle(hdl) : handle(hdl) {}
{}
void operator()(bool ta, bool tb, void operator()(bool ta, bool tb,
int M, int N, int K, int M, int N, int K,
...@@ -58,13 +86,11 @@ struct CublasSgemmOp { ...@@ -58,13 +86,11 @@ struct CublasSgemmOp {
} }
}; };
struct CublasDgemmOp { struct CublasDgemmOp {
typedef double TDatatype; typedef double TDatatype;
cublasHandle_t handle; cublasHandle_t handle;
explicit CublasDgemmOp(cublasHandle_t hdl) explicit CublasDgemmOp(cublasHandle_t hdl)
: handle(hdl) : handle(hdl) {}
{}
void operator()(bool ta, bool tb, void operator()(bool ta, bool tb,
int M, int N, int K, int M, int N, int K,
double alpha, double* A, int lda, double alpha, double* A, int lda,
...@@ -80,12 +106,32 @@ struct CublasDgemmOp { ...@@ -80,12 +106,32 @@ struct CublasDgemmOp {
} }
}; };
struct CublasHgemmBatchOp {
typedef half TDatatype;
cublasHandle_t handle;
explicit CublasHgemmBatchOp(cublasHandle_t hdl)
: handle(hdl) {}
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, half alpha, half* A,
int a_stride, int lda, half* B, int b_stride, int ldb, half beta, half* C,
int c_stride, int ldc) {
CHECK_CUBLAS_ERROR(cublasHgemmStridedBatched(handle,
BooleanToTranspose(ta),
BooleanToTranspose(tb),
M, N, K,
&alpha,
A, lda, a_stride,
B, ldb, b_stride,
&beta,
C, ldc, c_stride,
batch_size));
}
};
struct CublasSgemmBatchOp { struct CublasSgemmBatchOp {
typedef float TDatatype; typedef float TDatatype;
cublasHandle_t handle; cublasHandle_t handle;
explicit CublasSgemmBatchOp(cublasHandle_t hdl) explicit CublasSgemmBatchOp(cublasHandle_t hdl)
: handle(hdl) : handle(hdl) {}
{}
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A, void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C, int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
int c_stride, int ldc) { int c_stride, int ldc) {
...@@ -106,8 +152,7 @@ struct CublasDgemmBatchOp { ...@@ -106,8 +152,7 @@ struct CublasDgemmBatchOp {
typedef double TDatatype; typedef double TDatatype;
cublasHandle_t handle; cublasHandle_t handle;
explicit CublasDgemmBatchOp(cublasHandle_t hdl) explicit CublasDgemmBatchOp(cublasHandle_t hdl)
: handle(hdl) : handle(hdl) {}
{}
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A, void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C, int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
int c_stride, int ldc) { int c_stride, int ldc) {
...@@ -124,35 +169,203 @@ struct CublasDgemmBatchOp { ...@@ -124,35 +169,203 @@ struct CublasDgemmBatchOp {
} }
}; };
// Check cublas supported mix-precision computation type and return computeType
bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_support = true) {
if (int_support && TypeMatch(out_dtype, kDLInt, 32)) {
return TypeMatch(in_dtype, kDLInt, 8);
} else if (TypeMatch(out_dtype, kDLFloat, 32)) {
return TypeMatch(in_dtype, kDLInt, 8) ||
TypeMatch(in_dtype, kDLFloat, 16);
} else {
return false;
}
}
inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) {
DLTensor *A = args[0];
DLTensor *B = args[1];
DLTensor *C = args[2];
bool transa = args[3];
bool transb = args[4];
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);
CHECK_EQ(ElementStride(A), 1);
CHECK_EQ(ElementStride(B), 1);
CHECK_EQ(ElementStride(C), 1);
CHECK(TypeEqual(A->dtype, B->dtype));
// C can never be transposed.
CHECK(!IsInPlaceTransposed(C));
// Reversed strides indicates an in-place transpose operation.
transa = IsInPlaceTransposed(A) ? !transa : transa;
transb = IsInPlaceTransposed(B) ? !transb : transb;
CHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type";
CHECK(!TypeMatch(A->dtype, kDLInt, 8) ||
ColumnStride(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm";
CHECK(!TypeMatch(B->dtype, kDLInt, 8) ||
ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm";
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
cudaDataType_t cuda_in_type = GetCudaDataType(A->dtype);
cudaDataType_t cuda_out_type = GetCudaDataType(C->dtype);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
void *alpha_ptr = nullptr, *beta_ptr = nullptr;
auto alpha_int = static_cast<int32_t>(alpha);
auto beta_int = static_cast<int32_t>(beta);
auto alpha_float = static_cast<float>(alpha);
auto beta_float = static_cast<float>(beta);
if (C->dtype.code == kDLInt) {
alpha_ptr = &alpha_int;
beta_ptr = &beta_int;
} else if (C->dtype.code == kDLFloat) {
alpha_ptr = &alpha_float;
beta_ptr = &beta_float;
}
auto A_data = reinterpret_cast<void *>(static_cast<char *>(A->data) + A->byte_offset);
auto B_data = reinterpret_cast<void *>(static_cast<char *>(B->data) + B->byte_offset);
auto C_data = reinterpret_cast<void *>(static_cast<char *>(C->data) + C->byte_offset);
CHECK_CUBLAS_ERROR(cublasGemmEx(hdl,
BooleanToTranspose(transb),
BooleanToTranspose(transa),
ColumnCount(B, transb),
RowCount(A, transa),
ColumnCount(A, transa),
alpha_ptr,
B_data, cuda_in_type, ColumnStride(B),
A_data, cuda_in_type, ColumnStride(A),
beta_ptr,
C_data, cuda_out_type, ColumnStride(C),
cuda_out_type, algo));
}
inline void CallBatchGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) {
DLTensor *A = args[0];
DLTensor *B = args[1];
DLTensor *C = args[2];
bool transa = args[3];
bool transb = args[4];
CHECK_EQ(A->ndim, 3);
CHECK_EQ(B->ndim, 3);
CHECK_EQ(C->ndim, 3);
int batch_size = BatchCount3D(A);
CHECK_EQ(BatchCount3D(B), batch_size);
CHECK_EQ(BatchCount3D(C), batch_size);
CHECK_EQ(ElementStride(A), 1);
CHECK_EQ(ElementStride(B), 1);
CHECK_EQ(ElementStride(C), 1);
CHECK(TypeEqual(A->dtype, B->dtype));
// C can never be transposed.
CHECK(!IsInPlaceTransposed(C));
// Reversed strides indicates an in-place transpose operation.
transa = IsInPlaceTransposed(A) ? !transa : transa;
transb = IsInPlaceTransposed(B) ? !transb : transb;
CHECK(CheckMixPrecisionType(A->dtype, C->dtype, false)) << "Unsupported data type";
CHECK(!TypeMatch(A->dtype, kDLInt, 8) ||
ColumnStride(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm";
CHECK(!TypeMatch(B->dtype, kDLInt, 8) ||
ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm";
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
const int A_size = A->shape[1] * A->shape[2];
const int B_size = B->shape[1] * B->shape[2];
const int C_size = C->shape[1] * C->shape[2];
cudaDataType_t cuda_in_type = GetCudaDataType(A->dtype);
cudaDataType_t cuda_out_type = GetCudaDataType(C->dtype);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
void *alpha_ptr = nullptr, *beta_ptr = nullptr;
auto alpha_int = static_cast<int32_t>(alpha);
auto beta_int = static_cast<int32_t>(beta);
auto alpha_float = static_cast<float>(alpha);
auto beta_float = static_cast<float>(beta);
if (C->dtype.code == kDLInt) {
alpha_ptr = &alpha_int;
beta_ptr = &beta_int;
} else if (C->dtype.code == kDLFloat) {
alpha_ptr = &alpha_float;
beta_ptr = &beta_float;
}
auto A_data = reinterpret_cast<void *>(static_cast<char *>(A->data) + A->byte_offset);
auto B_data = reinterpret_cast<void *>(static_cast<char *>(B->data) + B->byte_offset);
auto C_data = reinterpret_cast<void *>(static_cast<char *>(C->data) + C->byte_offset);
CHECK_CUBLAS_ERROR(cublasGemmStridedBatchedEx(hdl,
BooleanToTranspose(transb),
BooleanToTranspose(transa),
ColumnCount3D(B, transb),
RowCount3D(A, transa),
ColumnCount3D(A, transa),
alpha_ptr,
B_data, cuda_in_type, ColumnStride3D(B), B_size,
A_data, cuda_in_type, ColumnStride3D(A), A_size,
beta_ptr,
C_data, cuda_out_type, ColumnStride3D(C), C_size,
batch_size, cuda_out_type, algo));
}
// matrix multiplication for row major // matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor* A = args[0]; DLTensor* A = args[0];
DLTensor* C = args[2];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) ||
TypeMatch(A->dtype, kDLFloat, 64));
CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
if (TypeMatch(A->dtype, kDLFloat, 32)) TryEnableTensorCore(entry_ptr->handle);
if (TypeEqual(A->dtype, C->dtype)) {
CHECK(TypeMatch(A->dtype, kDLFloat, 16) ||
TypeMatch(A->dtype, kDLFloat, 32) ||
TypeMatch(A->dtype, kDLFloat, 64));
if (TypeMatch(A->dtype, kDLFloat, 16))
CallGemm(args, ret, CublasHgemmOp(entry_ptr->handle));
else if (TypeMatch(A->dtype, kDLFloat, 32))
CallGemm(args, ret, CublasSgemmOp(entry_ptr->handle)); CallGemm(args, ret, CublasSgemmOp(entry_ptr->handle));
else else
CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle)); CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle));
} else {
CallGemmEx(args, ret, entry_ptr->handle);
}
}); });
TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul") TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0]; DLTensor* A = args[0];
DLTensor* C = args[2];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) ||
TypeMatch(A->dtype, kDLFloat, 64));
CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
if (TypeMatch(A->dtype, kDLFloat, 32)) TryEnableTensorCore(entry_ptr->handle);
if (TypeEqual(A->dtype, C->dtype)) {
CHECK(TypeMatch(A->dtype, kDLFloat, 16) ||
TypeMatch(A->dtype, kDLFloat, 32) ||
TypeMatch(A->dtype, kDLFloat, 64));
if (TypeMatch(A->dtype, kDLFloat, 16))
CallBatchGemm(args, ret, CublasHgemmBatchOp(entry_ptr->handle));
else if (TypeMatch(A->dtype, kDLFloat, 32))
CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle)); CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle));
else else
CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle)); CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle));
} else {
CallBatchGemmEx(args, ret, entry_ptr->handle);
}
}); });
} // namespace contrib } // namespace contrib
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define TVM_RUNTIME_CONTRIB_CUBLAS_CUBLAS_UTILS_H_ #define TVM_RUNTIME_CONTRIB_CUBLAS_CUBLAS_UTILS_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dlpack/dlpack.h>
#include <cublas_v2.h> #include <cublas_v2.h>
namespace tvm { namespace tvm {
...@@ -62,7 +62,27 @@ struct CuBlasThreadEntry { ...@@ -62,7 +62,27 @@ struct CuBlasThreadEntry {
static CuBlasThreadEntry* ThreadLocal(); static CuBlasThreadEntry* ThreadLocal();
}; // CuBlasThreadEntry }; // CuBlasThreadEntry
inline cudaDataType_t GetCudaDataType(DLDataType type) {
if (type.code == kDLInt) {
switch (type.bits) {
case 8: return CUDA_R_8I;
case 32: return CUDA_R_32I;
}
} else if (type.code == kDLUInt) {
switch (type.bits) {
case 8: return CUDA_R_8U;
case 32: return CUDA_R_32U;
}
} else if (type.code == kDLFloat) {
switch (type.bits) {
case 16: return CUDA_R_16F;
case 32: return CUDA_R_32F;
case 64: return CUDA_R_64F;
}
}
LOG(FATAL) << "Unsupported cuda type";
return CUDA_R_16F;
}
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
......
...@@ -42,9 +42,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") ...@@ -42,9 +42,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
int stride_w = args[6]; int stride_w = args[6];
int dilation_h = args[7]; int dilation_h = args[7];
int dilation_w = args[8]; int dilation_w = args[8];
DLTensor *x = args[9]; DLTensor* x = args[9];
DLTensor *w = args[10]; DLTensor* w = args[10];
DLTensor *y = args[11]; DLTensor* y = args[11];
std::string conv_dtype = args[12];
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
// Set Mode // Set Mode
entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode); entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
...@@ -55,7 +57,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") ...@@ -55,7 +57,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
// Set Ctx // Set Ctx
entry_ptr->conv_entry.ctx = x->ctx; entry_ptr->conv_entry.ctx = x->ctx;
// Set Data Type // Set Data Type
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(conv_dtype));
cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
// Set Desc // Set Desc
CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc, CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc,
pad_h, pad_h,
...@@ -68,8 +71,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") ...@@ -68,8 +71,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
entry_ptr->conv_entry.data_type)); entry_ptr->conv_entry.data_type));
// Set Filter // Set Filter
CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc,
entry_ptr->conv_entry.data_type, data_type,
CUDNN_TENSOR_NCHW, entry_ptr->conv_entry.tensor_format,
static_cast<int>(w->shape[0]), static_cast<int>(w->shape[0]),
static_cast<int>(w->shape[1]), static_cast<int>(w->shape[1]),
static_cast<int>(w->shape[2]), static_cast<int>(w->shape[2]),
...@@ -77,7 +80,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") ...@@ -77,7 +80,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
// Set Input // Set Input
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.tensor_format, entry_ptr->conv_entry.tensor_format,
entry_ptr->conv_entry.data_type, data_type,
static_cast<int>(x->shape[0]), static_cast<int>(x->shape[0]),
static_cast<int>(x->shape[1]), static_cast<int>(x->shape[1]),
static_cast<int>(x->shape[2]), static_cast<int>(x->shape[2]),
...@@ -85,11 +88,15 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") ...@@ -85,11 +88,15 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
// Set Output // Set Output
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc, CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc,
entry_ptr->conv_entry.tensor_format, entry_ptr->conv_entry.tensor_format,
entry_ptr->conv_entry.data_type, data_type,
static_cast<int>(y->shape[0]), static_cast<int>(y->shape[0]),
static_cast<int>(y->shape[1]), static_cast<int>(y->shape[1]),
static_cast<int>(y->shape[2]), static_cast<int>(y->shape[2]),
static_cast<int>(y->shape[3]))); static_cast<int>(y->shape[3])));
if (cudnnGetVersion() > 7000) {
CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH))
}
// Set workspace // Set workspace
size_t workspace_size = 0; size_t workspace_size = 0;
CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(entry_ptr->handle, CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(entry_ptr->handle,
...@@ -135,6 +142,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.output_shape") ...@@ -135,6 +142,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.output_shape")
int w_dim2 = args[13]; int w_dim2 = args[13];
int w_dim3 = args[14]; int w_dim3 = args[14];
void *out_shape = args[15]; void *out_shape = args[15];
std::string data_dtype = args[16];
std::string conv_dtype = args[17];
// Set Data Type
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(conv_dtype));
cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(data_dtype));
// Set Format // Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format); entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// conv desc // conv desc
...@@ -150,15 +162,15 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.output_shape") ...@@ -150,15 +162,15 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.output_shape")
// input desc // input desc
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.tensor_format, entry_ptr->conv_entry.tensor_format,
CUDNN_DATA_FLOAT, data_type,
x_dim0, x_dim0,
x_dim1, x_dim1,
x_dim2, x_dim2,
x_dim3)); x_dim3));
// filter desc // filter desc
CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc,
CUDNN_DATA_FLOAT, data_type,
CUDNN_TENSOR_NCHW, entry_ptr->conv_entry.tensor_format,
w_dim0, w_dim0,
w_dim1, w_dim1,
w_dim2, w_dim2,
...@@ -196,7 +208,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.find_algo") ...@@ -196,7 +208,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.find_algo")
int y_dim1 = args[16]; int y_dim1 = args[16];
int y_dim2 = args[17]; int y_dim2 = args[17];
int y_dim3 = args[18]; int y_dim3 = args[18];
std::string data_dtype = args[19];
std::string conv_dtype = args[20];
// Set Data Type
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(conv_dtype));
cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(data_dtype));
// Set Format // Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format); entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// conv desc // conv desc
...@@ -212,15 +229,15 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.find_algo") ...@@ -212,15 +229,15 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.find_algo")
// input desc // input desc
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.tensor_format, entry_ptr->conv_entry.tensor_format,
CUDNN_DATA_FLOAT, data_type,
x_dim0, x_dim0,
x_dim1, x_dim1,
x_dim2, x_dim2,
x_dim3)); x_dim3));
// filter desc // filter desc
CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc,
CUDNN_DATA_FLOAT, data_type,
CUDNN_TENSOR_NCHW, entry_ptr->conv_entry.tensor_format,
w_dim0, w_dim0,
w_dim1, w_dim1,
w_dim2, w_dim2,
...@@ -229,11 +246,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.find_algo") ...@@ -229,11 +246,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.find_algo")
// output desc // output desc
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc, CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc,
entry_ptr->conv_entry.tensor_format, entry_ptr->conv_entry.tensor_format,
entry_ptr->conv_entry.data_type, data_type,
y_dim0, y_dim0,
y_dim1, y_dim1,
y_dim2, y_dim2,
y_dim3)); y_dim3));
if (cudnnGetVersion() > 7000) {
CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH))
}
int returned_algo_count = 0; int returned_algo_count = 0;
cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT]; cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT];
......
...@@ -18,13 +18,13 @@ import tvm ...@@ -18,13 +18,13 @@ import tvm
import numpy as np import numpy as np
from tvm.contrib import cublas from tvm.contrib import cublas
def test_matmul_add(): def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5):
n = 1024 n = 1024
l = 128 l = 128
m = 235 m = 236
A = tvm.placeholder((n, l), name='A') A = tvm.placeholder((n, l), name='A', dtype=in_dtype)
B = tvm.placeholder((l, m), name='B') B = tvm.placeholder((l, m), name='B', dtype=in_dtype)
C = cublas.matmul(A, B) C = cublas.matmul(A, B, dtype=out_dtype)
s = tvm.create_schedule(C.op) s = tvm.create_schedule(C.op)
def verify(target="cuda"): def verify(target="cuda"):
...@@ -36,22 +36,22 @@ def test_matmul_add(): ...@@ -36,22 +36,22 @@ def test_matmul_add():
return return
ctx = tvm.gpu(0) ctx = tvm.gpu(0)
f = tvm.build(s, [A, B, C], target) f = tvm.build(s, [A, B, C], target)
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(0, 128, size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx) b = tvm.nd.array(np.random.uniform(0, 128, size=(l, m)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
f(a, b, c) f(a, b, c)
tvm.testing.assert_allclose( tvm.testing.assert_allclose(
c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5) c.asnumpy(), np.dot(a.asnumpy().astype(C.dtype), b.asnumpy().astype(C.dtype)), rtol=rtol)
verify() verify()
def test_batch_matmul(): def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5):
j = 16 j = 16
n = 1024 n = 1024
l = 128 l = 128
m = 235 m = 236
A = tvm.placeholder((j, n, l), name='A') A = tvm.placeholder((j, n, l), name='A', dtype=in_dtype)
B = tvm.placeholder((j, l, m), name='B') B = tvm.placeholder((j, l, m), name='B', dtype=in_dtype)
C = cublas.batch_matmul(A, B) C = cublas.batch_matmul(A, B, dtype=out_dtype)
s = tvm.create_schedule(C.op) s = tvm.create_schedule(C.op)
def verify(target="cuda"): def verify(target="cuda"):
...@@ -68,10 +68,22 @@ def test_batch_matmul(): ...@@ -68,10 +68,22 @@ def test_batch_matmul():
c = tvm.nd.array(np.zeros((j, n, m), dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros((j, n, m), dtype=C.dtype), ctx)
f(a, b, c) f(a, b, c)
tvm.testing.assert_allclose( tvm.testing.assert_allclose(
c.asnumpy(), np.matmul(a.asnumpy(), b.asnumpy()), rtol=1e-5) c.asnumpy(), np.matmul(a.asnumpy().astype(C.dtype),
b.asnumpy().astype(C.dtype)).astype(C.dtype), rtol=rtol)
verify() verify()
def test_matmul_add():
verify_matmul_add('float', 'float')
verify_matmul_add('float16', 'float')
verify_matmul_add('float16', 'float16', rtol=1e-2)
verify_matmul_add('int8', 'int32')
def test_batch_matmul():
verify_batch_matmul('float', 'float')
verify_batch_matmul('float16', 'float')
verify_batch_matmul('float16', 'float16', rtol=1e-2)
if __name__ == "__main__": if __name__ == "__main__":
test_matmul_add() test_matmul_add()
test_batch_matmul() test_batch_matmul()
...@@ -19,8 +19,8 @@ from tvm.contrib import cudnn ...@@ -19,8 +19,8 @@ from tvm.contrib import cudnn
import numpy as np import numpy as np
def test_conv2d(): def verify_conv2d(data_dtype, conv_dtype, tensor_format=0):
in_channel = 3 in_channel = 4
out_channel = 32 out_channel = 32
filter_h = 3 filter_h = 3
filter_w = 3 filter_w = 3
...@@ -30,21 +30,25 @@ def test_conv2d(): ...@@ -30,21 +30,25 @@ def test_conv2d():
stride_w = 1 stride_w = 1
dilation_h = 1 dilation_h = 1
dilation_w = 1 dilation_w = 1
batch = 3
height = 32
weight = 32
xshape = [4, 3, 32, 32]
if not tvm.module.enabled("cuda"): if not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled...") print("skip because cuda is not enabled...")
return return
if not tvm.get_global_func("tvm.contrib.cudnn.conv2d.output_shape", True): if not tvm.get_global_func("tvm.contrib.cudnn.conv2d.output_shape", True):
print("skip because cudnn is not enabled...") print("skip because cudnn is not enabled...")
return return
xshape = [batch, in_channel, height, weight]
wshape = cudnn.conv2d_w_shape(in_channel, wshape = cudnn.conv2d_w_shape(in_channel,
out_channel, out_channel,
filter_h, filter_h,
filter_w) filter_w)
X = tvm.placeholder(xshape, name='X') X = tvm.placeholder(xshape, name='X', dtype=data_dtype)
W = tvm.placeholder(wshape, name='W') W = tvm.placeholder(wshape, name='W', dtype=data_dtype)
Y = cudnn.conv2d_forward(X, Y = cudnn.conv2d_forward(X,
W, W,
stride_h, stride_h,
...@@ -54,24 +58,31 @@ def test_conv2d(): ...@@ -54,24 +58,31 @@ def test_conv2d():
dilation_h, dilation_h,
dilation_w, dilation_w,
conv_mode=1, conv_mode=1,
tensor_format=0, tensor_format=tensor_format,
algo=1) conv_dtype=conv_dtype,
algo=-1)
yshape = [x.value for x in Y.shape] yshape = [x.value for x in Y.shape]
s = tvm.create_schedule(Y.op) s = tvm.create_schedule(Y.op)
def verify(): def verify():
ctx = tvm.gpu(0) ctx = tvm.gpu(0)
f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv2d") f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv2d")
x = tvm.nd.array(np.random.uniform(-1, 1, xshape).astype(np.float32), x = tvm.nd.array(np.random.uniform(-1, 1, xshape).astype(data_dtype),
ctx) ctx)
w = tvm.nd.array(np.random.uniform(-1, 1, wshape).astype(np.float32), w = tvm.nd.array(np.random.uniform(-1, 1, wshape).astype(data_dtype),
ctx) ctx)
y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(data_dtype),
ctx) ctx)
f(x, w, y) f(x, w, y)
verify() verify()
def test_conv2d():
verify_conv2d("float32", "float32", tensor_format=0)
verify_conv2d("float16", "float32", tensor_format=1)
verify_conv2d("float16", "float16", tensor_format=0)
verify_conv2d("int8", "int32", tensor_format=1)
if __name__ == "__main__": if __name__ == "__main__":
test_conv2d() test_conv2d()
...@@ -89,6 +89,13 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou ...@@ -89,6 +89,13 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\ cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\
((KW - 1) * dilation_w + 1)) ((KW - 1) * dilation_w + 1))
if data.dtype == "int8" or kernel.dtype == "int8":
if layout == 'NCHW':
raise ValueError("NCHW layout do not support int8 in cudnn")
dtype = "int32"
else:
dtype = data.dtype
return cudnn.conv2d_forward(data, return cudnn.conv2d_forward(data,
kernel, kernel,
stride_h, stride_h,
...@@ -99,7 +106,8 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou ...@@ -99,7 +106,8 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
dilation_w, dilation_w,
conv_mode=1, conv_mode=1,
tensor_format=tensor_format, tensor_format=tensor_format,
algo=-1) # let CUDNN choose the best algo algo=-1, # let CUDNN choose the best algo
conv_dtype=dtype)
if cfg.template_key == 'winograd': if cfg.template_key == 'winograd':
return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
......
...@@ -62,8 +62,7 @@ def dense_cuda(cfg, data, weight, bias=None, out_dtype=None): ...@@ -62,8 +62,7 @@ def dense_cuda(cfg, data, weight, bias=None, out_dtype=None):
out_dim, _ = weight.shape out_dim, _ = weight.shape
target = tvm.target.current_target() target = tvm.target.current_target()
if "cublas" in target.libs: if "cublas" in target.libs:
assert out_dtype == data.dtype, "Mixed precision not supported." matmul = cublas.matmul(data, weight, False, True, out_dtype)
matmul = cublas.matmul(data, weight, False, True)
if bias is not None: if bias is not None:
matmul = tvm.compute((batch, out_dim), \ matmul = tvm.compute((batch, out_dim), \
lambda i, j: matmul[i, j] + bias[j], \ lambda i, j: matmul[i, j] + bias[j], \
...@@ -256,9 +255,19 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None): ...@@ -256,9 +255,19 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator for int8 on CUDA""" """Dense operator for int8 on CUDA"""
if out_dtype is None: if out_dtype is None:
out_dtype = data.dtype out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape) batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape) out_dim, _ = get_const_tuple(weight.shape)
target = tvm.target.current_target()
if "cublas" in target.libs:
matmul = cublas.matmul(data, weight, False, True, out_dtype)
if bias is not None:
matmul = tvm.compute((batch, out_dim), \
lambda i, j: matmul[i, j] + bias[j].astype(out_dtype), \
tag=tag.BROADCAST)
return matmul
k = tvm.reduce_axis((0, in_dim), name='k') k = tvm.reduce_axis((0, in_dim), name='k')
matmul = tvm.compute((batch, out_dim), matmul = tvm.compute((batch, out_dim),
...@@ -279,7 +288,14 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None): ...@@ -279,7 +288,14 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
@autotvm.register_topi_schedule(generic.schedule_dense, ['cuda', 'gpu'], ['int8']) @autotvm.register_topi_schedule(generic.schedule_dense, ['cuda', 'gpu'], ['int8'])
def schedule_dense_int8(cfg, outs): def schedule_dense_int8(cfg, outs):
"""Dense schedule for int8 on CUDA"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
target = tvm.target.current_target()
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
if "cublas" in target.libs:
return generic.schedule_extern(outs)
def _callback(op): def _callback(op):
if "dense_int8" in op.tag: if "dense_int8" in op.tag:
_schedule_dense_int8(cfg, s, op.output(0)) _schedule_dense_int8(cfg, s, op.output(0))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment