Commit b14bb7f9 by Sergey Mironov Committed by Tianqi Chen

[TOPI] Access topi::matmul from Python (#1744)

parent be77cf19
......@@ -3,7 +3,7 @@
* \file matrix_op.cc
* \brief Matrix operators
*/
#include <topi/nn.h>
#include <topi/transform.h>
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
......
......@@ -238,10 +238,10 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
tensor: Tensor
The created tensor
"""
if _tag.TagScope.current is not None:
if _tag.TagScope.get_current() is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.current.tag
tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
ndim = len(shape)
code = fcompute.__code__
......@@ -311,10 +311,10 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attr
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state, X)
"""
if _tag.TagScope.current is not None:
if _tag.TagScope.get_current() is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.current.tag
tag = _tag.TagScope.get_current().tag
if isinstance(init, _tensor.Tensor):
init = [init]
if isinstance(update, _tensor.Tensor):
......@@ -407,10 +407,10 @@ def extern(shape,
"tvm.contrib.cblas.matmul",
ins[0], ins[1], outs[0], 0, 0), name="C")
"""
if _tag.TagScope.current is not None:
if _tag.TagScope.get_current() is not None:
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.current.tag
tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape
if in_buffers is not None:
......
"""Tag class for TVM operators."""
import warnings
from ._ffi.base import decorate
class TagScope(object):
"""Tag scope object to set tag for operators, working as context
manager and decorator both. See also tag_scope.
"""
current = None
_current = None
@classmethod
def get_current(cls):
if cls._current:
cls._current.accessed = True
return cls._current
def __init__(self, tag):
self._old_scope = None
self.tag = tag
self.accessed = False
def __enter__(self):
if TagScope.current is not None:
if TagScope._current is not None:
raise ValueError("nested op_tag is not allowed for now")
self._old_scope = TagScope.current
TagScope.current = self
self._old_scope = TagScope._current
TagScope._current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope is None
TagScope.current = self._old_scope
if not self.accessed:
warnings.warn("Tag '%s' declared via TagScope was not used." % (self.tag,))
TagScope._current = self._old_scope
def __call__(self, fdecl):
def tagged_fdecl(func, *args, **kwargs):
......
......@@ -201,37 +201,6 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
}
/*!
* \brief Creates an operation that calculates a matrix multiplication
* (row-major notation):
* A(i, k) * B(k, j), if trans_a == trans_b
* the usual transposed combinations, otherwise
*
* \param A The matrix A
* \param B The matrix B
* \param trans_a Is A's layout transposed?
* \param trans_b Is B's layout transposed?
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the matmul operation
*/
inline tvm::Tensor matmul(const tvm::Tensor& A,
const tvm::Tensor& B,
bool trans_a = false,
bool trans_b = false,
std::string name = "tensor",
std::string tag = kMatMul) {
tvm::Array<tvm::Expr> output_shape{A->shape[trans_a ? 1 : 0],
B->shape[trans_b ? 0 : 1]};
auto k = tvm::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k");
auto l = [&](tvm::Var i, tvm::Var j) {
return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]),
{k});
};
return tvm::compute(output_shape, l, name, tag);
}
/*!
* \brief Creates an operation that performs a 2-D convolution with an
* NCHW-layout
*
......
......@@ -627,6 +627,37 @@ inline Tensor where(const Tensor& condition,
return out;
}
/*!
* \brief Creates an operation that calculates a matrix multiplication
* (row-major notation):
* A(i, k) * B(k, j), if trans_a == trans_b
* the usual transposed combinations, otherwise
*
* \param A The matrix A
* \param B The matrix B
* \param trans_a Is A's layout transposed?
* \param trans_b Is B's layout transposed?
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the matmul operation
*/
inline tvm::Tensor matmul(const tvm::Tensor& A,
const tvm::Tensor& B,
bool trans_a = false,
bool trans_b = false,
std::string name = "tensor",
std::string tag = kMatMul) {
tvm::Array<tvm::Expr> output_shape{A->shape[trans_a ? 1 : 0],
B->shape[trans_b ? 0 : 1]};
auto k = tvm::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k");
auto l = [&](tvm::Var i, tvm::Var j) {
return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]),
{k});
};
return tvm::compute(output_shape, l, name, tag);
}
} // namespace topi
#endif // TOPI_TRANSFORM_H_
# pylint: disable=redefined-builtin,consider-using-enumerate,no-member
"""Reduce operators"""
from __future__ import absolute_import as _abs
import tvm
from . import cpp
from . import tag
def _get_real_axis(ndim, axis):
if axis is None:
......@@ -26,7 +24,6 @@ def _get_real_axis(ndim, axis):
return real_axis
@tvm.tag_scope(tag=tag.COMM_REDUCE)
def sum(data, axis=None, keepdims=False):
"""Sum of array elements over a given axis or a list of axes
......@@ -52,7 +49,6 @@ def sum(data, axis=None, keepdims=False):
return cpp.sum(data, axis, keepdims)
@tvm.tag_scope(tag=tag.COMM_REDUCE)
def max(data, axis=None, keepdims=False):
"""Maximum of array elements over a given axis or a list of axes
......@@ -78,7 +74,6 @@ def max(data, axis=None, keepdims=False):
return cpp.max(data, axis, keepdims)
@tvm.tag_scope(tag=tag.COMM_REDUCE)
def min(data, axis=None, keepdims=False):
"""Minimum of array elements over a given axis or a list of axes
......@@ -104,7 +99,6 @@ def min(data, axis=None, keepdims=False):
return cpp.min(data, axis, keepdims)
@tvm.tag_scope(tag=tag.COMM_REDUCE_IDX)
def argmax(data, axis=None, keepdims=False):
"""Returns the indices of the maximum values along an axis.
......@@ -130,7 +124,6 @@ def argmax(data, axis=None, keepdims=False):
return cpp.argmax(data, axis, keepdims)
@tvm.tag_scope(tag=tag.COMM_REDUCE_IDX)
def argmin(data, axis=None, keepdims=False):
"""Returns the indices of the minimum values along an axis.
......@@ -156,7 +149,6 @@ def argmin(data, axis=None, keepdims=False):
return cpp.argmin(data, axis, keepdims)
@tvm.tag_scope(tag=tag.COMM_REDUCE)
def prod(data, axis=None, keepdims=False):
"""Product of array elements over a given axis or a list of axes
......
# pylint: disable=invalid-name,consider-using-enumerate,unused-argument,len-as-condition
"""Elementwise operators"""
from __future__ import absolute_import as _abs
import tvm
from . import cpp
from . import tag
@tvm.tag_scope(tag=tag.ELEMWISE)
def elemwise_sum(xs):
"""Perform element-wise sum on inputs
......@@ -22,7 +19,6 @@ def elemwise_sum(xs):
return cpp.elemwise_sum(xs)
@tvm.tag_scope(tag=tag.ELEMWISE)
def full(shape, dtype, fill_value):
"""Fill tensor with fill_value
......@@ -43,7 +39,6 @@ def full(shape, dtype, fill_value):
return cpp.full(shape, dtype, fill_value)
@tvm.tag_scope(tag=tag.ELEMWISE)
def full_like(x, fill_value):
"""Construct a tensor with same shape as input tensor,
then fill tensor with fill_value.
......
......@@ -111,7 +111,6 @@ def transpose(a, axes=None):
return a(*idx)
return tvm.compute(new_shape, _compute)
@tvm.tag_scope(tag=tag.INJECTIVE)
def flip(a, axis=0):
"""Flip/reverse elements of an array in a particular axis.
......@@ -129,7 +128,6 @@ def flip(a, axis=0):
"""
return cpp.flip(a, axis)
@tvm.tag_scope(tag=tag.INJECTIVE)
def strided_slice(a, begin, end, strides=None):
"""Slice of an array.
......@@ -315,7 +313,6 @@ def split(ary, indices_or_sections, axis=0):
# pylint: enable=cell-var-from-loop
@tvm.tag_scope(tag=tag.INJECTIVE)
def take(a, indices, axis=None):
"""Take elements from an array along an axis.
......@@ -338,3 +335,22 @@ def take(a, indices, axis=None):
if axis is None:
return cpp.take(a, indices)
return cpp.take(a, indices, int(axis))
def matmul(a, b, transp_a=False, transp_b=False):
"""
Creates an operation that calculates a matrix multiplication (row-major notation):
A(i, k) * B(k, j)
if trans_a == trans_b, the usual transposed combinations, otherwise
Parameters
----------
a : The matrix A
b : The matrix B
trans_a : Is A's layout transposed?
trans_b : Is B's layout transposed?
Returns
-------
A Tensor whose op member is the matmul operation
"""
return cpp.matmul(a, b, transp_a, transp_b)
......@@ -292,6 +292,15 @@ TVM_REGISTER_GLOBAL("topi.where")
*rv = where(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.matmul")
.set_body([](TVMArgs args, TVMRetValue *rv) {
switch ( args.size() ) {
case 2: *rv = matmul(args[0], args[1]); break;
case 3: *rv = matmul(args[0], args[1], args[2]); break;
case 4: *rv = matmul(args[0], args[1], args[2], args[3]); break;
default: CHECK(0) << "topi.matmul expects 2, 3 or 4 arguments";
}});
TVM_REGISTER_GLOBAL("topi.strided_slice")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = strided_slice(args[0], args[1], args[2], args[3]);
......
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
def with_tvm(lam, *args):
""" Take numpy arrays as args, convert them to TVM tensors and call `lam`.
Result of lambda is converted back to numpy array and returned.
"""
ctx = tvm.cpu(0)
pls = [] # placeholders
vals_nd = [] # initial values
for i,arg in enumerate(args):
pls.append(tvm.placeholder(arg.shape, name='pl'+str(i)))
vals_nd.append(tvm.nd.array(arg, ctx))
out = lam(*pls)
out_nd = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=out.dtype), ctx)
s = tvm.create_schedule([out.op])
m = tvm.build(s, pls + [out], "llvm")
m(*(vals_nd+[out_nd]))
return out_nd.asnumpy()
def verify_matmul(sa, sb, transp_a, transp_b):
a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32)
b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32)
c1 = np.matmul(np.transpose(a) if transp_a else a,
np.transpose(b) if transp_b else b)
c2 = with_tvm(lambda A,B: topi.matmul(A,B,transp_a,transp_b), a,b)
np.testing.assert_allclose(c1, c2, rtol=1e-5)
def test_matmul():
verify_matmul((1,1),(1,1),False,False)
verify_matmul((1,1),(1,1),True,True)
verify_matmul((2,2),(2,2),False,False)
verify_matmul((2,2),(2,2),True,True)
verify_matmul((2,3),(3,5),False,False)
verify_matmul((5,3),(3,2),False,False)
verify_matmul((3,5),(3,2),True,False)
verify_matmul((3,5),(2,3),True,True)
if __name__ == "__main__":
test_matmul()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment