Commit 1e4bb2f8 by Yizhi Liu Committed by Tianqi Chen

allow missing FCorrectLayout (#457)

* allow missing FCorrectLayout

* misunderstood OpMap[], fix
parent 1b5877f2
...@@ -28,11 +28,11 @@ nnvm::NodePtr CreateLayoutTransformNode(const Layout& src, ...@@ -28,11 +28,11 @@ nnvm::NodePtr CreateLayoutTransformNode(const Layout& src,
using LayoutAttrDict = std::unordered_map<const Node*, std::vector<Layout> >; using LayoutAttrDict = std::unordered_map<const Node*, std::vector<Layout> >;
/*! /*!
* \brief A simple layout infer pass that will * \brief A simple layout infer & correct pass that will
* insert layout transform nodes automatically. * insert layout transform nodes automatically.
*/ */
nnvm::Graph CorrectLayout(nnvm::Graph src) { nnvm::Graph CorrectLayout(nnvm::Graph src) {
static auto& op_infer_layout = static auto& op_correct_layout =
nnvm::Op::GetAttr<FCorrectLayout>("FCorrectLayout"); nnvm::Op::GetAttr<FCorrectLayout>("FCorrectLayout");
const IndexedGraph& idx = src.indexed_graph(); const IndexedGraph& idx = src.indexed_graph();
...@@ -91,14 +91,13 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { ...@@ -91,14 +91,13 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) {
} }
} }
const auto& flayout = op_infer_layout[new_node->op()]; if (op_correct_layout.count(new_node->op())) {
CHECK(flayout != nullptr) << "Attribute FCorrectLayout" const auto &flayout = op_correct_layout[new_node->op()];
<< " is not registered by op " << inode.source->op()->name
<< " we are not able to complete layout transform.";
CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts)) CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts))
<< "Layout infer fail"; << "Layout infer fail";
CHECK_EQ(request_ilayouts.size(), num_inputs); CHECK_EQ(request_ilayouts.size(), num_inputs);
CHECK_EQ(produce_olayouts.size(), num_outputs); CHECK_EQ(produce_olayouts.size(), num_outputs);
}
// update new layouts // update new layouts
new_layouts[new_node.get()] = std::move(produce_olayouts); new_layouts[new_node.get()] = std::move(produce_olayouts);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment