/*!
 *  Copyright (c) 2017 by Contributors
 * \file Use external cblas library call.
 */
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>

extern "C" {
#include <cublas_v2.h>
}

namespace tvm {
namespace contrib {

using namespace runtime;

#ifndef CHECK_CUBLAS_ERROR
#define CHECK_CUBLAS_ERROR(error) \
if (error != CUBLAS_STATUS_SUCCESS) { \
  fprintf(stderr, "cuBLAS error: "); \
  if (error == CUBLAS_STATUS_NOT_INITIALIZED) fprintf(stderr, "CUBLAS_STATUS_NOT_INITIALIZED"); \
  if (error == CUBLAS_STATUS_ALLOC_FAILED) fprintf(stderr, "CUBLAS_STATUS_ALLOC_FAILED"); \
  if (error == CUBLAS_STATUS_INVALID_VALUE) fprintf(stderr, "CUBLAS_STATUS_INVALID_VALUE"); \
  if (error == CUBLAS_STATUS_ARCH_MISMATCH) fprintf(stderr, "CUBLAS_STATUS_ARCH_MISMATCH"); \
  if (error == CUBLAS_STATUS_MAPPING_ERROR) fprintf(stderr, "CUBLAS_STATUS_MAPPING_ERROR"); \
  if (error == CUBLAS_STATUS_EXECUTION_FAILED) fprintf(stderr, "CUBLAS_STATUS_EXECUTION_FAILED"); \
  if (error == CUBLAS_STATUS_INTERNAL_ERROR) fprintf(stderr, "CUBLAS_STATUS_INTERNAL_ERROR"); \
  if (error == CUBLAS_STATUS_NOT_SUPPORTED) fprintf(stderr, "CUBLAS_STATUS_NOT_SUPPORTED"); \
  if (error == CUBLAS_STATUS_LICENSE_ERROR) fprintf(stderr, "CUBLAS_STATUS_LICENSE_ERROR"); \
  fprintf(stderr, "\n"); \
  exit(EXIT_FAILURE); \
}
#endif

// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    DLTensor* A = args[0];
    DLTensor* B = args[1];
    DLTensor* C = args[2];
    bool transa = args[3];
    bool transb = args[4];
    // call gemm for simple compact code.
    CHECK_EQ(A->ndim, 2);
    CHECK_EQ(B->ndim, 2);
    CHECK_EQ(C->ndim, 2);
    CHECK(C->strides == nullptr);
    CHECK(B->strides == nullptr);
    CHECK(A->strides == nullptr);
    CHECK(TypeMatch(A->dtype, kDLFloat, 32));
    CHECK(TypeMatch(B->dtype, kDLFloat, 32));
    CHECK(TypeMatch(C->dtype, kDLFloat, 32));

    cublasHandle_t handle;
    CHECK_CUBLAS_ERROR(cublasCreate(&handle));
    float alpha = 1.0;
    float beta = 0.0;
    float *A_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset);
    float *B_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset);
    float *C_ptr = reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset);

    CHECK_CUBLAS_ERROR(cublasSgemm(handle,
                                   transb ? CUBLAS_OP_T : CUBLAS_OP_N,
                                   transa ? CUBLAS_OP_T : CUBLAS_OP_N,
                                   transb ? B->shape[0] : B->shape[1],
                                   transa ? A->shape[1] : A->shape[0],
                                   transb ? B->shape[1] : B->shape[0],
                                   &alpha,
                                   A_ptr,
                                   B->shape[1],
                                   B_ptr,
                                   A->shape[1],
                                   &beta,
                                   C_ptr,
                                   C->shape[1]));

    CHECK_CUBLAS_ERROR(cublasDestroy(handle));
});
}  // namespace contrib
}  // namespace tvm