Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
b14bb7f9
Commit
b14bb7f9
authored
Sep 27, 2018
by
Sergey Mironov
Committed by
Tianqi Chen
Sep 26, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TOPI] Access topi::matmul from Python (#1744)
parent
be77cf19
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
126 additions
and
59 deletions
+126
-59
nnvm/src/top/tensor/matrix_op.cc
+1
-1
python/tvm/api.py
+6
-6
python/tvm/tag.py
+16
-5
topi/include/topi/nn.h
+0
-31
topi/include/topi/transform.h
+31
-0
topi/python/topi/reduction.py
+0
-8
topi/python/topi/tensor.py
+0
-5
topi/python/topi/transform.py
+19
-3
topi/src/topi.cc
+9
-0
topi/tests/python/test_topi_matmul.py
+44
-0
No files found.
nnvm/src/top/tensor/matrix_op.cc
View file @
b14bb7f9
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
* \file matrix_op.cc
* \file matrix_op.cc
* \brief Matrix operators
* \brief Matrix operators
*/
*/
#include <topi/
nn
.h>
#include <topi/
transform
.h>
#include <nnvm/op.h>
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/op_attr_types.h>
...
...
python/tvm/api.py
View file @
b14bb7f9
...
@@ -238,10 +238,10 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
...
@@ -238,10 +238,10 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
tensor: Tensor
tensor: Tensor
The created tensor
The created tensor
"""
"""
if
_tag
.
TagScope
.
current
is
not
None
:
if
_tag
.
TagScope
.
get_current
()
is
not
None
:
if
tag
!=
""
:
if
tag
!=
""
:
raise
ValueError
(
"nested tag is not allowed for now"
)
raise
ValueError
(
"nested tag is not allowed for now"
)
tag
=
_tag
.
TagScope
.
current
.
tag
tag
=
_tag
.
TagScope
.
get_current
()
.
tag
shape
=
(
shape
,)
if
isinstance
(
shape
,
_expr
.
Expr
)
else
shape
shape
=
(
shape
,)
if
isinstance
(
shape
,
_expr
.
Expr
)
else
shape
ndim
=
len
(
shape
)
ndim
=
len
(
shape
)
code
=
fcompute
.
__code__
code
=
fcompute
.
__code__
...
@@ -311,10 +311,10 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attr
...
@@ -311,10 +311,10 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attr
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state, X)
res = tvm.scan(s_init, s_update, s_state, X)
"""
"""
if
_tag
.
TagScope
.
current
is
not
None
:
if
_tag
.
TagScope
.
get_current
()
is
not
None
:
if
tag
!=
""
:
if
tag
!=
""
:
raise
ValueError
(
"nested tag is not allowed for now"
)
raise
ValueError
(
"nested tag is not allowed for now"
)
tag
=
_tag
.
TagScope
.
current
.
tag
tag
=
_tag
.
TagScope
.
get_current
()
.
tag
if
isinstance
(
init
,
_tensor
.
Tensor
):
if
isinstance
(
init
,
_tensor
.
Tensor
):
init
=
[
init
]
init
=
[
init
]
if
isinstance
(
update
,
_tensor
.
Tensor
):
if
isinstance
(
update
,
_tensor
.
Tensor
):
...
@@ -407,10 +407,10 @@ def extern(shape,
...
@@ -407,10 +407,10 @@ def extern(shape,
"tvm.contrib.cblas.matmul",
"tvm.contrib.cblas.matmul",
ins[0], ins[1], outs[0], 0, 0), name="C")
ins[0], ins[1], outs[0], 0, 0), name="C")
"""
"""
if
_tag
.
TagScope
.
current
is
not
None
:
if
_tag
.
TagScope
.
get_current
()
is
not
None
:
if
tag
!=
""
:
if
tag
!=
""
:
raise
ValueError
(
"nested tag is not allowed for now"
)
raise
ValueError
(
"nested tag is not allowed for now"
)
tag
=
_tag
.
TagScope
.
current
.
tag
tag
=
_tag
.
TagScope
.
get_current
()
.
tag
shape
=
(
shape
,)
if
isinstance
(
shape
,
(
_expr
.
Expr
,
_Integral
))
else
shape
shape
=
(
shape
,)
if
isinstance
(
shape
,
(
_expr
.
Expr
,
_Integral
))
else
shape
shape
=
[
shape
]
if
isinstance
(
shape
[
0
],
(
_expr
.
Expr
,
_Integral
))
else
shape
shape
=
[
shape
]
if
isinstance
(
shape
[
0
],
(
_expr
.
Expr
,
_Integral
))
else
shape
if
in_buffers
is
not
None
:
if
in_buffers
is
not
None
:
...
...
python/tvm/tag.py
View file @
b14bb7f9
"""Tag class for TVM operators."""
"""Tag class for TVM operators."""
import
warnings
from
._ffi.base
import
decorate
from
._ffi.base
import
decorate
class
TagScope
(
object
):
class
TagScope
(
object
):
"""Tag scope object to set tag for operators, working as context
"""Tag scope object to set tag for operators, working as context
manager and decorator both. See also tag_scope.
manager and decorator both. See also tag_scope.
"""
"""
current
=
None
_current
=
None
@classmethod
def
get_current
(
cls
):
if
cls
.
_current
:
cls
.
_current
.
accessed
=
True
return
cls
.
_current
def
__init__
(
self
,
tag
):
def
__init__
(
self
,
tag
):
self
.
_old_scope
=
None
self
.
_old_scope
=
None
self
.
tag
=
tag
self
.
tag
=
tag
self
.
accessed
=
False
def
__enter__
(
self
):
def
__enter__
(
self
):
if
TagScope
.
current
is
not
None
:
if
TagScope
.
_
current
is
not
None
:
raise
ValueError
(
"nested op_tag is not allowed for now"
)
raise
ValueError
(
"nested op_tag is not allowed for now"
)
self
.
_old_scope
=
TagScope
.
current
self
.
_old_scope
=
TagScope
.
_
current
TagScope
.
current
=
self
TagScope
.
_
current
=
self
return
self
return
self
def
__exit__
(
self
,
ptype
,
value
,
trace
):
def
__exit__
(
self
,
ptype
,
value
,
trace
):
assert
self
.
_old_scope
is
None
assert
self
.
_old_scope
is
None
TagScope
.
current
=
self
.
_old_scope
if
not
self
.
accessed
:
warnings
.
warn
(
"Tag '
%
s' declared via TagScope was not used."
%
(
self
.
tag
,))
TagScope
.
_current
=
self
.
_old_scope
def
__call__
(
self
,
fdecl
):
def
__call__
(
self
,
fdecl
):
def
tagged_fdecl
(
func
,
*
args
,
**
kwargs
):
def
tagged_fdecl
(
func
,
*
args
,
**
kwargs
):
...
...
topi/include/topi/nn.h
View file @
b14bb7f9
...
@@ -201,37 +201,6 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
...
@@ -201,37 +201,6 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
}
}
/*!
/*!
* \brief Creates an operation that calculates a matrix multiplication
* (row-major notation):
* A(i, k) * B(k, j), if trans_a == trans_b
* the usual transposed combinations, otherwise
*
* \param A The matrix A
* \param B The matrix B
* \param trans_a Is A's layout transposed?
* \param trans_b Is B's layout transposed?
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the matmul operation
*/
inline
tvm
::
Tensor
matmul
(
const
tvm
::
Tensor
&
A
,
const
tvm
::
Tensor
&
B
,
bool
trans_a
=
false
,
bool
trans_b
=
false
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
kMatMul
)
{
tvm
::
Array
<
tvm
::
Expr
>
output_shape
{
A
->
shape
[
trans_a
?
1
:
0
],
B
->
shape
[
trans_b
?
0
:
1
]};
auto
k
=
tvm
::
reduce_axis
(
tvm
::
Range
{
0
,
A
->
shape
[
trans_a
?
0
:
1
]},
"k"
);
auto
l
=
[
&
](
tvm
::
Var
i
,
tvm
::
Var
j
)
{
return
tvm
::
sum
((
trans_a
?
A
[
k
][
i
]
:
A
[
i
][
k
])
*
(
trans_b
?
B
[
j
][
k
]
:
B
[
k
][
j
]),
{
k
});
};
return
tvm
::
compute
(
output_shape
,
l
,
name
,
tag
);
}
/*!
* \brief Creates an operation that performs a 2-D convolution with an
* \brief Creates an operation that performs a 2-D convolution with an
* NCHW-layout
* NCHW-layout
*
*
...
...
topi/include/topi/transform.h
View file @
b14bb7f9
...
@@ -627,6 +627,37 @@ inline Tensor where(const Tensor& condition,
...
@@ -627,6 +627,37 @@ inline Tensor where(const Tensor& condition,
return
out
;
return
out
;
}
}
/*!
* \brief Creates an operation that calculates a matrix multiplication
* (row-major notation):
* A(i, k) * B(k, j), if trans_a == trans_b
* the usual transposed combinations, otherwise
*
* \param A The matrix A
* \param B The matrix B
* \param trans_a Is A's layout transposed?
* \param trans_b Is B's layout transposed?
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the matmul operation
*/
inline
tvm
::
Tensor
matmul
(
const
tvm
::
Tensor
&
A
,
const
tvm
::
Tensor
&
B
,
bool
trans_a
=
false
,
bool
trans_b
=
false
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
kMatMul
)
{
tvm
::
Array
<
tvm
::
Expr
>
output_shape
{
A
->
shape
[
trans_a
?
1
:
0
],
B
->
shape
[
trans_b
?
0
:
1
]};
auto
k
=
tvm
::
reduce_axis
(
tvm
::
Range
{
0
,
A
->
shape
[
trans_a
?
0
:
1
]},
"k"
);
auto
l
=
[
&
](
tvm
::
Var
i
,
tvm
::
Var
j
)
{
return
tvm
::
sum
((
trans_a
?
A
[
k
][
i
]
:
A
[
i
][
k
])
*
(
trans_b
?
B
[
j
][
k
]
:
B
[
k
][
j
]),
{
k
});
};
return
tvm
::
compute
(
output_shape
,
l
,
name
,
tag
);
}
}
// namespace topi
}
// namespace topi
#endif // TOPI_TRANSFORM_H_
#endif // TOPI_TRANSFORM_H_
topi/python/topi/reduction.py
View file @
b14bb7f9
# pylint: disable=redefined-builtin,consider-using-enumerate,no-member
# pylint: disable=redefined-builtin,consider-using-enumerate,no-member
"""Reduce operators"""
"""Reduce operators"""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
import
tvm
from
.
import
cpp
from
.
import
cpp
from
.
import
tag
def
_get_real_axis
(
ndim
,
axis
):
def
_get_real_axis
(
ndim
,
axis
):
if
axis
is
None
:
if
axis
is
None
:
...
@@ -26,7 +24,6 @@ def _get_real_axis(ndim, axis):
...
@@ -26,7 +24,6 @@ def _get_real_axis(ndim, axis):
return
real_axis
return
real_axis
@tvm.tag_scope
(
tag
=
tag
.
COMM_REDUCE
)
def
sum
(
data
,
axis
=
None
,
keepdims
=
False
):
def
sum
(
data
,
axis
=
None
,
keepdims
=
False
):
"""Sum of array elements over a given axis or a list of axes
"""Sum of array elements over a given axis or a list of axes
...
@@ -52,7 +49,6 @@ def sum(data, axis=None, keepdims=False):
...
@@ -52,7 +49,6 @@ def sum(data, axis=None, keepdims=False):
return
cpp
.
sum
(
data
,
axis
,
keepdims
)
return
cpp
.
sum
(
data
,
axis
,
keepdims
)
@tvm.tag_scope
(
tag
=
tag
.
COMM_REDUCE
)
def
max
(
data
,
axis
=
None
,
keepdims
=
False
):
def
max
(
data
,
axis
=
None
,
keepdims
=
False
):
"""Maximum of array elements over a given axis or a list of axes
"""Maximum of array elements over a given axis or a list of axes
...
@@ -78,7 +74,6 @@ def max(data, axis=None, keepdims=False):
...
@@ -78,7 +74,6 @@ def max(data, axis=None, keepdims=False):
return
cpp
.
max
(
data
,
axis
,
keepdims
)
return
cpp
.
max
(
data
,
axis
,
keepdims
)
@tvm.tag_scope
(
tag
=
tag
.
COMM_REDUCE
)
def
min
(
data
,
axis
=
None
,
keepdims
=
False
):
def
min
(
data
,
axis
=
None
,
keepdims
=
False
):
"""Minimum of array elements over a given axis or a list of axes
"""Minimum of array elements over a given axis or a list of axes
...
@@ -104,7 +99,6 @@ def min(data, axis=None, keepdims=False):
...
@@ -104,7 +99,6 @@ def min(data, axis=None, keepdims=False):
return
cpp
.
min
(
data
,
axis
,
keepdims
)
return
cpp
.
min
(
data
,
axis
,
keepdims
)
@tvm.tag_scope
(
tag
=
tag
.
COMM_REDUCE_IDX
)
def
argmax
(
data
,
axis
=
None
,
keepdims
=
False
):
def
argmax
(
data
,
axis
=
None
,
keepdims
=
False
):
"""Returns the indices of the maximum values along an axis.
"""Returns the indices of the maximum values along an axis.
...
@@ -130,7 +124,6 @@ def argmax(data, axis=None, keepdims=False):
...
@@ -130,7 +124,6 @@ def argmax(data, axis=None, keepdims=False):
return
cpp
.
argmax
(
data
,
axis
,
keepdims
)
return
cpp
.
argmax
(
data
,
axis
,
keepdims
)
@tvm.tag_scope
(
tag
=
tag
.
COMM_REDUCE_IDX
)
def
argmin
(
data
,
axis
=
None
,
keepdims
=
False
):
def
argmin
(
data
,
axis
=
None
,
keepdims
=
False
):
"""Returns the indices of the minimum values along an axis.
"""Returns the indices of the minimum values along an axis.
...
@@ -156,7 +149,6 @@ def argmin(data, axis=None, keepdims=False):
...
@@ -156,7 +149,6 @@ def argmin(data, axis=None, keepdims=False):
return
cpp
.
argmin
(
data
,
axis
,
keepdims
)
return
cpp
.
argmin
(
data
,
axis
,
keepdims
)
@tvm.tag_scope
(
tag
=
tag
.
COMM_REDUCE
)
def
prod
(
data
,
axis
=
None
,
keepdims
=
False
):
def
prod
(
data
,
axis
=
None
,
keepdims
=
False
):
"""Product of array elements over a given axis or a list of axes
"""Product of array elements over a given axis or a list of axes
...
...
topi/python/topi/tensor.py
View file @
b14bb7f9
# pylint: disable=invalid-name,consider-using-enumerate,unused-argument,len-as-condition
# pylint: disable=invalid-name,consider-using-enumerate,unused-argument,len-as-condition
"""Elementwise operators"""
"""Elementwise operators"""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
import
tvm
from
.
import
cpp
from
.
import
cpp
from
.
import
tag
@tvm.tag_scope
(
tag
=
tag
.
ELEMWISE
)
def
elemwise_sum
(
xs
):
def
elemwise_sum
(
xs
):
"""Perform element-wise sum on inputs
"""Perform element-wise sum on inputs
...
@@ -22,7 +19,6 @@ def elemwise_sum(xs):
...
@@ -22,7 +19,6 @@ def elemwise_sum(xs):
return
cpp
.
elemwise_sum
(
xs
)
return
cpp
.
elemwise_sum
(
xs
)
@tvm.tag_scope
(
tag
=
tag
.
ELEMWISE
)
def
full
(
shape
,
dtype
,
fill_value
):
def
full
(
shape
,
dtype
,
fill_value
):
"""Fill tensor with fill_value
"""Fill tensor with fill_value
...
@@ -43,7 +39,6 @@ def full(shape, dtype, fill_value):
...
@@ -43,7 +39,6 @@ def full(shape, dtype, fill_value):
return
cpp
.
full
(
shape
,
dtype
,
fill_value
)
return
cpp
.
full
(
shape
,
dtype
,
fill_value
)
@tvm.tag_scope
(
tag
=
tag
.
ELEMWISE
)
def
full_like
(
x
,
fill_value
):
def
full_like
(
x
,
fill_value
):
"""Construct a tensor with same shape as input tensor,
"""Construct a tensor with same shape as input tensor,
then fill tensor with fill_value.
then fill tensor with fill_value.
...
...
topi/python/topi/transform.py
View file @
b14bb7f9
...
@@ -111,7 +111,6 @@ def transpose(a, axes=None):
...
@@ -111,7 +111,6 @@ def transpose(a, axes=None):
return
a
(
*
idx
)
return
a
(
*
idx
)
return
tvm
.
compute
(
new_shape
,
_compute
)
return
tvm
.
compute
(
new_shape
,
_compute
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
def
flip
(
a
,
axis
=
0
):
def
flip
(
a
,
axis
=
0
):
"""Flip/reverse elements of an array in a particular axis.
"""Flip/reverse elements of an array in a particular axis.
...
@@ -129,7 +128,6 @@ def flip(a, axis=0):
...
@@ -129,7 +128,6 @@ def flip(a, axis=0):
"""
"""
return
cpp
.
flip
(
a
,
axis
)
return
cpp
.
flip
(
a
,
axis
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
def
strided_slice
(
a
,
begin
,
end
,
strides
=
None
):
def
strided_slice
(
a
,
begin
,
end
,
strides
=
None
):
"""Slice of an array.
"""Slice of an array.
...
@@ -315,7 +313,6 @@ def split(ary, indices_or_sections, axis=0):
...
@@ -315,7 +313,6 @@ def split(ary, indices_or_sections, axis=0):
# pylint: enable=cell-var-from-loop
# pylint: enable=cell-var-from-loop
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
def
take
(
a
,
indices
,
axis
=
None
):
def
take
(
a
,
indices
,
axis
=
None
):
"""Take elements from an array along an axis.
"""Take elements from an array along an axis.
...
@@ -338,3 +335,22 @@ def take(a, indices, axis=None):
...
@@ -338,3 +335,22 @@ def take(a, indices, axis=None):
if
axis
is
None
:
if
axis
is
None
:
return
cpp
.
take
(
a
,
indices
)
return
cpp
.
take
(
a
,
indices
)
return
cpp
.
take
(
a
,
indices
,
int
(
axis
))
return
cpp
.
take
(
a
,
indices
,
int
(
axis
))
def
matmul
(
a
,
b
,
transp_a
=
False
,
transp_b
=
False
):
"""
Creates an operation that calculates a matrix multiplication (row-major notation):
A(i, k) * B(k, j)
if trans_a == trans_b, the usual transposed combinations, otherwise
Parameters
----------
a : The matrix A
b : The matrix B
trans_a : Is A's layout transposed?
trans_b : Is B's layout transposed?
Returns
-------
A Tensor whose op member is the matmul operation
"""
return
cpp
.
matmul
(
a
,
b
,
transp_a
,
transp_b
)
topi/src/topi.cc
View file @
b14bb7f9
...
@@ -292,6 +292,15 @@ TVM_REGISTER_GLOBAL("topi.where")
...
@@ -292,6 +292,15 @@ TVM_REGISTER_GLOBAL("topi.where")
*
rv
=
where
(
args
[
0
],
args
[
1
],
args
[
2
]);
*
rv
=
where
(
args
[
0
],
args
[
1
],
args
[
2
]);
});
});
TVM_REGISTER_GLOBAL
(
"topi.matmul"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
switch
(
args
.
size
()
)
{
case
2
:
*
rv
=
matmul
(
args
[
0
],
args
[
1
]);
break
;
case
3
:
*
rv
=
matmul
(
args
[
0
],
args
[
1
],
args
[
2
]);
break
;
case
4
:
*
rv
=
matmul
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
break
;
default
:
CHECK
(
0
)
<<
"topi.matmul expects 2, 3 or 4 arguments"
;
}});
TVM_REGISTER_GLOBAL
(
"topi.strided_slice"
)
TVM_REGISTER_GLOBAL
(
"topi.strided_slice"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
strided_slice
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
*
rv
=
strided_slice
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
...
...
topi/tests/python/test_topi_matmul.py
0 → 100644
View file @
b14bb7f9
import
numpy
as
np
import
tvm
import
topi
from
topi.util
import
get_const_tuple
def
with_tvm
(
lam
,
*
args
):
""" Take numpy arrays as args, convert them to TVM tensors and call `lam`.
Result of lambda is converted back to numpy array and returned.
"""
ctx
=
tvm
.
cpu
(
0
)
pls
=
[]
# placeholders
vals_nd
=
[]
# initial values
for
i
,
arg
in
enumerate
(
args
):
pls
.
append
(
tvm
.
placeholder
(
arg
.
shape
,
name
=
'pl'
+
str
(
i
)))
vals_nd
.
append
(
tvm
.
nd
.
array
(
arg
,
ctx
))
out
=
lam
(
*
pls
)
out_nd
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
out
.
shape
),
dtype
=
out
.
dtype
),
ctx
)
s
=
tvm
.
create_schedule
([
out
.
op
])
m
=
tvm
.
build
(
s
,
pls
+
[
out
],
"llvm"
)
m
(
*
(
vals_nd
+
[
out_nd
]))
return
out_nd
.
asnumpy
()
def
verify_matmul
(
sa
,
sb
,
transp_a
,
transp_b
):
a
=
np
.
random
.
uniform
(
low
=-
1.0
,
high
=
1.0
,
size
=
sa
)
.
astype
(
np
.
float32
)
b
=
np
.
random
.
uniform
(
low
=-
1.0
,
high
=
1.0
,
size
=
sb
)
.
astype
(
np
.
float32
)
c1
=
np
.
matmul
(
np
.
transpose
(
a
)
if
transp_a
else
a
,
np
.
transpose
(
b
)
if
transp_b
else
b
)
c2
=
with_tvm
(
lambda
A
,
B
:
topi
.
matmul
(
A
,
B
,
transp_a
,
transp_b
),
a
,
b
)
np
.
testing
.
assert_allclose
(
c1
,
c2
,
rtol
=
1e-5
)
def
test_matmul
():
verify_matmul
((
1
,
1
),(
1
,
1
),
False
,
False
)
verify_matmul
((
1
,
1
),(
1
,
1
),
True
,
True
)
verify_matmul
((
2
,
2
),(
2
,
2
),
False
,
False
)
verify_matmul
((
2
,
2
),(
2
,
2
),
True
,
True
)
verify_matmul
((
2
,
3
),(
3
,
5
),
False
,
False
)
verify_matmul
((
5
,
3
),(
3
,
2
),
False
,
False
)
verify_matmul
((
3
,
5
),(
3
,
2
),
True
,
False
)
verify_matmul
((
3
,
5
),(
2
,
3
),
True
,
True
)
if
__name__
==
"__main__"
:
test_matmul
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment