Commit a407ec15 by masahi Committed by Tianqi Chen

[CONTRIB] rocBLAS integration (#751)

* rocblas integration

* fix include

* fix lint
parent c4927378
...@@ -139,6 +139,7 @@ include make/contrib/cudnn.mk ...@@ -139,6 +139,7 @@ include make/contrib/cudnn.mk
include make/contrib/miopen.mk include make/contrib/miopen.mk
include make/contrib/mps.mk include make/contrib/mps.mk
include make/contrib/cublas.mk include make/contrib/cublas.mk
include make/contrib/rocblas.mk
ifdef ADD_CFLAGS ifdef ADD_CFLAGS
CFLAGS += $(ADD_CFLAGS) CFLAGS += $(ADD_CFLAGS)
......
...@@ -80,3 +80,6 @@ USE_MPS = 0 ...@@ -80,3 +80,6 @@ USE_MPS = 0
# Whether use cuBLAS # Whether use cuBLAS
USE_CUBLAS = 0 USE_CUBLAS = 0
# Whether use rocBlas
USE_ROCBLAS = 0
ROCBLAS_CONTRIB_SRC = $(wildcard src/contrib/rocblas/*.cc)
ROCBLAS_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(ROCBLAS_CONTRIB_SRC))
ifeq ($(USE_ROCBLAS), 1)
CFLAGS += -DTVM_USE_ROCBLAS=1
ADD_LDFLAGS += -lrocblas
RUNTIME_DEP += $(ROCBLAS_CONTRIB_OBJ)
endif
...@@ -7,7 +7,7 @@ from .. import intrin as _intrin ...@@ -7,7 +7,7 @@ from .. import intrin as _intrin
def matmul(lhs, rhs, transa=False, transb=False): def matmul(lhs, rhs, transa=False, transb=False):
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS """Create an extern op that compute matrix mult of A and rhs with CrhsLAS
This function serves as an example on how to calle external libraries. This function serves as an example on how to call external libraries.
Parameters Parameters
---------- ----------
......
"""External function interface to rocBLAS libraries."""
from __future__ import absolute_import as _abs
from .. import api as _api
from .. import intrin as _intrin
def matmul(lhs, rhs, transa=False, transb=False):
"""Create an extern op that compute matrix mult of A and rhs with rocBLAS
Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs
Returns
-------
C : Tensor
The result tensor.
"""
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
return _api.extern(
(n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.rocblas.matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")
/*!
* Copyright (c) 2017 by Contributors
* \file Use external rocblas library call.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
#include "rocblas.h"
namespace tvm {
namespace contrib {
using namespace runtime;
#ifndef CHECK_ROCBLAS_ERROR
#define CHECK_ROCBLAS_ERROR(error) \
if (error != rocblas_status_success) { \
fprintf(stderr, "rocBLAS error: "); \
if (error == rocblas_status_invalid_handle) fprintf(stderr, "rocblas_status_invalid_handle"); \
if (error == rocblas_status_not_implemented) fprintf(stderr, " rocblas_status_not_implemented"); \
if (error == rocblas_status_invalid_pointer) fprintf(stderr, "rocblas_status_invalid_pointer"); \
if (error == rocblas_status_invalid_size) fprintf(stderr, "rocblas_status_invalid_size"); \
if (error == rocblas_status_memory_error) fprintf(stderr, "rocblas_status_memory_error"); \
if (error == rocblas_status_internal_error) fprintf(stderr, "rocblas_status_internal_error"); \
fprintf(stderr, "\n"); \
exit(EXIT_FAILURE); \
}
#endif
// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];
// call gemm for simple compact code.
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);
CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));
rocblas_handle handle;
CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle));
float alpha = 1.0;
float beta = 0.0;
float *A_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset);
float *B_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset);
float *C_ptr = reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset);
CHECK_ROCBLAS_ERROR(rocblas_sgemm(handle,
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
transb ? B->shape[0] : B->shape[1],
transa ? A->shape[1] : A->shape[0],
transb ? B->shape[1] : B->shape[0],
&alpha,
A_ptr,
B->shape[1],
B_ptr,
A->shape[1],
&beta,
C_ptr,
C->shape[1]));
CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle));
});
} // namespace contrib
} // namespace tvm
import tvm
import numpy as np
from tvm.contrib import rocblas
def test_matmul_add():
n = 1024
l = 128
m = 235
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((l, m), name='B')
C = rocblas.matmul(A, B)
s = tvm.create_schedule(C.op)
def verify(target="rocm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.rocblas.matmul", True):
print("skip because extern function is not avalable")
return
ctx = tvm.rocm(0)
f = tvm.build(s, [A, B, C], target)
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
verify()
if __name__ == "__main__":
test_matmul_add()
...@@ -3,3 +3,4 @@ ...@@ -3,3 +3,4 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .conv2d import * from .conv2d import *
from .dense import *
# pylint: disable=invalid-name, unused-variable
"""Schedule for dense operator"""
from __future__ import absolute_import as _abs
import tvm
from tvm.contrib import rocblas
import topi
from ..nn.dense import dense, dense_default
from .. import tag
from .. import generic
@dense.register("rocm")
def dense_rocm(data, weight, bias=None):
"""Dense operator for rocm backend.
Parameters
----------
data : tvm.Tensor
2-D with shape [batch, in_dim]
weight : tvm.Tensor
2-D with shape [out_dim, in_dim]
bias : tvm.Tensor, optional
1-D with shape [out_dim]
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
assert len(data.shape) == 2 and len(weight.shape) == 2, \
"only support 2-dim dense"
if bias is not None:
assert len(bias.shape) == 1
batch, in_dim = data.shape
out_dim, _ = weight.shape
target = tvm.target.current_target()
if "rocblas" in target.libs:
matmul = rocblas.matmul(data, weight, False, True)
if bias is not None:
matmul = tvm.compute((batch, out_dim), \
lambda i, j: matmul[i, j] + bias[j], \
tag=tag.BROADCAST)
return matmul
return dense_default(data, weight, bias)
@generic.schedule_dense.register(["rocm"])
def schedule_dense(outs):
"""Schedule for dense operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of dense
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for dense.
"""
target = tvm.target.current_target()
if target.target_name == "rocm" and "rocblas" in target.libs:
return generic.schedule_extern(outs)
return topi.cuda.schedule_dense(outs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment