Commit 59a8d099 by nhynes Committed by Tianqi Chen

[NNVM][TOPI] Add FTVMCompute for matmul (#1239)

parent d29b1c9e
...@@ -73,6 +73,9 @@ def schedule_dense(_, outs, target): ...@@ -73,6 +73,9 @@ def schedule_dense(_, outs, target):
reg.register_pattern("dense", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("dense", OpPattern.OUT_ELEMWISE_FUSABLE)
#matmul
reg.register_pattern("matmul", OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_schedule("matmul", _fschedule_injective)
# conv2d # conv2d
@reg.register_compute("conv2d") @reg.register_compute("conv2d")
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
* \file matrix_op.cc * \file matrix_op.cc
* \brief Matrix operators * \brief Matrix operators
*/ */
#include <topi/nn.h>
#include <nnvm/op.h> #include <nnvm/op.h>
#include <nnvm/node.h> #include <nnvm/node.h>
#include <nnvm/op_attr_types.h> #include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/top/tensor.h> #include <nnvm/top/tensor.h>
#include "../op_common.h" #include "../op_common.h"
#include "../elemwise_op_common.h" #include "../elemwise_op_common.h"
...@@ -13,6 +15,8 @@ ...@@ -13,6 +15,8 @@
namespace nnvm { namespace nnvm {
namespace top { namespace top {
using namespace nnvm::compiler;
DMLC_REGISTER_PARAMETER(MatMulParam); DMLC_REGISTER_PARAMETER(MatMulParam);
inline bool DotShape(const nnvm::NodeAttrs& attrs, inline bool DotShape(const nnvm::NodeAttrs& attrs,
...@@ -93,6 +97,15 @@ NNVM_REGISTER_OP(matmul) ...@@ -93,6 +97,15 @@ NNVM_REGISTER_OP(matmul)
.set_attr<FInferShape>("FInferShape", DotShape) .set_attr<FInferShape>("FInferShape", DotShape)
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", DotCorrectLayout) .set_attr<FCorrectLayout>("FCorrectLayout", DotCorrectLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const MatMulParam& param = nnvm::get<MatMulParam>(attrs.parsed);
return Array<Tensor>{
topi::matmul(inputs[0], inputs[1], param.transpose_a, param.transpose_b)
};
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
......
...@@ -214,14 +214,14 @@ inline tvm::Tensor pad(const tvm::Tensor& t, ...@@ -214,14 +214,14 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
* \param name The name of the operation * \param name The name of the operation
* \param tag The tag to mark the operation * \param tag The tag to mark the operation
* *
* \return A Tensor whose op member is the matmult operation * \return A Tensor whose op member is the matmul operation
*/ */
inline tvm::Tensor matmult(const tvm::Tensor& A, inline tvm::Tensor matmul(const tvm::Tensor& A,
const tvm::Tensor& B, const tvm::Tensor& B,
bool trans_a = false, bool trans_a = false,
bool trans_b = false, bool trans_b = false,
std::string name = "tensor", std::string name = "tensor",
std::string tag = kMatMult) { std::string tag = kMatMul) {
tvm::Array<tvm::Expr> output_shape{A->shape[trans_a ? 1 : 0], tvm::Array<tvm::Expr> output_shape{A->shape[trans_a ? 1 : 0],
B->shape[trans_b ? 0 : 1]}; B->shape[trans_b ? 0 : 1]};
auto k = tvm::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k"); auto k = tvm::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k");
......
...@@ -15,7 +15,7 @@ constexpr auto kInjective = "injective"; ...@@ -15,7 +15,7 @@ constexpr auto kInjective = "injective";
constexpr auto kCommReduce = "comm_reduce"; constexpr auto kCommReduce = "comm_reduce";
constexpr auto kCommReduceIdx = "comm_reduce_idx"; constexpr auto kCommReduceIdx = "comm_reduce_idx";
constexpr auto kBroadcast = "broadcast"; constexpr auto kBroadcast = "broadcast";
constexpr auto kMatMult = "matmult"; constexpr auto kMatMul = "matmul";
constexpr auto kConv2dNCHW = "conv2d_nchw"; constexpr auto kConv2dNCHW = "conv2d_nchw";
constexpr auto kConv2dHWCN = "conv2d_hwcn"; constexpr auto kConv2dHWCN = "conv2d_hwcn";
constexpr auto kDepthwiseConv2dNCHW = "depthwise_conv2d_nchw"; constexpr auto kDepthwiseConv2dNCHW = "depthwise_conv2d_nchw";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment