/*! * Copyright (c) 2017 by Contributors * \file matrix_op.cc * \brief Matrix operators */ #include <nnvm/op.h> #include <nnvm/node.h> #include <nnvm/op_attr_types.h> #include <nnvm/top/tensor.h> #include "../op_common.h" #include "../elemwise_op_common.h" namespace nnvm { namespace top { DMLC_REGISTER_PARAMETER(MatMulParam); inline bool DotShape(const nnvm::NodeAttrs& attrs, std::vector<TShape> *in_attrs, std::vector<TShape> *out_attrs) { const MatMulParam& param = nnvm::get<MatMulParam>(attrs.parsed); CHECK_EQ(in_attrs->size(), 2U); CHECK_EQ(out_attrs->size(), 1U); TShape lshape = (*in_attrs)[0]; TShape rshape = (*in_attrs)[1]; if (lshape.ndim() == 1) lshape = TShape{1, lshape[0]}; if (rshape.ndim() == 1) rshape = TShape{1, rshape[0]}; if (param.transpose_a) std::reverse(lshape.begin(), lshape.end()); if (param.transpose_b) std::reverse(rshape.begin(), rshape.end()); CHECK_EQ(lshape[lshape.ndim() - 1], rshape[0]) << "dot shape inconsistent: " << lshape << " X " << rshape; TShape oshape(lshape.ndim() + rshape.ndim() - 2); for (uint32_t i = 0; i < lshape.ndim() - 1; i++) oshape[i] = lshape[i]; for (uint32_t i = 1; i < rshape.ndim(); i++) oshape[i + lshape.ndim() - 2] = rshape[i]; NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape); return true; } inline bool DotCorrectLayout(const NodeAttrs& attrs, std::vector<Layout> *ilayouts, const std::vector<Layout> *last_ilayouts, std::vector<Layout> *olayouts) { const MatMulParam& param = nnvm::get<MatMulParam>(attrs.parsed); CHECK_EQ(ilayouts->size(), 2U); CHECK_EQ(olayouts->size(), 1U); const Layout& lhs = last_ilayouts->at(0).defined() ? last_ilayouts->at(0) : ilayouts->at(0); const Layout& rhs = last_ilayouts->at(1).defined() ? last_ilayouts->at(1) : ilayouts->at(1); NNVM_ASSIGN_LAYOUT(*ilayouts, 0, lhs); NNVM_ASSIGN_LAYOUT(*ilayouts, 1, rhs); if (lhs.ndim() > 1 && rhs.ndim() > 1) { // concat lhs and rhs layout const Layout& lhs_out = param.transpose_a ? lhs.reverse() : lhs; const Layout& rhs_out = param.transpose_b ? rhs.reverse() : rhs; Layout out = lhs_out.sublayout(0, lhs_out.ndim()-1) + rhs_out.sublayout(1, rhs_out.ndim()-1); NNVM_ASSIGN_LAYOUT(*olayouts, 0, out); } return true; } NNVM_REGISTER_OP(matmul) .describe(R"doc(Matrix multiplication of two arrays. ``dot``'s behavior depends on the input array dimensions: - 1-D arrays: inner product of vectors - 2-D arrays: matrix multiplication - N-D arrays: a sum product over the last axis of the first input and the first axis of the second input For example, given 3-D ``x`` with shape `(n,m,k)` and ``y`` with shape `(k,r,s)`, the result array will have shape `(n,m,r,s)`. It is computed by:: dot(x,y) = sum(x[i,j,:]*y[:,a,b]) )doc" NNVM_ADD_FILELINE) .set_support_level(1) .set_num_inputs(2) .set_num_outputs(1) .set_attr_parser(ParamParser<MatMulParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<MatMulParam>) .add_arguments(MatMulParam::__FIELDS__()) .add_argument("lhs", "NDArray-or-Symbol", "The first input") .add_argument("rhs", "NDArray-or-Symbol", "The second input") .set_attr<FInferShape>("FInferShape", DotShape) .set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", DotCorrectLayout) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { // z = x dot y // xshape (n,m,k), yshape (k,r,s) const MatMulParam& param = nnvm::get<MatMulParam>(n->attrs.parsed); bool Ta = param.transpose_a; bool Tb = param.transpose_b; // Ta = false, Tb = false // grad_x = grad_z dot y.T // grad_y = x.T dot grad_z if (!Ta && !Tb) { return std::vector<NodeEntry>{ MakeNode("matmul", n->attrs.name + "_grad_0", {ograds[0], n->inputs[1]}, {{"transpose_a", "false"}, {"transpose_b", "true"}}), MakeNode("matmul", n->attrs.name + "_grad_1", {n->inputs[0], ograds[0]}, {{"transpose_a", "true"}, {"transpose_b", "false"}}) }; } else if (Ta && !Tb) { // Ta = true, Tb = false // grad_x = y dot grad_z.T // grad_y = x dot grad_z return std::vector<NodeEntry>{ MakeNode("matmul", n->attrs.name + "_grad_0", {n->inputs[1], ograds[0]}, {{"transpose_a", "false"}, {"transpose_b", "true"}}), MakeNode("matmul", n->attrs.name + "_grad_1", {n->inputs[0], ograds[0]}, {{"transpose_a", "false"}, {"transpose_b", "false"}}) }; } else if (!Ta && Tb) { // Ta = false, Tb = true // grad_x = grad_z dot y // grad_y = grad_z.T dot x return std::vector<NodeEntry>{ MakeNode("matmul", n->attrs.name + "_grad_0", {ograds[0], n->inputs[1]}, {{"transpose_a", "false"}, {"transpose_b", "false"}}), MakeNode("matmul", n->attrs.name + "_grad_1", {ograds[0], n->inputs[0]}, {{"transpose_a", "true"}, {"transpose_b", "false"}}) }; } else { // Ta = true, Tb = true // grad_x = y.T dot grad_z.T // grad_y = grad_z.T dot x.T return std::vector<NodeEntry>{ MakeNode("matmul", n->attrs.name + "_grad_0", {n->inputs[1], ograds[0]}, {{"transpose_a", "true"}, {"transpose_b", "true"}}), MakeNode("matmul", n->attrs.name + "_grad_1", {ograds[0], n->inputs[0]}, {{"transpose_a", "true"}, {"transpose_b", "true"}}) }; } }); } // namespace top } // namespace nnvm