/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2017 by Contributors * \file graph_transform.h * \brief A mutator class that does local pattern matching and mutates a node. */ #ifndef NNVM_COMPILER_GRAPH_TRANSFORM_H_ #define NNVM_COMPILER_GRAPH_TRANSFORM_H_ #include <nnvm/graph.h> #include <vector> #include <utility> #include <unordered_map> namespace nnvm { namespace compiler { /*! * \brief Transform the graph to build a new Graph, in post DFS order. * * Automatically copies node when some of its children or control_deps changed. * This function won't be called in Variable. * * \param graph The original graph * * \param ftransform Function of (int nid, const NodePtr& node, std::vector<NodeEntry>* out) -> bool * * If empty vector is returned, it means original entries should be kept. * * \tparam FTransform The transformation function. */ template<typename FTransform> Graph GraphTransform(Graph graph, FTransform ftransform) { const IndexedGraph& idx = graph.indexed_graph(); // new nodes std::vector<NodeEntry> new_entry_map(idx.num_node_entries()); std::vector<bool> updated(idx.num_node_entries(), false); // setup inputs and placeholder. for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; bool need_copy = false; for (const IndexedGraph::NodeEntry& e : inode.inputs) { if (updated[idx.entry_id(e)]) { need_copy = true; break; } } if (!need_copy) { for (const uint32_t cid : inode.control_deps) { const auto& cnode = idx[cid]; for (uint32_t i = 0 ; i < cnode.source->num_outputs(); ++i) { if (updated[idx.entry_id(cid, i)]) { need_copy = true; } } if (need_copy) break; } } if (!need_copy) { std::vector<NodeEntry> ret; if (ftransform(nid, inode.weak_ref.lock(), &ret)) { CHECK_EQ(ret.size(), static_cast<size_t>(inode.source->num_outputs())); for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) { updated[idx.entry_id(nid, i)] = true; new_entry_map[idx.entry_id(nid, i)] = ret[i]; } } } else { NodePtr node = Node::Create(); node->attrs = inode.source->attrs; for (size_t i = 0; i < inode.inputs.size(); ++i) { const IndexedGraph::NodeEntry& e = inode.inputs[i]; if (updated[idx.entry_id(e)]) { node->inputs.push_back(new_entry_map[idx.entry_id(e)]); } else { node->inputs.push_back(inode.source->inputs[i]); } } for (size_t i = 0; i < inode.control_deps.size(); ++i) { const uint32_t cid = inode.control_deps[i]; const auto& cnode = idx[cid]; CHECK_NE(cnode.source->num_outputs(), 0U); NodePtr selected_ptr; for (uint32_t j = 0 ; j < cnode.source->num_outputs(); ++j) { NodePtr cptr = updated[idx.entry_id(cid, j)] ? new_entry_map[idx.entry_id(cid, j)].node : inode.source->control_deps[i]; if (selected_ptr == nullptr) { selected_ptr = std::move(cptr); } else { CHECK(selected_ptr.get() == cptr.get()) << "Control dependency node changed to more than one node"; } } node->control_deps.push_back(selected_ptr); } std::vector<NodeEntry> ret; if (ftransform(nid, node, &ret)) { CHECK_EQ(ret.size(), static_cast<size_t>(inode.source->num_outputs())); for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) { updated[idx.entry_id(nid, i)] = true; new_entry_map[idx.entry_id(nid, i)] = ret[i]; } } else { for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) { updated[idx.entry_id(nid, i)] = true; new_entry_map[idx.entry_id(nid, i)] = NodeEntry{node, i, 0}; } } } } Graph ret; for (size_t i = 0; i < idx.outputs().size(); ++i) { const IndexedGraph::NodeEntry& e = idx.outputs()[i]; if (updated[idx.entry_id(e)]) { ret.outputs.push_back(new_entry_map[idx.entry_id(e)]); } else { ret.outputs.push_back(graph.outputs[i]); } } return ret; } } // namespace compiler } // namespace nnvm #endif // NNVM_COMPILER_GRAPH_TRANSFORM_H_