precompute_prune.cc 2.96 KB
Newer Older
1 2
/*!
 *  Copyright (c) 2017 by Contributors
3 4
 * \file precompute_prune.cc
 * \brief Split the graph into a pre-compute graph and a execution graph.
5
 *
6 7
 *  The pre-compute graph outputs parameters that can be taken
 *  by execution graph during execution phase.
8 9 10 11 12 13 14 15 16 17 18
 */
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass.h>
#include <nnvm/compiler/op_attr_types.h>
#include <unordered_set>

namespace nnvm {
namespace compiler {

19 20 21 22
nnvm::Graph PrecomputePrune(nnvm::Graph src) {
  const auto& plist
      = src.GetAttr<std::vector<std::string> >("param_name_list");
  std::unordered_set<std::string> params(plist.begin(), plist.end());
23 24 25

  std::unordered_set<nnvm::Node*> pruned;
  nnvm::NodeEntryMap<nnvm::NodePtr> entry_var;
26
  std::unordered_set<std::string> unique_name;
27 28
  // number of edges that are not variable
  int non_var_edge = 0;
29

30 31 32 33 34 35 36
  auto replace_pruned_entry = [&] (const NodeEntry& e) {
    if (!entry_var.count(e)) {
      if (!e.node->is_variable()) {
        ++non_var_edge;
      }
      nnvm::NodePtr var = nnvm::Node::Create();
      var->attrs.name = e.node->attrs.name;
37 38 39
      if (e.version) {
          var->attrs.name += "_" + std::to_string(e.version);
      }
40 41 42 43 44 45 46 47 48 49 50 51
      if (e.node->num_outputs() != 1) {
        var->attrs.name += "_output" + std::to_string(e.index);
      }
      entry_var.emplace(e, var);
      CHECK(!unique_name.count(var->attrs.name));
      unique_name.insert(var->attrs.name);
      return nnvm::NodeEntry{var, 0, 0};
    } else {
      return nnvm::NodeEntry{entry_var.at(e), 0, 0};
    }
  };

52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
  DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) {
    bool can_be_pruned = true;
    if (n->is_variable()) {
      if (params.count(n->attrs.name)) {
        pruned.emplace(n.get());
      }
      can_be_pruned = false;
    }

    for (const auto& e : n->inputs) {
      if (!pruned.count(e.node.get())) {
        can_be_pruned = false;
      }
    }
    if (can_be_pruned) {
      pruned.emplace(n.get());
    } else {
      // scan again to find edge nodes, skip variables
      for (auto& e : n->inputs) {
71
        if (pruned.count(e.node.get())) {
72
          e = replace_pruned_entry(e);
73 74 75 76 77
        }
      }
    }
  });

78 79 80 81 82
  // nothing being pruned.
  if (non_var_edge == 0) {
    return src;
  }

83 84 85 86 87 88
  for (auto& e : src.outputs) {
    if (pruned.count(e.node.get())) {
      e = replace_pruned_entry(e);
    }
  }

89 90 91 92
  nnvm::Graph pre_graph;
  pre_graph.outputs.reserve(entry_var.size());
  std::vector<std::string> output_names;
  output_names.reserve(entry_var.size());
93

94 95 96 97
  for (auto kv : entry_var) {
    pre_graph.outputs.emplace_back(kv.first);
    output_names.emplace_back(kv.second->attrs.name);
  }
98 99 100 101 102
  // new parameter list
  pre_graph.attrs["output_names"] =
      std::make_shared<dmlc::any>(std::move(output_names));
  src.attrs["precompute_graph"] =
      std::make_shared<dmlc::any>(std::move(pre_graph));
103 104 105
  return src;
}

106 107
NNVM_REGISTER_PASS(PrecomputePrune)
.set_body(PrecomputePrune);
108 109
}  // namespace compiler
}  // namespace nnvm