cublas.cc 18.7 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 * \file Use external cblas library call.
 */
#include <tvm/runtime/registry.h>
24
#include <tvm/runtime/data_type.h>
25
#include <dmlc/logging.h>
26 27
#include "../cblas/gemm_common.h"
#include "cublas_utils.h"
28 29 30 31 32 33 34


namespace tvm {
namespace contrib {

using namespace runtime;

35 36
inline cublasOperation_t BooleanToTranspose(bool item) {
  return item ? CUBLAS_OP_T : CUBLAS_OP_N;
37
}
38

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
inline void TryEnableTensorCore(cublasHandle_t hdl) {
  // TensorCores are only supported in cublas 9.0 or higher
  int version;
  CHECK_CUBLAS_ERROR(cublasGetVersion(hdl, &version));
  if (version >= 9000)
    CHECK_CUBLAS_ERROR(cublasSetMathMode(hdl, CUBLAS_TENSOR_OP_MATH));
}

struct CublasHgemmOp {
  typedef half TDatatype;
  cublasHandle_t handle;
  explicit CublasHgemmOp(cublasHandle_t hdl)
      : handle(hdl) {}

  void operator()(bool ta, bool tb,
                  int M, int N, int K,
                  half alpha, half* A, int lda,
                  half* B, int ldb,
                  half beta, half* C, int ldc) {
    CHECK_CUBLAS_ERROR(cublasHgemm(handle,
                                   BooleanToTranspose(ta),
                                   BooleanToTranspose(tb),
                                   M, N, K,
                                   &alpha, A, lda,
                                   B, ldb,
                                   &beta, C, ldc));
  }
};

68 69 70 71
struct CublasSgemmOp {
  typedef float TDatatype;
  cublasHandle_t handle;
  explicit CublasSgemmOp(cublasHandle_t hdl)
72
    : handle(hdl) {}
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92

  void operator()(bool ta, bool tb,
                  int M, int N, int K,
                  float alpha, float* A, int lda,
                  float* B, int ldb,
                  float beta, float* C, int ldc) {
    CHECK_CUBLAS_ERROR(cublasSgemm(handle,
                                   BooleanToTranspose(ta),
                                   BooleanToTranspose(tb),
                                   M, N, K,
                                   &alpha, A, lda,
                                   B, ldb,
                                   &beta, C, ldc));
  }
};

struct CublasDgemmOp {
  typedef double TDatatype;
  cublasHandle_t handle;
  explicit CublasDgemmOp(cublasHandle_t hdl)
93
    : handle(hdl) {}
94 95 96 97 98 99 100 101 102 103 104 105 106 107
  void operator()(bool ta, bool tb,
                  int M, int N, int K,
                  double alpha, double* A, int lda,
                  double* B, int ldb,
                  double beta, double* C, int ldc) {
    CHECK_CUBLAS_ERROR(cublasDgemm(handle,
                                   BooleanToTranspose(ta),
                                   BooleanToTranspose(tb),
                                   M, N, K,
                                   &alpha, A, lda,
                                   B, ldb,
                                   &beta, C, ldc));
  }
};
108

109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
struct CublasHgemmBatchOp {
  typedef half TDatatype;
  cublasHandle_t handle;
  explicit CublasHgemmBatchOp(cublasHandle_t hdl)
      : handle(hdl) {}
  void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, half alpha, half* A,
                  int a_stride, int lda, half* B, int b_stride, int ldb, half beta, half* C,
                  int c_stride, int ldc) {
    CHECK_CUBLAS_ERROR(cublasHgemmStridedBatched(handle,
                                                 BooleanToTranspose(ta),
                                                 BooleanToTranspose(tb),
                                                 M, N, K,
                                                 &alpha,
                                                 A, lda, a_stride,
                                                 B, ldb, b_stride,
                                                 &beta,
                                                 C, ldc, c_stride,
                                                 batch_size));
  }
};

130 131 132 133
struct CublasSgemmBatchOp {
  typedef float TDatatype;
  cublasHandle_t handle;
  explicit CublasSgemmBatchOp(cublasHandle_t hdl)
134
    : handle(hdl) {}
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
  void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
                  int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
                  int c_stride, int ldc) {
    CHECK_CUBLAS_ERROR(cublasSgemmStridedBatched(handle,
                                                 BooleanToTranspose(ta),
                                                 BooleanToTranspose(tb),
                                                 M, N, K,
                                                 &alpha,
                                                 A, lda, a_stride,
                                                 B, ldb, b_stride,
                                                 &beta,
                                                 C, ldc, c_stride,
                                                 batch_size));
  }
};

struct CublasDgemmBatchOp {
  typedef double TDatatype;
  cublasHandle_t handle;
  explicit CublasDgemmBatchOp(cublasHandle_t hdl)
155
    : handle(hdl) {}
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
  void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
                  int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
                  int c_stride, int ldc) {
    CHECK_CUBLAS_ERROR(cublasDgemmStridedBatched(handle,
                                                 BooleanToTranspose(ta),
                                                 BooleanToTranspose(tb),
                                                 M, N, K,
                                                 &alpha,
                                                 A, lda, a_stride,
                                                 B, ldb, b_stride,
                                                 &beta,
                                                 C, ldc, c_stride,
                                                 batch_size));
  }
};

172 173 174 175 176 177 178 179 180 181 182 183
// Check cublas supported mix-precision computation type and return computeType
bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_support = true) {
  if (int_support && TypeMatch(out_dtype, kDLInt, 32)) {
    return TypeMatch(in_dtype, kDLInt, 8);
  } else if (TypeMatch(out_dtype, kDLFloat, 32)) {
    return TypeMatch(in_dtype, kDLInt, 8) ||
           TypeMatch(in_dtype, kDLFloat, 16);
  } else {
    return false;
  }
}

184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
int roundoff(int v, int d) {
  return (v + d - 1) / d * d;
}

#if CUDART_VERSION >= 10010
inline void CallLtIgemm(TVMArgs args, TVMRetValue *ret, cublasLtHandle_t hdl) {
  DLTensor *A = args[0];
  DLTensor *B = args[1];
  DLTensor *C = args[2];
  bool transa = args[3];
  bool transb = args[4];
  // Reversed strides indicates an in-place transpose operation.
  transa = IsInPlaceTransposed(A) ? !transa : transa;
  transb = IsInPlaceTransposed(B) ? !transb : transb;
  int M = ColumnCount(B, transb);
  int N = RowCount(A, transa);
  int K = ColumnCount(A, transa);
  int N_out = ColumnCount(C, false);
  int m = M;
  int n = m;
  int k = m;
  int lda = M * K / (roundoff(K, 32) / 32);
  int ldb = K * N / (roundoff(K, 32) / 32);
  int ldc = M * N_out / (roundoff(N_out, 32) / 32);
  CHECK_EQ(A->ndim, 2);
  CHECK_EQ(B->ndim, 2);
  CHECK_EQ(C->ndim, 2);

  CHECK_EQ(ElementStride(A), 1);
  CHECK_EQ(ElementStride(B), 1);
  CHECK_EQ(ElementStride(C), 1);

  CHECK(TypeEqual(A->dtype, B->dtype));
  CHECK(TypeMatch(A->dtype, kDLInt, 8));
  CHECK(TypeMatch(C->dtype, kDLInt, 32));

  CHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type";
  int32_t alpha = args.size() > 5 ? args[5] : 1;
  int32_t beta = args.size() > 6 ? args[6] : 0;
  cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
  auto A_data = reinterpret_cast<void*>(static_cast<char*>(A->data) + A->byte_offset);
  auto B_data = reinterpret_cast<void*>(static_cast<char*>(B->data) + B->byte_offset);
  auto C_data = reinterpret_cast<void*>(static_cast<char*>(C->data) + C->byte_offset);

  cublasOperation_t opTranspose = CUBLAS_OP_T;
  cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
  cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C;
  cublasLtMatmulDesc_t operationDesc = nullptr;
  CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUDA_R_32I));
  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(opTranspose)));
  cublasOperation_t opTransA = BooleanToTranspose(transa);
  cublasOperation_t opTransB = BooleanToTranspose(transb);
  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTransA, sizeof(opTransA)));
  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
          operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTransB, sizeof(opTransB)));
  // Create descriptors for the original matrices
  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(
          &Adesc, CUDA_R_8I, opTransA == CUBLAS_OP_N ? m : k ,
          opTransA == CUBLAS_OP_N ? k : m, lda));
  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(
          &Bdesc, CUDA_R_8I, opTransB == CUBLAS_OP_N ? k : n ,
          opTransB == CUBLAS_OP_N ? n : k, ldb));
  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc));

  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
          Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)));
  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
          Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL4_4R2_8C, sizeof(order_COL4_4R2_8C)));
  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
          Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)));

  CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl,
                                    operationDesc,
                                    &alpha,
                                    B_data,
                                    Adesc,
                                    A_data,
                                    Bdesc,
                                    &beta,
                                    C_data,
                                    Cdesc,
                                    C_data,
                                    Cdesc,
                                    NULL,
                                    NULL,
                                    0,
                                    0));
}
#endif

276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) {
  DLTensor *A = args[0];
  DLTensor *B = args[1];
  DLTensor *C = args[2];
  bool transa = args[3];
  bool transb = args[4];
  CHECK_EQ(A->ndim, 2);
  CHECK_EQ(B->ndim, 2);
  CHECK_EQ(C->ndim, 2);

  CHECK_EQ(ElementStride(A), 1);
  CHECK_EQ(ElementStride(B), 1);
  CHECK_EQ(ElementStride(C), 1);

  CHECK(TypeEqual(A->dtype, B->dtype));

  // C can never be transposed.
  CHECK(!IsInPlaceTransposed(C));

  // Reversed strides indicates an in-place transpose operation.
  transa = IsInPlaceTransposed(A) ? !transa : transa;
  transb = IsInPlaceTransposed(B) ? !transb : transb;

  CHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type";
  CHECK(!TypeMatch(A->dtype, kDLInt, 8) ||
      ColumnStride(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm";
  CHECK(!TypeMatch(B->dtype, kDLInt, 8) ||
      ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm";
  double alpha = args.size() > 5 ? args[5] : 1.0;
  double beta = args.size() > 6 ? args[6] : 0.0;

  cudaDataType_t cuda_in_type = GetCudaDataType(A->dtype);
  cudaDataType_t cuda_out_type = GetCudaDataType(C->dtype);
  cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
  void *alpha_ptr = nullptr, *beta_ptr = nullptr;
  auto alpha_int = static_cast<int32_t>(alpha);
  auto beta_int = static_cast<int32_t>(beta);
  auto alpha_float = static_cast<float>(alpha);
  auto beta_float = static_cast<float>(beta);
  if (C->dtype.code == kDLInt) {
    alpha_ptr = &alpha_int;
    beta_ptr = &beta_int;
  } else if (C->dtype.code == kDLFloat) {
    alpha_ptr = &alpha_float;
    beta_ptr = &beta_float;
  }

  auto A_data = reinterpret_cast<void *>(static_cast<char *>(A->data) + A->byte_offset);
  auto B_data = reinterpret_cast<void *>(static_cast<char *>(B->data) + B->byte_offset);
  auto C_data = reinterpret_cast<void *>(static_cast<char *>(C->data) + C->byte_offset);

  CHECK_CUBLAS_ERROR(cublasGemmEx(hdl,
                                 BooleanToTranspose(transb),
                                 BooleanToTranspose(transa),
                                 ColumnCount(B, transb),
                                 RowCount(A, transa),
                                 ColumnCount(A, transa),
                                 alpha_ptr,
                                 B_data, cuda_in_type, ColumnStride(B),
                                 A_data, cuda_in_type, ColumnStride(A),
                                 beta_ptr,
                                 C_data, cuda_out_type, ColumnStride(C),
                                 cuda_out_type, algo));
}

inline void CallBatchGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) {
  DLTensor *A = args[0];
  DLTensor *B = args[1];
  DLTensor *C = args[2];
  bool transa = args[3];
  bool transb = args[4];
  CHECK_EQ(A->ndim, 3);
  CHECK_EQ(B->ndim, 3);
  CHECK_EQ(C->ndim, 3);
  int batch_size = BatchCount3D(A);
  CHECK_EQ(BatchCount3D(B), batch_size);
  CHECK_EQ(BatchCount3D(C), batch_size);
  CHECK_EQ(ElementStride(A), 1);
  CHECK_EQ(ElementStride(B), 1);
  CHECK_EQ(ElementStride(C), 1);

  CHECK(TypeEqual(A->dtype, B->dtype));

  // C can never be transposed.
  CHECK(!IsInPlaceTransposed(C));

  // Reversed strides indicates an in-place transpose operation.
  transa = IsInPlaceTransposed(A) ? !transa : transa;
  transb = IsInPlaceTransposed(B) ? !transb : transb;

  CHECK(CheckMixPrecisionType(A->dtype, C->dtype, false)) << "Unsupported data type";
  CHECK(!TypeMatch(A->dtype, kDLInt, 8) ||
      ColumnStride(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm";
  CHECK(!TypeMatch(B->dtype, kDLInt, 8) ||
      ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm";
  double alpha = args.size() > 5 ? args[5] : 1.0;
  double beta = args.size() > 6 ? args[6] : 0.0;

  const int A_size = A->shape[1] * A->shape[2];
  const int B_size = B->shape[1] * B->shape[2];
  const int C_size = C->shape[1] * C->shape[2];

  cudaDataType_t cuda_in_type = GetCudaDataType(A->dtype);
  cudaDataType_t cuda_out_type = GetCudaDataType(C->dtype);
  cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
  void *alpha_ptr = nullptr, *beta_ptr = nullptr;
  auto alpha_int = static_cast<int32_t>(alpha);
  auto beta_int = static_cast<int32_t>(beta);
  auto alpha_float = static_cast<float>(alpha);
  auto beta_float = static_cast<float>(beta);
  if (C->dtype.code == kDLInt) {
    alpha_ptr = &alpha_int;
    beta_ptr = &beta_int;
  } else if (C->dtype.code == kDLFloat) {
    alpha_ptr = &alpha_float;
    beta_ptr = &beta_float;
  }

  auto A_data = reinterpret_cast<void *>(static_cast<char *>(A->data) + A->byte_offset);
  auto B_data = reinterpret_cast<void *>(static_cast<char *>(B->data) + B->byte_offset);
  auto C_data = reinterpret_cast<void *>(static_cast<char *>(C->data) + C->byte_offset);
  CHECK_CUBLAS_ERROR(cublasGemmStridedBatchedEx(hdl,
                                  BooleanToTranspose(transb),
                                  BooleanToTranspose(transa),
                                  ColumnCount3D(B, transb),
                                  RowCount3D(A, transa),
                                  ColumnCount3D(A, transa),
                                  alpha_ptr,
                                  B_data, cuda_in_type, ColumnStride3D(B), B_size,
                                  A_data, cuda_in_type, ColumnStride3D(A), A_size,
                                  beta_ptr,
                                  C_data, cuda_out_type, ColumnStride3D(C), C_size,
                                  batch_size, cuda_out_type, algo));
}

411 412 413 414
// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    DLTensor* A = args[0];
415
    DLTensor* C = args[2];
416

417
    CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
418

419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
    TryEnableTensorCore(entry_ptr->handle);

    if (TypeEqual(A->dtype, C->dtype)) {
      CHECK(TypeMatch(A->dtype, kDLFloat, 16) ||
          TypeMatch(A->dtype, kDLFloat, 32) ||
          TypeMatch(A->dtype, kDLFloat, 64));

      if (TypeMatch(A->dtype, kDLFloat, 16))
        CallGemm(args, ret, CublasHgemmOp(entry_ptr->handle));
      else if (TypeMatch(A->dtype, kDLFloat, 32))
        CallGemm(args, ret, CublasSgemmOp(entry_ptr->handle));
      else
        CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle));
    } else {
      CallGemmEx(args, ret, entry_ptr->handle);
    }
435
});
436

437 438
#if CUDART_VERSION >= 10010
TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul")
439 440
.set_body([](TVMArgs args, TVMRetValue* ret) {
    DLTensor* A = args[0];
441

442
    CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
443

444 445 446 447 448 449 450 451 452 453 454 455 456 457
    TryEnableTensorCore(entry_ptr->handle);

    CHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n";
    cublasLtHandle_t ltHandle;
    CHECK_CUBLAS_ERROR(cublasLtCreate(&ltHandle));
    CallLtIgemm(args, ret, ltHandle);
    CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle));
});
#endif  // CUDART_VERSION >= 10010

TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul")
.set_body([](TVMArgs args, TVMRetValue* ret) {
    DLTensor* A = args[0];
    DLTensor* C = args[2];
458 459 460

    CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();

461 462 463 464 465 466 467 468 469 470 471 472 473 474 475
    TryEnableTensorCore(entry_ptr->handle);
    if (TypeEqual(A->dtype, C->dtype)) {
      CHECK(TypeMatch(A->dtype, kDLFloat, 16) ||
          TypeMatch(A->dtype, kDLFloat, 32) ||
          TypeMatch(A->dtype, kDLFloat, 64));

      if (TypeMatch(A->dtype, kDLFloat, 16))
        CallBatchGemm(args, ret, CublasHgemmBatchOp(entry_ptr->handle));
      else if (TypeMatch(A->dtype, kDLFloat, 32))
        CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle));
      else
        CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle));
    } else {
      CallBatchGemmEx(args, ret, entry_ptr->handle);
    }
476 477
});

478 479
}  // namespace contrib
}  // namespace tvm