Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
88f9bfd4
Commit
88f9bfd4
authored
Sep 12, 2019
by
Jon Soifer
Committed by
Haichen Shen
Sep 12, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TOPI][CUDA] Support cuBLAS BatchMatMul (#3936)
* Support cuBLAS BatchMatMul * Add test and check target name
parent
1de52bb0
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
176 additions
and
1 deletions
+176
-1
python/tvm/contrib/cublas.py
+28
-0
src/contrib/cublas/cublas.cc
+60
-0
tests/python/contrib/test_cublas.py
+28
-0
topi/include/topi/contrib/cublas.h
+32
-0
topi/python/topi/cuda/batch_matmul.py
+28
-1
No files found.
python/tvm/contrib/cublas.py
View file @
88f9bfd4
...
@@ -46,3 +46,31 @@ def matmul(lhs, rhs, transa=False, transb=False):
...
@@ -46,3 +46,31 @@ def matmul(lhs, rhs, transa=False, transb=False):
lambda
ins
,
outs
:
_intrin
.
call_packed
(
lambda
ins
,
outs
:
_intrin
.
call_packed
(
"tvm.contrib.cublas.matmul"
,
"tvm.contrib.cublas.matmul"
,
ins
[
0
],
ins
[
1
],
outs
[
0
],
transa
,
transb
),
name
=
"C"
)
ins
[
0
],
ins
[
1
],
outs
[
0
],
transa
,
transb
),
name
=
"C"
)
def
batch_matmul
(
lhs
,
rhs
,
transa
=
False
,
transb
=
False
):
"""Create an extern op that compute batch matrix mult of A and rhs with cuBLAS
Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs
Returns
-------
C : Tensor
The result tensor.
"""
b
=
lhs
.
shape
[
0
]
n
=
lhs
.
shape
[
2
]
if
transa
else
lhs
.
shape
[
1
]
m
=
rhs
.
shape
[
1
]
if
transb
else
rhs
.
shape
[
2
]
return
_api
.
extern
(
(
b
,
n
,
m
),
[
lhs
,
rhs
],
lambda
ins
,
outs
:
_intrin
.
call_packed
(
"tvm.contrib.cublas.batch_matmul"
,
ins
[
0
],
ins
[
1
],
outs
[
0
],
transa
,
transb
),
name
=
"C"
)
src/contrib/cublas/cublas.cc
View file @
88f9bfd4
...
@@ -81,6 +81,50 @@ struct CublasDgemmOp {
...
@@ -81,6 +81,50 @@ struct CublasDgemmOp {
}
}
};
};
struct
CublasSgemmBatchOp
{
typedef
float
TDatatype
;
cublasHandle_t
handle
;
explicit
CublasSgemmBatchOp
(
cublasHandle_t
hdl
)
:
handle
(
hdl
)
{}
void
operator
()(
int
batch_size
,
bool
ta
,
bool
tb
,
int
M
,
int
N
,
int
K
,
float
alpha
,
float
*
A
,
int
a_stride
,
int
lda
,
float
*
B
,
int
b_stride
,
int
ldb
,
float
beta
,
float
*
C
,
int
c_stride
,
int
ldc
)
{
CHECK_CUBLAS_ERROR
(
cublasSgemmStridedBatched
(
handle
,
BooleanToTranspose
(
ta
),
BooleanToTranspose
(
tb
),
M
,
N
,
K
,
&
alpha
,
A
,
lda
,
a_stride
,
B
,
ldb
,
b_stride
,
&
beta
,
C
,
ldc
,
c_stride
,
batch_size
));
}
};
struct
CublasDgemmBatchOp
{
typedef
double
TDatatype
;
cublasHandle_t
handle
;
explicit
CublasDgemmBatchOp
(
cublasHandle_t
hdl
)
:
handle
(
hdl
)
{}
void
operator
()(
int
batch_size
,
bool
ta
,
bool
tb
,
int
M
,
int
N
,
int
K
,
double
alpha
,
double
*
A
,
int
a_stride
,
int
lda
,
double
*
B
,
int
b_stride
,
int
ldb
,
double
beta
,
double
*
C
,
int
c_stride
,
int
ldc
)
{
CHECK_CUBLAS_ERROR
(
cublasDgemmStridedBatched
(
handle
,
BooleanToTranspose
(
ta
),
BooleanToTranspose
(
tb
),
M
,
N
,
K
,
&
alpha
,
A
,
lda
,
a_stride
,
B
,
ldb
,
b_stride
,
&
beta
,
C
,
ldc
,
c_stride
,
batch_size
));
}
};
// matrix multiplication for row major
// matrix multiplication for row major
TVM_REGISTER_GLOBAL
(
"tvm.contrib.cublas.matmul"
)
TVM_REGISTER_GLOBAL
(
"tvm.contrib.cublas.matmul"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
...
@@ -96,5 +140,21 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
...
@@ -96,5 +140,21 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
else
else
CallGemm
(
args
,
ret
,
CublasDgemmOp
(
entry_ptr
->
handle
));
CallGemm
(
args
,
ret
,
CublasDgemmOp
(
entry_ptr
->
handle
));
});
});
TVM_REGISTER_GLOBAL
(
"tvm.contrib.cublas.batch_matmul"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
DLTensor
*
A
=
args
[
0
];
CHECK
(
TypeMatch
(
A
->
dtype
,
kDLFloat
,
32
)
||
TypeMatch
(
A
->
dtype
,
kDLFloat
,
64
));
CuBlasThreadEntry
*
entry_ptr
=
CuBlasThreadEntry
::
ThreadLocal
();
if
(
TypeMatch
(
A
->
dtype
,
kDLFloat
,
32
))
CallBatchGemm
(
args
,
ret
,
CublasSgemmBatchOp
(
entry_ptr
->
handle
));
else
CallBatchGemm
(
args
,
ret
,
CublasDgemmBatchOp
(
entry_ptr
->
handle
));
});
}
// namespace contrib
}
// namespace contrib
}
// namespace tvm
}
// namespace tvm
tests/python/contrib/test_cublas.py
View file @
88f9bfd4
...
@@ -44,6 +44,34 @@ def test_matmul_add():
...
@@ -44,6 +44,34 @@ def test_matmul_add():
c
.
asnumpy
(),
np
.
dot
(
a
.
asnumpy
(),
b
.
asnumpy
()),
rtol
=
1e-5
)
c
.
asnumpy
(),
np
.
dot
(
a
.
asnumpy
(),
b
.
asnumpy
()),
rtol
=
1e-5
)
verify
()
verify
()
def
test_batch_matmul
():
j
=
16
n
=
1024
l
=
128
m
=
235
A
=
tvm
.
placeholder
((
j
,
n
,
l
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
j
,
l
,
m
),
name
=
'B'
)
C
=
cublas
.
batch_matmul
(
A
,
B
)
s
=
tvm
.
create_schedule
(
C
.
op
)
def
verify
(
target
=
"cuda"
):
if
not
tvm
.
module
.
enabled
(
target
):
print
(
"skip because
%
s is not enabled..."
%
target
)
return
if
not
tvm
.
get_global_func
(
"tvm.contrib.cublas.matmul"
,
True
):
print
(
"skip because extern function is not available"
)
return
ctx
=
tvm
.
gpu
(
0
)
f
=
tvm
.
build
(
s
,
[
A
,
B
,
C
],
target
)
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
j
,
n
,
l
))
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
j
,
l
,
m
))
.
astype
(
B
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
((
j
,
n
,
m
),
dtype
=
C
.
dtype
),
ctx
)
f
(
a
,
b
,
c
)
tvm
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
np
.
matmul
(
a
.
asnumpy
(),
b
.
asnumpy
()),
rtol
=
1e-5
)
verify
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_matmul_add
()
test_matmul_add
()
test_batch_matmul
()
topi/include/topi/contrib/cublas.h
View file @
88f9bfd4
...
@@ -61,6 +61,38 @@ inline Tensor cublas_matmul(const Tensor& lhs,
...
@@ -61,6 +61,38 @@ inline Tensor cublas_matmul(const Tensor& lhs,
},
"C"
,
""
,
{})[
0
];
},
"C"
,
""
,
{})[
0
];
}
}
/*!
* \brief Create an op that multiplies batch matrices
* lhs and rhs with cuBLAS
*
* \param lhs The left matrix operand
* \param rhs The right matrix operand
* \param transa Whether to transpose lhs
* \param transb Whether to transpose rhs
*
* \return The output tensor
*/
inline
Tensor
cublas_batch_matmul
(
const
Tensor
&
lhs
,
const
Tensor
&
rhs
,
bool
transa
,
bool
transb
)
{
auto
b
=
lhs
->
shape
[
0
];
auto
n
=
transa
?
lhs
->
shape
[
2
]
:
lhs
->
shape
[
1
];
auto
m
=
transb
?
rhs
->
shape
[
1
]
:
rhs
->
shape
[
2
];
return
make_extern
(
{
{
b
,
n
,
m
}
},
{
lhs
->
dtype
},
{
lhs
,
rhs
},
[
&
](
Array
<
Buffer
>
ins
,
Array
<
Buffer
>
outs
)
{
return
call_packed
({
Expr
(
"tvm.contrib.cublas.batch_matmul"
),
pack_buffer
(
ins
[
0
]),
pack_buffer
(
ins
[
1
]),
pack_buffer
(
outs
[
0
]),
transa
,
transb
});
},
"C"
,
""
,
{})[
0
];
}
}
// namespace contrib
}
// namespace contrib
}
// namespace topi
}
// namespace topi
...
...
topi/python/topi/cuda/batch_matmul.py
View file @
88f9bfd4
...
@@ -18,10 +18,33 @@
...
@@ -18,10 +18,33 @@
"""cuda batch_matmul operators"""
"""cuda batch_matmul operators"""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
import
tvm
import
tvm
from
tvm.contrib
import
cublas
from
topi.nn
import
batch_matmul
,
batch_matmul_default
from
..
import
generic
from
..
import
generic
from
..util
import
traverse_inline
,
get_const_tuple
,
get_max_power2_factor
from
..util
import
traverse_inline
,
get_const_tuple
,
get_max_power2_factor
@batch_matmul.register
([
"cuda"
,
"gpu"
])
def
batch_matmul_cuda
(
x
,
y
):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
Parameters
----------
x : tvm.Tensor
3-D with shape [batch, M, K]
y : tvm.Tensor
3-D with shape [batch, N, K]
Returns
-------
output : tvm.Tensor
3-D with shape [batch, M, N]
"""
target
=
tvm
.
target
.
current_target
()
if
target
.
target_name
==
"cuda"
and
"cublas"
in
target
.
libs
:
return
cublas
.
batch_matmul
(
x
,
y
,
False
,
True
)
return
batch_matmul_default
(
x
,
y
)
@generic.schedule_batch_matmul.register
([
"cuda"
,
"gpu"
])
@generic.schedule_batch_matmul.register
([
"cuda"
,
"gpu"
])
def
schedule_batch_matmul
(
outs
):
def
schedule_batch_matmul
(
outs
):
...
@@ -38,6 +61,10 @@ def schedule_batch_matmul(outs):
...
@@ -38,6 +61,10 @@ def schedule_batch_matmul(outs):
s: Schedule
s: Schedule
The computation schedule for the op.
The computation schedule for the op.
"""
"""
target
=
tvm
.
target
.
current_target
()
if
target
.
target_name
==
"cuda"
and
"cublas"
in
target
.
libs
:
return
generic
.
schedule_extern
(
outs
)
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment