Commit 98003afa by Dmitrii Murygin Committed by Tianqi Chen

[TOPI] Add tensor multiplication. (#2106)

parent bbc78221
...@@ -753,6 +753,121 @@ inline tvm::Tensor matmul(const tvm::Tensor& A, ...@@ -753,6 +753,121 @@ inline tvm::Tensor matmul(const tvm::Tensor& A,
return tvm::compute(output_shape, l, name, tag); return tvm::compute(output_shape, l, name, tag);
} }
/*!
* \brief A generalization of matrix multiplication to tensors.
*
* \param A The tensor A
* \param B The tensor B
* \param axes The number of the dimensions to reduce over
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor computing the result
*/
inline Tensor tensordot(const Tensor& A,
const tvm::Tensor& B,
int axes = 2,
std::string name = "tensor",
std::string tag = kMatMul) {
CHECK_GE(A->shape.size(), axes);
CHECK_GE(B->shape.size(), axes);
Array<Expr> output_shape(A->shape.begin(), A->shape.end() + (-axes));
for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it)
output_shape.push_back(*it);
Array<IterVar> iter_vars;
for (int i = 0; i < axes; ++i)
iter_vars.push_back(reduce_axis(Range(0, B->shape[i]), "k" + std::to_string(i)));
auto func =
[&A, &B, &iter_vars, axes]
(const Array<Var>& input_indices) {
Array<Expr> A_indices(
input_indices.begin(),
input_indices.begin() + (A->shape.size() - axes));
for (auto& v : iter_vars)
A_indices.push_back(v);
Array<Expr> B_indices;
for (auto& v : iter_vars)
B_indices.push_back(v);
auto it = input_indices.begin() + (A->shape.size() - axes);
for (; it != input_indices.end(); ++it)
B_indices.push_back(*it);
// Some passes don't like reductions with empty axis, so avoid it here
if (iter_vars.empty())
return A(A_indices) * B(B_indices);
else
return sum(A(A_indices) * B(B_indices), iter_vars);
};
return compute(output_shape, func, name, tag);
}
/*!
* \brief A generalization of matrix multiplication to tensors.
*
* \param A The tensor A
* \param B The tensor B
* \param A_axes The indices of the dimensions of tensor A to reduce over
* \param B_axes The indices of the dimensions of tensor B to reduce over
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor computing the result
*/
inline Tensor tensordot(const Tensor& A,
const tvm::Tensor& B,
Array<Expr> A_axes,
Array<Expr> B_axes,
std::string name = "tensor",
std::string tag = kMatMul) {
CHECK_EQ(A_axes.size(), B_axes.size());
auto A_axes_val = GetConstIntValues(A_axes, "A_axes");
auto B_axes_val = GetConstIntValues(B_axes, "B_axes");
Array<Expr> output_shape;
for (unsigned i = 0; i < A->shape.size(); ++i)
if (std::find(A_axes_val.begin(), A_axes_val.end(), i) == A_axes_val.end())
output_shape.push_back(A->shape[i]);
for (unsigned i = 0; i < B->shape.size(); ++i)
if (std::find(B_axes_val.begin(), B_axes_val.end(), i) == B_axes_val.end())
output_shape.push_back(B->shape[i]);
Array<IterVar> iter_vars;
for (unsigned i = 0; i < B_axes_val.size(); ++i)
iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i)));
auto func =
[&A, &B, &iter_vars, A_axes_val, B_axes_val]
(const Array<Var>& input_indices) {
int idx_input = 0;
Array<Expr> A_indices;
for (unsigned i = 0; i < A->shape.size(); ++i) {
auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i);
if (axes_pos == A_axes_val.end())
A_indices.push_back(input_indices[idx_input++]);
else
A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]);
}
Array<Expr> B_indices;
for (unsigned i = 0; i < B->shape.size(); ++i) {
auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i);
if (axes_pos == B_axes_val.end())
B_indices.push_back(input_indices[idx_input++]);
else
B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]);
}
return sum(A(A_indices) * B(B_indices), iter_vars);
};
return compute(output_shape, func, name, tag);
}
} // namespace topi } // namespace topi
#endif // TOPI_TRANSFORM_H_ #endif // TOPI_TRANSFORM_H_
...@@ -269,3 +269,23 @@ def matmul(a, b, transp_a=False, transp_b=False): ...@@ -269,3 +269,23 @@ def matmul(a, b, transp_a=False, transp_b=False):
A Tensor whose op member is the matmul operation A Tensor whose op member is the matmul operation
""" """
return cpp.matmul(a, b, transp_a, transp_b) return cpp.matmul(a, b, transp_a, transp_b)
def tensordot(a, b, axes):
"""A generalization of matrix multiplication to tensor.
Parameters
----------
a : The tensor A
b : The tensor B
axes : The number of dimensions to reduce over
Returns
-------
A Tensor computing the result
"""
if isinstance(axes, int):
return cpp.tensordot(a, b, axes)
if isinstance(axes[0], int):
return cpp.tensordot(a, b, (axes[0],), (axes[1],))
return cpp.tensordot(a, b, axes[0], axes[1])
...@@ -305,6 +305,18 @@ TVM_REGISTER_GLOBAL("topi.matmul") ...@@ -305,6 +305,18 @@ TVM_REGISTER_GLOBAL("topi.matmul")
default: CHECK(0) << "topi.matmul expects 2, 3 or 4 arguments"; default: CHECK(0) << "topi.matmul expects 2, 3 or 4 arguments";
}}); }});
TVM_REGISTER_GLOBAL("topi.tensordot")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args.size() == 2) {
*rv = tensordot(args[0], args[1]);
} else if (args.size() == 3) {
*rv = tensordot(args[0], args[1], args[2]);
} else {
Array<Expr> axes = args[3];
*rv = tensordot(args[0], args[1], args[2], axes);
}
});
TVM_REGISTER_GLOBAL("topi.strided_slice") TVM_REGISTER_GLOBAL("topi.strided_slice")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = strided_slice(args[0], args[1], args[2], args[3]); *rv = strided_slice(args[0], args[1], args[2], args[3]);
......
...@@ -39,6 +39,23 @@ def test_matmul(): ...@@ -39,6 +39,23 @@ def test_matmul():
verify_matmul((3,5),(3,2),True,False) verify_matmul((3,5),(3,2),True,False)
verify_matmul((3,5),(2,3),True,True) verify_matmul((3,5),(2,3),True,True)
def verify_tensordot(sa, sb, axes):
a = np.random.uniform(low=-1.0, high=1.0, size=sa).astype(np.float32)
b = np.random.uniform(low=-1.0, high=1.0, size=sb).astype(np.float32)
c1 = np.tensordot(a, b, axes)
c2 = with_tvm(lambda A, B: topi.tensordot(A, B, axes), a, b)
tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5)
def test_tensordot():
verify_tensordot((3), (3), 0)
verify_tensordot((2, 3), (3, 5), 1)
verify_tensordot((2, 2, 3), (2, 3, 5), 2)
verify_tensordot((2, 2, 3, 4), (2, 3, 4, 5), 3)
verify_tensordot((3, 2, 2), (2, 3, 5), (1, 0))
verify_tensordot((3, 2, 2), (2, 3, 5), ((1, 0), (0, 1)))
verify_tensordot((4, 3, 2, 2), (2, 4, 3, 5), ((1, 2, 0), (2, 0, 1)))
if __name__ == "__main__": if __name__ == "__main__":
test_matmul() test_matmul()
test_tensordot()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment