/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2017 by Contributors * \file precompute_prune.cc * \brief Split the graph into a pre-compute graph and a execution graph. * * The pre-compute graph outputs parameters that can be taken * by execution graph during execution phase. */ #include <nnvm/graph.h> #include <nnvm/op_attr_types.h> #include <nnvm/graph_attr_types.h> #include <nnvm/pass.h> #include <nnvm/compiler/op_attr_types.h> #include <unordered_set> namespace nnvm { namespace compiler { nnvm::Graph PrecomputePrune(nnvm::Graph src) { const auto& plist = src.GetAttr<std::vector<std::string> >("param_name_list"); std::unordered_set<std::string> params(plist.begin(), plist.end()); std::unordered_set<nnvm::Node*> pruned; nnvm::NodeEntryMap<nnvm::NodePtr> entry_var; std::unordered_set<std::string> unique_name; // number of edges that are not variable int non_var_edge = 0; auto replace_pruned_entry = [&] (const NodeEntry& e) { if (!entry_var.count(e)) { if (!e.node->is_variable()) { ++non_var_edge; } nnvm::NodePtr var = nnvm::Node::Create(); var->attrs.name = e.node->attrs.name; if (e.version) { var->attrs.name += "_" + std::to_string(e.version); } if (e.node->num_outputs() != 1) { var->attrs.name += "_output" + std::to_string(e.index); } entry_var.emplace(e, var); CHECK(!unique_name.count(var->attrs.name)); unique_name.insert(var->attrs.name); return nnvm::NodeEntry{var, 0, 0}; } else { return nnvm::NodeEntry{entry_var.at(e), 0, 0}; } }; DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) { bool can_be_pruned = true; if (n->is_variable()) { if (params.count(n->attrs.name)) { pruned.emplace(n.get()); } can_be_pruned = false; } for (const auto& e : n->inputs) { if (!pruned.count(e.node.get())) { can_be_pruned = false; } } if (can_be_pruned) { pruned.emplace(n.get()); } else { // scan again to find edge nodes, skip variables for (auto& e : n->inputs) { if (pruned.count(e.node.get())) { e = replace_pruned_entry(e); } } } }); // nothing being pruned. if (non_var_edge == 0) { return src; } for (auto& e : src.outputs) { if (pruned.count(e.node.get())) { e = replace_pruned_entry(e); } } nnvm::Graph pre_graph; pre_graph.outputs.reserve(entry_var.size()); std::vector<std::string> output_names; output_names.reserve(entry_var.size()); for (auto kv : entry_var) { pre_graph.outputs.emplace_back(kv.first); output_names.emplace_back(kv.second->attrs.name); } // new parameter list pre_graph.attrs["output_names"] = std::make_shared<dmlc::any>(std::move(output_names)); src.attrs["precompute_graph"] = std::make_shared<dmlc::any>(std::move(pre_graph)); return src; } NNVM_REGISTER_PASS(PrecomputePrune) .set_body(PrecomputePrune); } // namespace compiler } // namespace nnvm