Commit 5dc70763 by tqchen Committed by Tianqi Chen

checkin dfs visit from min

parent 4baf150d
......@@ -10,6 +10,7 @@
#include <dmlc/any.h>
#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <dmlc/array_view.h>
namespace nngraph {
......@@ -17,6 +18,13 @@ namespace nngraph {
using any = dmlc::any;
/*!
* \brief array_veiw type
* \tparam ValueType The value content of array view.
*/
template<typename ValueType>
using array_view = dmlc::array_view<ValueType>;
/*!
* \brief get reference of type T stored in src.
* \param src The source container
* \return the reference to the type.
......
......@@ -8,7 +8,10 @@
#include <vector>
#include <string>
#include <utility>
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include "./node.h"
namespace nngraph {
......@@ -22,8 +25,75 @@ class Graph {
std::vector<NodeEntry> outputs;
/*! \brief attributes of a graph */
std::unordered_map<std::string, any> attrs;
/*!
* \brief perform a Post Order DFS visit to each node in the graph.
* This order is deterministic and is also topoligical sorted.
* \param fvisit a function of type std::function<void(const std::shared_ptr<Node>&)>
* \tparam FVisit The function type to perform the visit.
*/
template<typename FVisit>
inline void DFSVisit(FVisit fvisit) const;
};
// inline function implementations
template <typename GNode, typename HashType,
typename FVisit, typename HashFunc,
typename InDegree, typename GetInput>
void PostOrderDFSVisit(const std::vector<GNode>& heads,
FVisit fvisit,
HashFunc hash,
InDegree indegree,
GetInput getinput) {
std::vector<std::pair<GNode, uint32_t> > stack;
std::unordered_set<HashType> visited;
for (auto& head : heads) {
HashType head_hash = hash(head);
if (visited.count(head_hash) == 0) {
stack.push_back(std::make_pair(head, 0));
visited.insert(head_hash);
}
while (!stack.empty()) {
std::pair<GNode, uint32_t>& back = stack.back();
if (back.second == indegree(back.first)) {
fvisit(back.first);
stack.pop_back();
} else {
const GNode& input = getinput(back.first, back.second++);
HashType input_hash = hash(input);
if (visited.count(input_hash) == 0) {
stack.push_back(std::make_pair(input, 0));
visited.insert(input_hash);
}
}
}
}
}
template<typename FVisit>
inline void Graph::DFSVisit(FVisit fvisit) const {
typedef const std::shared_ptr<Node>* GNode;
std::vector<GNode> head_nodes(outputs.size());
std::transform(outputs.begin(), outputs.end(), head_nodes.begin(),
[](const NodeEntry& e)->GNode {
return &e.node;
});
PostOrderDFSVisit<GNode, Node*>(
head_nodes,
[fvisit](GNode n) { fvisit(*n); }, // FVisit
[](GNode n)->Node* { return n->get(); }, // HashFunc
[](GNode n)->uint32_t { // InDegree
return (*n)->inputs.size() + (*n)->control_deps.size();
},
[](GNode n, uint32_t index)->GNode { // GetInput
if (index < (*n)->inputs.size()) {
return &(*n)->inputs.at(index).node;
} else {
return &(*n)->control_deps.at(index - (*n)->inputs.size());
}
});
}
} // namespace nngraph
#endif // NNGRAPH_GRAPH_H_
/*!
* Copyright (c) 2016 by Contributors
* \file graph_attr_types.h
* \brief Data structures that can appear in graph attributes.
*/
#ifndef NNGRAPH_GRAPH_ATTR_TYPES_H_
#define NNGRAPH_GRAPH_ATTR_TYPES_H_
#include <vector>
#include <unordered_map>
#include "./graph.h"
namespace nngraph {
/*!
* \brief Index to the graph.
* Maps pointers to Node to consecutive integers.
*
* This is an Auxililary data structure that can be used
* to iterate over the graph in a more efficient manner.
* It also allows storing
*/
struct IndexedGraph {
public:
/*! \brief represents a data in the graph */
struct NodeEntry {
/*! \brief the source node id in the computation graph */
uint32_t node_id;
/*! \brief index of output from the source. */
uint32_t index;
/*!
* \brief compare equality
* \param other the other entry to compare
* \return whether two entries equals to each other
*/
inline bool operator==(const NodeEntry& other) const {
return node_id == other.node_id && index == other.index;
}
};
/*! \brief Node data structure in IndexedGraph */
struct Node {
/*! \brief pointer to the source node */
const nngraph::Node* source;
/*! \brief inputs to the node */
array_view<NodeEntry> inputs;
/*! \brief control flow dependencies to the node */
array_view<uint32_t> control_deps;
};
/*! \return number of nodes in the graph */
inline size_t num_nodes() const {
return nodes_.size();
}
/*! \return total number of NodeEntry in the graph */
inline size_t num_node_entries() const {
return entry_rptr_.back();
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given IndexedGraph::NodeEntry
* \param e The entry to query for index.
* \return the unique index.
*/
inline uint32_t entry_id(const NodeEntry& e) const {
return entry_rptr_[e.node_id] + e.index;
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given NodeEntry.
* \param e The entry to query for index.
* \return the unique index.
*/
inline uint32_t entry_id(const nngraph::NodeEntry& e) const {
return entry_rptr_[node_id(e.node.get())] + e.index;
}
/*!
* \brief Get the corresponding node id for a given Node in the IndexedGraph.
* \param node The Node to query for index.
* \return the node index.
*/
inline uint32_t node_id(const nngraph::Node* node) const {
return node2index_.at(node);
}
/*!
* \brief Get the corresponding Node structure for a given node_id.
* \param node_id The node id
* \return const reference to the corresponding IndexedGraph::Node
*/
inline const Node& operator[](uint32_t node_id) const {
return nodes_[node_id];
}
/*!
* \brief Get the corresponding Node structure
* \param node The pointer to the Node structure
* \return const reference to the corresponding IndexedGraph::Node
*/
inline const Node& operator[](const nngraph::Node* node) const {
return nodes_[node_id(node)];
}
/*!
* \brief Constructor an IndexedGraph from normal Graph
* \param other The source graph.
*/
explicit IndexedGraph(const Graph& other);
// disallow copy assign
IndexedGraph(const IndexedGraph& other) = delete;
private:
// node pointers in CSR structure.
std::vector<Node> nodes_;
// mapping from node to index.
std::unordered_map<const nngraph::Node*, uint32_t> node2index_;
// CSR pointer of node entries
std::vector<size_t> entry_rptr_;
// space to store input entries of each
std::vector<NodeEntry> input_entries_;
// control flow dependencies
std::vector<uint32_t> control_deps_;
};
} // namespace nngraph
#endif // NNGRAPH_GRAPH_ATTR_TYPES_H_
......@@ -60,7 +60,7 @@ class Tuple {
this->swap(src);
}
/*!
* \brief construct an Tuple to fill the value with v.
* \param ndim the number of dimension of the Tuple
* \param v The value to fill.
*/
......
......@@ -2,6 +2,7 @@
#include <nngraph/op.h>
#include <nngraph/graph.h>
#include <nngraph/tuple.h>
#include <nngraph/graph_attr_types.h>
#include <string>
void test_op() {
......@@ -32,6 +33,12 @@ void test_tuple() {
CHECK((s == TShape{1, 2, 3}));
}
void test_graph() {
nngraph::Graph g;
g.DFSVisit([](const std::shared_ptr<nngraph::Node>& n){
});
}
int main() {
test_tuple();
return 0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment