cblas.cc 2.65 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26
/*!
 *  Copyright (c) 2017 by Contributors
 * \file Use external cblas library call.
 */
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
27 28
#include "gemm_common.h"

29 30

extern "C" {
31 32 33
#if USE_MKL_BLAS == 1
#include <mkl_cblas.h>
#else
34
#include <cblas.h>
35
#endif
36 37 38 39 40 41 42
}

namespace tvm {
namespace contrib {

using namespace runtime;

43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) {
  return trans ? CblasTrans : CblasNoTrans;
}

struct CblasSgemmOp {
  typedef float TDatatype;
  void operator()(bool ta, bool tb,
                  int M, int N, int K,
                  float alpha, float* A, int lda,
                  float* B, int ldb,
                  float beta, float* C, int ldc) {
    cblas_sgemm(CblasColMajor,
                BooleanToTranspose(ta),
                BooleanToTranspose(tb),
                M, N, K,
                alpha, A, lda,
                B, ldb,
                beta, C, ldc);
  }
};

struct CblasDgemmOp {
  typedef double TDatatype;
  void operator()(bool ta, bool tb,
                  int M, int N, int K,
                  double alpha, double* A, int lda,
                  double* B, int ldb,
                  double beta, double* C, int ldc) {
    cblas_dgemm(CblasColMajor,
                BooleanToTranspose(ta),
                BooleanToTranspose(tb),
                M, N, K,
                alpha, A, lda,
                B, ldb,
                beta, C, ldc);
  }
};


82 83 84 85
// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    DLTensor* A = args[0];
86 87 88 89 90 91 92
    CHECK(TypeMatch(A->dtype, kDLFloat, 32) ||
          TypeMatch(A->dtype, kDLFloat, 64));

    if (TypeMatch(A->dtype, kDLFloat, 32))
      CallGemm(args, ret, CblasSgemmOp());
    else
      CallGemm(args, ret, CblasDgemmOp());
93 94 95
  });
}  // namespace contrib
}  // namespace tvm