/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2018 by Contributors * \file correct_layout.cc * \brief Infer and correct layout. */ #include <nnvm/graph.h> #include <nnvm/op_attr_types.h> #include <nnvm/graph_attr_types.h> #include <nnvm/pass.h> #include <nnvm/layout.h> namespace nnvm { namespace pass { nnvm::NodePtr CreateLayoutTransformNode(const Layout& src, const Layout& dst) { static const nnvm::Op* trans_op = nnvm::Op::Get("__layout_transform__"); static int count = 0; nnvm::NodePtr n = nnvm::Node::Create(); n->attrs.op = trans_op; n->attrs.name = src.name() + "_to_" + dst.name() + std::to_string(count++); n->attrs.dict["src_layout"] = src.name(); n->attrs.dict["dst_layout"] = dst.name(); n->op()->attr_parser(&(n->attrs)); return n; } using LayoutAttrDict = std::unordered_map<const Node*, std::vector<Layout> >; /*! * \brief A simple layout infer & correct pass that will * insert layout transform nodes automatically. */ nnvm::Graph CorrectLayout(nnvm::Graph src) { static auto& op_correct_layout = nnvm::Op::GetAttr<FCorrectLayout>("FCorrectLayout"); const IndexedGraph& idx = src.indexed_graph(); std::vector<nnvm::NodePtr> mirror_vec(idx.num_nodes(), nullptr); // (new) NodePtr -> output_layouts LayoutAttrDict new_layouts; for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; nnvm::NodePtr new_node = nnvm::Node::Create(); *new_node = *(inode.source); if (new_node->is_variable()) { // Variable node. No operator. Only one output entry. auto input_iter = std::find( idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid); CHECK(input_iter != idx.input_nodes().cend()); int64_t input_id = std::distance(idx.input_nodes().cbegin(), input_iter); if (src.HasAttr("layout_inputs")) { new_layouts[new_node.get()] = {src.GetAttr<std::vector<Layout> >("layout_inputs")[input_id]}; } else { new_layouts[new_node.get()] = {Layout::Undef()}; } mirror_vec[nid] = new_node; continue; } const uint32_t num_inputs = inode.inputs.size(); const uint32_t num_outputs = inode.source->num_outputs(); // set up output and input layouts std::vector<Layout> request_ilayouts(num_inputs, Layout::Undef()); for (size_t i = 0; i < num_inputs; ++i) { const IndexedGraph::NodeEntry& input_entry = inode.inputs[i]; const NodePtr& new_input_node = mirror_vec[input_entry.node_id]; CHECK(new_input_node != nullptr); // fill inputs by previous node (DFS order) inferred layouts. const auto& layouts_iter = new_layouts.find(new_input_node.get()); CHECK(layouts_iter != new_layouts.end()); request_ilayouts[i] = layouts_iter->second[input_entry.index]; } // layouts produced by previous node. std::vector<Layout> produce_ilayouts(request_ilayouts); // input layouts from last pass of LayoutTransform (if apply) std::vector<Layout> last_request_ilayouts(num_inputs, Layout::Undef()); // fill outputs by last pass of LayoutTransform (if apply) std::vector<Layout> produce_olayouts(num_outputs, Layout::Undef()); if (src.HasAttr("layout")) { const auto& layouts = src.GetAttr<std::vector<Layout> >("layout"); for (uint32_t i = 0; i < num_outputs; ++i) { produce_olayouts[i] = layouts[idx.entry_id(nid, i)]; } for (uint32_t i = 0; i < num_inputs; ++i) { last_request_ilayouts[i] = layouts[idx.entry_id(inode.inputs[i])]; } } if (op_correct_layout.count(new_node->op())) { const auto &flayout = op_correct_layout[new_node->op()]; CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts)) << "Layout infer fail"; CHECK_EQ(request_ilayouts.size(), num_inputs); CHECK_EQ(produce_olayouts.size(), num_outputs); } // update new layouts new_layouts[new_node.get()] = std::move(produce_olayouts); for (uint32_t i = 0; i < inode.inputs.size(); ++i) { const auto& e = inode.inputs[i]; const nnvm::NodePtr& in = mirror_vec[e.node_id]; new_node->inputs[i] = nnvm::NodeEntry{in, e.index, e.version}; // insert layout_transform if necessary const Layout& produce = produce_ilayouts[i]; const Layout& request = request_ilayouts[i]; if (produce != request && produce.defined()) { nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request); tnode->attrs.name = idx[e.node_id].source->attrs.name + "_" + request.name(); tnode->inputs.emplace_back(new_node->inputs[i]); nnvm::NodeEntry tnode_output(std::move(tnode), 0, 0); new_node->inputs[i] = tnode_output; // layout produced by LayoutTransformNode new_layouts[tnode_output.node.get()] = {request}; } else if (!produce.defined()) { // do reverse infer new_layouts[in.get()][e.index] = request; } } mirror_vec[nid] = new_node; } std::vector<nnvm::NodeEntry> outputs; for (const auto& e : idx.outputs()) { outputs.emplace_back(nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version}); } nnvm::Graph ret; ret.outputs = outputs; // restore the layouts to return graph const auto& ret_idx = ret.indexed_graph(); std::vector<Layout> ret_layouts(ret_idx.num_node_entries(), Layout::Undef()); for (uint32_t nid = 0; nid < ret_idx.num_nodes(); ++nid) { const auto& inode = ret_idx[nid]; const auto& layout_iter = new_layouts.find(inode.source); if (layout_iter != new_layouts.end()) { for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) { ret_layouts[ret_idx.entry_id(nid, i)] = std::move(layout_iter->second[i]); } } } // cannot call indexed_graph() before return the origin Graph, // thus create a new one nnvm::Graph new_ret; new_ret.outputs = std::move(outputs); new_ret.attrs["layout"] = std::make_shared<any>(std::move(ret_layouts)); return new_ret; } // register pass NNVM_REGISTER_PASS(CorrectLayout) .describe("Return a layout-transformed graph of src.") .set_body(CorrectLayout) .provide_graph_attr("layout") .set_change_graph(true); DMLC_JSON_ENABLE_ANY(LayoutVector, list_layout); } // namespace pass } // namespace nnvm