Commit cd49ed0e by tqchen Committed by Tianqi Chen

check in pass

parent ea811968
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <algorithm> #include <algorithm>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "./base.h"
#include "./node.h" #include "./node.h"
namespace nngraph { namespace nngraph {
...@@ -23,9 +24,12 @@ class Graph { ...@@ -23,9 +24,12 @@ class Graph {
public: public:
/*! \brief outputs of the computation graph. */ /*! \brief outputs of the computation graph. */
std::vector<NodeEntry> outputs; std::vector<NodeEntry> outputs;
/*! \brief attributes of a graph */ /*!
std::unordered_map<std::string, any> attrs; * \brief attributes of a graph
* Each attribute is immutable,
* and can be shared across multiple Instance of graph
*/
std::unordered_map<std::string, std::shared_ptr<const any> > attrs;
/*! /*!
* \brief perform a Post Order DFS visit to each node in the graph. * \brief perform a Post Order DFS visit to each node in the graph.
* This order is deterministic and is also topoligical sorted. * This order is deterministic and is also topoligical sorted.
......
...@@ -13,12 +13,11 @@ ...@@ -13,12 +13,11 @@
namespace nngraph { namespace nngraph {
/*! /*!
* \brief Index to the graph. * \brief Auxililary data structure to index a graph.
* Maps pointers to Node to consecutive integers. * It maps Nodes in the graph to consecutive integers node_id.
* * It also maps IndexedGraph::NodeEntry to consecutive integer entry_id.
* This is an Auxililary data structure that can be used * This allows storing properties of Node and NodeEntry into
* to iterate over the graph in a more efficient manner. * compact vector and quickly access them without resorting to hashmap.
* It also allows storing
*/ */
struct IndexedGraph { struct IndexedGraph {
public: public:
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include "./base.h"
#include "./op.h" #include "./op.h"
namespace nngraph { namespace nngraph {
......
/*!
* Copyright (c) 2016 by Contributors
* \file pass.h
* \brief Pass that can be applied to a graph.
*/
#ifndef NNGRAPH_PASS_H_
#define NNGRAPH_PASS_H_
#include <vector>
#include <functional>
#include "./base.h"
#include "./graph.h"
namespace nngraph {
/*!
* \brief A PassFunction is a basic "Operator on Graph"
* It takes a source graph
*
* A pass function can either change the graph structure of g,
* generating a new Graph, or add new attributes to the graph.
*
* \param src The graph to be transformed.
* \return The generated graph.
*/
typedef std::function<Graph (const Graph& src)> PassFunction;
/*!
* \brief Apply a series of pass transformations on g.
* \param src The graph to be transformed.
* \param pass The name of pass to be applied.
* \return The transformed graph
*/
Graph ApplyPass(const Graph& src,
const std::vector<std::string>& pass);
/*!
* \brief Registry entry for DataIterator factory functions.
*/
struct PassFunctionReg
: public dmlc::FunctionRegEntryBase<PassFunctionReg,
PassFunction> {
/*!
* \brief Whether the pass will change graph structure
* If this is false, the pass will only change attributes.
*/
bool change_graph{false};
/*! \brief dependencies on operator attributes */
std::vector<std::string> op_attr_dependency;
/*! \brief dependencies on attributes in the graph */
std::vector<std::string> graph_attr_dependency;
/*! \brief generated targets of graph attributes */
std::vector<std::string> graph_attr_targets;
/*!
* \brief set whether this pass will change graph structure.
* \param v the value to set
* \return reference to self.
*/
PassFunctionReg& set_change_graph(bool v) { // NOLINT(*)
change_graph = v;
return *this;
}
/*!
* \brief Declare this pass require operator attribute attr_name to be available.
* \param attr_name Name of the attribute.
* \return reference to self.
*/
PassFunctionReg& provide_graph_attr(const std::string& attr_name) { // NOLINT(*)
graph_attr_targets.push_back(attr_name);
return *this;
}
/*!
* \brief declare this pass require operator attribute attr_name to be available.
* \param attr_name Name of the attribute.
* \return reference to self.
*/
PassFunctionReg& depend_op_attr(const std::string& attr_name) { // NOLINT(*)
op_attr_dependency.push_back(attr_name);
return *this;
}
/*!
* \brief declare this pass require graph attribute attr_name to be available.
* \param attr_name Name of the attribute.
* \return reference to self.
*/
PassFunctionReg& depend_graph_attr(const std::string& attr_name) { // NOLINT(*)
graph_attr_dependency.push_back(attr_name);
return *this;
}
};
/*!
* \def NNGRAPH_REGISTER_PASS
* \brief Macro to register pass fuctions.
*
* \code
* // example of registering a shape inference pass
* NNGRAPH_REGISTER_PASS(InferShape)
* .describe("Shape Inference function, generate graph attributes")
* .provide_graph_attr("data_shape")
* .depend_graph_attr("indexed_graph")
* .depend_op_attr("infer_shape")
* .set_body([](const Graph& g) {
* // shape inference logic
* });
* \endcode
*/
#define NNGRAPH_REGISTER_PASS(name) \
DMLC_REGISTRY_REGISTER(::nngraph::PassFunctionReg, PassFunctionReg, name)
} // namespace nngraph
#endif // NNGRAPH_PASS_H_
...@@ -46,6 +46,7 @@ IndexedGraph::IndexedGraph(const Graph &g) { ...@@ -46,6 +46,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
}); });
// setup array view // setup array view
// input_entries_ and control_rptr must not change after this step.
const NodeEntry* iptr = dmlc::BeginPtr(input_entries_); const NodeEntry* iptr = dmlc::BeginPtr(input_entries_);
for (size_t nid = 0; nid < nodes_.size(); ++nid) { for (size_t nid = 0; nid < nodes_.size(); ++nid) {
nodes_[nid].inputs = array_view<NodeEntry>( nodes_[nid].inputs = array_view<NodeEntry>(
......
/*!
* Copyright (c) 2016 by Contributors
* \file pass.cc
* \brief Support for pass registry.
*/
#include <nngraph/pass.h>
#include <algorithm>
namespace dmlc {
// enable registry
DMLC_REGISTRY_ENABLE(nngraph::PassFunctionReg);
} // namespace dmlc
namespace nngraph {
const PassFunctionReg* FindPassDep(const std::string&attr_name) {
for (auto* r : dmlc::Registry<PassFunctionReg>::List()) {
for (auto& s : r->graph_attr_targets) {
if (s == attr_name) return r;
}
}
return nullptr;
}
Graph ApplyPass(const Graph& src,
const std::vector<std::string>& pass) {
std::vector<const PassFunctionReg*> fpass;
for (auto& name : pass) {
auto* reg = dmlc::Registry<PassFunctionReg>::Find(name);
CHECK(reg != nullptr)
<< "Cannot find pass " << name << " in the registry";
fpass.push_back(reg);
}
Graph g;
const Graph* s = &src;
for (auto r : fpass) {
for (auto& dep : r->graph_attr_dependency) {
if (s->attrs.count(dep) == 0) {
auto* pass_dep = FindPassDep(dep);
std::string msg;
if (pass_dep != nullptr) {
msg = " The attribute is provided by pass " + pass_dep->name;
}
LOG(FATAL) << "Graph attr dependency " << dep
<< " is required by pass " << r->name
<< " but is not available "
<< msg;
}
}
g = r->body(*s);
s = &g;
}
return g;
}
} // namespace nngraph
...@@ -36,7 +36,7 @@ void test_tuple() { ...@@ -36,7 +36,7 @@ void test_tuple() {
void test_graph() { void test_graph() {
nngraph::Graph g; nngraph::Graph g;
g.DFSVisit([](const std::shared_ptr<nngraph::Node>& n){ g.DFSVisit([](const std::shared_ptr<const nngraph::Node>& n){
}); });
} }
int main() { int main() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment