Commit 72821b20 by Haichen Shen Committed by Tianqi Chen

[Contrib] Add MKL DNN option (#4323)

* [Contrib] Add MKL DNN

* update

* update
parent 2573b3b8
......@@ -53,6 +53,7 @@ tvm_option(PICOJSON_PATH "Path to PicoJSON" "3rdparty/picojson")
# Contrib library options
tvm_option(USE_BLAS "The blas library to be linked" none)
tvm_option(USE_MKL_PATH "MKL root path when use MKL blas" none)
tvm_option(USE_MKLDNN "Build with MKLDNN" OFF)
tvm_option(USE_CUDNN "Build with cuDNN" OFF)
tvm_option(USE_CUBLAS "Build with cuBLAS" OFF)
tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF)
......
......@@ -115,6 +115,9 @@ set(USE_BLAS none)
# set(USE_MKL_PATH <path to venv or site-packages directory>) if using `pip install mkl`
set(USE_MKL_PATH none)
# Whether use MKLDNN library
set(USE_MKLDNN OFF)
# Whether use OpenMP thread pool, choices: gnu, intel
# Note: "gnu" uses gomp library, "intel" uses iomp5 library
set(USE_OPENMP none)
......
......@@ -55,3 +55,10 @@ elseif(USE_BLAS STREQUAL "none")
else()
message(FATAL_ERROR "Invalid option: USE_BLAS=" ${USE_BLAS})
endif()
if(USE_MKLDNN STREQUAL "ON")
find_library(BLAS_LIBRARY_MKLDNN dnnl)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY_MKLDNN})
add_definitions(-DUSE_DNNL=1)
message(STATUS "Use MKLDNN library " ${BLAS_LIBRARY_MKLDNN})
endif()
......@@ -31,6 +31,9 @@ extern "C" {
#else
#include <cblas.h>
#endif
#if USE_DNNL == 1
#include <dnnl.h>
#endif
}
namespace tvm {
......@@ -40,12 +43,19 @@ using namespace runtime;
inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; }
inline char BooleanToTransposeChar(bool trans) { return trans ? 'T' : 'N'; }
struct CblasSgemmOp {
typedef float TDatatype;
void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B,
int ldb, float beta, float* C, int ldc) {
#if USE_DNNL == 1
dnnl_sgemm(BooleanToTransposeChar(tb), BooleanToTransposeChar(ta), N, M, K, alpha, B,
ldb, A, lda, beta, C, ldc);
#else
cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
#endif
}
};
......
......@@ -32,7 +32,7 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
if "cblas" in target.libs:
C = cblas.matmul(data, weight, False, True)
if bias is not None:
C = tvm.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype),
C = tvm.compute(C.shape, lambda i, j: C[i, j] + bias[j],
tag=tag.BROADCAST)
return C
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment