Commit 88f9bfd4 by Jon Soifer Committed by Haichen Shen

[TOPI][CUDA] Support cuBLAS BatchMatMul (#3936)

* Support cuBLAS BatchMatMul

* Add test and check target name
parent 1de52bb0
......@@ -46,3 +46,31 @@ def matmul(lhs, rhs, transa=False, transb=False):
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cublas.matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")
def batch_matmul(lhs, rhs, transa=False, transb=False):
"""Create an extern op that compute batch matrix mult of A and rhs with cuBLAS
Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs
Returns
-------
C : Tensor
The result tensor.
"""
b = lhs.shape[0]
n = lhs.shape[2] if transa else lhs.shape[1]
m = rhs.shape[1] if transb else rhs.shape[2]
return _api.extern(
(b, n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cublas.batch_matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")
......@@ -81,6 +81,50 @@ struct CublasDgemmOp {
}
};
struct CublasSgemmBatchOp {
typedef float TDatatype;
cublasHandle_t handle;
explicit CublasSgemmBatchOp(cublasHandle_t hdl)
: handle(hdl)
{}
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A,
int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C,
int c_stride, int ldc) {
CHECK_CUBLAS_ERROR(cublasSgemmStridedBatched(handle,
BooleanToTranspose(ta),
BooleanToTranspose(tb),
M, N, K,
&alpha,
A, lda, a_stride,
B, ldb, b_stride,
&beta,
C, ldc, c_stride,
batch_size));
}
};
struct CublasDgemmBatchOp {
typedef double TDatatype;
cublasHandle_t handle;
explicit CublasDgemmBatchOp(cublasHandle_t hdl)
: handle(hdl)
{}
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
int c_stride, int ldc) {
CHECK_CUBLAS_ERROR(cublasDgemmStridedBatched(handle,
BooleanToTranspose(ta),
BooleanToTranspose(tb),
M, N, K,
&alpha,
A, lda, a_stride,
B, ldb, b_stride,
&beta,
C, ldc, c_stride,
batch_size));
}
};
// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
......@@ -96,5 +140,21 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
else
CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle));
});
TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) ||
TypeMatch(A->dtype, kDLFloat, 64));
CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
if (TypeMatch(A->dtype, kDLFloat, 32))
CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle));
else
CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle));
});
} // namespace contrib
} // namespace tvm
......@@ -44,6 +44,34 @@ def test_matmul_add():
c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
verify()
def test_batch_matmul():
j = 16
n = 1024
l = 128
m = 235
A = tvm.placeholder((j, n, l), name='A')
B = tvm.placeholder((j, l, m), name='B')
C = cublas.batch_matmul(A, B)
s = tvm.create_schedule(C.op)
def verify(target="cuda"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
print("skip because extern function is not available")
return
ctx = tvm.gpu(0)
f = tvm.build(s, [A, B, C], target)
a = tvm.nd.array(np.random.uniform(size=(j, n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(j, l, m)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((j, n, m), dtype=C.dtype), ctx)
f(a, b, c)
tvm.testing.assert_allclose(
c.asnumpy(), np.matmul(a.asnumpy(), b.asnumpy()), rtol=1e-5)
verify()
if __name__ == "__main__":
test_matmul_add()
test_batch_matmul()
......@@ -61,6 +61,38 @@ inline Tensor cublas_matmul(const Tensor& lhs,
}, "C", "", {})[0];
}
/*!
* \brief Create an op that multiplies batch matrices
* lhs and rhs with cuBLAS
*
* \param lhs The left matrix operand
* \param rhs The right matrix operand
* \param transa Whether to transpose lhs
* \param transb Whether to transpose rhs
*
* \return The output tensor
*/
inline Tensor cublas_batch_matmul(const Tensor& lhs,
const Tensor& rhs,
bool transa,
bool transb) {
auto b = lhs->shape[0];
auto n = transa ? lhs->shape[2] : lhs->shape[1];
auto m = transb ? rhs->shape[1] : rhs->shape[2];
return make_extern(
{ { b, n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
Expr("tvm.contrib.cublas.batch_matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
transa,
transb });
}, "C", "", {})[0];
}
} // namespace contrib
} // namespace topi
......
......@@ -18,10 +18,33 @@
"""cuda batch_matmul operators"""
from __future__ import absolute_import as _abs
import tvm
from tvm.contrib import cublas
from topi.nn import batch_matmul, batch_matmul_default
from .. import generic
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
@batch_matmul.register(["cuda", "gpu"])
def batch_matmul_cuda(x, y):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
Parameters
----------
x : tvm.Tensor
3-D with shape [batch, M, K]
y : tvm.Tensor
3-D with shape [batch, N, K]
Returns
-------
output : tvm.Tensor
3-D with shape [batch, M, N]
"""
target = tvm.target.current_target()
if target.target_name == "cuda" and "cublas" in target.libs:
return cublas.batch_matmul(x, y, False, True)
return batch_matmul_default(x, y)
@generic.schedule_batch_matmul.register(["cuda", "gpu"])
def schedule_batch_matmul(outs):
......@@ -38,6 +61,10 @@ def schedule_batch_matmul(outs):
s: Schedule
The computation schedule for the op.
"""
target = tvm.target.current_target()
if target.target_name == "cuda" and "cublas" in target.libs:
return generic.schedule_extern(outs)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment