cublas.cc 2.34 KB
Newer Older
1
/*!
2
 *  Copyright (c) 2018 by Contributors
3 4 5 6 7
 * \file Use external cblas library call.
 */
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
8 9
#include "../cblas/gemm_common.h"
#include "cublas_utils.h"
10 11 12 13 14 15 16


namespace tvm {
namespace contrib {

using namespace runtime;

17 18
inline cublasOperation_t BooleanToTranspose(bool item) {
  return item ? CUBLAS_OP_T : CUBLAS_OP_N;
19
}
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63

struct CublasSgemmOp {
  typedef float TDatatype;
  cublasHandle_t handle;
  explicit CublasSgemmOp(cublasHandle_t hdl)
    : handle(hdl)
    {}

  void operator()(bool ta, bool tb,
                  int M, int N, int K,
                  float alpha, float* A, int lda,
                  float* B, int ldb,
                  float beta, float* C, int ldc) {
    CHECK_CUBLAS_ERROR(cublasSgemm(handle,
                                   BooleanToTranspose(ta),
                                   BooleanToTranspose(tb),
                                   M, N, K,
                                   &alpha, A, lda,
                                   B, ldb,
                                   &beta, C, ldc));
  }
};


struct CublasDgemmOp {
  typedef double TDatatype;
  cublasHandle_t handle;
  explicit CublasDgemmOp(cublasHandle_t hdl)
    : handle(hdl)
    {}
  void operator()(bool ta, bool tb,
                  int M, int N, int K,
                  double alpha, double* A, int lda,
                  double* B, int ldb,
                  double beta, double* C, int ldc) {
    CHECK_CUBLAS_ERROR(cublasDgemm(handle,
                                   BooleanToTranspose(ta),
                                   BooleanToTranspose(tb),
                                   M, N, K,
                                   &alpha, A, lda,
                                   B, ldb,
                                   &beta, C, ldc));
  }
};
64 65 66 67 68 69

// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    DLTensor* A = args[0];

70 71
    CHECK(TypeMatch(A->dtype, kDLFloat, 32) ||
          TypeMatch(A->dtype, kDLFloat, 64));
72

73
    CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
74

75 76 77 78
    if (TypeMatch(A->dtype, kDLFloat, 32))
      CallGemm(args, ret, CublasSgemmOp(entry_ptr->handle));
    else
      CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle));
79 80 81
});
}  // namespace contrib
}  // namespace tvm