Commit d5744844 by Yizhi Liu Committed by Tianqi Chen

Rename FInferLayout -> FCorrectLayout (#453)

* rename FInferLayout -> FCorrectLayout

* correct stupid IDE

* update submodule tvm
parent 9f8fcfc9
...@@ -66,7 +66,7 @@ using DTypeVector = std::vector<int>; ...@@ -66,7 +66,7 @@ using DTypeVector = std::vector<int>;
* int entry_layout = layouts[g.indexed_graph().entry_id(my_entry)]; * int entry_layout = layouts[g.indexed_graph().entry_id(my_entry)];
* \endcode * \endcode
* *
* \sa FInferLayout * \sa FCorrectLayout
*/ */
using LayoutVector = std::vector<Layout>; using LayoutVector = std::vector<Layout>;
......
...@@ -178,7 +178,7 @@ using FSetInputVarAttrOnCompose = std::function<void( ...@@ -178,7 +178,7 @@ using FSetInputVarAttrOnCompose = std::function<void(
const int index)>; const int index)>;
/*! /*!
* \brief Inference function of node layout. See \p Layout for layout convention * \brief Infer & correct function of node layout. See \p Layout for layout convention
* \param attrs The attribute of the node. * \param attrs The attribute of the node.
* \param ilayouts Given the input layouts produced by ancestor nodes, * \param ilayouts Given the input layouts produced by ancestor nodes,
* it should be filled by layouts that the node requests. * it should be filled by layouts that the node requests.
...@@ -196,7 +196,7 @@ using FSetInputVarAttrOnCompose = std::function<void( ...@@ -196,7 +196,7 @@ using FSetInputVarAttrOnCompose = std::function<void(
* \param olayouts Inferred output layouts. * \param olayouts Inferred output layouts.
* \return success flag. * \return success flag.
*/ */
using FInferLayout = std::function<bool( using FCorrectLayout = std::function<bool(
const NodeAttrs& attrs, const NodeAttrs& attrs,
std::vector<Layout> *ilayouts, std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts, const std::vector<Layout> *last_ilayouts,
......
...@@ -33,7 +33,7 @@ using LayoutAttrDict = std::unordered_map<const Node*, std::vector<Layout> >; ...@@ -33,7 +33,7 @@ using LayoutAttrDict = std::unordered_map<const Node*, std::vector<Layout> >;
*/ */
nnvm::Graph CorrectLayout(nnvm::Graph src) { nnvm::Graph CorrectLayout(nnvm::Graph src) {
static auto& op_infer_layout = static auto& op_infer_layout =
nnvm::Op::GetAttr<FInferLayout>("FInferLayout"); nnvm::Op::GetAttr<FCorrectLayout>("FCorrectLayout");
const IndexedGraph& idx = src.indexed_graph(); const IndexedGraph& idx = src.indexed_graph();
std::vector<nnvm::NodePtr> mirror_vec(idx.num_nodes(), nullptr); std::vector<nnvm::NodePtr> mirror_vec(idx.num_nodes(), nullptr);
...@@ -92,7 +92,7 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { ...@@ -92,7 +92,7 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) {
} }
const auto& flayout = op_infer_layout[new_node->op()]; const auto& flayout = op_infer_layout[new_node->op()];
CHECK(flayout != nullptr) << "Attribute FInferLayout" CHECK(flayout != nullptr) << "Attribute FCorrectLayout"
<< " is not registered by op " << inode.source->op()->name << " is not registered by op " << inode.source->op()->name
<< " we are not able to complete layout transform."; << " we are not able to complete layout transform.";
CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts)) CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts))
......
...@@ -271,7 +271,7 @@ inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs, ...@@ -271,7 +271,7 @@ inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs,
.set_num_outputs(1) \ .set_num_outputs(1) \
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) \ .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \ .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<FInferLayout>("FInferLayout", \ .set_attr<FCorrectLayout>("FCorrectLayout", \
ElemwiseArbitraryLayout<1, 1>) \ ElemwiseArbitraryLayout<1, 1>) \
.set_attr<FInplaceOption>("FInplaceOption", \ .set_attr<FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \ [](const NodeAttrs& attrs){ \
...@@ -298,7 +298,7 @@ inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs, ...@@ -298,7 +298,7 @@ inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs,
.set_num_outputs(1) \ .set_num_outputs(1) \
.set_attr<FInferShape>("FInferShape", ElemwiseShape<2, 1>) \ .set_attr<FInferShape>("FInferShape", ElemwiseShape<2, 1>) \
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) \ .set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<FInferLayout>("FInferLayout", \ .set_attr<FCorrectLayout>("FCorrectLayout", \
ElemwiseBinaryKeepLeftLayout) \ ElemwiseBinaryKeepLeftLayout) \
.set_attr<FInplaceOption>("FInplaceOption", \ .set_attr<FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs) { \ [](const NodeAttrs& attrs) { \
...@@ -319,7 +319,7 @@ inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs, ...@@ -319,7 +319,7 @@ inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs,
ParamGetAttrDict<ElementWiseReduceParam>) \ ParamGetAttrDict<ElementWiseReduceParam>) \
.set_attr<nnvm::FInferShape>("FInferShape", \ .set_attr<nnvm::FInferShape>("FInferShape", \
ElementWiseReduceShape) \ ElementWiseReduceShape) \
.set_attr<FInferLayout>("FInferLayout", \ .set_attr<FCorrectLayout>("FCorrectLayout", \
ElemwiseFixedLayoutCopyToOut<1, 1>) \ ElemwiseFixedLayoutCopyToOut<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", ElementWiseReduceType) \ .set_attr<nnvm::FInferType>("FInferType", ElementWiseReduceType) \
.add_argument("args", "Symbol[]", "Positional input arguments") .add_argument("args", "Symbol[]", "Positional input arguments")
...@@ -337,7 +337,7 @@ inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs, ...@@ -337,7 +337,7 @@ inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs,
static_cast<int>(kFloat32)); \ static_cast<int>(kFloat32)); \
return true; \ return true; \
}) \ }) \
.set_attr<FInferLayout>("FInferLayout", \ .set_attr<FCorrectLayout>("FCorrectLayout", \
ElemwiseFixedLayoutUnknownOut<1, 1>) \ ElemwiseFixedLayoutUnknownOut<1, 1>) \
.set_attr<FGradient>( \ .set_attr<FGradient>( \
"FGradient", [](const NodePtr& n, \ "FGradient", [](const NodePtr& n, \
......
...@@ -129,7 +129,7 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, ...@@ -129,7 +129,7 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool Conv2DInferLayout(const NodeAttrs& attrs, inline bool Conv2DCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts, std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts, const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) { std::vector<Layout> *olayouts) {
...@@ -189,7 +189,7 @@ a bias vector is created and added to the outputs. ...@@ -189,7 +189,7 @@ a bias vector is created and added to the outputs.
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>) .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>)
.set_attr<FInferShape>("FInferShape", Conv2DInferShape) .set_attr<FInferShape>("FInferShape", Conv2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FInferLayout>("FInferLayout", Conv2DInferLayout) .set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout)
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DParam>) .set_num_inputs(UseBiasNumInputs<Conv2DParam>)
.set_support_level(2) .set_support_level(2)
...@@ -214,7 +214,7 @@ NNVM_REGISTER_OP(_contrib_conv2d_NCHWc) ...@@ -214,7 +214,7 @@ NNVM_REGISTER_OP(_contrib_conv2d_NCHWc)
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>) .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>)
.set_attr<FInferShape>("FInferShape", Conv2DInferShape) .set_attr<FInferShape>("FInferShape", Conv2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FInferLayout>("FInferLayout", Conv2DInferLayout) .set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout)
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DParam>) .set_num_inputs(UseBiasNumInputs<Conv2DParam>)
.set_support_level(2); .set_support_level(2);
...@@ -306,7 +306,7 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs, ...@@ -306,7 +306,7 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool Conv2DTransposeInferLayout(const NodeAttrs& attrs, inline bool Conv2DTransposeCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts, std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts, const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) { std::vector<Layout> *olayouts) {
...@@ -363,7 +363,7 @@ said convolution. ...@@ -363,7 +363,7 @@ said convolution.
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DTransposeParam>) .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DTransposeParam>)
.set_attr<FInferShape>("FInferShape", Conv2DTransposeInferShape) .set_attr<FInferShape>("FInferShape", Conv2DTransposeInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FInferLayout>("FInferLayout", Conv2DTransposeInferLayout) .set_attr<FCorrectLayout>("FCorrectLayout", Conv2DTransposeCorrectLayout)
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DTransposeParam>) .set_num_inputs(UseBiasNumInputs<Conv2DTransposeParam>)
.set_support_level(2); .set_support_level(2);
......
...@@ -87,7 +87,7 @@ If ``use_bias`` is set to be false, then the ``bias`` term is ignored. ...@@ -87,7 +87,7 @@ If ``use_bias`` is set to be false, then the ``bias`` term is ignored.
.set_attr<FInferShape>("FInferShape", DenseInferShape) .set_attr<FInferShape>("FInferShape", DenseInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
// leave weight & bias layout undefined // leave weight & bias layout undefined
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutCopyToOut<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutCopyToOut<1, 1>)
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -167,7 +167,7 @@ NNVM_REGISTER_OP(dropout) ...@@ -167,7 +167,7 @@ NNVM_REGISTER_OP(dropout)
.set_num_outputs(2) .set_num_outputs(2)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 2>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 2>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 2>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 2>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseArbitraryLayout<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FNumVisibleOutputs>("FNumVisibleOutputs", [](const NodeAttrs& attrs) { .set_attr<FNumVisibleOutputs>("FNumVisibleOutputs", [](const NodeAttrs& attrs) {
return 1; return 1;
}) })
...@@ -201,7 +201,7 @@ inline bool BatchNormInferShape(const nnvm::NodeAttrs& attrs, ...@@ -201,7 +201,7 @@ inline bool BatchNormInferShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool BatchNormInferLayout(const NodeAttrs& attrs, inline bool BatchNormCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts, std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts, const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) { std::vector<Layout> *out_layouts) {
...@@ -307,7 +307,7 @@ axis to be the last item in the input shape. ...@@ -307,7 +307,7 @@ axis to be the last item in the input shape.
.add_arguments(BatchNormParam::__FIELDS__()) .add_arguments(BatchNormParam::__FIELDS__())
.set_attr_parser(ParamParser<BatchNormParam>) .set_attr_parser(ParamParser<BatchNormParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BatchNormParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BatchNormParam>)
.set_attr<FInferLayout>("FInferLayout", BatchNormInferLayout) .set_attr<FCorrectLayout>("FCorrectLayout", BatchNormCorrectLayout)
.set_num_inputs(5) .set_num_inputs(5)
.set_num_outputs(3) .set_num_outputs(3)
.set_attr<FInferShape>("FInferShape", BatchNormInferShape) .set_attr<FInferShape>("FInferShape", BatchNormInferShape)
...@@ -345,7 +345,7 @@ NNVM_REGISTER_OP(softmax) ...@@ -345,7 +345,7 @@ NNVM_REGISTER_OP(softmax)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseArbitraryLayout<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
...@@ -402,7 +402,7 @@ NNVM_REGISTER_OP(log_softmax) ...@@ -402,7 +402,7 @@ NNVM_REGISTER_OP(log_softmax)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseArbitraryLayout<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
...@@ -460,7 +460,7 @@ NNVM_REGISTER_OP(leaky_relu) ...@@ -460,7 +460,7 @@ NNVM_REGISTER_OP(leaky_relu)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseArbitraryLayout<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
...@@ -512,7 +512,7 @@ inline bool PReluInferShape(const nnvm::NodeAttrs &attrs, ...@@ -512,7 +512,7 @@ inline bool PReluInferShape(const nnvm::NodeAttrs &attrs,
return true; return true;
} }
inline bool PReluInferLayout(const NodeAttrs& attrs, inline bool PReluCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts, std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts, const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) { std::vector<Layout> *out_layouts) {
...@@ -550,7 +550,7 @@ where :math:`*` is an channelwise multiplication for each sample in the ...@@ -550,7 +550,7 @@ where :math:`*` is an channelwise multiplication for each sample in the
.set_num_inputs(2) .set_num_inputs(2)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", PReluInferShape) .set_attr<FInferShape>("FInferShape", PReluInferShape)
.set_attr<FInferLayout>("FInferLayout", PReluInferLayout) .set_attr<FCorrectLayout>("FCorrectLayout", PReluCorrectLayout)
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) { .set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "alpha"}; return std::vector<std::string>{"data", "alpha"};
}) })
...@@ -597,7 +597,7 @@ NNVM_REGISTER_OP(pad) ...@@ -597,7 +597,7 @@ NNVM_REGISTER_OP(pad)
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", PadInferShape) .set_attr<FInferShape>("FInferShape", PadInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutCopyToOut<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutCopyToOut<1, 1>)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
...@@ -651,8 +651,8 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] ...@@ -651,8 +651,8 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
.set_attr_parser(ParamParser<LayoutTransformParam>) .set_attr_parser(ParamParser<LayoutTransformParam>)
.set_attr<FInferShape>("FInferShape", LayoutTransformInferShape) .set_attr<FInferShape>("FInferShape", LayoutTransformInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>( .set_attr<FCorrectLayout>(
"FInferLayout", [](const NodeAttrs& attrs, "FCorrectLayout", [](const NodeAttrs& attrs,
std::vector<Layout> *ilayouts, std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts, const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) { std::vector<Layout> *olayouts) {
......
...@@ -66,7 +66,7 @@ inline bool Pool2DInferShape(const nnvm::NodeAttrs& attrs, ...@@ -66,7 +66,7 @@ inline bool Pool2DInferShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool Pool2DInferLayout(const NodeAttrs& attrs, inline bool Pool2DCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts, std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts, const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) { std::vector<Layout> *olayouts) {
...@@ -121,7 +121,7 @@ NNVM_REGISTER_OP(max_pool2d) ...@@ -121,7 +121,7 @@ NNVM_REGISTER_OP(max_pool2d)
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", Pool2DInferShape) .set_attr<FInferShape>("FInferShape", Pool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", Pool2DInferLayout) .set_attr<FCorrectLayout>("FCorrectLayout", Pool2DCorrectLayout)
.set_attr<FTVMCompute>("FTVMCompute", [](const NodeAttrs& attrs, .set_attr<FTVMCompute>("FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
...@@ -192,7 +192,7 @@ NNVM_REGISTER_OP(avg_pool2d) ...@@ -192,7 +192,7 @@ NNVM_REGISTER_OP(avg_pool2d)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Pool2DParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Pool2DParam>)
.set_attr<FInferShape>("FInferShape", Pool2DInferShape) .set_attr<FInferShape>("FInferShape", Pool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", Pool2DInferLayout) .set_attr<FCorrectLayout>("FCorrectLayout", Pool2DCorrectLayout)
.set_attr<FTVMCompute>("FTVMCompute", [](const NodeAttrs& attrs, .set_attr<FTVMCompute>("FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
...@@ -252,7 +252,7 @@ inline bool GlobalPool2DInferShape(const nnvm::NodeAttrs& attrs, ...@@ -252,7 +252,7 @@ inline bool GlobalPool2DInferShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool GlobalPool2DInferLayout(const NodeAttrs& attrs, inline bool GlobalPool2DCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts, std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts, const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) { std::vector<Layout> *olayouts) {
...@@ -298,7 +298,7 @@ NNVM_REGISTER_OP(global_max_pool2d) ...@@ -298,7 +298,7 @@ NNVM_REGISTER_OP(global_max_pool2d)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>)
.set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape) .set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", GlobalPool2DInferLayout) .set_attr<FCorrectLayout>("FCorrectLayout", GlobalPool2DCorrectLayout)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
...@@ -339,7 +339,7 @@ NNVM_REGISTER_OP(global_avg_pool2d) ...@@ -339,7 +339,7 @@ NNVM_REGISTER_OP(global_avg_pool2d)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>)
.set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape) .set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", GlobalPool2DInferLayout) .set_attr<FCorrectLayout>("FCorrectLayout", GlobalPool2DCorrectLayout)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
......
...@@ -60,7 +60,7 @@ NNVM_REGISTER_OP(upsampling) ...@@ -60,7 +60,7 @@ NNVM_REGISTER_OP(upsampling)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<UpSamplingParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<UpSamplingParam>)
.set_attr<FInferShape>("FInferShape", UpSamplingInferShape) .set_attr<FInferShape>("FInferShape", UpSamplingInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", UpsamplingLayout) .set_attr<FCorrectLayout>("FCorrectLayout", UpsamplingLayout)
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(1) .set_num_inputs(1)
.set_support_level(2); .set_support_level(2);
......
...@@ -75,7 +75,7 @@ So with `shape=(2,0)`, we will obtain the same result as in the above example. ...@@ -75,7 +75,7 @@ So with `shape=(2,0)`, we will obtain the same result as in the above example.
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BroadcastToParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BroadcastToParam>)
.set_attr<FInferShape>("FInferShape", BroadcastToInferShape) .set_attr<FInferShape>("FInferShape", BroadcastToInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
...@@ -128,7 +128,7 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs, ...@@ -128,7 +128,7 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool BinaryBroadcastInferLayout(const NodeAttrs& attrs, inline bool BinaryBroadcastCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts, std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts, const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) { std::vector<Layout> *olayouts) {
...@@ -206,8 +206,8 @@ inline bool BinaryBroadcastInferLayout(const NodeAttrs& attrs, ...@@ -206,8 +206,8 @@ inline bool BinaryBroadcastInferLayout(const NodeAttrs& attrs,
.set_num_outputs(1) \ .set_num_outputs(1) \
.set_attr<FInferShape>("FInferShape", BinaryBroadcastShape) \ .set_attr<FInferShape>("FInferShape", BinaryBroadcastShape) \
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) \ .set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<FInferLayout>("FInferLayout", \ .set_attr<FCorrectLayout>("FCorrectLayout", \
BinaryBroadcastInferLayout) \ BinaryBroadcastCorrectLayout) \
.set_attr<FInplaceOption>("FInplaceOption", \ .set_attr<FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs) { \ [](const NodeAttrs& attrs) { \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \ return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
......
...@@ -333,7 +333,7 @@ NNVM_REGISTER_INIT_OP(full) ...@@ -333,7 +333,7 @@ NNVM_REGISTER_INIT_OP(full)
.add_arguments(InitOpWithScalarParam::__FIELDS__()) .add_arguments(InitOpWithScalarParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpWithScalarParam>) .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpWithScalarParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpWithScalarParam>) .set_attr<FInferType>("FInferType", ZeroType<InitOpWithScalarParam>)
.set_attr<FInferLayout>("FInferLayout", ZeroLayout) .set_attr<FCorrectLayout>("FCorrectLayout", ZeroLayout)
.set_support_level(4); .set_support_level(4);
NNVM_REGISTER_INIT_OP(zeros) NNVM_REGISTER_INIT_OP(zeros)
...@@ -346,7 +346,7 @@ NNVM_REGISTER_INIT_OP(zeros) ...@@ -346,7 +346,7 @@ NNVM_REGISTER_INIT_OP(zeros)
.add_arguments(InitOpParam::__FIELDS__()) .add_arguments(InitOpParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>) .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpParam>) .set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
.set_attr<FInferLayout>("FInferLayout", ZeroLayout) .set_attr<FCorrectLayout>("FCorrectLayout", ZeroLayout)
.set_support_level(4); .set_support_level(4);
NNVM_REGISTER_INIT_OP(ones) NNVM_REGISTER_INIT_OP(ones)
...@@ -359,7 +359,7 @@ NNVM_REGISTER_INIT_OP(ones) ...@@ -359,7 +359,7 @@ NNVM_REGISTER_INIT_OP(ones)
.add_arguments(InitOpParam::__FIELDS__()) .add_arguments(InitOpParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>) .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpParam>) .set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
.set_attr<FInferLayout>("FInferLayout", ZeroLayout) .set_attr<FCorrectLayout>("FCorrectLayout", ZeroLayout)
.set_support_level(4); .set_support_level(4);
// full_like // full_like
...@@ -696,7 +696,7 @@ Example:: ...@@ -696,7 +696,7 @@ Example::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ClipParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ClipParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_attr<nnvm::FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
......
...@@ -41,7 +41,7 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs, ...@@ -41,7 +41,7 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool DotInferLayout(const NodeAttrs& attrs, inline bool DotCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts, std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts, const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) { std::vector<Layout> *olayouts) {
...@@ -67,7 +67,7 @@ inline bool DotInferLayout(const NodeAttrs& attrs, ...@@ -67,7 +67,7 @@ inline bool DotInferLayout(const NodeAttrs& attrs,
} }
NNVM_REGISTER_OP(matmul) NNVM_REGISTER_OP(matmul)
.describe(R"doc(Matrix multiplication of two arrays. .describe(R"doc(Matrix multiplication of two arrays.
``dot``'s behavior depends on the input array dimensions: ``dot``'s behavior depends on the input array dimensions:
...@@ -92,7 +92,7 @@ NNVM_REGISTER_OP(matmul) ...@@ -92,7 +92,7 @@ NNVM_REGISTER_OP(matmul)
.add_argument("rhs", "NDArray-or-Symbol", "The second input") .add_argument("rhs", "NDArray-or-Symbol", "The second input")
.set_attr<FInferShape>("FInferShape", DotShape) .set_attr<FInferShape>("FInferShape", DotShape)
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FInferLayout>("FInferLayout", DotInferLayout) .set_attr<FCorrectLayout>("FCorrectLayout", DotCorrectLayout)
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
......
...@@ -111,7 +111,7 @@ inline void AxesParamParser(nnvm::NodeAttrs* attrs) { ...@@ -111,7 +111,7 @@ inline void AxesParamParser(nnvm::NodeAttrs* attrs) {
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>) \ .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>) \
.set_attr<FInferShape>("FInferShape", ReduceShape) \ .set_attr<FInferShape>("FInferShape", ReduceShape) \
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \ .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<FInferLayout>("FInferLayout", \ .set_attr<FCorrectLayout>("FCorrectLayout", \
ElemwiseFixedLayoutUnknownOut<1, 1>) \ ElemwiseFixedLayoutUnknownOut<1, 1>) \
.set_num_inputs(1) \ .set_num_inputs(1) \
.set_num_outputs(1) .set_num_outputs(1)
......
...@@ -45,8 +45,8 @@ This is an experimental operator. ...@@ -45,8 +45,8 @@ This is an experimental operator.
return Array<Tensor>{ topi::identity(inputs[1]) }; return Array<Tensor>{ topi::identity(inputs[1]) };
}) })
.set_attr<FInferShape>("FInferShape", SameShape) .set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FInferLayout>( .set_attr<FCorrectLayout>(
"FInferLayout", [](const NodeAttrs& attrs, "FCorrectLayout", [](const NodeAttrs& attrs,
std::vector<Layout> *in_layouts, std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts, const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) { std::vector<Layout> *out_layouts) {
......
...@@ -64,7 +64,7 @@ Example:: ...@@ -64,7 +64,7 @@ Example::
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", FlattenInferShape) .set_attr<FInferShape>("FInferShape", FlattenInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.add_argument("data", "Tensor", "Input data.") .add_argument("data", "Tensor", "Input data.")
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
...@@ -121,7 +121,7 @@ inline bool ConcatenateInferShape(const NodeAttrs& attrs, ...@@ -121,7 +121,7 @@ inline bool ConcatenateInferShape(const NodeAttrs& attrs,
return dshape.Size() != 0; return dshape.Size() != 0;
} }
inline bool ConcatenateInferLayout(const NodeAttrs& attrs, inline bool ConcatenateCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts, std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts, const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) { std::vector<Layout> *olayouts) {
...@@ -174,7 +174,7 @@ Example:: ...@@ -174,7 +174,7 @@ Example::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ConcatenateParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ConcatenateParam>)
.set_attr<FInferShape>("FInferShape", ConcatenateInferShape) .set_attr<FInferShape>("FInferShape", ConcatenateInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FInferLayout>("FInferLayout", ConcatenateInferLayout) .set_attr<FCorrectLayout>("FCorrectLayout", ConcatenateCorrectLayout)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
...@@ -227,7 +227,7 @@ will return a new array with shape ``(2,1,1,1,1,1,3,4)``. ...@@ -227,7 +227,7 @@ will return a new array with shape ``(2,1,1,1,1,1,3,4)``.
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ExpandDimsParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ExpandDimsParam>)
.set_attr<FInferShape>("FInferShape", ExpandDimsInferShape) .set_attr<FInferShape>("FInferShape", ExpandDimsInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
...@@ -271,7 +271,7 @@ Examples:: ...@@ -271,7 +271,7 @@ Examples::
.set_attr<nnvm::FInferShape>("FInferShape", AssignOutputAttr<TShape, 1, 0>) .set_attr<nnvm::FInferShape>("FInferShape", AssignOutputAttr<TShape, 1, 0>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
// never transform layout of the second input array. // never transform layout of the second input array.
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(2) .set_num_inputs(2)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FGradient>( .set_attr<FGradient>(
...@@ -368,7 +368,7 @@ along which to split the array. ...@@ -368,7 +368,7 @@ along which to split the array.
.set_attr_parser(SplitParamParser) .set_attr_parser(SplitParamParser)
.set_attr<FInferShape>("FInferShape", SplitInferShape) .set_attr<FInferShape>("FInferShape", SplitInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, -1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, -1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, -1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, -1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(SplitNumOutputs) .set_num_outputs(SplitNumOutputs)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
...@@ -411,7 +411,7 @@ NNVM_REGISTER_OP(cast) ...@@ -411,7 +411,7 @@ NNVM_REGISTER_OP(cast)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<CastParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<CastParam>)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", CastInferType) .set_attr<FInferType>("FInferType", CastInferType)
.set_attr<FInferLayout>("FInferLayout", ElemwiseArbitraryLayout<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_support_level(1); .set_support_level(1);
...@@ -564,7 +564,7 @@ The significance of each is explained below: ...@@ -564,7 +564,7 @@ The significance of each is explained below:
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReshapeParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReshapeParam>)
.set_attr<FInferShape>("FInferShape", ReshapeInferShape) .set_attr<FInferShape>("FInferShape", ReshapeInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
...@@ -605,7 +605,7 @@ the input array into an output array with the same shape as the second input arr ...@@ -605,7 +605,7 @@ the input array into an output array with the same shape as the second input arr
}) })
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
// never transform layout of the second input array. // never transform layout of the second input array.
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -688,7 +688,7 @@ Examples:: ...@@ -688,7 +688,7 @@ Examples::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<SqueezeParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<SqueezeParam>)
.set_attr<nnvm::FInferShape>("FInferShape", SqueezeShape) .set_attr<nnvm::FInferShape>("FInferShape", SqueezeShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
...@@ -737,7 +737,7 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs, ...@@ -737,7 +737,7 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool TransposeInferLayout(const NodeAttrs& attrs, inline bool TransposeCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts, std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts, const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) { std::vector<Layout> *olayouts) {
...@@ -805,7 +805,7 @@ Examples:: ...@@ -805,7 +805,7 @@ Examples::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<TransposeParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<TransposeParam>)
.set_attr<nnvm::FInferShape>("FInferShape", TransposeShape) .set_attr<nnvm::FInferShape>("FInferShape", TransposeShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", TransposeInferLayout) .set_attr<FCorrectLayout>("FCorrectLayout", TransposeCorrectLayout)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_support_level(4) .set_support_level(4)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment