Commit ea811968 by tqchen Committed by Tianqi Chen

[GRAPH] checkin the constructor of indexed graph

parent 5dc70763
......@@ -96,6 +96,10 @@ struct IndexedGraph {
inline const Node& operator[](const nngraph::Node* node) const {
return nodes_[node_id(node)];
}
/*! \return list of argument nodes */
inline const std::vector<uint32_t>& arg_nodes() const {
return arg_nodes_;
}
/*!
* \brief Constructor an IndexedGraph from normal Graph
* \param other The source graph.
......@@ -107,6 +111,8 @@ struct IndexedGraph {
private:
// node pointers in CSR structure.
std::vector<Node> nodes_;
// index to argument nodes
std::vector<uint32_t> arg_nodes_;
// mapping from node to index.
std::unordered_map<const nngraph::Node*, uint32_t> node2index_;
// CSR pointer of node entries
......
......@@ -69,6 +69,8 @@ class Node {
* \return whether node is placeholder input variable
*/
inline bool is_variable() const;
/*! \return number of outputs from this node */
inline uint32_t num_outputs() const;
/*!
* \brief create a new empty shared_ptr of Node.
* \return a created empty node.
......@@ -81,6 +83,15 @@ inline bool Node::is_variable() const {
return this->op == nullptr;
}
inline uint32_t Node::num_outputs() const {
if (is_variable()) return 1;
if (this->op->num_outputs >= 0) {
return static_cast<uint32_t>(this->op->num_outputs);
} else {
return this->op->get_num_outputs(*this);
}
}
} // namespace nngraph
#endif // NNGRAPH_NODE_H_
/*!
* Copyright (c) 2016 by Contributors
* \file graph_attr_types.cc
* \brief Graph node data structure.
*/
#include <nngraph/graph_attr_types.h>
#include <limits>
namespace nngraph {
// implement constructor from graph
IndexedGraph::IndexedGraph(const Graph &g) {
entry_rptr_.push_back(0);
std::vector<size_t> inputs_rptr{0}, control_rptr{0};
g.DFSVisit([this, &inputs_rptr, &control_rptr]
(const std::shared_ptr<nngraph::Node>& n) {
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
uint32_t nid = static_cast<uint32_t>(nodes_.size());
// nodes_
IndexedGraph::Node new_node;
new_node.source = n.get();
nodes_.emplace_back(std::move(new_node));
// arg_nodes_
if (n->is_variable()) {
arg_nodes_.push_back(nid);
}
// node2index_
node2index_[n.get()] = nid;
// entry rptr
entry_rptr_.push_back(entry_rptr_.back() + n->num_outputs());
// input entries
for (const auto& e : n->inputs) {
auto it = node2index_.find(e.node.get());
CHECK(it != node2index_.end() && it->first == e.node.get());
input_entries_.emplace_back(NodeEntry{it->second, e.index});
}
inputs_rptr.push_back(input_entries_.size());
// control deps
for (const auto& nptr : n->control_deps) {
auto it = node2index_.find(nptr.get());
CHECK(it != node2index_.end() && it->first == nptr.get());
control_deps_.push_back(it->second);
}
control_rptr.push_back(control_deps_.size());
});
// setup array view
const NodeEntry* iptr = dmlc::BeginPtr(input_entries_);
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
nodes_[nid].inputs = array_view<NodeEntry>(
iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]);
}
const uint32_t* cptr = dmlc::BeginPtr(control_deps_);
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
nodes_[nid].control_deps = array_view<uint32_t>(
cptr + control_rptr[nid], cptr + control_rptr[nid + 1]);
}
}
} // namespace nngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment