Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3d1d17e3
Commit
3d1d17e3
authored
May 21, 2019
by
hlu1
Committed by
Leyuan Wang
May 21, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Contrib] cblas batch_matmul (#3210)
parent
21935dcb
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
351 additions
and
92 deletions
+351
-92
cmake/modules/contrib/BLAS.cmake
+5
-1
python/tvm/contrib/cblas.py
+49
-6
src/contrib/cblas/cblas.cc
+131
-40
src/contrib/cblas/gemm_common.h
+94
-34
tests/python/contrib/test_cblas.py
+72
-11
No files found.
cmake/modules/contrib/BLAS.cmake
View file @
3d1d17e3
...
...
@@ -27,7 +27,11 @@ elseif(USE_BLAS STREQUAL "mkl")
if
(
NOT IS_DIRECTORY
${
USE_MKL_PATH
}
)
set
(
USE_MKL_PATH /opt/intel/mkl
)
endif
()
find_library
(
BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS
${
USE_MKL_PATH
}
/lib/
${
USE_MKL_PATH
}
/lib/intel64
)
if
(
APPLE
)
find_library
(
BLAS_LIBRARY NAMES mklml HINTS
${
USE_MKL_PATH
}
/lib/
${
USE_MKL_PATH
}
/lib/intel64
)
elseif
(
UNIX
)
find_library
(
BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS
${
USE_MKL_PATH
}
/lib/
${
USE_MKL_PATH
}
/lib/intel64
)
endif
()
include_directories
(
${
USE_MKL_PATH
}
/include
)
list
(
APPEND TVM_RUNTIME_LINKER_LIBS
${
BLAS_LIBRARY
}
)
list
(
APPEND RUNTIME_SRCS
${
CBLAS_CONTRIB_SRC
}
)
...
...
python/tvm/contrib/cblas.py
View file @
3d1d17e3
...
...
@@ -17,10 +17,10 @@
"""External function interface to BLAS libraries."""
from
__future__
import
absolute_import
as
_abs
from
..
import
api
as
_api
from
..
import
intrin
as
_intrin
from
..
import
api
as
_api
,
intrin
as
_intrin
def
matmul
(
lhs
,
rhs
,
transa
=
False
,
transb
=
False
):
def
matmul
(
lhs
,
rhs
,
transa
=
False
,
transb
=
False
,
**
kwargs
):
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS
This function serves as an example on how to call external libraries.
...
...
@@ -44,7 +44,50 @@ def matmul(lhs, rhs, transa=False, transb=False):
n
=
lhs
.
shape
[
1
]
if
transa
else
lhs
.
shape
[
0
]
m
=
rhs
.
shape
[
0
]
if
transb
else
rhs
.
shape
[
1
]
return
_api
.
extern
(
(
n
,
m
),
[
lhs
,
rhs
],
(
n
,
m
),
[
lhs
,
rhs
],
lambda
ins
,
outs
:
_intrin
.
call_packed
(
"tvm.contrib.cblas.matmul"
,
ins
[
0
],
ins
[
1
],
outs
[
0
],
transa
,
transb
),
name
=
"C"
,
**
kwargs
)
def
batch_matmul
(
lhs
,
rhs
,
transa
=
False
,
transb
=
False
,
iterative
=
False
,
**
kwargs
):
"""Create an extern op that compute batched matrix mult of A and rhs with CBLAS
This function serves as an example on how to call external libraries.
Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs
Returns
-------
C : Tensor
The result tensor.
"""
b
=
lhs
.
shape
[
0
]
n
=
lhs
.
shape
[
2
]
if
transa
else
lhs
.
shape
[
1
]
m
=
rhs
.
shape
[
1
]
if
transb
else
rhs
.
shape
[
2
]
return
_api
.
extern
(
(
b
,
n
,
m
),
[
lhs
,
rhs
],
lambda
ins
,
outs
:
_intrin
.
call_packed
(
"tvm.contrib.cblas.matmul"
,
ins
[
0
],
ins
[
1
],
outs
[
0
],
transa
,
transb
),
name
=
"C"
)
"tvm.contrib.cblas.batch_matmul"
if
not
iterative
else
"tvm.contrib.cblas.batch_matmul_iterative"
,
ins
[
0
],
ins
[
1
],
outs
[
0
],
transa
,
transb
,
),
name
=
"C"
,
**
kwargs
)
src/contrib/cblas/cblas.cc
View file @
3d1d17e3
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -21,12 +21,11 @@
* Copyright (c) 2017 by Contributors
* \file Use external cblas library call.
*/
#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
#include "gemm_common.h"
extern
"C"
{
#if USE_MKL_BLAS == 1
#include <mkl_cblas.h>
...
...
@@ -40,56 +39,148 @@ namespace contrib {
using
namespace
runtime
;
inline
CBLAS_TRANSPOSE
BooleanToTranspose
(
bool
trans
)
{
return
trans
?
CblasTrans
:
CblasNoTrans
;
}
inline
CBLAS_TRANSPOSE
BooleanToTranspose
(
bool
trans
)
{
return
trans
?
CblasTrans
:
CblasNoTrans
;
}
struct
CblasSgemmOp
{
typedef
float
TDatatype
;
void
operator
()(
bool
ta
,
bool
tb
,
int
M
,
int
N
,
int
K
,
float
alpha
,
float
*
A
,
int
lda
,
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
)
{
cblas_sgemm
(
CblasColMajor
,
BooleanToTranspose
(
ta
),
BooleanToTranspose
(
tb
),
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
void
operator
()(
bool
ta
,
bool
tb
,
int
M
,
int
N
,
int
K
,
float
alpha
,
float
*
A
,
int
lda
,
float
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
)
{
cblas_sgemm
(
CblasColMajor
,
BooleanToTranspose
(
ta
),
BooleanToTranspose
(
tb
),
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
};
struct
CblasDgemmOp
{
typedef
double
TDatatype
;
void
operator
()(
bool
ta
,
bool
tb
,
int
M
,
int
N
,
int
K
,
double
alpha
,
double
*
A
,
int
lda
,
double
*
B
,
int
ldb
,
double
beta
,
double
*
C
,
int
ldc
)
{
cblas_dgemm
(
CblasColMajor
,
BooleanToTranspose
(
ta
),
BooleanToTranspose
(
tb
),
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
void
operator
()(
bool
ta
,
bool
tb
,
int
M
,
int
N
,
int
K
,
double
alpha
,
double
*
A
,
int
lda
,
double
*
B
,
int
ldb
,
double
beta
,
double
*
C
,
int
ldc
)
{
cblas_dgemm
(
CblasColMajor
,
BooleanToTranspose
(
ta
),
BooleanToTranspose
(
tb
),
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
};
struct
CblasSgemmBatchOp
{
typedef
float
TDatatype
;
void
operator
()(
int
batch_size
,
bool
ta
,
bool
tb
,
int
M
,
int
N
,
int
K
,
float
alpha
,
float
*
A
,
int
a_stride
,
int
lda
,
float
*
B
,
int
b_stride
,
int
ldb
,
float
beta
,
float
*
C
,
int
c_stride
,
int
ldc
)
{
CBLAS_TRANSPOSE
trans_a
=
BooleanToTranspose
(
ta
);
CBLAS_TRANSPOSE
trans_b
=
BooleanToTranspose
(
tb
);
#if USE_MKL_BLAS == 1
std
::
vector
<
const
float
*>
A_array
(
batch_size
);
std
::
vector
<
const
float
*>
B_array
(
batch_size
);
std
::
vector
<
float
*>
C_array
(
batch_size
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
A_array
[
i
]
=
A
+
i
*
a_stride
;
B_array
[
i
]
=
B
+
i
*
b_stride
;
C_array
[
i
]
=
C
+
i
*
c_stride
;
}
cblas_sgemm_batch
(
CblasColMajor
,
&
trans_a
,
&
trans_b
,
&
M
,
&
N
,
&
K
,
&
alpha
,
A_array
.
data
(),
&
lda
,
B_array
.
data
(),
&
ldb
,
&
beta
,
C_array
.
data
(),
&
ldc
,
1
,
&
batch_size
);
#else
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
cblas_sgemm
(
CblasColMajor
,
trans_a
,
trans_b
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
A
+=
a_stride
;
B
+=
b_stride
;
C
+=
c_stride
;
}
#endif
}
};
struct
CblasSgemmBatchIterativeOp
{
typedef
float
TDatatype
;
void
operator
()(
int
batch_size
,
bool
ta
,
bool
tb
,
int
M
,
int
N
,
int
K
,
float
alpha
,
float
*
A
,
int
a_stride
,
int
lda
,
float
*
B
,
int
b_stride
,
int
ldb
,
float
beta
,
float
*
C
,
int
c_stride
,
int
ldc
)
{
CBLAS_TRANSPOSE
trans_a
=
BooleanToTranspose
(
ta
);
CBLAS_TRANSPOSE
trans_b
=
BooleanToTranspose
(
tb
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
cblas_sgemm
(
CblasColMajor
,
trans_a
,
trans_b
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
A
+=
a_stride
;
B
+=
b_stride
;
C
+=
c_stride
;
}
}
};
struct
CblasDgemmBatchOp
{
typedef
double
TDatatype
;
void
operator
()(
int
batch_size
,
bool
ta
,
bool
tb
,
int
M
,
int
N
,
int
K
,
double
alpha
,
double
*
A
,
int
a_stride
,
int
lda
,
double
*
B
,
int
b_stride
,
int
ldb
,
double
beta
,
double
*
C
,
int
c_stride
,
int
ldc
)
{
CBLAS_TRANSPOSE
trans_a
=
BooleanToTranspose
(
ta
);
CBLAS_TRANSPOSE
trans_b
=
BooleanToTranspose
(
tb
);
#if USE_MKL_BLAS == 1
std
::
vector
<
const
double
*>
A_array
(
batch_size
);
std
::
vector
<
const
double
*>
B_array
(
batch_size
);
std
::
vector
<
double
*>
C_array
(
batch_size
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
A_array
[
i
]
=
A
+
i
*
a_stride
;
B_array
[
i
]
=
B
+
i
*
b_stride
;
C_array
[
i
]
=
C
+
i
*
c_stride
;
}
cblas_dgemm_batch
(
CblasColMajor
,
&
trans_a
,
&
trans_b
,
&
M
,
&
N
,
&
K
,
&
alpha
,
A_array
.
data
(),
&
lda
,
B_array
.
data
(),
&
ldb
,
&
beta
,
C_array
.
data
(),
&
ldc
,
1
,
&
batch_size
);
#else
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
cblas_dgemm
(
CblasColMajor
,
trans_a
,
trans_b
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
A
+=
a_stride
;
B
+=
b_stride
;
C
+=
c_stride
;
}
#endif
}
};
struct
CblasDgemmBatchIterativeOp
{
typedef
double
TDatatype
;
void
operator
()(
int
batch_size
,
bool
ta
,
bool
tb
,
int
M
,
int
N
,
int
K
,
double
alpha
,
double
*
A
,
int
a_stride
,
int
lda
,
double
*
B
,
int
b_stride
,
int
ldb
,
double
beta
,
double
*
C
,
int
c_stride
,
int
ldc
)
{
CBLAS_TRANSPOSE
trans_a
=
BooleanToTranspose
(
ta
);
CBLAS_TRANSPOSE
trans_b
=
BooleanToTranspose
(
tb
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
cblas_dgemm
(
CblasColMajor
,
trans_a
,
trans_b
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
A
+=
a_stride
;
B
+=
b_stride
;
C
+=
c_stride
;
}
}
};
// matrix multiplication for row major
TVM_REGISTER_GLOBAL
(
"tvm.contrib.cblas.matmul"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
DLTensor
*
A
=
args
[
0
];
CHECK
(
TypeMatch
(
A
->
dtype
,
kDLFloat
,
32
)
||
TypeMatch
(
A
->
dtype
,
kDLFloat
,
64
));
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
DLTensor
*
A
=
args
[
0
];
CHECK
(
TypeMatch
(
A
->
dtype
,
kDLFloat
,
32
)
||
TypeMatch
(
A
->
dtype
,
kDLFloat
,
64
));
if
(
TypeMatch
(
A
->
dtype
,
kDLFloat
,
32
))
CallGemm
(
args
,
ret
,
CblasSgemmOp
());
else
CallGemm
(
args
,
ret
,
CblasDgemmOp
());
});
if
(
TypeMatch
(
A
->
dtype
,
kDLFloat
,
32
))
CallGemm
(
args
,
ret
,
CblasSgemmOp
());
else
CallGemm
(
args
,
ret
,
CblasDgemmOp
());
});
TVM_REGISTER_GLOBAL
(
"tvm.contrib.cblas.batch_matmul"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
DLTensor
*
A
=
args
[
0
];
CHECK
(
TypeMatch
(
A
->
dtype
,
kDLFloat
,
32
)
||
TypeMatch
(
A
->
dtype
,
kDLFloat
,
64
));
if
(
TypeMatch
(
A
->
dtype
,
kDLFloat
,
32
))
{
CallBatchGemm
(
args
,
ret
,
CblasSgemmBatchOp
());
}
else
{
CallBatchGemm
(
args
,
ret
,
CblasDgemmBatchOp
());
}
});
TVM_REGISTER_GLOBAL
(
"tvm.contrib.cblas.batch_matmul_iterative"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
DLTensor
*
A
=
args
[
0
];
CHECK
(
TypeMatch
(
A
->
dtype
,
kDLFloat
,
32
)
||
TypeMatch
(
A
->
dtype
,
kDLFloat
,
64
));
if
(
TypeMatch
(
A
->
dtype
,
kDLFloat
,
32
))
{
CallBatchGemm
(
args
,
ret
,
CblasSgemmBatchIterativeOp
());
}
else
{
CallBatchGemm
(
args
,
ret
,
CblasDgemmBatchIterativeOp
());
}
});
}
// namespace contrib
}
// namespace tvm
src/contrib/cblas/gemm_common.h
View file @
3d1d17e3
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -22,16 +22,17 @@
* \file tvm/contrib/gemm.h
* \brief Shared implementation of gemm
*/
#ifndef TVM_CONTRIB_CBLAS_GEMM_COMMON_H_
#define TVM_CONTRIB_CBLAS_GEMM_COMMON_H_
#pragma once
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <algorithm>
namespace
tvm
{
namespace
contrib
{
using
namespace
runtime
;
inline
int
ColumnStride
(
DLTensor
*
tensor
)
{
inline
int
ColumnStride
(
DLTensor
*
tensor
)
{
// If the tensor itself is transposed then it will have strides
// backward from what we expect. Regardless, the max of the strides
// (the other stride is 1) is the column stride.
...
...
@@ -42,8 +43,7 @@ inline int ColumnStride(DLTensor* tensor) {
}
}
inline
int
ElementStride
(
DLTensor
*
tensor
)
{
inline
int
ElementStride
(
DLTensor
*
tensor
)
{
if
(
tensor
->
strides
)
{
return
std
::
min
(
tensor
->
strides
[
0
],
tensor
->
strides
[
1
]);
}
else
{
...
...
@@ -51,29 +51,26 @@ inline int ElementStride(DLTensor* tensor) {
}
}
// Reversed strides indicates an in-place transpose operation.
inline
bool
IsInPlaceTransposed
(
DLTensor
*
tensor
)
{
inline
bool
IsInPlaceTransposed
(
DLTensor
*
tensor
)
{
return
tensor
->
strides
&&
(
tensor
->
strides
[
1
]
>
tensor
->
strides
[
0
]);
}
inline
int
RowCount
(
DLTensor
*
tensor
,
bool
trans
)
{
inline
int
RowCount
(
DLTensor
*
tensor
,
bool
trans
)
{
return
tensor
->
shape
[
trans
?
1
:
0
];
}
inline
int
ColumnCount
(
DLTensor
*
tensor
,
bool
trans
)
{
inline
int
ColumnCount
(
DLTensor
*
tensor
,
bool
trans
)
{
return
tensor
->
shape
[
trans
?
0
:
1
];
}
// Call a column major blas. Note that data is stored in tvm as row
// major, so this we switch the arguments.
template
<
typename
TGemmOp
>
template
<
typename
TGemmOp
>
inline
void
CallGemm
(
TVMArgs
args
,
TVMRetValue
*
ret
,
TGemmOp
op
)
{
DLTensor
*
A
=
args
[
0
];
DLTensor
*
B
=
args
[
1
];
DLTensor
*
C
=
args
[
2
];
DLTensor
*
A
=
args
[
0
];
DLTensor
*
B
=
args
[
1
];
DLTensor
*
C
=
args
[
2
];
bool
transa
=
args
[
3
];
bool
transb
=
args
[
4
];
int
bit_depth
=
sizeof
(
typename
TGemmOp
::
TDatatype
)
*
8
;
...
...
@@ -96,25 +93,88 @@ inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) {
CHECK
(
TypeMatch
(
C
->
dtype
,
kDLFloat
,
bit_depth
));
double
alpha
=
args
.
size
()
>
5
?
args
[
5
]
:
1
.
0
;
double
beta
=
args
.
size
()
>
6
?
args
[
6
]
:
0
.
0
;
op
(
transb
,
transa
,
ColumnCount
(
B
,
transb
),
RowCount
(
A
,
transa
),
ColumnCount
(
A
,
transa
),
static_cast
<
float
>
(
alpha
),
reinterpret_cast
<
typename
TGemmOp
::
TDatatype
*>
(
static_cast
<
char
*>
(
B
->
data
)
+
B
->
byte_offset
),
op
(
transb
,
transa
,
ColumnCount
(
B
,
transb
),
RowCount
(
A
,
transa
),
ColumnCount
(
A
,
transa
),
static_cast
<
float
>
(
alpha
),
reinterpret_cast
<
typename
TGemmOp
::
TDatatype
*>
(
static_cast
<
char
*>
(
B
->
data
)
+
B
->
byte_offset
),
ColumnStride
(
B
),
reinterpret_cast
<
typename
TGemmOp
::
TDatatype
*>
(
static_cast
<
char
*>
(
A
->
data
)
+
A
->
byte_offset
),
ColumnStride
(
A
),
static_cast
<
float
>
(
beta
),
reinterpret_cast
<
typename
TGemmOp
::
TDatatype
*>
(
static_cast
<
char
*>
(
C
->
data
)
+
C
->
byte_offset
),
reinterpret_cast
<
typename
TGemmOp
::
TDatatype
*>
(
static_cast
<
char
*>
(
A
->
data
)
+
A
->
byte_offset
),
ColumnStride
(
A
),
static_cast
<
float
>
(
beta
),
reinterpret_cast
<
typename
TGemmOp
::
TDatatype
*>
(
static_cast
<
char
*>
(
C
->
data
)
+
C
->
byte_offset
),
ColumnStride
(
C
));
}
inline
int
ColumnStride3D
(
DLTensor
*
tensor
)
{
// If the tensor itself is transposed then it will have strides
// backward from what we expect. Regardless, the max of the strides
// (the other stride is 1) is the column stride.
if
(
tensor
->
strides
)
{
return
std
::
max
(
tensor
->
strides
[
1
],
tensor
->
strides
[
2
]);
}
else
{
return
tensor
->
shape
[
2
];
}
}
inline
int
ElementStride3D
(
DLTensor
*
tensor
)
{
if
(
tensor
->
strides
)
{
return
std
::
min
(
tensor
->
strides
[
1
],
tensor
->
strides
[
2
]);
}
else
{
return
1
;
}
}
// Reversed strides indicates an in-place transpose operation.
inline
bool
IsInPlaceTransposed3D
(
DLTensor
*
tensor
)
{
return
tensor
->
strides
&&
(
tensor
->
strides
[
2
]
>
tensor
->
strides
[
1
]);
}
inline
int
BatchCount3D
(
DLTensor
*
tensor
)
{
return
tensor
->
shape
[
0
];
}
inline
int
RowCount3D
(
DLTensor
*
tensor
,
bool
trans
)
{
return
tensor
->
shape
[
trans
?
2
:
1
];
}
inline
int
ColumnCount3D
(
DLTensor
*
tensor
,
bool
trans
)
{
return
tensor
->
shape
[
trans
?
1
:
2
];
}
template
<
typename
TBatchGemmOp
>
inline
void
CallBatchGemm
(
TVMArgs
args
,
TVMRetValue
*
ret
,
TBatchGemmOp
op
)
{
using
DType
=
typename
TBatchGemmOp
::
TDatatype
;
DLTensor
*
A
=
args
[
0
];
DLTensor
*
B
=
args
[
1
];
DLTensor
*
C
=
args
[
2
];
bool
transa
=
args
[
3
];
bool
transb
=
args
[
4
];
int
bit_depth
=
sizeof
(
DType
)
*
8
;
CHECK_EQ
(
A
->
ndim
,
3
);
CHECK_EQ
(
B
->
ndim
,
3
);
CHECK_EQ
(
C
->
ndim
,
3
);
int
batch_size
=
BatchCount3D
(
A
);
CHECK_EQ
(
BatchCount3D
(
B
),
batch_size
);
CHECK_EQ
(
BatchCount3D
(
C
),
batch_size
);
CHECK_EQ
(
ElementStride
(
A
),
1
);
CHECK_EQ
(
ElementStride
(
B
),
1
);
CHECK_EQ
(
ElementStride
(
C
),
1
);
// C can never be transposed.
CHECK
(
!
IsInPlaceTransposed3D
(
C
));
// Reversed strides indicates an in-place transpose operation.
transa
=
IsInPlaceTransposed3D
(
A
)
?
!
transa
:
transa
;
transb
=
IsInPlaceTransposed3D
(
B
)
?
!
transb
:
transb
;
CHECK
(
TypeMatch
(
B
->
dtype
,
kDLFloat
,
bit_depth
));
CHECK
(
TypeMatch
(
C
->
dtype
,
kDLFloat
,
bit_depth
));
double
alpha
=
args
.
size
()
>
5
?
args
[
5
]
:
1
.
0
;
double
beta
=
args
.
size
()
>
6
?
args
[
6
]
:
0
.
0
;
const
int
A_size
=
A
->
shape
[
1
]
*
A
->
shape
[
2
];
const
int
B_size
=
B
->
shape
[
1
]
*
B
->
shape
[
2
];
const
int
C_size
=
C
->
shape
[
1
]
*
C
->
shape
[
2
];
DType
*
A_data
=
reinterpret_cast
<
typename
TBatchGemmOp
::
TDatatype
*>
(
static_cast
<
char
*>
(
A
->
data
)
+
A
->
byte_offset
);
DType
*
B_data
=
reinterpret_cast
<
typename
TBatchGemmOp
::
TDatatype
*>
(
static_cast
<
char
*>
(
B
->
data
)
+
B
->
byte_offset
);
DType
*
C_data
=
reinterpret_cast
<
typename
TBatchGemmOp
::
TDatatype
*>
(
static_cast
<
char
*>
(
C
->
data
)
+
C
->
byte_offset
);
op
(
batch_size
,
transb
,
transa
,
ColumnCount3D
(
B
,
transb
),
RowCount3D
(
A
,
transa
),
ColumnCount3D
(
A
,
transa
),
static_cast
<
float
>
(
alpha
),
B_data
,
B_size
,
ColumnStride3D
(
B
),
A_data
,
A_size
,
ColumnStride3D
(
A
),
static_cast
<
float
>
(
beta
),
C_data
,
C_size
,
ColumnStride3D
(
C
));
}
}
// namespace contrib
}
// namespace tvm
#endif // TVM_CONTRIB_CBLAS_GEMM_COMMON_H_
tests/python/contrib/test_cblas.py
View file @
3d1d17e3
...
...
@@ -16,19 +16,26 @@
# under the License.
import
tvm
import
numpy
as
np
import
topi.testing
from
tvm.contrib
import
cblas
def
test_matmul_add
():
n
=
1024
l
=
128
m
=
235
bias
=
tvm
.
var
(
'bias'
,
dtype
=
tvm
.
float32
)
A
=
tvm
.
placeholder
((
n
,
l
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
l
,
m
),
name
=
'B'
)
C
=
cblas
.
matmul
(
A
,
B
)
def
verify_matmul_add
(
m
,
l
,
n
,
transa
=
False
,
transb
=
False
,
dtype
=
tvm
.
float32
):
bias
=
tvm
.
var
(
'bias'
,
dtype
=
dtype
)
ashape
=
(
l
,
n
)
if
transa
else
(
n
,
l
)
bshape
=
(
m
,
l
)
if
transb
else
(
l
,
m
)
A
=
tvm
.
placeholder
(
ashape
,
name
=
'A'
,
dtype
=
dtype
)
B
=
tvm
.
placeholder
(
bshape
,
name
=
'B'
,
dtype
=
dtype
)
C
=
cblas
.
matmul
(
A
,
B
,
transa
,
transb
)
D
=
tvm
.
compute
(
C
.
shape
,
lambda
i
,
j
:
C
[
i
,
j
]
+
bias
,
name
=
"D"
)
s
=
tvm
.
create_schedule
(
D
.
op
)
def
get_numpy
(
a
,
b
,
bb
,
transa
,
transb
):
if
transa
:
a
=
a
.
transpose
()
if
transb
:
b
=
b
.
transpose
()
return
np
.
dot
(
a
,
b
)
+
bb
def
verify
(
target
=
"llvm"
):
if
not
tvm
.
module
.
enabled
(
target
):
print
(
"skip because
%
s is not enabled..."
%
target
)
...
...
@@ -38,15 +45,69 @@ def test_matmul_add():
return
ctx
=
tvm
.
cpu
(
0
)
f
=
tvm
.
build
(
s
,
[
A
,
B
,
D
,
bias
],
target
)
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
n
,
l
)
)
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
l
,
m
)
)
.
astype
(
B
.
dtype
),
ctx
)
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
ashape
)
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
bshape
)
.
astype
(
B
.
dtype
),
ctx
)
d
=
tvm
.
nd
.
array
(
np
.
zeros
((
n
,
m
),
dtype
=
D
.
dtype
),
ctx
)
bb
=
10.0
f
(
a
,
b
,
d
,
bb
)
tvm
.
testing
.
assert_allclose
(
d
.
asnumpy
(),
np
.
dot
(
a
.
asnumpy
(),
b
.
asnumpy
())
+
bb
,
rtol
=
1e-5
)
d
.
asnumpy
(),
get_numpy
(
a
.
asnumpy
(),
b
.
asnumpy
(),
bb
,
transa
,
transb
),
rtol
=
1e-5
)
verify
()
def
test_matmul_add
():
verify_matmul_add
(
235
,
128
,
1024
)
verify_matmul_add
(
235
,
128
,
1024
,
True
,
False
)
verify_matmul_add
(
235
,
128
,
1024
,
False
,
True
)
verify_matmul_add
(
235
,
128
,
1024
,
True
,
True
)
verify_matmul_add
(
1
,
16
,
4
)
verify_matmul_add
(
1
,
16
,
3
,
True
,
False
)
verify_matmul_add
(
1
,
16
,
3
,
False
,
False
)
verify_matmul_add
(
1
,
16
,
3
,
True
,
True
)
def
verify_batch_matmul
(
batch
,
m
,
l
,
n
,
transa
=
False
,
transb
=
False
,
iterative
=
False
,
dtype
=
tvm
.
float32
):
ashape
=
(
batch
,
l
,
n
)
if
transa
else
(
batch
,
n
,
l
)
bshape
=
(
batch
,
m
,
l
)
if
transb
else
(
batch
,
l
,
m
)
A
=
tvm
.
placeholder
(
ashape
,
name
=
'A'
,
dtype
=
dtype
)
B
=
tvm
.
placeholder
(
bshape
,
name
=
'B'
,
dtype
=
dtype
)
C
=
cblas
.
batch_matmul
(
A
,
B
,
transa
,
transb
)
D
=
tvm
.
compute
(
C
.
shape
,
lambda
k
,
i
,
j
:
C
[
k
,
i
,
j
],
name
=
"D"
)
s
=
tvm
.
create_schedule
(
D
.
op
)
def
get_numpy
(
a
,
b
,
transa
,
transb
):
if
transa
:
a
=
a
.
transpose
(
0
,
2
,
1
)
if
not
transb
:
b
=
b
.
transpose
(
0
,
2
,
1
)
return
topi
.
testing
.
batch_matmul
(
a
,
b
)
def
verify
(
target
=
"llvm"
):
if
not
tvm
.
module
.
enabled
(
target
):
print
(
"skip because
%
s is not enabled..."
%
target
)
return
if
not
tvm
.
get_global_func
(
"tvm.contrib.cblas.matmul"
,
True
):
print
(
"skip because extern function is not available"
)
return
ctx
=
tvm
.
cpu
(
0
)
f
=
tvm
.
build
(
s
,
[
A
,
B
,
D
],
target
)
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
ashape
)
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
bshape
)
.
astype
(
B
.
dtype
),
ctx
)
d
=
tvm
.
nd
.
array
(
np
.
zeros
((
batch
,
n
,
m
),
dtype
=
D
.
dtype
),
ctx
)
f
(
a
,
b
,
d
)
tvm
.
testing
.
assert_allclose
(
d
.
asnumpy
(),
get_numpy
(
a
.
asnumpy
(),
b
.
asnumpy
(),
transa
,
transb
),
rtol
=
1e-5
)
verify
()
def
test_batch_matmul
():
verify_batch_matmul
(
16
,
235
,
128
,
1024
)
verify_batch_matmul
(
16
,
235
,
128
,
1024
,
True
,
False
)
verify_batch_matmul
(
16
,
235
,
128
,
1024
,
False
,
True
)
verify_batch_matmul
(
16
,
235
,
128
,
1024
,
True
,
True
)
verify_batch_matmul
(
1
,
1
,
16
,
3
)
verify_batch_matmul
(
1
,
1
,
16
,
3
,
True
,
False
)
verify_batch_matmul
(
1
,
1
,
16
,
3
,
False
,
False
)
verify_batch_matmul
(
1
,
1
,
16
,
3
,
True
,
True
)
verify_batch_matmul
(
1
,
1
,
16
,
3
,
iterative
=
True
)
if
__name__
==
"__main__"
:
test_matmul_add
()
test_batch_matmul
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment