graph.cc 5.56 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 *  Copyright (c) 2016 by Contributors
 * \file graph_attr_types.cc
 * \brief Graph node data structure.
 */
25
#include <nnvm/graph.h>
Tianqi Chen committed
26
#include <nnvm/op_attr_types.h>
27 28
#include <limits>

tqchen committed
29
namespace nnvm {
30

31
const IndexedGraph& Graph::indexed_graph() const {
32 33 34 35 36 37
  if (indexed_graph_ == nullptr) {
    indexed_graph_.reset(new IndexedGraph(*this));
  }
  return *indexed_graph_;
}

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
// a subgraph should not refer to any nodes with higher level
// where "level" refers to the nested depth of the subgraph
// e.g. the main graph is level 0
// subgraphs of the main graph is level 1
// subgraphs of the subgraphs of the main graph is level 2
static void SubgraphSanityCheck(const std::vector<std::shared_ptr<Symbol>> &subgraphs) {
  std::vector<const std::vector<nnvm::NodeEntry>*> curr_level;
  std::vector<const std::vector<nnvm::NodeEntry>*> next_level;
  std::unordered_map<nnvm::Node*, uint32_t> node2level;
  for (auto &subgraph : subgraphs)
    next_level.push_back(&subgraph->outputs);
  for (uint32_t level = 0; !next_level.empty(); ++level) {
    curr_level.swap(next_level);
    next_level.clear();
    for (const std::vector<NodeEntry> *graph_ptr : curr_level) {
      const std::vector<NodeEntry> &graph = *graph_ptr;
      DFSVisit(graph, [&next_level, &node2level, level](const NodePtr& n) {
        nnvm::Node *node = n.get();
        // if the node is visited, but on a different level, then check failed
        // if check failed here or before, we stop doing anything, but raise an error
        CHECK(!node2level.count(node) || node2level[node] == level)
          << "A subgraph should not depend on the outputs of nodes on higher levels";
        // otherwise, this node belongs to the current level
        node2level[node] = level;
        // subgraphs of current node belongs to next level
        for (const auto& subgraph : n->attrs.subgraphs) {
          next_level.push_back(&subgraph->outputs);
        }
      });
    }
  }
}

71 72 73 74
// implement constructor from graph
IndexedGraph::IndexedGraph(const Graph &g) {
  entry_rptr_.push_back(0);
  std::vector<size_t> inputs_rptr{0}, control_rptr{0};
75
  std::vector<std::shared_ptr<Symbol>> subgraphs;
76

77
  DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs]
Tianqi Chen committed
78
             (const NodePtr& n) {
79 80
      const auto& is_ghost = Op::GetAttr<TIsGhost>("TIsGhost");
      if (!n->is_variable() && is_ghost.get(n->op(), false)) return;
81 82
      CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
      uint32_t nid = static_cast<uint32_t>(nodes_.size());
83
      CHECK(n);
84 85
      for (const auto &subgraph : n->attrs.subgraphs)
        subgraphs.push_back(subgraph);
86 87 88
      // nodes_
      IndexedGraph::Node new_node;
      new_node.source = n.get();
89
      new_node.weak_ref = n;
90 91 92
      nodes_.emplace_back(std::move(new_node));
      // arg_nodes_
      if (n->is_variable()) {
93
        input_nodes_.push_back(nid);
94 95 96 97 98 99 100 101 102
      }
      // node2index_
      node2index_[n.get()] = nid;
      // entry rptr
      entry_rptr_.push_back(entry_rptr_.back() + n->num_outputs());
      // input entries
      for (const auto& e : n->inputs) {
        auto it = node2index_.find(e.node.get());
        CHECK(it != node2index_.end() && it->first == e.node.get());
103
        input_entries_.emplace_back(NodeEntry{it->second, e.index, e.version});
104 105 106 107
      }
      inputs_rptr.push_back(input_entries_.size());
      // control deps
      for (const auto& nptr : n->control_deps) {
108
        if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue;
109
        auto it = node2index_.find(nptr.get());
110
        CHECK(it != node2index_.end()) << "control dep not found in graph";
111 112 113 114
        control_deps_.push_back(it->second);
      }
      control_rptr.push_back(control_deps_.size());
  });
115 116
  if (!subgraphs.empty())
    SubgraphSanityCheck(subgraphs);
117

118 119 120 121 122
  for (const auto& e : g.outputs) {
    outputs_.emplace_back(NodeEntry{
        node2index_.at(e.node.get()), e.index, e.version});
  }

Tianqi Chen committed
123
  static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
124
  // setup array view
tqchen committed
125
  // input_entries_ and control_rptr must not change after this step.
126 127 128 129
  const NodeEntry* iptr = dmlc::BeginPtr(input_entries_);
  for (size_t nid = 0; nid < nodes_.size(); ++nid) {
    nodes_[nid].inputs = array_view<NodeEntry>(
        iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]);
130 131 132
    if (nodes_[nid].source->op() != nullptr &&
        fmutate_inputs.count(nodes_[nid].source->op())) {
      for (uint32_t i : fmutate_inputs[nodes_[nid].source->op()](nodes_[nid].source->attrs)) {
Tianqi Chen committed
133 134 135
        mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id);
      }
    }
136 137 138 139 140 141 142 143
  }
  const uint32_t* cptr = dmlc::BeginPtr(control_deps_);
  for (size_t nid = 0; nid < nodes_.size(); ++nid) {
    nodes_[nid].control_deps = array_view<uint32_t>(
        cptr + control_rptr[nid], cptr + control_rptr[nid + 1]);
  }
}

tqchen committed
144
}  // namespace nnvm