/*!
 *  Copyright (c) 2017 by Contributors
 * \file Use external cblas library call.
 */
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
#include "gemm_common.h"


extern "C" {
#if USE_MKL_BLAS == 1
#include <mkl_cblas.h>
#else
#include <cblas.h>
#endif
}

namespace tvm {
namespace contrib {

using namespace runtime;

inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) {
  return trans ? CblasTrans : CblasNoTrans;
}

struct CblasSgemmOp {
  typedef float TDatatype;
  void operator()(bool ta, bool tb,
                  int M, int N, int K,
                  float alpha, float* A, int lda,
                  float* B, int ldb,
                  float beta, float* C, int ldc) {
    cblas_sgemm(CblasColMajor,
                BooleanToTranspose(ta),
                BooleanToTranspose(tb),
                M, N, K,
                alpha, A, lda,
                B, ldb,
                beta, C, ldc);
  }
};

struct CblasDgemmOp {
  typedef double TDatatype;
  void operator()(bool ta, bool tb,
                  int M, int N, int K,
                  double alpha, double* A, int lda,
                  double* B, int ldb,
                  double beta, double* C, int ldc) {
    cblas_dgemm(CblasColMajor,
                BooleanToTranspose(ta),
                BooleanToTranspose(tb),
                M, N, K,
                alpha, A, lda,
                B, ldb,
                beta, C, ldc);
  }
};


// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    DLTensor* A = args[0];
    CHECK(TypeMatch(A->dtype, kDLFloat, 32) ||
          TypeMatch(A->dtype, kDLFloat, 64));

    if (TypeMatch(A->dtype, kDLFloat, 32))
      CallGemm(args, ret, CblasSgemmOp());
    else
      CallGemm(args, ret, CblasDgemmOp());
  });
}  // namespace contrib
}  // namespace tvm