Commit 1e4bb2f8 by Yizhi Liu Committed by Tianqi Chen

allow missing FCorrectLayout (#457)

* allow missing FCorrectLayout

* misunderstood OpMap[], fix
parent 1b5877f2
......@@ -28,11 +28,11 @@ nnvm::NodePtr CreateLayoutTransformNode(const Layout& src,
using LayoutAttrDict = std::unordered_map<const Node*, std::vector<Layout> >;
/*!
* \brief A simple layout infer pass that will
* \brief A simple layout infer & correct pass that will
* insert layout transform nodes automatically.
*/
nnvm::Graph CorrectLayout(nnvm::Graph src) {
static auto& op_infer_layout =
static auto& op_correct_layout =
nnvm::Op::GetAttr<FCorrectLayout>("FCorrectLayout");
const IndexedGraph& idx = src.indexed_graph();
......@@ -91,14 +91,13 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) {
}
}
const auto& flayout = op_infer_layout[new_node->op()];
CHECK(flayout != nullptr) << "Attribute FCorrectLayout"
<< " is not registered by op " << inode.source->op()->name
<< " we are not able to complete layout transform.";
CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts))
if (op_correct_layout.count(new_node->op())) {
const auto &flayout = op_correct_layout[new_node->op()];
CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts))
<< "Layout infer fail";
CHECK_EQ(request_ilayouts.size(), num_inputs);
CHECK_EQ(produce_olayouts.size(), num_outputs);
CHECK_EQ(request_ilayouts.size(), num_inputs);
CHECK_EQ(produce_olayouts.size(), num_outputs);
}
// update new layouts
new_layouts[new_node.get()] = std::move(produce_olayouts);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment