Commit d5744844 by Yizhi Liu Committed by Tianqi Chen

Rename FInferLayout -> FCorrectLayout (#453)

* rename FInferLayout -> FCorrectLayout

* correct stupid IDE

* update submodule tvm
parent 9f8fcfc9
......@@ -66,7 +66,7 @@ using DTypeVector = std::vector<int>;
* int entry_layout = layouts[g.indexed_graph().entry_id(my_entry)];
* \endcode
*
* \sa FInferLayout
* \sa FCorrectLayout
*/
using LayoutVector = std::vector<Layout>;
......
......@@ -178,7 +178,7 @@ using FSetInputVarAttrOnCompose = std::function<void(
const int index)>;
/*!
* \brief Inference function of node layout. See \p Layout for layout convention
* \brief Infer & correct function of node layout. See \p Layout for layout convention
* \param attrs The attribute of the node.
* \param ilayouts Given the input layouts produced by ancestor nodes,
* it should be filled by layouts that the node requests.
......@@ -196,7 +196,7 @@ using FSetInputVarAttrOnCompose = std::function<void(
* \param olayouts Inferred output layouts.
* \return success flag.
*/
using FInferLayout = std::function<bool(
using FCorrectLayout = std::function<bool(
const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
......
......@@ -33,7 +33,7 @@ using LayoutAttrDict = std::unordered_map<const Node*, std::vector<Layout> >;
*/
nnvm::Graph CorrectLayout(nnvm::Graph src) {
static auto& op_infer_layout =
nnvm::Op::GetAttr<FInferLayout>("FInferLayout");
nnvm::Op::GetAttr<FCorrectLayout>("FCorrectLayout");
const IndexedGraph& idx = src.indexed_graph();
std::vector<nnvm::NodePtr> mirror_vec(idx.num_nodes(), nullptr);
......@@ -92,7 +92,7 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) {
}
const auto& flayout = op_infer_layout[new_node->op()];
CHECK(flayout != nullptr) << "Attribute FInferLayout"
CHECK(flayout != nullptr) << "Attribute FCorrectLayout"
<< " is not registered by op " << inode.source->op()->name
<< " we are not able to complete layout transform.";
CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts))
......
......@@ -271,7 +271,7 @@ inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs,
.set_num_outputs(1) \
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<FInferLayout>("FInferLayout", \
.set_attr<FCorrectLayout>("FCorrectLayout", \
ElemwiseArbitraryLayout<1, 1>) \
.set_attr<FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
......@@ -298,7 +298,7 @@ inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs,
.set_num_outputs(1) \
.set_attr<FInferShape>("FInferShape", ElemwiseShape<2, 1>) \
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<FInferLayout>("FInferLayout", \
.set_attr<FCorrectLayout>("FCorrectLayout", \
ElemwiseBinaryKeepLeftLayout) \
.set_attr<FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs) { \
......@@ -319,7 +319,7 @@ inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs,
ParamGetAttrDict<ElementWiseReduceParam>) \
.set_attr<nnvm::FInferShape>("FInferShape", \
ElementWiseReduceShape) \
.set_attr<FInferLayout>("FInferLayout", \
.set_attr<FCorrectLayout>("FCorrectLayout", \
ElemwiseFixedLayoutCopyToOut<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", ElementWiseReduceType) \
.add_argument("args", "Symbol[]", "Positional input arguments")
......@@ -337,7 +337,7 @@ inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs,
static_cast<int>(kFloat32)); \
return true; \
}) \
.set_attr<FInferLayout>("FInferLayout", \
.set_attr<FCorrectLayout>("FCorrectLayout", \
ElemwiseFixedLayoutUnknownOut<1, 1>) \
.set_attr<FGradient>( \
"FGradient", [](const NodePtr& n, \
......
......@@ -129,10 +129,10 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
return true;
}
inline bool Conv2DInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
inline bool Conv2DCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
const Layout in_layout(param.layout);
......@@ -189,7 +189,7 @@ a bias vector is created and added to the outputs.
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>)
.set_attr<FInferShape>("FInferShape", Conv2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FInferLayout>("FInferLayout", Conv2DInferLayout)
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout)
.set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DParam>)
.set_support_level(2)
......@@ -214,7 +214,7 @@ NNVM_REGISTER_OP(_contrib_conv2d_NCHWc)
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>)
.set_attr<FInferShape>("FInferShape", Conv2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FInferLayout>("FInferLayout", Conv2DInferLayout)
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout)
.set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DParam>)
.set_support_level(2);
......@@ -306,10 +306,10 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs,
return true;
}
inline bool Conv2DTransposeInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
inline bool Conv2DTransposeCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const Conv2DTransposeParam& param = nnvm::get<Conv2DTransposeParam>(attrs.parsed);
const Layout in_layout(param.layout);
......@@ -363,7 +363,7 @@ said convolution.
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DTransposeParam>)
.set_attr<FInferShape>("FInferShape", Conv2DTransposeInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FInferLayout>("FInferLayout", Conv2DTransposeInferLayout)
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DTransposeCorrectLayout)
.set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DTransposeParam>)
.set_support_level(2);
......
......@@ -87,7 +87,7 @@ If ``use_bias`` is set to be false, then the ``bias`` term is ignored.
.set_attr<FInferShape>("FInferShape", DenseInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
// leave weight & bias layout undefined
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutCopyToOut<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutCopyToOut<1, 1>)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......@@ -167,7 +167,7 @@ NNVM_REGISTER_OP(dropout)
.set_num_outputs(2)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 2>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 2>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FNumVisibleOutputs>("FNumVisibleOutputs", [](const NodeAttrs& attrs) {
return 1;
})
......@@ -201,10 +201,10 @@ inline bool BatchNormInferShape(const nnvm::NodeAttrs& attrs,
return true;
}
inline bool BatchNormInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
inline bool BatchNormCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
CHECK_EQ(in_layouts->size(), 5U);
CHECK_EQ(last_in_layouts->size(), 5U);
......@@ -307,7 +307,7 @@ axis to be the last item in the input shape.
.add_arguments(BatchNormParam::__FIELDS__())
.set_attr_parser(ParamParser<BatchNormParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BatchNormParam>)
.set_attr<FInferLayout>("FInferLayout", BatchNormInferLayout)
.set_attr<FCorrectLayout>("FCorrectLayout", BatchNormCorrectLayout)
.set_num_inputs(5)
.set_num_outputs(3)
.set_attr<FInferShape>("FInferShape", BatchNormInferShape)
......@@ -345,7 +345,7 @@ NNVM_REGISTER_OP(softmax)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
......@@ -402,7 +402,7 @@ NNVM_REGISTER_OP(log_softmax)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
......@@ -460,7 +460,7 @@ NNVM_REGISTER_OP(leaky_relu)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
......@@ -512,10 +512,10 @@ inline bool PReluInferShape(const nnvm::NodeAttrs &attrs,
return true;
}
inline bool PReluInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
inline bool PReluCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
const PReLUParam& param = nnvm::get<PReLUParam>(attrs.parsed);
CHECK_EQ(in_layouts->size(), 2U);
CHECK_EQ(last_in_layouts->size(), 2U);
......@@ -550,7 +550,7 @@ where :math:`*` is an channelwise multiplication for each sample in the
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", PReluInferShape)
.set_attr<FInferLayout>("FInferLayout", PReluInferLayout)
.set_attr<FCorrectLayout>("FCorrectLayout", PReluCorrectLayout)
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "alpha"};
})
......@@ -597,7 +597,7 @@ NNVM_REGISTER_OP(pad)
.set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", PadInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutCopyToOut<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutCopyToOut<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
......@@ -651,8 +651,8 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
.set_attr_parser(ParamParser<LayoutTransformParam>)
.set_attr<FInferShape>("FInferShape", LayoutTransformInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>(
"FInferLayout", [](const NodeAttrs& attrs,
.set_attr<FCorrectLayout>(
"FCorrectLayout", [](const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
......
......@@ -66,10 +66,10 @@ inline bool Pool2DInferShape(const nnvm::NodeAttrs& attrs,
return true;
}
inline bool Pool2DInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
inline bool Pool2DCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const Pool2DParam &param = nnvm::get<Pool2DParam>(attrs.parsed);
CHECK_EQ(ilayouts->size(), 1);
CHECK_EQ(last_ilayouts->size(), 1);
......@@ -121,7 +121,7 @@ NNVM_REGISTER_OP(max_pool2d)
.set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", Pool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", Pool2DInferLayout)
.set_attr<FCorrectLayout>("FCorrectLayout", Pool2DCorrectLayout)
.set_attr<FTVMCompute>("FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
......@@ -192,7 +192,7 @@ NNVM_REGISTER_OP(avg_pool2d)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Pool2DParam>)
.set_attr<FInferShape>("FInferShape", Pool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", Pool2DInferLayout)
.set_attr<FCorrectLayout>("FCorrectLayout", Pool2DCorrectLayout)
.set_attr<FTVMCompute>("FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
......@@ -252,10 +252,10 @@ inline bool GlobalPool2DInferShape(const nnvm::NodeAttrs& attrs,
return true;
}
inline bool GlobalPool2DInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
inline bool GlobalPool2DCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const GlobalPool2DParam &param = nnvm::get<GlobalPool2DParam>(attrs.parsed);
CHECK_EQ(ilayouts->size(), 1);
CHECK_EQ(last_ilayouts->size(), 1);
......@@ -298,7 +298,7 @@ NNVM_REGISTER_OP(global_max_pool2d)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>)
.set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", GlobalPool2DInferLayout)
.set_attr<FCorrectLayout>("FCorrectLayout", GlobalPool2DCorrectLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
......@@ -339,7 +339,7 @@ NNVM_REGISTER_OP(global_avg_pool2d)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>)
.set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", GlobalPool2DInferLayout)
.set_attr<FCorrectLayout>("FCorrectLayout", GlobalPool2DCorrectLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
......
......@@ -60,7 +60,7 @@ NNVM_REGISTER_OP(upsampling)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<UpSamplingParam>)
.set_attr<FInferShape>("FInferShape", UpSamplingInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", UpsamplingLayout)
.set_attr<FCorrectLayout>("FCorrectLayout", UpsamplingLayout)
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(2);
......
......@@ -75,7 +75,7 @@ So with `shape=(2,0)`, we will obtain the same result as in the above example.
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BroadcastToParam>)
.set_attr<FInferShape>("FInferShape", BroadcastToInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
......@@ -128,10 +128,10 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
return true;
}
inline bool BinaryBroadcastInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
inline bool BinaryBroadcastCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
CHECK_EQ(ilayouts->size(), 2U);
CHECK_EQ(olayouts->size(), 1U);
Layout lhs = (*ilayouts)[0];
......@@ -206,8 +206,8 @@ inline bool BinaryBroadcastInferLayout(const NodeAttrs& attrs,
.set_num_outputs(1) \
.set_attr<FInferShape>("FInferShape", BinaryBroadcastShape) \
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<FInferLayout>("FInferLayout", \
BinaryBroadcastInferLayout) \
.set_attr<FCorrectLayout>("FCorrectLayout", \
BinaryBroadcastCorrectLayout) \
.set_attr<FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs) { \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
......
......@@ -333,7 +333,7 @@ NNVM_REGISTER_INIT_OP(full)
.add_arguments(InitOpWithScalarParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpWithScalarParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpWithScalarParam>)
.set_attr<FInferLayout>("FInferLayout", ZeroLayout)
.set_attr<FCorrectLayout>("FCorrectLayout", ZeroLayout)
.set_support_level(4);
NNVM_REGISTER_INIT_OP(zeros)
......@@ -346,7 +346,7 @@ NNVM_REGISTER_INIT_OP(zeros)
.add_arguments(InitOpParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
.set_attr<FInferLayout>("FInferLayout", ZeroLayout)
.set_attr<FCorrectLayout>("FCorrectLayout", ZeroLayout)
.set_support_level(4);
NNVM_REGISTER_INIT_OP(ones)
......@@ -359,7 +359,7 @@ NNVM_REGISTER_INIT_OP(ones)
.add_arguments(InitOpParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
.set_attr<FInferLayout>("FInferLayout", ZeroLayout)
.set_attr<FCorrectLayout>("FCorrectLayout", ZeroLayout)
.set_support_level(4);
// full_like
......@@ -696,7 +696,7 @@ Example::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ClipParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<nnvm::FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
......
......@@ -41,10 +41,10 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs,
return true;
}
inline bool DotInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
inline bool DotCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const MatMulParam& param = nnvm::get<MatMulParam>(attrs.parsed);
CHECK_EQ(ilayouts->size(), 2U);
CHECK_EQ(olayouts->size(), 1U);
......@@ -67,7 +67,7 @@ inline bool DotInferLayout(const NodeAttrs& attrs,
}
NNVM_REGISTER_OP(matmul)
.describe(R"doc(Matrix multiplication of two arrays.
.describe(R"doc(Matrix multiplication of two arrays.
``dot``'s behavior depends on the input array dimensions:
......@@ -92,7 +92,7 @@ NNVM_REGISTER_OP(matmul)
.add_argument("rhs", "NDArray-or-Symbol", "The second input")
.set_attr<FInferShape>("FInferShape", DotShape)
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FInferLayout>("FInferLayout", DotInferLayout)
.set_attr<FCorrectLayout>("FCorrectLayout", DotCorrectLayout)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......
......@@ -111,7 +111,7 @@ inline void AxesParamParser(nnvm::NodeAttrs* attrs) {
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>) \
.set_attr<FInferShape>("FInferShape", ReduceShape) \
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<FInferLayout>("FInferLayout", \
.set_attr<FCorrectLayout>("FCorrectLayout", \
ElemwiseFixedLayoutUnknownOut<1, 1>) \
.set_num_inputs(1) \
.set_num_outputs(1)
......
......@@ -45,8 +45,8 @@ This is an experimental operator.
return Array<Tensor>{ topi::identity(inputs[1]) };
})
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FInferLayout>(
"FInferLayout", [](const NodeAttrs& attrs,
.set_attr<FCorrectLayout>(
"FCorrectLayout", [](const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
......
......@@ -64,7 +64,7 @@ Example::
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", FlattenInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.add_argument("data", "Tensor", "Input data.")
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
......@@ -121,10 +121,10 @@ inline bool ConcatenateInferShape(const NodeAttrs& attrs,
return dshape.Size() != 0;
}
inline bool ConcatenateInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
inline bool ConcatenateCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
CHECK_EQ(ilayouts->size(), last_ilayouts->size());
CHECK_EQ(olayouts->size(), 1U);
......@@ -174,7 +174,7 @@ Example::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ConcatenateParam>)
.set_attr<FInferShape>("FInferShape", ConcatenateInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FInferLayout>("FInferLayout", ConcatenateInferLayout)
.set_attr<FCorrectLayout>("FCorrectLayout", ConcatenateCorrectLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
......@@ -227,7 +227,7 @@ will return a new array with shape ``(2,1,1,1,1,1,3,4)``.
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ExpandDimsParam>)
.set_attr<FInferShape>("FInferShape", ExpandDimsInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<FTVMCompute>(
......@@ -271,7 +271,7 @@ Examples::
.set_attr<nnvm::FInferShape>("FInferShape", AssignOutputAttr<TShape, 1, 0>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
// never transform layout of the second input array.
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<FGradient>(
......@@ -368,7 +368,7 @@ along which to split the array.
.set_attr_parser(SplitParamParser)
.set_attr<FInferShape>("FInferShape", SplitInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, -1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, -1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, -1>)
.set_num_inputs(1)
.set_num_outputs(SplitNumOutputs)
.set_attr<FTVMCompute>(
......@@ -411,7 +411,7 @@ NNVM_REGISTER_OP(cast)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<CastParam>)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", CastInferType)
.set_attr<FInferLayout>("FInferLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(1);
......@@ -564,7 +564,7 @@ The significance of each is explained below:
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReshapeParam>)
.set_attr<FInferShape>("FInferShape", ReshapeInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<FTVMCompute>(
......@@ -605,7 +605,7 @@ the input array into an output array with the same shape as the second input arr
})
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
// never transform layout of the second input array.
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......@@ -688,7 +688,7 @@ Examples::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<SqueezeParam>)
.set_attr<nnvm::FInferShape>("FInferShape", SqueezeShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<FTVMCompute>(
......@@ -737,10 +737,10 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
return true;
}
inline bool TransposeInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
inline bool TransposeCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
CHECK_EQ(ilayouts->size(), 1U);
CHECK_EQ(olayouts->size(), 1U);
......@@ -805,7 +805,7 @@ Examples::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<TransposeParam>)
.set_attr<nnvm::FInferShape>("FInferShape", TransposeShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", TransposeInferLayout)
.set_attr<FCorrectLayout>("FCorrectLayout", TransposeCorrectLayout)
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(4)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment