Commit bd20bfd8 by Tianqi Chen

[Pass] Check in infershape, move indexedgraph to graph.h (#15)

parent 94ae677a
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
namespace nnvm { namespace nnvm {
class IndexedGraph;
/*! /*!
* \brief Symbolic computation graph. * \brief Symbolic computation graph.
* This is the intermediate representation for optimization pass. * This is the intermediate representation for optimization pass.
...@@ -32,6 +34,145 @@ class Graph { ...@@ -32,6 +34,145 @@ class Graph {
* and can be shared across multiple Instance of graph * and can be shared across multiple Instance of graph
*/ */
std::unordered_map<std::string, std::shared_ptr<const any> > attrs; std::unordered_map<std::string, std::shared_ptr<const any> > attrs;
/*!
* \brief Get the attribute from attrs.
* \param attr_name the name of the attribute
* \return the reference to corresponding attribute
* \tparam T the type of the attribute.
*/
template<typename T>
inline const T& GetAttr(const std::string& attr_name);
/*!
* \brief get a indexed graph of current graph, if not exist, create it on demand
* \return The indexed graph.
* \sa IndexedGraph
*/
const IndexedGraph& indexed_graph();
private:
// internal structure of indexed graph
std::shared_ptr<const IndexedGraph> indexed_graph_;
};
/*!
* \brief Auxililary data structure to index a graph.
* It maps Nodes in the graph to consecutive integers node_id.
* It also maps IndexedGraph::NodeEntry to consecutive integer entry_id.
* This allows storing properties of Node and NodeEntry into
* compact vector and quickly access them without resorting to hashmap.
*
* The node_id and entry_rptr are the same as the JSON graph produced by SaveJSON Pass.
*/
class IndexedGraph {
public:
/*! \brief represents a data in the graph */
struct NodeEntry {
/*! \brief the source node id in the computation graph */
uint32_t node_id;
/*! \brief index of output from the source. */
uint32_t index;
/*!
* \brief compare equality
* \param other the other entry to compare
* \return whether two entries equals to each other
*/
inline bool operator==(const NodeEntry& other) const {
return node_id == other.node_id && index == other.index;
}
};
/*! \brief Node data structure in IndexedGraph */
struct Node {
/*! \brief pointer to the source node */
const nnvm::Node* source;
/*! \brief inputs to the node */
array_view<NodeEntry> inputs;
/*! \brief control flow dependencies to the node */
array_view<uint32_t> control_deps;
};
/*! \return number of nodes in the graph */
inline size_t num_nodes() const {
return nodes_.size();
}
/*! \return total number of NodeEntry in the graph */
inline size_t num_node_entries() const {
return entry_rptr_.back();
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given IndexedGraph::NodeEntry
* \param node_id The node index
* \param index the output index
* \return the unique index.
*/
inline uint32_t entry_id(uint32_t node_id, uint32_t index) const {
return entry_rptr_[node_id] + index;
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given IndexedGraph::NodeEntry
* \param e The entry to query for index.
* \return the unique index.
*/
inline uint32_t entry_id(const NodeEntry& e) const {
return entry_rptr_[e.node_id] + e.index;
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given NodeEntry.
* \param e The entry to query for index.
* \return the unique index.
*/
inline uint32_t entry_id(const nnvm::NodeEntry& e) const {
return entry_rptr_[node_id(e.node.get())] + e.index;
}
/*!
* \brief Get the corresponding node id for a given Node in the IndexedGraph.
* \param node The Node to query for index.
* \return the node index.
*/
inline uint32_t node_id(const nnvm::Node* node) const {
return node2index_.at(node);
}
/*!
* \brief Get the corresponding Node structure for a given node_id.
* \param node_id The node id
* \return const reference to the corresponding IndexedGraph::Node
*/
inline const Node& operator[](uint32_t node_id) const {
return nodes_[node_id];
}
/*!
* \brief Get the corresponding Node structure
* \param node The pointer to the Node structure
* \return const reference to the corresponding IndexedGraph::Node
*/
inline const Node& operator[](const nnvm::Node* node) const {
return nodes_[node_id(node)];
}
/*! \return list of argument nodes */
inline const std::vector<uint32_t>& arg_nodes() const {
return arg_nodes_;
}
private:
friend class Graph;
/*!
* \brief Constructor an IndexedGraph from normal Graph
* \param other The source graph.
*/
explicit IndexedGraph(const Graph& other);
// node pointers in CSR structure.
std::vector<Node> nodes_;
// index to argument nodes
std::vector<uint32_t> arg_nodes_;
// mapping from node to index.
std::unordered_map<const nnvm::Node*, uint32_t> node2index_;
// CSR pointer of node entries
std::vector<size_t> entry_rptr_;
// space to store input entries of each
std::vector<NodeEntry> input_entries_;
// control flow dependencies
std::vector<uint32_t> control_deps_;
}; };
/*! /*!
...@@ -45,6 +186,14 @@ template<typename FVisit> ...@@ -45,6 +186,14 @@ template<typename FVisit>
inline void DFSVisit(const std::vector<NodeEntry>& heads, FVisit fvisit); inline void DFSVisit(const std::vector<NodeEntry>& heads, FVisit fvisit);
// inline function implementations // inline function implementations
template<typename T>
inline const T& Graph::GetAttr(const std::string& attr_name) {
auto it = attrs.find(attr_name);
CHECK(it != attrs.end())
<< "Cannot find attribute " << attr_name << " in the graph";
return nnvm::get<T>(*it->second);
}
template <typename GNode, typename HashType, template <typename GNode, typename HashType,
typename FVisit, typename HashFunc, typename FVisit, typename HashFunc,
typename InDegree, typename GetInput> typename InDegree, typename GetInput>
......
...@@ -7,120 +7,37 @@ ...@@ -7,120 +7,37 @@
#define NNVM_GRAPH_ATTR_TYPES_H_ #define NNVM_GRAPH_ATTR_TYPES_H_
#include <vector> #include <vector>
#include <unordered_map> #include <string>
#include "./graph.h" #include "./tuple.h"
namespace nnvm { namespace nnvm {
/*! /*!
* \brief Auxililary data structure to index a graph. * \brief The result holder of JSON serializer
* It maps Nodes in the graph to consecutive integers node_id. *
* It also maps IndexedGraph::NodeEntry to consecutive integer entry_id. * \note Stored under ret.attrs["json"], provided by Pass "SaveJSON"
* This allows storing properties of Node and NodeEntry into
* compact vector and quickly access them without resorting to hashmap. * \code
* Graph ret = ApplyPass(src_graph, {"SaveJSON"});
* const JSONString& json = ret.GetAttr<JSONString>("shape");
* \endcode
*/ */
struct IndexedGraph { using JSONString = std::string;
public:
/*! \brief represents a data in the graph */
struct NodeEntry {
/*! \brief the source node id in the computation graph */
uint32_t node_id;
/*! \brief index of output from the source. */
uint32_t index;
/*!
* \brief compare equality
* \param other the other entry to compare
* \return whether two entries equals to each other
*/
inline bool operator==(const NodeEntry& other) const {
return node_id == other.node_id && index == other.index;
}
};
/*! \brief Node data structure in IndexedGraph */
struct Node {
/*! \brief pointer to the source node */
const nnvm::Node* source;
/*! \brief inputs to the node */
array_view<NodeEntry> inputs;
/*! \brief control flow dependencies to the node */
array_view<uint32_t> control_deps;
};
/*! \return number of nodes in the graph */
inline size_t num_nodes() const {
return nodes_.size();
}
/*! \return total number of NodeEntry in the graph */
inline size_t num_node_entries() const {
return entry_rptr_.back();
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given IndexedGraph::NodeEntry
* \param e The entry to query for index.
* \return the unique index.
*/
inline uint32_t entry_id(const NodeEntry& e) const {
return entry_rptr_[e.node_id] + e.index;
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given NodeEntry.
* \param e The entry to query for index.
* \return the unique index.
*/
inline uint32_t entry_id(const nnvm::NodeEntry& e) const {
return entry_rptr_[node_id(e.node.get())] + e.index;
}
/*!
* \brief Get the corresponding node id for a given Node in the IndexedGraph.
* \param node The Node to query for index.
* \return the node index.
*/
inline uint32_t node_id(const nnvm::Node* node) const {
return node2index_.at(node);
}
/*!
* \brief Get the corresponding Node structure for a given node_id.
* \param node_id The node id
* \return const reference to the corresponding IndexedGraph::Node
*/
inline const Node& operator[](uint32_t node_id) const {
return nodes_[node_id];
}
/*!
* \brief Get the corresponding Node structure
* \param node The pointer to the Node structure
* \return const reference to the corresponding IndexedGraph::Node
*/
inline const Node& operator[](const nnvm::Node* node) const {
return nodes_[node_id(node)];
}
/*! \return list of argument nodes */
inline const std::vector<uint32_t>& arg_nodes() const {
return arg_nodes_;
}
/*!
* \brief Constructor an IndexedGraph from normal Graph
* \param other The source graph.
*/
explicit IndexedGraph(const Graph& other);
// disallow copy assign
IndexedGraph(const IndexedGraph& other) = delete;
private: /*!
// node pointers in CSR structure. * \brief The result holder of shape of each NodeEntry in the graph.
std::vector<Node> nodes_; * \note Stored under graph.attrs["shape"], provided by Pass "InferShape"
// index to argument nodes *
std::vector<uint32_t> arg_nodes_; * \code
// mapping from node to index. * Graph g = ApplyPass(src_graph, {"InferShape"});
std::unordered_map<const nnvm::Node*, uint32_t> node2index_; * const ShapeVector& shapes = g.GetAttr<ShapeVector>("shape");
// CSR pointer of node entries * // get shape by entry id
std::vector<size_t> entry_rptr_; * TShape entry_shape = shapes[g.indexed_graph().entry_id(my_entry)];
// space to store input entries of each * \endcode
std::vector<NodeEntry> input_entries_; *
// control flow dependencies * \sa FInferShape
std::vector<uint32_t> control_deps_; */
}; using ShapeVector = std::vector<TShape>;
} // namespace nnvm } // namespace nnvm
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include <functional> #include <functional>
#include "./base.h"
#include "./tuple.h"
namespace nnvm { namespace nnvm {
...@@ -39,6 +41,7 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs ...@@ -39,6 +41,7 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs
/*! /*!
* \brief Check whether operator will mutate k-th input. * \brief Check whether operator will mutate k-th input.
* \param attrs The attributes of the node.
* \param index The input index * \param index The input index
* \return Whether this operator will mutate index-th input. * \return Whether this operator will mutate index-th input.
* *
...@@ -47,6 +50,26 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs ...@@ -47,6 +50,26 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs
*/ */
using FMutateInput = std::function<bool (const NodeAttrs& attrs, uint32_t index)>; using FMutateInput = std::function<bool (const NodeAttrs& attrs, uint32_t index)>;
/*!
* \brief Shape inference function.
* Update the shapes given the input shape information.
* TShape.ndim() == 0 means the shape is still unknown.
*
* \param attrs The attributes of the node.
* \param in_shapes Array of shapes from the inputs.
* \param out_shapes Array of shapes from the outputs.
*
* \return Whether all the shapes are known.
*
* \note Register under "FInferShape",
* by default do not update any shapes.
*
* FInferShape is needed by shape inference
*/
using FInferShape = std::function<bool (const NodeAttrs& attrs,
array_view<TShape*> in_shapes,
array_view<TShape*> out_shapes)>;
} // namespace nnvm } // namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_ #endif // NNVM_OP_ATTR_TYPES_H_
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <type_traits> #include <type_traits>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include "./base.h"
namespace nnvm { namespace nnvm {
...@@ -179,7 +180,23 @@ class Tuple { ...@@ -179,7 +180,23 @@ class Tuple {
inline const ValueType& operator[](index_t i) const { inline const ValueType& operator[](index_t i) const {
return begin()[i]; return begin()[i];
} }
/*!
* \brief Save Tuple to JSON.
* \param writer JSONWriter
*/
inline void Save(dmlc::JSONWriter* writer) const {
std::vector<ValueType> tmp(begin(), end());
writer->Write(tmp);
}
/*!
* \brief Load Tuple from JSON.
* \param reader JSONReader
*/
inline void Load(dmlc::JSONReader* reader) {
std::vector<ValueType> tmp;
reader->Read(&tmp);
this->assign(tmp.begin(), tmp.end());
}
/*! /*!
* \brief allow output string of tuple to ostream * \brief allow output string of tuple to ostream
* \param os the output stream * \param os the output stream
...@@ -287,6 +304,8 @@ class TShape : public Tuple<index_t> { ...@@ -287,6 +304,8 @@ class TShape : public Tuple<index_t> {
public: public:
// inheritate other constructors from Tuple // inheritate other constructors from Tuple
using Tuple<index_t>::Tuple; using Tuple<index_t>::Tuple;
/*! \brief default constructor */
TShape() = default;
/*! /*!
* \brief copy constructor of TShape * \brief copy constructor of TShape
* \param s source shape. * \param s source shape.
......
...@@ -3,11 +3,18 @@ ...@@ -3,11 +3,18 @@
* \file graph_attr_types.cc * \file graph_attr_types.cc
* \brief Graph node data structure. * \brief Graph node data structure.
*/ */
#include <nnvm/graph_attr_types.h> #include <nnvm/graph.h>
#include <limits> #include <limits>
namespace nnvm { namespace nnvm {
const IndexedGraph& Graph::indexed_graph() {
if (indexed_graph_ == nullptr) {
indexed_graph_.reset(new IndexedGraph(*this));
}
return *indexed_graph_;
}
// implement constructor from graph // implement constructor from graph
IndexedGraph::IndexedGraph(const Graph &g) { IndexedGraph::IndexedGraph(const Graph &g) {
entry_rptr_.push_back(0); entry_rptr_.push_back(0);
......
// Copyright (c) 2016 by Contributors // Copyright (c) 2016 by Contributors
// This is an example on how we can register operator information to NNVM // This is an example on how we can register operator information to NNVM
#include <nnvm/base.h>
#include <nnvm/op.h> #include <nnvm/op.h>
#include <nnvm/op_attr_types.h> #include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <utility> #include <utility>
namespace myproject {
using nnvm::FListInputNames; using nnvm::FListInputNames;
using nnvm::FMutateInput; using nnvm::FMutateInput;
using nnvm::FInferShape;
using nnvm::NodeAttrs; using nnvm::NodeAttrs;
using nnvm::TShape;
using nnvm::array_view;
// simply return the shape as same
inline bool SameShape(const NodeAttrs& attrs,
array_view<TShape*> ishape,
array_view<TShape*> oshape) {
if (ishape.size() == 0 || ishape[0]->ndim() == 0) return false;
for (TShape* pshape : oshape) {
*pshape = *ishape[0];
}
for (TShape* pshape : ishape) {
*pshape = *ishape[0];
}
return true;
}
NNVM_REGISTER_OP(add) NNVM_REGISTER_OP(add)
.describe("add two data together") .describe("add two data together")
.set_num_inputs(2); .set_num_inputs(2)
.attr<FInferShape>("FInferShape", SameShape);
NNVM_REGISTER_OP(__add_symbol__) NNVM_REGISTER_OP(__add_symbol__)
.describe("Alias of add") .describe("Alias of add")
...@@ -20,7 +42,8 @@ NNVM_REGISTER_OP(__add_symbol__) ...@@ -20,7 +42,8 @@ NNVM_REGISTER_OP(__add_symbol__)
NNVM_REGISTER_OP(exp) NNVM_REGISTER_OP(exp)
.describe("take exponmential") .describe("take exponmential")
.set_num_inputs(1) .set_num_inputs(1)
.attr("inplace_pair", std::make_pair(0, 0)); .attr("inplace_pair", std::make_pair(0, 0))
.attr<FInferShape>("FInferShape", SameShape);
NNVM_REGISTER_OP(conv2d) NNVM_REGISTER_OP(conv2d)
...@@ -39,3 +62,5 @@ NNVM_REGISTER_OP(assign) ...@@ -39,3 +62,5 @@ NNVM_REGISTER_OP(assign)
.attr<FMutateInput>("FMutateInput", [](const NodeAttrs& attrs, uint32_t index) { .attr<FMutateInput>("FMutateInput", [](const NodeAttrs& attrs, uint32_t index) {
return index == 0; return index == 0;
}); });
} // namespace myproject
/*!
* Copyright (c) 2016 by Contributors
* \file infer_shape.cc
* \brief Inference the shapes given
*/
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
namespace nnvm {
namespace pass {
Graph InferShape(const Graph& src) {
Graph ret = src;
const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape = Op::GetAttr<FInferShape>("FInferShape");
// reshape shape vector
ShapeVector rshape(idx.num_node_entries());
// temp space for shape inference.
std::vector<TShape*> ishape, oshape;
// number of completed nodes
size_t num_known = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
ishape.resize(inode.inputs.size());
for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = &rshape[idx.entry_id(inode.inputs[i])];
}
oshape.resize(inode.source->num_outputs());
for (uint32_t i = 0; i < oshape.size(); ++i) {
oshape[i] = &rshape[idx.entry_id(nid, i)];
}
if (finfer_shape.count(inode.source->op)) {
num_known +=
finfer_shape[inode.source->op](inode.source->attrs, ishape, oshape);
}
}
// set the shapes
ret.attrs["shape"] = std::make_shared<any>(std::move(rshape));
// number of nodes who knows the shape.
ret.attrs["shape_num_known_nodes"] = std::make_shared<any>(num_known);
return ret;
}
} // namespace pass
} // namespace nnvm
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <nnvm/op_attr_types.h> #include <nnvm/op_attr_types.h>
namespace nnvm { namespace nnvm {
namespace pass {
template<typename T> template<typename T>
inline T get_with_default(const std::unordered_map<Node*, T> &map, inline T get_with_default(const std::unordered_map<Node*, T> &map,
...@@ -139,4 +140,5 @@ NNVM_REGISTER_PASS(OrderMutation) ...@@ -139,4 +140,5 @@ NNVM_REGISTER_PASS(OrderMutation)
.set_body(OrderMutation) .set_body(OrderMutation)
.set_change_graph(true); .set_change_graph(true);
} // namespace pass
} // namespace nnvm } // namespace nnvm
...@@ -120,6 +120,7 @@ struct JSONNode { ...@@ -120,6 +120,7 @@ struct JSONNode {
struct JSONGraph { struct JSONGraph {
std::vector<JSONNode> nodes; std::vector<JSONNode> nodes;
std::vector<uint32_t> arg_nodes; std::vector<uint32_t> arg_nodes;
std::vector<uint32_t> node_row_ptr;
std::vector<JSONNode::Entry> heads; std::vector<JSONNode::Entry> heads;
std::unordered_map<std::string, std::shared_ptr<const any> > attrs; std::unordered_map<std::string, std::shared_ptr<const any> > attrs;
...@@ -127,6 +128,7 @@ struct JSONGraph { ...@@ -127,6 +128,7 @@ struct JSONGraph {
writer->BeginObject(); writer->BeginObject();
writer->WriteObjectKeyValue("nodes", nodes); writer->WriteObjectKeyValue("nodes", nodes);
writer->WriteObjectKeyValue("arg_nodes", arg_nodes); writer->WriteObjectKeyValue("arg_nodes", arg_nodes);
writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr);
writer->WriteObjectKeyValue("heads", heads); writer->WriteObjectKeyValue("heads", heads);
if (attrs.size() != 0) { if (attrs.size() != 0) {
writer->WriteObjectKeyValue("attrs", attrs); writer->WriteObjectKeyValue("attrs", attrs);
...@@ -140,6 +142,7 @@ struct JSONGraph { ...@@ -140,6 +142,7 @@ struct JSONGraph {
helper.DeclareField("nodes", &nodes); helper.DeclareField("nodes", &nodes);
helper.DeclareField("arg_nodes", &arg_nodes); helper.DeclareField("arg_nodes", &arg_nodes);
helper.DeclareField("heads", &heads); helper.DeclareField("heads", &heads);
helper.DeclareOptionalField("node_row_ptr", &node_row_ptr);
helper.DeclareOptionalField("attrs", &attrs); helper.DeclareOptionalField("attrs", &attrs);
helper.ReadAllFields(reader); helper.ReadAllFields(reader);
} }
...@@ -188,6 +191,7 @@ Graph LoadJSON(const Graph& src) { ...@@ -188,6 +191,7 @@ Graph LoadJSON(const Graph& src) {
Graph SaveJSON(const Graph& src) { Graph SaveJSON(const Graph& src) {
JSONGraph jgraph; JSONGraph jgraph;
std::unordered_map<Node*, uint32_t> node2index; std::unordered_map<Node*, uint32_t> node2index;
jgraph.node_row_ptr.push_back(0);
DFSVisit(src.outputs, [&node2index, &jgraph](const NodePtr& n) { DFSVisit(src.outputs, [&node2index, &jgraph](const NodePtr& n) {
uint32_t nid = static_cast<uint32_t>(jgraph.nodes.size()); uint32_t nid = static_cast<uint32_t>(jgraph.nodes.size());
node2index[n.get()] = nid; node2index[n.get()] = nid;
...@@ -204,6 +208,8 @@ Graph SaveJSON(const Graph& src) { ...@@ -204,6 +208,8 @@ Graph SaveJSON(const Graph& src) {
for (const NodePtr& c : n->control_deps) { for (const NodePtr& c : n->control_deps) {
jnode.control_deps.push_back(node2index.at(c.get())); jnode.control_deps.push_back(node2index.at(c.get()));
} }
jgraph.node_row_ptr.push_back(
jgraph.node_row_ptr.back() + n->num_outputs());
jgraph.nodes.emplace_back(std::move(jnode)); jgraph.nodes.emplace_back(std::move(jnode));
}); });
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment