Commit 3d1d17e3 by hlu1 Committed by Leyuan Wang

[Contrib] cblas batch_matmul (#3210)

parent 21935dcb
......@@ -27,7 +27,11 @@ elseif(USE_BLAS STREQUAL "mkl")
if(NOT IS_DIRECTORY ${USE_MKL_PATH})
set(USE_MKL_PATH /opt/intel/mkl)
endif()
find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
if(APPLE)
find_library(BLAS_LIBRARY NAMES mklml HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
elseif(UNIX)
find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
endif()
include_directories(${USE_MKL_PATH}/include)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY})
list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC})
......
......@@ -17,10 +17,10 @@
"""External function interface to BLAS libraries."""
from __future__ import absolute_import as _abs
from .. import api as _api
from .. import intrin as _intrin
from .. import api as _api, intrin as _intrin
def matmul(lhs, rhs, transa=False, transb=False):
def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS
This function serves as an example on how to call external libraries.
......@@ -44,7 +44,50 @@ def matmul(lhs, rhs, transa=False, transb=False):
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
return _api.extern(
(n, m), [lhs, rhs],
(n, m),
[lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb
),
name="C",
**kwargs
)
def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs):
"""Create an extern op that compute batched matrix mult of A and rhs with CBLAS
This function serves as an example on how to call external libraries.
Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs
Returns
-------
C : Tensor
The result tensor.
"""
b = lhs.shape[0]
n = lhs.shape[2] if transa else lhs.shape[1]
m = rhs.shape[1] if transb else rhs.shape[2]
return _api.extern(
(b, n, m),
[lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cblas.matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")
"tvm.contrib.cblas.batch_matmul"
if not iterative
else "tvm.contrib.cblas.batch_matmul_iterative",
ins[0],
ins[1],
outs[0],
transa,
transb,
),
name="C",
**kwargs
)
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -21,12 +21,11 @@
* Copyright (c) 2017 by Contributors
* \file Use external cblas library call.
*/
#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
#include "gemm_common.h"
extern "C" {
#if USE_MKL_BLAS == 1
#include <mkl_cblas.h>
......@@ -40,56 +39,148 @@ namespace contrib {
using namespace runtime;
inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) {
return trans ? CblasTrans : CblasNoTrans;
}
inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; }
struct CblasSgemmOp {
typedef float TDatatype;
void operator()(bool ta, bool tb,
int M, int N, int K,
float alpha, float* A, int lda,
float* B, int ldb,
float beta, float* C, int ldc) {
cblas_sgemm(CblasColMajor,
BooleanToTranspose(ta),
BooleanToTranspose(tb),
M, N, K,
alpha, A, lda,
B, ldb,
beta, C, ldc);
void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B,
int ldb, float beta, float* C, int ldc) {
cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
};
struct CblasDgemmOp {
typedef double TDatatype;
void operator()(bool ta, bool tb,
int M, int N, int K,
double alpha, double* A, int lda,
double* B, int ldb,
double beta, double* C, int ldc) {
cblas_dgemm(CblasColMajor,
BooleanToTranspose(ta),
BooleanToTranspose(tb),
M, N, K,
alpha, A, lda,
B, ldb,
beta, C, ldc);
void operator()(bool ta, bool tb, int M, int N, int K, double alpha, double* A, int lda,
double* B, int ldb, double beta, double* C, int ldc) {
cblas_dgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc);
}
};
struct CblasSgemmBatchOp {
typedef float TDatatype;
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
int c_stride, int ldc) {
CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
#if USE_MKL_BLAS == 1
std::vector<const float*> A_array(batch_size);
std::vector<const float*> B_array(batch_size);
std::vector<float*> C_array(batch_size);
for (int i = 0; i < batch_size; ++i) {
A_array[i] = A + i * a_stride;
B_array[i] = B + i * b_stride;
C_array[i] = C + i * c_stride;
}
cblas_sgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda,
B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size);
#else
for (int i = 0; i < batch_size; ++i) {
cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
A += a_stride;
B += b_stride;
C += c_stride;
}
#endif
}
};
struct CblasSgemmBatchIterativeOp {
typedef float TDatatype;
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
int c_stride, int ldc) {
CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
for (int i = 0; i < batch_size; ++i) {
cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
A += a_stride;
B += b_stride;
C += c_stride;
}
}
};
struct CblasDgemmBatchOp {
typedef double TDatatype;
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
int c_stride, int ldc) {
CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
#if USE_MKL_BLAS == 1
std::vector<const double*> A_array(batch_size);
std::vector<const double*> B_array(batch_size);
std::vector<double*> C_array(batch_size);
for (int i = 0; i < batch_size; ++i) {
A_array[i] = A + i * a_stride;
B_array[i] = B + i * b_stride;
C_array[i] = C + i * c_stride;
}
cblas_dgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda,
B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size);
#else
for (int i = 0; i < batch_size; ++i) {
cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
A += a_stride;
B += b_stride;
C += c_stride;
}
#endif
}
};
struct CblasDgemmBatchIterativeOp {
typedef double TDatatype;
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
int c_stride, int ldc) {
CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
for (int i = 0; i < batch_size; ++i) {
cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
A += a_stride;
B += b_stride;
C += c_stride;
}
}
};
// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) ||
TypeMatch(A->dtype, kDLFloat, 64));
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
if (TypeMatch(A->dtype, kDLFloat, 32))
CallGemm(args, ret, CblasSgemmOp());
else
CallGemm(args, ret, CblasDgemmOp());
});
if (TypeMatch(A->dtype, kDLFloat, 32))
CallGemm(args, ret, CblasSgemmOp());
else
CallGemm(args, ret, CblasDgemmOp());
});
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
if (TypeMatch(A->dtype, kDLFloat, 32)) {
CallBatchGemm(args, ret, CblasSgemmBatchOp());
} else {
CallBatchGemm(args, ret, CblasDgemmBatchOp());
}
});
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
if (TypeMatch(A->dtype, kDLFloat, 32)) {
CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp());
} else {
CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp());
}
});
} // namespace contrib
} // namespace tvm
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -22,16 +22,17 @@
* \file tvm/contrib/gemm.h
* \brief Shared implementation of gemm
*/
#ifndef TVM_CONTRIB_CBLAS_GEMM_COMMON_H_
#define TVM_CONTRIB_CBLAS_GEMM_COMMON_H_
#pragma once
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <algorithm>
namespace tvm {
namespace contrib {
using namespace runtime;
inline int ColumnStride(DLTensor* tensor) {
inline int ColumnStride(DLTensor *tensor) {
// If the tensor itself is transposed then it will have strides
// backward from what we expect. Regardless, the max of the strides
// (the other stride is 1) is the column stride.
......@@ -42,8 +43,7 @@ inline int ColumnStride(DLTensor* tensor) {
}
}
inline int ElementStride(DLTensor* tensor) {
inline int ElementStride(DLTensor *tensor) {
if (tensor->strides) {
return std::min(tensor->strides[0], tensor->strides[1]);
} else {
......@@ -51,29 +51,26 @@ inline int ElementStride(DLTensor* tensor) {
}
}
// Reversed strides indicates an in-place transpose operation.
inline bool IsInPlaceTransposed(DLTensor* tensor) {
inline bool IsInPlaceTransposed(DLTensor *tensor) {
return tensor->strides && (tensor->strides[1] > tensor->strides[0]);
}
inline int RowCount(DLTensor* tensor, bool trans) {
inline int RowCount(DLTensor *tensor, bool trans) {
return tensor->shape[trans ? 1 : 0];
}
inline int ColumnCount(DLTensor* tensor, bool trans) {
inline int ColumnCount(DLTensor *tensor, bool trans) {
return tensor->shape[trans ? 0 : 1];
}
// Call a column major blas. Note that data is stored in tvm as row
// major, so this we switch the arguments.
template<typename TGemmOp>
template <typename TGemmOp>
inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) {
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
DLTensor *A = args[0];
DLTensor *B = args[1];
DLTensor *C = args[2];
bool transa = args[3];
bool transb = args[4];
int bit_depth = sizeof(typename TGemmOp::TDatatype) * 8;
......@@ -96,25 +93,88 @@ inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) {
CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth));
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
op(transb,
transa,
ColumnCount(B, transb),
RowCount(A, transa),
ColumnCount(A, transa),
static_cast<float>(alpha),
reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(B->data)
+ B->byte_offset),
op(transb, transa, ColumnCount(B, transb), RowCount(A, transa),
ColumnCount(A, transa), static_cast<float>(alpha),
reinterpret_cast<typename TGemmOp::TDatatype *>(
static_cast<char *>(B->data) + B->byte_offset),
ColumnStride(B),
reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(A->data)
+ A->byte_offset),
ColumnStride(A),
static_cast<float>(beta),
reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(C->data)
+ C->byte_offset),
reinterpret_cast<typename TGemmOp::TDatatype *>(
static_cast<char *>(A->data) + A->byte_offset),
ColumnStride(A), static_cast<float>(beta),
reinterpret_cast<typename TGemmOp::TDatatype *>(
static_cast<char *>(C->data) + C->byte_offset),
ColumnStride(C));
}
inline int ColumnStride3D(DLTensor *tensor) {
// If the tensor itself is transposed then it will have strides
// backward from what we expect. Regardless, the max of the strides
// (the other stride is 1) is the column stride.
if (tensor->strides) {
return std::max(tensor->strides[1], tensor->strides[2]);
} else {
return tensor->shape[2];
}
}
inline int ElementStride3D(DLTensor *tensor) {
if (tensor->strides) {
return std::min(tensor->strides[1], tensor->strides[2]);
} else {
return 1;
}
}
// Reversed strides indicates an in-place transpose operation.
inline bool IsInPlaceTransposed3D(DLTensor *tensor) {
return tensor->strides && (tensor->strides[2] > tensor->strides[1]);
}
inline int BatchCount3D(DLTensor *tensor) { return tensor->shape[0]; }
inline int RowCount3D(DLTensor *tensor, bool trans) {
return tensor->shape[trans ? 2 : 1];
}
inline int ColumnCount3D(DLTensor *tensor, bool trans) {
return tensor->shape[trans ? 1 : 2];
}
template <typename TBatchGemmOp>
inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) {
using DType = typename TBatchGemmOp::TDatatype;
DLTensor *A = args[0];
DLTensor *B = args[1];
DLTensor *C = args[2];
bool transa = args[3];
bool transb = args[4];
int bit_depth = sizeof(DType) * 8;
CHECK_EQ(A->ndim, 3);
CHECK_EQ(B->ndim, 3);
CHECK_EQ(C->ndim, 3);
int batch_size = BatchCount3D(A);
CHECK_EQ(BatchCount3D(B), batch_size);
CHECK_EQ(BatchCount3D(C), batch_size);
CHECK_EQ(ElementStride(A), 1);
CHECK_EQ(ElementStride(B), 1);
CHECK_EQ(ElementStride(C), 1);
// C can never be transposed.
CHECK(!IsInPlaceTransposed3D(C));
// Reversed strides indicates an in-place transpose operation.
transa = IsInPlaceTransposed3D(A) ? !transa : transa;
transb = IsInPlaceTransposed3D(B) ? !transb : transb;
CHECK(TypeMatch(B->dtype, kDLFloat, bit_depth));
CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth));
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
const int A_size = A->shape[1] * A->shape[2];
const int B_size = B->shape[1] * B->shape[2];
const int C_size = C->shape[1] * C->shape[2];
DType *A_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
static_cast<char *>(A->data) + A->byte_offset);
DType *B_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
static_cast<char *>(B->data) + B->byte_offset);
DType *C_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
static_cast<char *>(C->data) + C->byte_offset);
op(batch_size, transb, transa, ColumnCount3D(B, transb),
RowCount3D(A, transa), ColumnCount3D(A, transa), static_cast<float>(alpha),
B_data, B_size, ColumnStride3D(B), A_data, A_size, ColumnStride3D(A),
static_cast<float>(beta), C_data, C_size, ColumnStride3D(C));
}
} // namespace contrib
} // namespace tvm
#endif // TVM_CONTRIB_CBLAS_GEMM_COMMON_H_
......@@ -16,19 +16,26 @@
# under the License.
import tvm
import numpy as np
import topi.testing
from tvm.contrib import cblas
def test_matmul_add():
n = 1024
l = 128
m = 235
bias = tvm.var('bias', dtype=tvm.float32)
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((l, m), name='B')
C = cblas.matmul(A, B)
def verify_matmul_add(m, l, n, transa=False, transb=False, dtype=tvm.float32):
bias = tvm.var('bias', dtype=dtype)
ashape = (l, n) if transa else (n, l)
bshape = (m, l) if transb else (l, m)
A = tvm.placeholder(ashape, name='A', dtype=dtype)
B = tvm.placeholder(bshape, name='B', dtype=dtype)
C = cblas.matmul(A, B, transa, transb)
D = tvm.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
s = tvm.create_schedule(D.op)
def get_numpy(a, b, bb, transa, transb):
if transa:
a = a.transpose()
if transb:
b = b.transpose()
return np.dot(a, b) + bb
def verify(target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
......@@ -38,15 +45,69 @@ def test_matmul_add():
return
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, D, bias], target)
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
bb = 10.0
f(a, b, d, bb)
tvm.testing.assert_allclose(
d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb, rtol=1e-5)
d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), bb, transa, transb), rtol=1e-5)
verify()
def test_matmul_add():
verify_matmul_add(235, 128, 1024)
verify_matmul_add(235, 128, 1024, True, False)
verify_matmul_add(235, 128, 1024, False, True)
verify_matmul_add(235, 128, 1024, True, True)
verify_matmul_add(1, 16, 4)
verify_matmul_add(1, 16, 3, True, False)
verify_matmul_add(1, 16, 3, False, False)
verify_matmul_add(1, 16, 3, True, True)
def verify_batch_matmul(batch, m, l, n, transa=False, transb=False, iterative=False, dtype=tvm.float32):
ashape = (batch, l, n) if transa else (batch, n, l)
bshape = (batch, m, l) if transb else (batch, l, m)
A = tvm.placeholder(ashape, name='A', dtype=dtype)
B = tvm.placeholder(bshape, name='B', dtype=dtype)
C = cblas.batch_matmul(A, B, transa, transb)
D = tvm.compute(C.shape, lambda k, i, j: C[k, i,j], name="D")
s = tvm.create_schedule(D.op)
def get_numpy(a, b, transa, transb):
if transa:
a = a.transpose(0, 2, 1)
if not transb:
b = b.transpose(0, 2, 1)
return topi.testing.batch_matmul(a, b)
def verify(target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cblas.matmul", True):
print("skip because extern function is not available")
return
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, D], target)
a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((batch, n, m), dtype=D.dtype), ctx)
f(a, b, d)
tvm.testing.assert_allclose(
d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), transa, transb), rtol=1e-5)
verify()
def test_batch_matmul():
verify_batch_matmul(16, 235, 128, 1024)
verify_batch_matmul(16, 235, 128, 1024, True, False)
verify_batch_matmul(16, 235, 128, 1024, False, True)
verify_batch_matmul(16, 235, 128, 1024, True, True)
verify_batch_matmul(1, 1, 16, 3)
verify_batch_matmul(1, 1, 16, 3, True, False)
verify_batch_matmul(1, 1, 16, 3, False, False)
verify_batch_matmul(1, 1, 16, 3, True, True)
verify_batch_matmul(1, 1, 16, 3, iterative=True)
if __name__ == "__main__":
test_matmul_add()
test_batch_matmul()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment