Commit b14bb7f9 by Sergey Mironov Committed by Tianqi Chen

[TOPI] Access topi::matmul from Python (#1744)

parent be77cf19
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
* \file matrix_op.cc * \file matrix_op.cc
* \brief Matrix operators * \brief Matrix operators
*/ */
#include <topi/nn.h> #include <topi/transform.h>
#include <nnvm/op.h> #include <nnvm/op.h>
#include <nnvm/node.h> #include <nnvm/node.h>
#include <nnvm/op_attr_types.h> #include <nnvm/op_attr_types.h>
......
...@@ -238,10 +238,10 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): ...@@ -238,10 +238,10 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
tensor: Tensor tensor: Tensor
The created tensor The created tensor
""" """
if _tag.TagScope.current is not None: if _tag.TagScope.get_current() is not None:
if tag != "": if tag != "":
raise ValueError("nested tag is not allowed for now") raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.current.tag tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, _expr.Expr) else shape shape = (shape,) if isinstance(shape, _expr.Expr) else shape
ndim = len(shape) ndim = len(shape)
code = fcompute.__code__ code = fcompute.__code__
...@@ -311,10 +311,10 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attr ...@@ -311,10 +311,10 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attr
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i]) s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state, X) res = tvm.scan(s_init, s_update, s_state, X)
""" """
if _tag.TagScope.current is not None: if _tag.TagScope.get_current() is not None:
if tag != "": if tag != "":
raise ValueError("nested tag is not allowed for now") raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.current.tag tag = _tag.TagScope.get_current().tag
if isinstance(init, _tensor.Tensor): if isinstance(init, _tensor.Tensor):
init = [init] init = [init]
if isinstance(update, _tensor.Tensor): if isinstance(update, _tensor.Tensor):
...@@ -407,10 +407,10 @@ def extern(shape, ...@@ -407,10 +407,10 @@ def extern(shape,
"tvm.contrib.cblas.matmul", "tvm.contrib.cblas.matmul",
ins[0], ins[1], outs[0], 0, 0), name="C") ins[0], ins[1], outs[0], 0, 0), name="C")
""" """
if _tag.TagScope.current is not None: if _tag.TagScope.get_current() is not None:
if tag != "": if tag != "":
raise ValueError("nested tag is not allowed for now") raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.current.tag tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape
if in_buffers is not None: if in_buffers is not None:
......
"""Tag class for TVM operators.""" """Tag class for TVM operators."""
import warnings
from ._ffi.base import decorate from ._ffi.base import decorate
class TagScope(object): class TagScope(object):
"""Tag scope object to set tag for operators, working as context """Tag scope object to set tag for operators, working as context
manager and decorator both. See also tag_scope. manager and decorator both. See also tag_scope.
""" """
current = None _current = None
@classmethod
def get_current(cls):
if cls._current:
cls._current.accessed = True
return cls._current
def __init__(self, tag): def __init__(self, tag):
self._old_scope = None self._old_scope = None
self.tag = tag self.tag = tag
self.accessed = False
def __enter__(self): def __enter__(self):
if TagScope.current is not None: if TagScope._current is not None:
raise ValueError("nested op_tag is not allowed for now") raise ValueError("nested op_tag is not allowed for now")
self._old_scope = TagScope.current self._old_scope = TagScope._current
TagScope.current = self TagScope._current = self
return self return self
def __exit__(self, ptype, value, trace): def __exit__(self, ptype, value, trace):
assert self._old_scope is None assert self._old_scope is None
TagScope.current = self._old_scope if not self.accessed:
warnings.warn("Tag '%s' declared via TagScope was not used." % (self.tag,))
TagScope._current = self._old_scope
def __call__(self, fdecl): def __call__(self, fdecl):
def tagged_fdecl(func, *args, **kwargs): def tagged_fdecl(func, *args, **kwargs):
......
...@@ -201,37 +201,6 @@ inline tvm::Tensor pad(const tvm::Tensor& t, ...@@ -201,37 +201,6 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
} }
/*! /*!
* \brief Creates an operation that calculates a matrix multiplication
* (row-major notation):
* A(i, k) * B(k, j), if trans_a == trans_b
* the usual transposed combinations, otherwise
*
* \param A The matrix A
* \param B The matrix B
* \param trans_a Is A's layout transposed?
* \param trans_b Is B's layout transposed?
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the matmul operation
*/
inline tvm::Tensor matmul(const tvm::Tensor& A,
const tvm::Tensor& B,
bool trans_a = false,
bool trans_b = false,
std::string name = "tensor",
std::string tag = kMatMul) {
tvm::Array<tvm::Expr> output_shape{A->shape[trans_a ? 1 : 0],
B->shape[trans_b ? 0 : 1]};
auto k = tvm::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k");
auto l = [&](tvm::Var i, tvm::Var j) {
return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]),
{k});
};
return tvm::compute(output_shape, l, name, tag);
}
/*!
* \brief Creates an operation that performs a 2-D convolution with an * \brief Creates an operation that performs a 2-D convolution with an
* NCHW-layout * NCHW-layout
* *
......
...@@ -627,6 +627,37 @@ inline Tensor where(const Tensor& condition, ...@@ -627,6 +627,37 @@ inline Tensor where(const Tensor& condition,
return out; return out;
} }
/*!
* \brief Creates an operation that calculates a matrix multiplication
* (row-major notation):
* A(i, k) * B(k, j), if trans_a == trans_b
* the usual transposed combinations, otherwise
*
* \param A The matrix A
* \param B The matrix B
* \param trans_a Is A's layout transposed?
* \param trans_b Is B's layout transposed?
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the matmul operation
*/
inline tvm::Tensor matmul(const tvm::Tensor& A,
const tvm::Tensor& B,
bool trans_a = false,
bool trans_b = false,
std::string name = "tensor",
std::string tag = kMatMul) {
tvm::Array<tvm::Expr> output_shape{A->shape[trans_a ? 1 : 0],
B->shape[trans_b ? 0 : 1]};
auto k = tvm::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k");
auto l = [&](tvm::Var i, tvm::Var j) {
return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]),
{k});
};
return tvm::compute(output_shape, l, name, tag);
}
} // namespace topi } // namespace topi
#endif // TOPI_TRANSFORM_H_ #endif // TOPI_TRANSFORM_H_
# pylint: disable=redefined-builtin,consider-using-enumerate,no-member # pylint: disable=redefined-builtin,consider-using-enumerate,no-member
"""Reduce operators""" """Reduce operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm
from . import cpp from . import cpp
from . import tag
def _get_real_axis(ndim, axis): def _get_real_axis(ndim, axis):
if axis is None: if axis is None:
...@@ -26,7 +24,6 @@ def _get_real_axis(ndim, axis): ...@@ -26,7 +24,6 @@ def _get_real_axis(ndim, axis):
return real_axis return real_axis
@tvm.tag_scope(tag=tag.COMM_REDUCE)
def sum(data, axis=None, keepdims=False): def sum(data, axis=None, keepdims=False):
"""Sum of array elements over a given axis or a list of axes """Sum of array elements over a given axis or a list of axes
...@@ -52,7 +49,6 @@ def sum(data, axis=None, keepdims=False): ...@@ -52,7 +49,6 @@ def sum(data, axis=None, keepdims=False):
return cpp.sum(data, axis, keepdims) return cpp.sum(data, axis, keepdims)
@tvm.tag_scope(tag=tag.COMM_REDUCE)
def max(data, axis=None, keepdims=False): def max(data, axis=None, keepdims=False):
"""Maximum of array elements over a given axis or a list of axes """Maximum of array elements over a given axis or a list of axes
...@@ -78,7 +74,6 @@ def max(data, axis=None, keepdims=False): ...@@ -78,7 +74,6 @@ def max(data, axis=None, keepdims=False):
return cpp.max(data, axis, keepdims) return cpp.max(data, axis, keepdims)
@tvm.tag_scope(tag=tag.COMM_REDUCE)
def min(data, axis=None, keepdims=False): def min(data, axis=None, keepdims=False):
"""Minimum of array elements over a given axis or a list of axes """Minimum of array elements over a given axis or a list of axes
...@@ -104,7 +99,6 @@ def min(data, axis=None, keepdims=False): ...@@ -104,7 +99,6 @@ def min(data, axis=None, keepdims=False):
return cpp.min(data, axis, keepdims) return cpp.min(data, axis, keepdims)
@tvm.tag_scope(tag=tag.COMM_REDUCE_IDX)
def argmax(data, axis=None, keepdims=False): def argmax(data, axis=None, keepdims=False):
"""Returns the indices of the maximum values along an axis. """Returns the indices of the maximum values along an axis.
...@@ -130,7 +124,6 @@ def argmax(data, axis=None, keepdims=False): ...@@ -130,7 +124,6 @@ def argmax(data, axis=None, keepdims=False):
return cpp.argmax(data, axis, keepdims) return cpp.argmax(data, axis, keepdims)
@tvm.tag_scope(tag=tag.COMM_REDUCE_IDX)
def argmin(data, axis=None, keepdims=False): def argmin(data, axis=None, keepdims=False):
"""Returns the indices of the minimum values along an axis. """Returns the indices of the minimum values along an axis.
...@@ -156,7 +149,6 @@ def argmin(data, axis=None, keepdims=False): ...@@ -156,7 +149,6 @@ def argmin(data, axis=None, keepdims=False):
return cpp.argmin(data, axis, keepdims) return cpp.argmin(data, axis, keepdims)
@tvm.tag_scope(tag=tag.COMM_REDUCE)
def prod(data, axis=None, keepdims=False): def prod(data, axis=None, keepdims=False):
"""Product of array elements over a given axis or a list of axes """Product of array elements over a given axis or a list of axes
......
# pylint: disable=invalid-name,consider-using-enumerate,unused-argument,len-as-condition # pylint: disable=invalid-name,consider-using-enumerate,unused-argument,len-as-condition
"""Elementwise operators""" """Elementwise operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm
from . import cpp from . import cpp
from . import tag
@tvm.tag_scope(tag=tag.ELEMWISE)
def elemwise_sum(xs): def elemwise_sum(xs):
"""Perform element-wise sum on inputs """Perform element-wise sum on inputs
...@@ -22,7 +19,6 @@ def elemwise_sum(xs): ...@@ -22,7 +19,6 @@ def elemwise_sum(xs):
return cpp.elemwise_sum(xs) return cpp.elemwise_sum(xs)
@tvm.tag_scope(tag=tag.ELEMWISE)
def full(shape, dtype, fill_value): def full(shape, dtype, fill_value):
"""Fill tensor with fill_value """Fill tensor with fill_value
...@@ -43,7 +39,6 @@ def full(shape, dtype, fill_value): ...@@ -43,7 +39,6 @@ def full(shape, dtype, fill_value):
return cpp.full(shape, dtype, fill_value) return cpp.full(shape, dtype, fill_value)
@tvm.tag_scope(tag=tag.ELEMWISE)
def full_like(x, fill_value): def full_like(x, fill_value):
"""Construct a tensor with same shape as input tensor, """Construct a tensor with same shape as input tensor,
then fill tensor with fill_value. then fill tensor with fill_value.
......
...@@ -111,7 +111,6 @@ def transpose(a, axes=None): ...@@ -111,7 +111,6 @@ def transpose(a, axes=None):
return a(*idx) return a(*idx)
return tvm.compute(new_shape, _compute) return tvm.compute(new_shape, _compute)
@tvm.tag_scope(tag=tag.INJECTIVE)
def flip(a, axis=0): def flip(a, axis=0):
"""Flip/reverse elements of an array in a particular axis. """Flip/reverse elements of an array in a particular axis.
...@@ -129,7 +128,6 @@ def flip(a, axis=0): ...@@ -129,7 +128,6 @@ def flip(a, axis=0):
""" """
return cpp.flip(a, axis) return cpp.flip(a, axis)
@tvm.tag_scope(tag=tag.INJECTIVE)
def strided_slice(a, begin, end, strides=None): def strided_slice(a, begin, end, strides=None):
"""Slice of an array. """Slice of an array.
...@@ -315,7 +313,6 @@ def split(ary, indices_or_sections, axis=0): ...@@ -315,7 +313,6 @@ def split(ary, indices_or_sections, axis=0):
# pylint: enable=cell-var-from-loop # pylint: enable=cell-var-from-loop
@tvm.tag_scope(tag=tag.INJECTIVE)
def take(a, indices, axis=None): def take(a, indices, axis=None):
"""Take elements from an array along an axis. """Take elements from an array along an axis.
...@@ -338,3 +335,22 @@ def take(a, indices, axis=None): ...@@ -338,3 +335,22 @@ def take(a, indices, axis=None):
if axis is None: if axis is None:
return cpp.take(a, indices) return cpp.take(a, indices)
return cpp.take(a, indices, int(axis)) return cpp.take(a, indices, int(axis))
def matmul(a, b, transp_a=False, transp_b=False):
"""
Creates an operation that calculates a matrix multiplication (row-major notation):
A(i, k) * B(k, j)
if trans_a == trans_b, the usual transposed combinations, otherwise
Parameters
----------
a : The matrix A
b : The matrix B
trans_a : Is A's layout transposed?
trans_b : Is B's layout transposed?
Returns
-------
A Tensor whose op member is the matmul operation
"""
return cpp.matmul(a, b, transp_a, transp_b)
...@@ -292,6 +292,15 @@ TVM_REGISTER_GLOBAL("topi.where") ...@@ -292,6 +292,15 @@ TVM_REGISTER_GLOBAL("topi.where")
*rv = where(args[0], args[1], args[2]); *rv = where(args[0], args[1], args[2]);
}); });
TVM_REGISTER_GLOBAL("topi.matmul")
.set_body([](TVMArgs args, TVMRetValue *rv) {
switch ( args.size() ) {
case 2: *rv = matmul(args[0], args[1]); break;
case 3: *rv = matmul(args[0], args[1], args[2]); break;
case 4: *rv = matmul(args[0], args[1], args[2], args[3]); break;
default: CHECK(0) << "topi.matmul expects 2, 3 or 4 arguments";
}});
TVM_REGISTER_GLOBAL("topi.strided_slice") TVM_REGISTER_GLOBAL("topi.strided_slice")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = strided_slice(args[0], args[1], args[2], args[3]); *rv = strided_slice(args[0], args[1], args[2], args[3]);
......
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
def with_tvm(lam, *args):
""" Take numpy arrays as args, convert them to TVM tensors and call `lam`.
Result of lambda is converted back to numpy array and returned.
"""
ctx = tvm.cpu(0)
pls = [] # placeholders
vals_nd = [] # initial values
for i,arg in enumerate(args):
pls.append(tvm.placeholder(arg.shape, name='pl'+str(i)))
vals_nd.append(tvm.nd.array(arg, ctx))
out = lam(*pls)
out_nd = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=out.dtype), ctx)
s = tvm.create_schedule([out.op])
m = tvm.build(s, pls + [out], "llvm")
m(*(vals_nd+[out_nd]))
return out_nd.asnumpy()
def verify_matmul(sa, sb, transp_a, transp_b):
a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32)
b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32)
c1 = np.matmul(np.transpose(a) if transp_a else a,
np.transpose(b) if transp_b else b)
c2 = with_tvm(lambda A,B: topi.matmul(A,B,transp_a,transp_b), a,b)
np.testing.assert_allclose(c1, c2, rtol=1e-5)
def test_matmul():
verify_matmul((1,1),(1,1),False,False)
verify_matmul((1,1),(1,1),True,True)
verify_matmul((2,2),(2,2),False,False)
verify_matmul((2,2),(2,2),True,True)
verify_matmul((2,3),(3,5),False,False)
verify_matmul((5,3),(3,2),False,False)
verify_matmul((3,5),(3,2),True,False)
verify_matmul((3,5),(2,3),True,True)
if __name__ == "__main__":
test_matmul()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment