cblas.cc 1.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
/*!
 *  Copyright (c) 2017 by Contributors
 * \file Use external cblas library call.
 */
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>

extern "C" {
#include <cblas.h>
}

namespace tvm {
namespace contrib {

using namespace runtime;

// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    DLTensor* A = args[0];
    DLTensor* B = args[1];
    DLTensor* C = args[2];
    bool transa = args[3];
    bool transb = args[4];
    // call gemm for simple compact code.
    CHECK_EQ(A->ndim, 2);
    CHECK_EQ(B->ndim, 2);
    CHECK_EQ(C->ndim, 2);
    CHECK(C->strides == nullptr);
    CHECK(B->strides == nullptr);
    CHECK(A->strides == nullptr);
    CHECK(TypeMatch(A->dtype, kFloat, 32));
    CHECK(TypeMatch(B->dtype, kFloat, 32));
    CHECK(TypeMatch(C->dtype, kFloat, 32));
    cblas_sgemm(CblasColMajor,
                transb ? CblasTrans : CblasNoTrans,
                transa ? CblasTrans : CblasNoTrans,
                transb ? B->shape[0] : B->shape[1],
                transa ? A->shape[1] : A->shape[0],
                transa ? B->shape[1] : B->shape[0],
                1.0f,
43 44 45 46
                reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset),
                B->shape[1],
                reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset),
                A->shape[1],
47
                0.0f,
48 49
                reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset),
                C->shape[1]);
50 51 52
  });
}  // namespace contrib
}  // namespace tvm