graph.cc 2.79 KB
Newer Older
1 2 3 4 5
/*!
 *  Copyright (c) 2016 by Contributors
 * \file graph_attr_types.cc
 * \brief Graph node data structure.
 */
6
#include <nnvm/graph.h>
Tianqi Chen committed
7
#include <nnvm/op_attr_types.h>
8 9
#include <limits>

tqchen committed
10
namespace nnvm {
11

12 13 14 15 16 17 18
const IndexedGraph& Graph::indexed_graph() {
  if (indexed_graph_ == nullptr) {
    indexed_graph_.reset(new IndexedGraph(*this));
  }
  return *indexed_graph_;
}

19 20 21 22 23
// implement constructor from graph
IndexedGraph::IndexedGraph(const Graph &g) {
  entry_rptr_.push_back(0);
  std::vector<size_t> inputs_rptr{0}, control_rptr{0};

24
  DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr]
Tianqi Chen committed
25
             (const NodePtr& n) {
26 27 28 29 30 31 32 33
      CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
      uint32_t nid = static_cast<uint32_t>(nodes_.size());
      // nodes_
      IndexedGraph::Node new_node;
      new_node.source = n.get();
      nodes_.emplace_back(std::move(new_node));
      // arg_nodes_
      if (n->is_variable()) {
34
        input_nodes_.push_back(nid);
35 36 37 38 39 40 41 42 43
      }
      // node2index_
      node2index_[n.get()] = nid;
      // entry rptr
      entry_rptr_.push_back(entry_rptr_.back() + n->num_outputs());
      // input entries
      for (const auto& e : n->inputs) {
        auto it = node2index_.find(e.node.get());
        CHECK(it != node2index_.end() && it->first == e.node.get());
44
        input_entries_.emplace_back(NodeEntry{it->second, e.index, e.version});
45 46 47 48 49 50 51 52 53 54 55
      }
      inputs_rptr.push_back(input_entries_.size());
      // control deps
      for (const auto& nptr : n->control_deps) {
        auto it = node2index_.find(nptr.get());
        CHECK(it != node2index_.end() && it->first == nptr.get());
        control_deps_.push_back(it->second);
      }
      control_rptr.push_back(control_deps_.size());
  });

56 57 58 59 60
  for (const auto& e : g.outputs) {
    outputs_.emplace_back(NodeEntry{
        node2index_.at(e.node.get()), e.index, e.version});
  }

Tianqi Chen committed
61 62
  static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
  std::unordered_set<uint32_t> mutable_inputs;
63
  // setup array view
tqchen committed
64
  // input_entries_ and control_rptr must not change after this step.
65 66 67 68
  const NodeEntry* iptr = dmlc::BeginPtr(input_entries_);
  for (size_t nid = 0; nid < nodes_.size(); ++nid) {
    nodes_[nid].inputs = array_view<NodeEntry>(
        iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]);
Tianqi Chen committed
69 70 71 72 73 74
    if (nodes_[nid].source->op != nullptr &&
        fmutate_inputs.count(nodes_[nid].source->op)) {
      for (uint32_t i : fmutate_inputs[nodes_[nid].source->op](nodes_[nid].source->attrs)) {
        mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id);
      }
    }
75 76 77 78 79 80 81 82
  }
  const uint32_t* cptr = dmlc::BeginPtr(control_deps_);
  for (size_t nid = 0; nid < nodes_.size(); ++nid) {
    nodes_[nid].control_deps = array_view<uint32_t>(
        cptr + control_rptr[nid], cptr + control_rptr[nid + 1]);
  }
}

tqchen committed
83
}  // namespace nnvm