Commit 3d1d17e3 by hlu1 Committed by Leyuan Wang

[Contrib] cblas batch_matmul (#3210)

parent 21935dcb
...@@ -27,7 +27,11 @@ elseif(USE_BLAS STREQUAL "mkl") ...@@ -27,7 +27,11 @@ elseif(USE_BLAS STREQUAL "mkl")
if(NOT IS_DIRECTORY ${USE_MKL_PATH}) if(NOT IS_DIRECTORY ${USE_MKL_PATH})
set(USE_MKL_PATH /opt/intel/mkl) set(USE_MKL_PATH /opt/intel/mkl)
endif() endif()
find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64) if(APPLE)
find_library(BLAS_LIBRARY NAMES mklml HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
elseif(UNIX)
find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
endif()
include_directories(${USE_MKL_PATH}/include) include_directories(${USE_MKL_PATH}/include)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY})
list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC}) list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC})
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
"""External function interface to BLAS libraries.""" """External function interface to BLAS libraries."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .. import api as _api from .. import api as _api, intrin as _intrin
from .. import intrin as _intrin
def matmul(lhs, rhs, transa=False, transb=False):
def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS """Create an extern op that compute matrix mult of A and rhs with CrhsLAS
This function serves as an example on how to call external libraries. This function serves as an example on how to call external libraries.
...@@ -44,7 +44,50 @@ def matmul(lhs, rhs, transa=False, transb=False): ...@@ -44,7 +44,50 @@ def matmul(lhs, rhs, transa=False, transb=False):
n = lhs.shape[1] if transa else lhs.shape[0] n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1] m = rhs.shape[0] if transb else rhs.shape[1]
return _api.extern( return _api.extern(
(n, m), [lhs, rhs], (n, m),
[lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb
),
name="C",
**kwargs
)
def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs):
"""Create an extern op that compute batched matrix mult of A and rhs with CBLAS
This function serves as an example on how to call external libraries.
Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs
Returns
-------
C : Tensor
The result tensor.
"""
b = lhs.shape[0]
n = lhs.shape[2] if transa else lhs.shape[1]
m = rhs.shape[1] if transb else rhs.shape[2]
return _api.extern(
(b, n, m),
[lhs, rhs],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cblas.matmul", "tvm.contrib.cblas.batch_matmul"
ins[0], ins[1], outs[0], transa, transb), name="C") if not iterative
else "tvm.contrib.cblas.batch_matmul_iterative",
ins[0],
ins[1],
outs[0],
transa,
transb,
),
name="C",
**kwargs
)
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -21,12 +21,11 @@ ...@@ -21,12 +21,11 @@
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file Use external cblas library call. * \file Use external cblas library call.
*/ */
#include <dmlc/logging.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h> #include <tvm/runtime/util.h>
#include <dmlc/logging.h>
#include "gemm_common.h" #include "gemm_common.h"
extern "C" { extern "C" {
#if USE_MKL_BLAS == 1 #if USE_MKL_BLAS == 1
#include <mkl_cblas.h> #include <mkl_cblas.h>
...@@ -40,56 +39,148 @@ namespace contrib { ...@@ -40,56 +39,148 @@ namespace contrib {
using namespace runtime; using namespace runtime;
inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; }
return trans ? CblasTrans : CblasNoTrans;
}
struct CblasSgemmOp { struct CblasSgemmOp {
typedef float TDatatype; typedef float TDatatype;
void operator()(bool ta, bool tb, void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B,
int M, int N, int K, int ldb, float beta, float* C, int ldc) {
float alpha, float* A, int lda, cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A,
float* B, int ldb, lda, B, ldb, beta, C, ldc);
float beta, float* C, int ldc) {
cblas_sgemm(CblasColMajor,
BooleanToTranspose(ta),
BooleanToTranspose(tb),
M, N, K,
alpha, A, lda,
B, ldb,
beta, C, ldc);
} }
}; };
struct CblasDgemmOp { struct CblasDgemmOp {
typedef double TDatatype; typedef double TDatatype;
void operator()(bool ta, bool tb, void operator()(bool ta, bool tb, int M, int N, int K, double alpha, double* A, int lda,
int M, int N, int K, double* B, int ldb, double beta, double* C, int ldc) {
double alpha, double* A, int lda, cblas_dgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A,
double* B, int ldb, lda, B, ldb, beta, C, ldc);
double beta, double* C, int ldc) {
cblas_dgemm(CblasColMajor,
BooleanToTranspose(ta),
BooleanToTranspose(tb),
M, N, K,
alpha, A, lda,
B, ldb,
beta, C, ldc);
} }
}; };
struct CblasSgemmBatchOp {
typedef float TDatatype;
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
int c_stride, int ldc) {
CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
#if USE_MKL_BLAS == 1
std::vector<const float*> A_array(batch_size);
std::vector<const float*> B_array(batch_size);
std::vector<float*> C_array(batch_size);
for (int i = 0; i < batch_size; ++i) {
A_array[i] = A + i * a_stride;
B_array[i] = B + i * b_stride;
C_array[i] = C + i * c_stride;
}
cblas_sgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda,
B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size);
#else
for (int i = 0; i < batch_size; ++i) {
cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
A += a_stride;
B += b_stride;
C += c_stride;
}
#endif
}
};
struct CblasSgemmBatchIterativeOp {
typedef float TDatatype;
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
int c_stride, int ldc) {
CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
for (int i = 0; i < batch_size; ++i) {
cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
A += a_stride;
B += b_stride;
C += c_stride;
}
}
};
struct CblasDgemmBatchOp {
typedef double TDatatype;
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
int c_stride, int ldc) {
CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
#if USE_MKL_BLAS == 1
std::vector<const double*> A_array(batch_size);
std::vector<const double*> B_array(batch_size);
std::vector<double*> C_array(batch_size);
for (int i = 0; i < batch_size; ++i) {
A_array[i] = A + i * a_stride;
B_array[i] = B + i * b_stride;
C_array[i] = C + i * c_stride;
}
cblas_dgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda,
B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size);
#else
for (int i = 0; i < batch_size; ++i) {
cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
A += a_stride;
B += b_stride;
C += c_stride;
}
#endif
}
};
struct CblasDgemmBatchIterativeOp {
typedef double TDatatype;
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
int c_stride, int ldc) {
CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
for (int i = 0; i < batch_size; ++i) {
cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
A += a_stride;
B += b_stride;
C += c_stride;
}
}
};
// matrix multiplication for row major // matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0]; DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
TypeMatch(A->dtype, kDLFloat, 64));
if (TypeMatch(A->dtype, kDLFloat, 32))
CallGemm(args, ret, CblasSgemmOp());
else
CallGemm(args, ret, CblasDgemmOp());
});
if (TypeMatch(A->dtype, kDLFloat, 32)) TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul")
CallGemm(args, ret, CblasSgemmOp()); .set_body([](TVMArgs args, TVMRetValue* ret) {
else DLTensor* A = args[0];
CallGemm(args, ret, CblasDgemmOp()); CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
}); if (TypeMatch(A->dtype, kDLFloat, 32)) {
CallBatchGemm(args, ret, CblasSgemmBatchOp());
} else {
CallBatchGemm(args, ret, CblasDgemmBatchOp());
}
});
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
if (TypeMatch(A->dtype, kDLFloat, 32)) {
CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp());
} else {
CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp());
}
});
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -22,16 +22,17 @@ ...@@ -22,16 +22,17 @@
* \file tvm/contrib/gemm.h * \file tvm/contrib/gemm.h
* \brief Shared implementation of gemm * \brief Shared implementation of gemm
*/ */
#ifndef TVM_CONTRIB_CBLAS_GEMM_COMMON_H_ #pragma once
#define TVM_CONTRIB_CBLAS_GEMM_COMMON_H_
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <algorithm> #include <algorithm>
namespace tvm { namespace tvm {
namespace contrib { namespace contrib {
using namespace runtime; using namespace runtime;
inline int ColumnStride(DLTensor *tensor) {
inline int ColumnStride(DLTensor* tensor) {
// If the tensor itself is transposed then it will have strides // If the tensor itself is transposed then it will have strides
// backward from what we expect. Regardless, the max of the strides // backward from what we expect. Regardless, the max of the strides
// (the other stride is 1) is the column stride. // (the other stride is 1) is the column stride.
...@@ -42,8 +43,7 @@ inline int ColumnStride(DLTensor* tensor) { ...@@ -42,8 +43,7 @@ inline int ColumnStride(DLTensor* tensor) {
} }
} }
inline int ElementStride(DLTensor *tensor) {
inline int ElementStride(DLTensor* tensor) {
if (tensor->strides) { if (tensor->strides) {
return std::min(tensor->strides[0], tensor->strides[1]); return std::min(tensor->strides[0], tensor->strides[1]);
} else { } else {
...@@ -51,29 +51,26 @@ inline int ElementStride(DLTensor* tensor) { ...@@ -51,29 +51,26 @@ inline int ElementStride(DLTensor* tensor) {
} }
} }
// Reversed strides indicates an in-place transpose operation. // Reversed strides indicates an in-place transpose operation.
inline bool IsInPlaceTransposed(DLTensor* tensor) { inline bool IsInPlaceTransposed(DLTensor *tensor) {
return tensor->strides && (tensor->strides[1] > tensor->strides[0]); return tensor->strides && (tensor->strides[1] > tensor->strides[0]);
} }
inline int RowCount(DLTensor *tensor, bool trans) {
inline int RowCount(DLTensor* tensor, bool trans) {
return tensor->shape[trans ? 1 : 0]; return tensor->shape[trans ? 1 : 0];
} }
inline int ColumnCount(DLTensor *tensor, bool trans) {
inline int ColumnCount(DLTensor* tensor, bool trans) {
return tensor->shape[trans ? 0 : 1]; return tensor->shape[trans ? 0 : 1];
} }
// Call a column major blas. Note that data is stored in tvm as row // Call a column major blas. Note that data is stored in tvm as row
// major, so this we switch the arguments. // major, so this we switch the arguments.
template<typename TGemmOp> template <typename TGemmOp>
inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) { inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) {
DLTensor* A = args[0]; DLTensor *A = args[0];
DLTensor* B = args[1]; DLTensor *B = args[1];
DLTensor* C = args[2]; DLTensor *C = args[2];
bool transa = args[3]; bool transa = args[3];
bool transb = args[4]; bool transb = args[4];
int bit_depth = sizeof(typename TGemmOp::TDatatype) * 8; int bit_depth = sizeof(typename TGemmOp::TDatatype) * 8;
...@@ -96,25 +93,88 @@ inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) { ...@@ -96,25 +93,88 @@ inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) {
CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth));
double alpha = args.size() > 5 ? args[5] : 1.0; double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0; double beta = args.size() > 6 ? args[6] : 0.0;
op(transb, op(transb, transa, ColumnCount(B, transb), RowCount(A, transa),
transa, ColumnCount(A, transa), static_cast<float>(alpha),
ColumnCount(B, transb), reinterpret_cast<typename TGemmOp::TDatatype *>(
RowCount(A, transa), static_cast<char *>(B->data) + B->byte_offset),
ColumnCount(A, transa),
static_cast<float>(alpha),
reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(B->data)
+ B->byte_offset),
ColumnStride(B), ColumnStride(B),
reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(A->data) reinterpret_cast<typename TGemmOp::TDatatype *>(
+ A->byte_offset), static_cast<char *>(A->data) + A->byte_offset),
ColumnStride(A), ColumnStride(A), static_cast<float>(beta),
static_cast<float>(beta), reinterpret_cast<typename TGemmOp::TDatatype *>(
reinterpret_cast<typename TGemmOp::TDatatype*>(static_cast<char*>(C->data) static_cast<char *>(C->data) + C->byte_offset),
+ C->byte_offset),
ColumnStride(C)); ColumnStride(C));
} }
inline int ColumnStride3D(DLTensor *tensor) {
// If the tensor itself is transposed then it will have strides
// backward from what we expect. Regardless, the max of the strides
// (the other stride is 1) is the column stride.
if (tensor->strides) {
return std::max(tensor->strides[1], tensor->strides[2]);
} else {
return tensor->shape[2];
}
}
inline int ElementStride3D(DLTensor *tensor) {
if (tensor->strides) {
return std::min(tensor->strides[1], tensor->strides[2]);
} else {
return 1;
}
}
// Reversed strides indicates an in-place transpose operation.
inline bool IsInPlaceTransposed3D(DLTensor *tensor) {
return tensor->strides && (tensor->strides[2] > tensor->strides[1]);
}
inline int BatchCount3D(DLTensor *tensor) { return tensor->shape[0]; }
inline int RowCount3D(DLTensor *tensor, bool trans) {
return tensor->shape[trans ? 2 : 1];
}
inline int ColumnCount3D(DLTensor *tensor, bool trans) {
return tensor->shape[trans ? 1 : 2];
}
template <typename TBatchGemmOp>
inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) {
using DType = typename TBatchGemmOp::TDatatype;
DLTensor *A = args[0];
DLTensor *B = args[1];
DLTensor *C = args[2];
bool transa = args[3];
bool transb = args[4];
int bit_depth = sizeof(DType) * 8;
CHECK_EQ(A->ndim, 3);
CHECK_EQ(B->ndim, 3);
CHECK_EQ(C->ndim, 3);
int batch_size = BatchCount3D(A);
CHECK_EQ(BatchCount3D(B), batch_size);
CHECK_EQ(BatchCount3D(C), batch_size);
CHECK_EQ(ElementStride(A), 1);
CHECK_EQ(ElementStride(B), 1);
CHECK_EQ(ElementStride(C), 1);
// C can never be transposed.
CHECK(!IsInPlaceTransposed3D(C));
// Reversed strides indicates an in-place transpose operation.
transa = IsInPlaceTransposed3D(A) ? !transa : transa;
transb = IsInPlaceTransposed3D(B) ? !transb : transb;
CHECK(TypeMatch(B->dtype, kDLFloat, bit_depth));
CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth));
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
const int A_size = A->shape[1] * A->shape[2];
const int B_size = B->shape[1] * B->shape[2];
const int C_size = C->shape[1] * C->shape[2];
DType *A_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
static_cast<char *>(A->data) + A->byte_offset);
DType *B_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
static_cast<char *>(B->data) + B->byte_offset);
DType *C_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
static_cast<char *>(C->data) + C->byte_offset);
op(batch_size, transb, transa, ColumnCount3D(B, transb),
RowCount3D(A, transa), ColumnCount3D(A, transa), static_cast<float>(alpha),
B_data, B_size, ColumnStride3D(B), A_data, A_size, ColumnStride3D(A),
static_cast<float>(beta), C_data, C_size, ColumnStride3D(C));
}
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
#endif // TVM_CONTRIB_CBLAS_GEMM_COMMON_H_
...@@ -16,19 +16,26 @@ ...@@ -16,19 +16,26 @@
# under the License. # under the License.
import tvm import tvm
import numpy as np import numpy as np
import topi.testing
from tvm.contrib import cblas from tvm.contrib import cblas
def test_matmul_add(): def verify_matmul_add(m, l, n, transa=False, transb=False, dtype=tvm.float32):
n = 1024 bias = tvm.var('bias', dtype=dtype)
l = 128 ashape = (l, n) if transa else (n, l)
m = 235 bshape = (m, l) if transb else (l, m)
bias = tvm.var('bias', dtype=tvm.float32) A = tvm.placeholder(ashape, name='A', dtype=dtype)
A = tvm.placeholder((n, l), name='A') B = tvm.placeholder(bshape, name='B', dtype=dtype)
B = tvm.placeholder((l, m), name='B') C = cblas.matmul(A, B, transa, transb)
C = cblas.matmul(A, B)
D = tvm.compute(C.shape, lambda i, j: C[i,j] + bias, name="D") D = tvm.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
s = tvm.create_schedule(D.op) s = tvm.create_schedule(D.op)
def get_numpy(a, b, bb, transa, transb):
if transa:
a = a.transpose()
if transb:
b = b.transpose()
return np.dot(a, b) + bb
def verify(target="llvm"): def verify(target="llvm"):
if not tvm.module.enabled(target): if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target) print("skip because %s is not enabled..." % target)
...@@ -38,15 +45,69 @@ def test_matmul_add(): ...@@ -38,15 +45,69 @@ def test_matmul_add():
return return
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, D, bias], target) f = tvm.build(s, [A, B, D, bias], target)
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx) d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
bb = 10.0 bb = 10.0
f(a, b, d, bb) f(a, b, d, bb)
tvm.testing.assert_allclose( tvm.testing.assert_allclose(
d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb, rtol=1e-5) d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), bb, transa, transb), rtol=1e-5)
verify()
def test_matmul_add():
verify_matmul_add(235, 128, 1024)
verify_matmul_add(235, 128, 1024, True, False)
verify_matmul_add(235, 128, 1024, False, True)
verify_matmul_add(235, 128, 1024, True, True)
verify_matmul_add(1, 16, 4)
verify_matmul_add(1, 16, 3, True, False)
verify_matmul_add(1, 16, 3, False, False)
verify_matmul_add(1, 16, 3, True, True)
def verify_batch_matmul(batch, m, l, n, transa=False, transb=False, iterative=False, dtype=tvm.float32):
ashape = (batch, l, n) if transa else (batch, n, l)
bshape = (batch, m, l) if transb else (batch, l, m)
A = tvm.placeholder(ashape, name='A', dtype=dtype)
B = tvm.placeholder(bshape, name='B', dtype=dtype)
C = cblas.batch_matmul(A, B, transa, transb)
D = tvm.compute(C.shape, lambda k, i, j: C[k, i,j], name="D")
s = tvm.create_schedule(D.op)
def get_numpy(a, b, transa, transb):
if transa:
a = a.transpose(0, 2, 1)
if not transb:
b = b.transpose(0, 2, 1)
return topi.testing.batch_matmul(a, b)
def verify(target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cblas.matmul", True):
print("skip because extern function is not available")
return
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, D], target)
a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((batch, n, m), dtype=D.dtype), ctx)
f(a, b, d)
tvm.testing.assert_allclose(
d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), transa, transb), rtol=1e-5)
verify() verify()
def test_batch_matmul():
verify_batch_matmul(16, 235, 128, 1024)
verify_batch_matmul(16, 235, 128, 1024, True, False)
verify_batch_matmul(16, 235, 128, 1024, False, True)
verify_batch_matmul(16, 235, 128, 1024, True, True)
verify_batch_matmul(1, 1, 16, 3)
verify_batch_matmul(1, 1, 16, 3, True, False)
verify_batch_matmul(1, 1, 16, 3, False, False)
verify_batch_matmul(1, 1, 16, 3, True, True)
verify_batch_matmul(1, 1, 16, 3, iterative=True)
if __name__ == "__main__": if __name__ == "__main__":
test_matmul_add() test_matmul_add()
test_batch_matmul()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment