Commit 3d5032ae by masahi Committed by Tianqi Chen

[CONTRIB] cuBLAS integration (#744)

* add cublas support

* integrate cublas to topi dense

* add cublas error check

* minor fix

* fix lint

* remove topi import from contrib unittest
parent 77299df3
......@@ -138,6 +138,7 @@ include make/contrib/nnpack.mk
include make/contrib/cudnn.mk
include make/contrib/miopen.mk
include make/contrib/mps.mk
include make/contrib/cublas.mk
ifdef ADD_CFLAGS
CFLAGS += $(ADD_CFLAGS)
......
......@@ -77,3 +77,6 @@ USE_MIOPEN = 0
# Whether use MPS
USE_MPS = 0
# Whether use cuBLAS
USE_CUBLAS = 0
CUBLAS_CONTRIB_SRC = $(wildcard src/contrib/cublas/*.cc)
CUBLAS_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(CUBLAS_CONTRIB_SRC))
ifeq ($(USE_CUBLAS), 1)
CFLAGS += -DTVM_USE_CUBLAS=1
ADD_LDFLAGS += -lcublas
RUNTIME_DEP += $(CUBLAS_CONTRIB_OBJ)
endif
"""External function interface to BLAS libraroes."""
"""External function interface to BLAS libraries."""
from __future__ import absolute_import as _abs
from .. import api as _api
......
"""External function interface to cuBLAS libraries."""
from __future__ import absolute_import as _abs
from .. import api as _api
from .. import intrin as _intrin
def matmul(lhs, rhs, transa=False, transb=False):
"""Create an extern op that compute matrix mult of A and rhs with cuBLAS
Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs
Returns
-------
C : Tensor
The result tensor.
"""
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
return _api.extern(
(n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cublas.matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")
......@@ -94,7 +94,8 @@ class Target(object):
# Parse device option
for item in self.options:
if item.startswith("-libs="):
self.libs.append(item.split("=")[1])
libs = item.split("=")[1]
self.libs += libs.split(",")
elif item.startswith("-device="):
self.device_name = item.split("=")[1]
# Target query searchs device name first
......
......@@ -38,7 +38,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
transa ? CblasTrans : CblasNoTrans,
transb ? B->shape[0] : B->shape[1],
transa ? A->shape[1] : A->shape[0],
transa ? B->shape[1] : B->shape[0],
transb ? B->shape[1] : B->shape[0],
1.0f,
reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset),
B->shape[1],
......
/*!
* Copyright (c) 2017 by Contributors
* \file Use external cblas library call.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
extern "C" {
#include <cublas_v2.h>
}
namespace tvm {
namespace contrib {
using namespace runtime;
#ifndef CHECK_CUBLAS_ERROR
#define CHECK_CUBLAS_ERROR(error) \
if (error != CUBLAS_STATUS_SUCCESS) { \
fprintf(stderr, "cuBLAS error: "); \
if (error == CUBLAS_STATUS_NOT_INITIALIZED) fprintf(stderr, "CUBLAS_STATUS_NOT_INITIALIZED"); \
if (error == CUBLAS_STATUS_ALLOC_FAILED) fprintf(stderr, "CUBLAS_STATUS_ALLOC_FAILED"); \
if (error == CUBLAS_STATUS_INVALID_VALUE) fprintf(stderr, "CUBLAS_STATUS_INVALID_VALUE"); \
if (error == CUBLAS_STATUS_ARCH_MISMATCH) fprintf(stderr, "CUBLAS_STATUS_ARCH_MISMATCH"); \
if (error == CUBLAS_STATUS_MAPPING_ERROR) fprintf(stderr, "CUBLAS_STATUS_MAPPING_ERROR"); \
if (error == CUBLAS_STATUS_EXECUTION_FAILED) fprintf(stderr, "CUBLAS_STATUS_EXECUTION_FAILED"); \
if (error == CUBLAS_STATUS_INTERNAL_ERROR) fprintf(stderr, "CUBLAS_STATUS_INTERNAL_ERROR"); \
if (error == CUBLAS_STATUS_NOT_SUPPORTED) fprintf(stderr, "CUBLAS_STATUS_NOT_SUPPORTED"); \
if (error == CUBLAS_STATUS_LICENSE_ERROR) fprintf(stderr, "CUBLAS_STATUS_LICENSE_ERROR"); \
fprintf(stderr, "\n"); \
exit(EXIT_FAILURE); \
}
#endif
// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];
// call gemm for simple compact code.
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);
CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));
cublasHandle_t handle;
CHECK_CUBLAS_ERROR(cublasCreate(&handle));
float alpha = 1.0;
float beta = 0.0;
float *A_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset);
float *B_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset);
float *C_ptr = reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset);
CHECK_CUBLAS_ERROR(cublasSgemm(handle,
transb ? CUBLAS_OP_T : CUBLAS_OP_N,
transa ? CUBLAS_OP_T : CUBLAS_OP_N,
transb ? B->shape[0] : B->shape[1],
transa ? A->shape[1] : A->shape[0],
transb ? B->shape[1] : B->shape[0],
&alpha,
A_ptr,
B->shape[1],
B_ptr,
A->shape[1],
&beta,
C_ptr,
C->shape[1]));
CHECK_CUBLAS_ERROR(cublasDestroy(handle));
});
} // namespace contrib
} // namespace tvm
import tvm
import numpy as np
from tvm.contrib import cublas
def test_matmul_add():
n = 1024
l = 128
m = 235
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((l, m), name='B')
C = cublas.matmul(A, B)
s = tvm.create_schedule(C.op)
def verify(target="cuda"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
print("skip because extern function is not avalable")
return
ctx = tvm.gpu(0)
f = tvm.build(s, [A, B, C], target)
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
verify()
if __name__ == "__main__":
test_matmul_add()
......@@ -41,8 +41,7 @@ def test_conv2d():
tensor_format=0,
algo=1)
yshape = [x.value for x in Y.shape]
with tvm.target.create("cuda -libs=cudnn"):
s = tvm.create_schedule(Y.op)
s = tvm.create_schedule(Y.op)
def verify():
ctx = tvm.gpu(0)
......
......@@ -11,7 +11,7 @@ from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
from .reduction import schedule_reduce
from .softmax import schedule_softmax
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import schedule_dense
from .dense import dense_cuda, schedule_dense
from .pooling import schedule_pool, schedule_global_pool
from .conv2d_transpose_nchw import schedule_conv2d_transpose_nchw
from .extern import schedule_extern
......@@ -2,9 +2,48 @@
"""Schedule for dense operator"""
from __future__ import absolute_import as _abs
import tvm
from tvm.contrib import cublas
from ..nn.dense import dense, dense_default
from .. import tag
from .. import generic
@dense.register("cuda")
def dense_cuda(data, weight, bias=None):
"""Dense operator for cuda backend.
Parameters
----------
data : tvm.Tensor
2-D with shape [batch, in_dim]
weight : tvm.Tensor
2-D with shape [out_dim, in_dim]
bias : tvm.Tensor, optional
1-D with shape [out_dim]
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
assert len(data.shape) == 2 and len(weight.shape) == 2, \
"only support 2-dim dense"
if bias is not None:
assert len(bias.shape) == 1
batch, in_dim = data.shape
out_dim, _ = weight.shape
target = tvm.target.current_target()
if "cublas" in target.libs:
matmul = cublas.matmul(data, weight, False, True)
if bias is not None:
matmul = tvm.compute((batch, out_dim), \
lambda i, j: matmul[i, j] + bias[j], \
tag=tag.BROADCAST)
return matmul
return dense_default(data, weight, bias)
@generic.schedule_dense.register(["cuda", "gpu"])
def schedule_dense(outs):
"""Schedule for dense operator.
......@@ -20,6 +59,10 @@ def schedule_dense(outs):
s: Schedule
The computation schedule for dense.
"""
target = tvm.target.current_target()
if target.target_name == "cuda" and "cublas" in target.libs:
return generic.schedule_extern(outs)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(Dense):
......
......@@ -3,9 +3,8 @@ from __future__ import absolute_import
import tvm
from .. import tag
def dense(data, weight, bias=None):
"""Applies a linear transformation: :math:`Y = XW^T + b`.
def dense_default(data, weight, bias=None):
"""The default implementation of dense in topi.
Parameters
----------
......@@ -38,3 +37,26 @@ def dense(data, weight, bias=None):
lambda i, j: matmul[i, j] + bias[j], \
tag=tag.BROADCAST)
return matmul
@tvm.target.generic_func
def dense(data, weight, bias=None):
"""Applies a linear transformation: :math:`Y = XW^T + b`.
Parameters
----------
data : tvm.Tensor
2-D with shape [batch, in_dim]
weight : tvm.Tensor
2-D with shape [out_dim, in_dim]
bias : tvm.Tensor, optional
1-D with shape [out_dim]
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
return dense_default(data, weight, bias)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment