Commit 2f9ab71e by Chris Nuernberger Committed by Tianqi Chen

Better gemm support for cublas and cpu (#1967)

parent f2d7787e
...@@ -63,7 +63,7 @@ macro(find_cuda use_cuda) ...@@ -63,7 +63,7 @@ macro(find_cuda use_cuda)
endif() endif()
find_library(CUDA_NVRTC_LIBRARY nvrtc find_library(CUDA_NVRTC_LIBRARY nvrtc
PATHS ${CUDA_TOOLKIT_ROOT_DIR} PATHS ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs lib64/stubs PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs lib64/stubs lib/x86_64-linux-gnu
NO_DEFAULT_PATH) NO_DEFAULT_PATH)
find_library(CUDA_CUDNN_LIBRARY cudnn find_library(CUDA_CUDNN_LIBRARY cudnn
${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h> #include <tvm/runtime/util.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include "gemm_common.h"
extern "C" { extern "C" {
#if USE_MKL_BLAS == 1 #if USE_MKL_BLAS == 1
...@@ -19,38 +21,56 @@ namespace contrib { ...@@ -19,38 +21,56 @@ namespace contrib {
using namespace runtime; using namespace runtime;
inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) {
return trans ? CblasTrans : CblasNoTrans;
}
struct CblasSgemmOp {
typedef float TDatatype;
void operator()(bool ta, bool tb,
int M, int N, int K,
float alpha, float* A, int lda,
float* B, int ldb,
float beta, float* C, int ldc) {
cblas_sgemm(CblasColMajor,
BooleanToTranspose(ta),
BooleanToTranspose(tb),
M, N, K,
alpha, A, lda,
B, ldb,
beta, C, ldc);
}
};
struct CblasDgemmOp {
typedef double TDatatype;
void operator()(bool ta, bool tb,
int M, int N, int K,
double alpha, double* A, int lda,
double* B, int ldb,
double beta, double* C, int ldc) {
cblas_dgemm(CblasColMajor,
BooleanToTranspose(ta),
BooleanToTranspose(tb),
M, N, K,
alpha, A, lda,
B, ldb,
beta, C, ldc);
}
};
// matrix multiplication for row major // matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor* A = args[0]; DLTensor* A = args[0];
DLTensor* B = args[1]; CHECK(TypeMatch(A->dtype, kDLFloat, 32) ||
DLTensor* C = args[2]; TypeMatch(A->dtype, kDLFloat, 64));
bool transa = args[3];
bool transb = args[4]; if (TypeMatch(A->dtype, kDLFloat, 32))
// call gemm for simple compact code. CallGemm(args, ret, CblasSgemmOp());
CHECK_EQ(A->ndim, 2); else
CHECK_EQ(B->ndim, 2); CallGemm(args, ret, CblasDgemmOp());
CHECK_EQ(C->ndim, 2);
CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));
cblas_sgemm(CblasColMajor,
transb ? CblasTrans : CblasNoTrans,
transa ? CblasTrans : CblasNoTrans,
transb ? B->shape[0] : B->shape[1],
transa ? A->shape[1] : A->shape[0],
transb ? B->shape[1] : B->shape[0],
1.0f,
reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset),
B->shape[1],
reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset),
A->shape[1],
0.0f,
reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset),
C->shape[1]);
}); });
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/contrib/gemm.h
* \brief Shared implementation of gemm
*/
#ifndef TVM_CONTRIB_CBLAS_GEMM_COMMON_H_
#define TVM_CONTRIB_CBLAS_GEMM_COMMON_H_
#include <algorithm>
namespace tvm {
namespace contrib {
using namespace runtime;
inline int ColumnStride(DLTensor* tensor) {
// If the tensor itself is transposed then it will have strides
// backward from what we expect. Regardless, the max of the strides
// (the other stride is 1) is the column stride.
if (tensor->strides) {
return std::max(tensor->strides[0], tensor->strides[1]);
} else {
return tensor->shape[1];
}
}
inline int ElementStride(DLTensor* tensor) {
if (tensor->strides) {
return std::min(tensor->strides[0], tensor->strides[1]);
} else {
return 1;
}
}
// Reversed strides indicates an in-place transpose operation.
inline bool IsInPlaceTransposed(DLTensor* tensor) {
return tensor->strides && (tensor->strides[1] > tensor->strides[0]);
}
inline int RowCount(DLTensor* tensor, bool trans) {
return tensor->shape[trans ? 1 : 0];
}
inline int ColumnCount(DLTensor* tensor, bool trans) {
return tensor->shape[trans ? 0 : 1];
}
// Call a column major blas. Note that data is stored in tvm as row
// major, so this we switch the arguments.
template<typename TGemmOp>
inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) {
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];
int bit_depth = sizeof(typename TGemmOp::TDatatype) * 8;
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);
CHECK_EQ(ElementStride(A), 1);
CHECK_EQ(ElementStride(B), 1);
CHECK_EQ(ElementStride(C), 1);
// C can never be transposed.
CHECK(!IsInPlaceTransposed(C));
// Reversed strides indicates an in-place transpose operation.
transa = IsInPlaceTransposed(A) ? !transa : transa;
transb = IsInPlaceTransposed(B) ? !transb : transb;
CHECK(TypeMatch(B->dtype, kDLFloat, bit_depth));
CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth));
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
op(transb,
transa,
ColumnCount(B, transb),
RowCount(A, transa),
ColumnCount(A, transa),
static_cast<float>(alpha),
reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(B->data)
+ B->byte_offset),
ColumnStride(B),
reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(A->data)
+ A->byte_offset),
ColumnStride(A),
static_cast<float>(beta),
reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(C->data)
+ C->byte_offset),
ColumnStride(C));
}
} // namespace contrib
} // namespace tvm
#endif // TVM_CONTRIB_CBLAS_GEMM_COMMON_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2018 by Contributors
* \file Use external cblas library call. * \file Use external cblas library call.
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h> #include <tvm/runtime/util.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include "../cblas/gemm_common.h"
#include "cublas_utils.h"
extern "C" {
#include <cublas_v2.h>
}
namespace tvm { namespace tvm {
namespace contrib { namespace contrib {
using namespace runtime; using namespace runtime;
#ifndef CHECK_CUBLAS_ERROR inline cublasOperation_t BooleanToTranspose(bool item) {
#define CHECK_CUBLAS_ERROR(error) \ return item ? CUBLAS_OP_T : CUBLAS_OP_N;
if (error != CUBLAS_STATUS_SUCCESS) { \
fprintf(stderr, "cuBLAS error: "); \
if (error == CUBLAS_STATUS_NOT_INITIALIZED) fprintf(stderr, "CUBLAS_STATUS_NOT_INITIALIZED"); \
if (error == CUBLAS_STATUS_ALLOC_FAILED) fprintf(stderr, "CUBLAS_STATUS_ALLOC_FAILED"); \
if (error == CUBLAS_STATUS_INVALID_VALUE) fprintf(stderr, "CUBLAS_STATUS_INVALID_VALUE"); \
if (error == CUBLAS_STATUS_ARCH_MISMATCH) fprintf(stderr, "CUBLAS_STATUS_ARCH_MISMATCH"); \
if (error == CUBLAS_STATUS_MAPPING_ERROR) fprintf(stderr, "CUBLAS_STATUS_MAPPING_ERROR"); \
if (error == CUBLAS_STATUS_EXECUTION_FAILED) fprintf(stderr, "CUBLAS_STATUS_EXECUTION_FAILED"); \
if (error == CUBLAS_STATUS_INTERNAL_ERROR) fprintf(stderr, "CUBLAS_STATUS_INTERNAL_ERROR"); \
if (error == CUBLAS_STATUS_NOT_SUPPORTED) fprintf(stderr, "CUBLAS_STATUS_NOT_SUPPORTED"); \
if (error == CUBLAS_STATUS_LICENSE_ERROR) fprintf(stderr, "CUBLAS_STATUS_LICENSE_ERROR"); \
fprintf(stderr, "\n"); \
exit(EXIT_FAILURE); \
} }
#endif
struct CublasSgemmOp {
typedef float TDatatype;
cublasHandle_t handle;
explicit CublasSgemmOp(cublasHandle_t hdl)
: handle(hdl)
{}
void operator()(bool ta, bool tb,
int M, int N, int K,
float alpha, float* A, int lda,
float* B, int ldb,
float beta, float* C, int ldc) {
CHECK_CUBLAS_ERROR(cublasSgemm(handle,
BooleanToTranspose(ta),
BooleanToTranspose(tb),
M, N, K,
&alpha, A, lda,
B, ldb,
&beta, C, ldc));
}
};
struct CublasDgemmOp {
typedef double TDatatype;
cublasHandle_t handle;
explicit CublasDgemmOp(cublasHandle_t hdl)
: handle(hdl)
{}
void operator()(bool ta, bool tb,
int M, int N, int K,
double alpha, double* A, int lda,
double* B, int ldb,
double beta, double* C, int ldc) {
CHECK_CUBLAS_ERROR(cublasDgemm(handle,
BooleanToTranspose(ta),
BooleanToTranspose(tb),
M, N, K,
&alpha, A, lda,
B, ldb,
&beta, C, ldc));
}
};
// matrix multiplication for row major // matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor* A = args[0]; DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];
// call gemm for simple compact code.
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);
CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));
cublasHandle_t handle; CHECK(TypeMatch(A->dtype, kDLFloat, 32) ||
CHECK_CUBLAS_ERROR(cublasCreate(&handle)); TypeMatch(A->dtype, kDLFloat, 64));
float alpha = 1.0;
float beta = 0.0;
float *A_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset);
float *B_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset);
float *C_ptr = reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset);
CHECK_CUBLAS_ERROR(cublasSgemm(handle, CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
transb ? CUBLAS_OP_T : CUBLAS_OP_N,
transa ? CUBLAS_OP_T : CUBLAS_OP_N,
transb ? B->shape[0] : B->shape[1],
transa ? A->shape[1] : A->shape[0],
transb ? B->shape[1] : B->shape[0],
&alpha,
A_ptr,
B->shape[1],
B_ptr,
A->shape[1],
&beta,
C_ptr,
C->shape[1]));
CHECK_CUBLAS_ERROR(cublasDestroy(handle)); if (TypeMatch(A->dtype, kDLFloat, 32))
CallGemm(args, ret, CublasSgemmOp(entry_ptr->handle));
else
CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle));
}); });
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file Use external cudnn utils function
*/
#include "cublas_utils.h"
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include "../../runtime/cuda/cuda_common.h"
namespace tvm {
namespace contrib {
CuBlasThreadEntry::CuBlasThreadEntry() {
CHECK_CUBLAS_ERROR(cublasCreate(&handle));
}
CuBlasThreadEntry::~CuBlasThreadEntry() {
if (handle) {
cublasDestroy(handle);
handle = 0;
}
}
typedef dmlc::ThreadLocalStore<CuBlasThreadEntry> CuBlasThreadStore;
CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal() {
auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
CuBlasThreadEntry* retval = CuBlasThreadStore::Get();
CHECK_CUBLAS_ERROR(cublasSetStream(retval->handle, static_cast<cudaStream_t>(stream)));
return retval;
}
} // namespace contrib
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file Use external cudnn utils function
*/
#ifndef TVM_CONTRIB_CUBLAS_CUBLAS_UTILS_H_
#define TVM_CONTRIB_CUBLAS_CUBLAS_UTILS_H_
#include <dmlc/logging.h>
extern "C" {
#include <cublas_v2.h>
}
namespace tvm {
namespace contrib {
inline const char* GetCublasErrorString(int error) {
switch (error) {
case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR";
}
return "Unrecognized error";
}
#ifndef CHECK_CUBLAS_ERROR
#define CHECK_CUBLAS_ERROR(fn) \
do { \
int error = static_cast<int>(fn); \
CHECK_EQ(error, CUBLAS_STATUS_SUCCESS) << "CUBLAS: " << GetCublasErrorString(error); \
} while (0) // ; intentionally left off.
#endif // CHECK_CUBLAS_ERROR
struct CuBlasThreadEntry {
CuBlasThreadEntry();
~CuBlasThreadEntry();
cublasHandle_t handle{nullptr};
static CuBlasThreadEntry* ThreadLocal();
}; // CuBlasThreadEntry
} // namespace contrib
} // namespace tvm
#endif // TVM_CONTRIB_CUBLAS_CUBLAS_UTILS_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment