Commit 93534ece by Tianqi Chen

[SYMBOLIC] Add symbolic API (#2)

* [SYMBOLIC] Add symbolic API

* Update Testcase to nnvm
parent 5d407324
......@@ -14,11 +14,13 @@
#include <unordered_set>
#include "./base.h"
#include "./node.h"
#include "./symbolic.h"
namespace nnvm {
/*!
* \brief Symbolic computation graph.
* This is the intermediate representation for optimization pass.
*/
class Graph {
public:
......@@ -30,16 +32,18 @@ class Graph {
* and can be shared across multiple Instance of graph
*/
std::unordered_map<std::string, std::shared_ptr<const any> > attrs;
/*!
* \brief perform a Post Order DFS visit to each node in the graph.
* This order is deterministic and is also topoligical sorted.
* \param fvisit a function of type std::function<void(const std::shared_ptr<Node>&)>
* \tparam FVisit The function type to perform the visit.
*/
template<typename FVisit>
inline void DFSVisit(FVisit fvisit) const;
};
/*!
* \brief perform a Post Order DFS visit to each node in the graph.
* This order is deterministic and is also topoligical sorted.
* \param heads The heads in the graph.
* \param fvisit a function of type std::function<void(const std::shared_ptr<Node>&)>
* \tparam FVisit The function type to perform the visit.
*/
template<typename FVisit>
inline void DFSVisit(const std::vector<NodeEntry>& heads, FVisit fvisit);
// inline function implementations
template <typename GNode, typename HashType,
typename FVisit, typename HashFunc,
......@@ -75,10 +79,11 @@ void PostOrderDFSVisit(const std::vector<GNode>& heads,
}
template<typename FVisit>
inline void Graph::DFSVisit(FVisit fvisit) const {
inline void DFSVisit(const std::vector<NodeEntry>& heads,
FVisit fvisit) {
typedef const std::shared_ptr<Node>* GNode;
std::vector<GNode> head_nodes(outputs.size());
std::transform(outputs.begin(), outputs.end(), head_nodes.begin(),
std::vector<GNode> head_nodes(heads.size());
std::transform(heads.begin(), heads.end(), head_nodes.begin(),
[](const NodeEntry& e)->GNode {
return &e.node;
});
......
......@@ -72,6 +72,8 @@ class Node {
inline bool is_variable() const;
/*! \return number of outputs from this node */
inline uint32_t num_outputs() const;
/*! \return number of inputs from this node */
inline uint32_t num_inputs() const;
/*!
* \brief create a new empty shared_ptr of Node.
* \return a created empty node.
......@@ -86,10 +88,19 @@ inline bool Node::is_variable() const {
inline uint32_t Node::num_outputs() const {
if (is_variable()) return 1;
if (this->op->num_outputs >= 0) {
return static_cast<uint32_t>(this->op->num_outputs);
if (this->op->get_num_outputs == nullptr) {
return this->op->num_outputs;
} else {
return this->op->get_num_outputs(*this);
return this->op->get_num_outputs(this->attrs);
}
}
inline uint32_t Node::num_inputs() const {
if (is_variable()) return 1;
if (this->op->get_num_inputs == nullptr) {
return this->op->num_inputs;
} else {
return this->op->get_num_inputs(this->attrs);
}
}
......
......@@ -10,6 +10,7 @@
#include <vector>
#include <utility>
#include <typeinfo>
#include <limits>
#include <functional>
#include "./base.h"
......@@ -22,8 +23,8 @@ template<typename ValueType>
class OpMap;
class OpRegistryEntry;
/*! \brief constant to indicate variable length inout and output */
static const int kVarg = -1;
/*! \brief constant to indicate it take any length of positional inputs */
static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
/*!
* \brief Operator structure.
......@@ -79,23 +80,31 @@ class Op {
/*!
* \brief number of inputs to the operator,
* -1 means it is variable length
* When get_num_inputs is presented,
* the number will be decided by get_num_inputs instead.
* \sa get_num_inputs
*/
int num_inputs = 0;
uint32_t num_inputs = 1;
/*!
* \brief number of outputs of the operator
* -1 means it is variable length
* When get_num_outputs is presented.
* The number of outputs will be decided by
* get_num_outputs function
* \sa get_num_outputs
*/
int num_outputs = 1;
uint32_t num_outputs = 1;
/*!
* \brief get number of outputs given information about the node.
* This is only valid when num_outputs == -1.
* \param node The constructed node.
* \param attrs The attribute of the node
* \return number of outputs.
*/
int (*get_num_outputs)(const Node& node) = nullptr;
uint32_t (*get_num_outputs)(const NodeAttrs& attrs) = nullptr;
/*!
* \brief get number of inputs given information about the node.
* \param attrs The attribute of the node
* \return number of inputs
*/
uint32_t (*get_num_inputs)(const NodeAttrs& attrs) = nullptr;
/*!
* \brief Attribute parser to parse the NodeAttrs information.
*
......@@ -143,19 +152,25 @@ class Op {
* \param n The number of inputs to be set.
* \return reference to self.
*/
inline Op& set_num_inputs(int n); // NOLINT(*)
inline Op& set_num_inputs(uint32_t n); // NOLINT(*)
/*!
* \brief Set the get_num_outputs function.
* \param fn The function to be set.
* \return reference to self.
*/
inline Op& set_num_inputs(uint32_t (*fn)(const NodeAttrs& attr)); // NOLINT(*)
/*!
* \brief Set the num_outputs
* \param n The number of outputs to be set.
* \return reference to self.
*/
inline Op& set_num_outputs(int n); // NOLINT(*)
inline Op& set_num_outputs(uint32_t n); // NOLINT(*)
/*!
* \brief Set the get_num_outputs function.
* \param fn The function to be set.
* \return reference to self.
*/
inline Op& set_num_outputs(int (*fn)(const Node& node)); // NOLINT(*)
inline Op& set_num_outputs(uint32_t (*fn)(const NodeAttrs& attr)); // NOLINT(*)
/*!
* \brief Set the attr_parser function.
* \param fn The number of outputs to be set.
......@@ -180,6 +195,7 @@ class Op {
static const Op* Get(const std::string& op_name);
/*!
* \brief Get additional registered attribute about operators.
* If nothing has been registered, an empty OpMap will be returned.
* \param attr_name The name of the attribute.
* \return An OpMap of specified attr_name.
* \tparam ValueType The type of the attribute.
......@@ -197,7 +213,7 @@ class Op {
// internal constructor
Op();
// get const reference to certain attribute
static const any& GetAttrMap(const std::string& key);
static const any* GetAttrMap(const std::string& key);
// update the attribute OpMap
static void UpdateAttrMap(const std::string& key,
std::function<void(any*)> updater);
......@@ -218,6 +234,13 @@ class OpMap {
*/
inline const ValueType& operator[](const Op* op) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param op The key to the map
* \param def_value The default value when the key does not exist.
* \return the const reference to the content value.
*/
inline const ValueType& get(const Op* op, const ValueType& def_value) const;
/*!
* \brief Check if the map has op as key.
* \param op The key to the map
* \return 1 if op is contained in map, 0 otherwise.
......@@ -262,8 +285,18 @@ class OpMap {
// member function of Op
template<typename ValueType>
inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
const any& ref = GetAttrMap(key);
return nnvm::get<OpMap<ValueType> >(ref);
const any* ref = GetAttrMap(key);
if (ref == nullptr) {
UpdateAttrMap(key, [key](any* pmap) {
if (pmap->empty()) {
OpMap<ValueType> pm;
pm.attr_name_ = key;
*pmap = std::move(pm);
}
});
ref = GetAttrMap(key);
}
return nnvm::get<OpMap<ValueType> >(*ref);
}
template<typename ValueType>
......@@ -273,7 +306,7 @@ inline Op& Op::attr( // NOLINT(*)
if (pmap->empty()) {
OpMap<ValueType> pm;
pm.attr_name_ = attr_name;
*pmap = pm;
*pmap = std::move(pm);
}
CHECK_EQ(pmap->type(), typeid(OpMap<ValueType>))
<< "Attribute " << attr_name
......@@ -301,18 +334,22 @@ inline Op& Op::describe(const std::string& descr) { // NOLINT(*)
return *this;
}
inline Op& Op::set_num_inputs(int n) { // NOLINT(*)
inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
this->num_inputs = n;
return *this;
}
inline Op& Op::set_num_outputs(int n) { // NOLINT(*)
inline Op& Op::set_num_inputs(uint32_t (*fn)(const NodeAttrs&)) { // NOLINT(*)
this->get_num_inputs = fn;
return *this;
}
inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
this->num_outputs = n;
return *this;
}
inline Op& Op::set_num_outputs(int (*fn)(const Node& node)) { // NOLINT(*)
this->num_outputs = kVarg;
inline Op& Op::set_num_outputs(uint32_t (*fn)(const NodeAttrs&)) { // NOLINT(*)
this->get_num_outputs = fn;
return *this;
}
......@@ -338,6 +375,16 @@ inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {
return data_[idx].first;
}
template<typename ValueType>
inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def_value) const {
const uint32_t idx = op->index_;
if (idx < data_.size() && data_[idx].second) {
return data_[idx].first;
} else {
return def_value;
}
}
} // namespace nnvm
#endif // NNVM_OP_H_
/*!
* Copyright (c) 2016 by Contributors
* \file op_attr_types.h
* \brief Data structures that can appear in operator attributes.
*/
#ifndef NNVM_OP_ATTR_TYPES_H_
#define NNVM_OP_ATTR_TYPES_H_
#include <vector>
#include <string>
#include <functional>
namespace nnvm {
// These types are optional attributes in each op
// Some of them are needed for certain pass.
/*!
* \brief Return list of input arguments names of each operator.
*
* \param attrs The attributes of the node.
* \return list of inputs
* \note Register under "FListInputNames", default return {"data"}.
*
* FListInputNames enables automatic variable creation for missing arguments.
*/
using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
/*!
* \brief Return list of output arguments names of each operator.
*
* \param attrs The attributes of the node.
* \return list of inputs
* \note Register under "FListOutputNames", default return {"outputs"}.
*
* FListOutputNames customized naming for operator outputs.
*/
using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
} // namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_
/*!
* Copyright (c) 2016 by Contributors
* \file symbolic.h
* \brief Symbolic graph construction API
*
* This API is optional, but useful to allow user
* to construct NNVM Graph easily, and quickly create
* front-end host languages.
*/
#ifndef NNVM_SYMBOLIC_H_
#define NNVM_SYMBOLIC_H_
#include <string>
#include <vector>
#include <utility>
#include "./base.h"
#include "./node.h"
namespace nnvm {
/*!
* \brief Symbol is used to represent the
*/
class Symbol {
public:
/*! \brief option passed to ListAttr */
enum ListAttrOption {
/*! \brief recursively list all attributes */
kRecursive,
/*! \brief only list attributes in current node */
kShallow
};
/*! \brief output entries contained in the symbol */
std::vector<NodeEntry> outputs;
/*!
* \brief copy the symbol
* \return a deep copy of the symbolic graph.
*/
Symbol Copy() const;
/*!
* \brief print the symbol info to output stream.
* \param os the output stream we like to print to
*/
void Print(std::ostream &os) const; // NOLINT(*)
/*!
* \brief get the index th element from the returned tuple.
* \param index index of multi output
* \return the symbol corresponds to the indexed element.
*/
Symbol operator[] (size_t index) const;
/*!
* \brief List the arguments names.
*
* The position of the returned list also corresponds to calling position in operator()
* \return the arguments list of this symbol, they can be either named or unnamed (empty string).
*/
std::vector<std::string> ListArguments() const;
/*!
* \brief List the names of outputs for this symbol.
* For normal operators, it is usually symbol node name + "_output"
* \return get the descriptions of outputs for this symbol.
*/
std::vector<std::string> ListOutputs() const;
/*!
* \brief Compose the symbol with arguments, this changes the current symbol.
* The kwargs passed in can be in-complete,
*
* The rest of the symbols will remain the same name.
*
* \param positional arguments
* \param kwargs keyword arguments for the symbol
* \param name name of returned symbol.
*/
void Compose(const std::vector<Symbol>& args,
const std::unordered_map<std::string, Symbol>& kwargs,
const std::string& name);
/*!
* \brief Apply the symbol as a function, compose with arguments
* This is equivalent to Copy then Compose.
* \param args positional arguments for the symbol
* \param kwargs keyword arguments for the symbol
* \param name name of returned symbol.
* \return a new Symbol which is the composition of current symbol with its arguments
*/
Symbol operator () (const std::vector<Symbol>& args,
const std::unordered_map<std::string, Symbol>& kwargs,
const std::string& name) const;
/*!
* \brief Add control flow depenencies to operators involved in symbols.
* For grouped sybmbol, an error will be raised.
* This mutate current symbolic Node.
*
* \param src The symbols to depend on.
*/
void AddControlDeps(const Symbol& src);
/*
* \brief Get all the internal nodes of the symbol.
* \return symbol A new symbol whose output contains all the outputs of the symbols
* Including input variables and intermediate outputs.
*/
Symbol GetInternals() const;
/*!
* \brief set additional attributes to current node.
* This only works for symbol with outputs from single operators.
* For grouped sybmbol, an error will be raised.
*
* This function mutate the node's symbol and is not recommended.
*
* \param key the key of the attribute
* \param value the value of the attribute.
*/
void SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs);
/*!
* \brief Get attribute dictionary from the symbol.
* For grouped sybmbol, an error will be raised.
* \param option If recursive is set, the attributes of all children are retrieved,
* The name of symbol will be pre-pended to each key.
* \return The created attribute.
*/
std::unordered_map<std::string, std::string> ListAttr(ListAttrOption option) const;
/*!
* \brief create symbolic functor(AtomicSymbol) by given operator and attributes.
* \param op_name The name of the operator.
* \param attrs The additional attributes.
*
* \return Symbol that can be used to call compose further.
*/
static Symbol CreateFunctor(const std::string& op_name,
const std::unordered_map<std::string, std::string>& attrs);
/*!
* \brief create variable symbol node
* \param name name of the variable
* \return the new variable
*/
static Symbol CreateVariable(const std::string& name);
/*!
* \brief create equivalence of symbol by grouping the symbols together
* \param symbols list of symbols
* \return the grouped symbol
*/
static Symbol CreateGroup(const std::vector<Symbol>& symbols);
};
} // namespace nnvm
#endif // NNVM_SYMBOLIC_H_
......@@ -13,7 +13,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
entry_rptr_.push_back(0);
std::vector<size_t> inputs_rptr{0}, control_rptr{0};
g.DFSVisit([this, &inputs_rptr, &control_rptr]
DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr]
(const std::shared_ptr<nnvm::Node>& n) {
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
uint32_t nid = static_cast<uint32_t>(nodes_.size());
......
......@@ -6,6 +6,7 @@
#include <nnvm/base.h>
#include <nnvm/op.h>
#include <memory>
#include <atomic>
#include <mutex>
......@@ -23,7 +24,7 @@ struct OpManager {
// global operator counter
std::atomic<int> op_counter{0};
// storage of additional attribute table.
std::unordered_map<std::string, any> attr;
std::unordered_map<std::string, std::unique_ptr<any> > attr;
// get singleton of the
static OpManager* Global() {
static OpManager inst;
......@@ -46,24 +47,24 @@ const Op* Op::Get(const std::string& name) {
}
// Get attribute map by key
const any& Op::GetAttrMap(const std::string& key) {
// assume no operator registration during
// the execution phase.
const auto& dict = OpManager::Global()->attr;
const any* Op::GetAttrMap(const std::string& key) {
auto& dict = OpManager::Global()->attr;
auto it = dict.find(key);
CHECK(it != dict.end() && it->first == key)
<< "Cannot find Operator attribute " << key
<< " for any operator";
return it->second;
if (it != dict.end()) {
return it->second.get();
} else {
return nullptr;
}
}
// update attribute map by updater function.
// update attribute map
void Op::UpdateAttrMap(const std::string& key,
std::function<void(any*)> updater) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex>(mgr->mutex);
any& value = mgr->attr[key];
updater(&value);
std::unique_ptr<any>& value = mgr->attr[key];
if (value.get() == nullptr) value.reset(new any());
if (updater != nullptr) updater(value.get());
}
} // namespace nnvm
......@@ -158,7 +158,7 @@ Graph LoadJSON(const Graph& src) {
Graph SaveJSON(const Graph& src) {
JSONGraph jgraph;
std::unordered_map<Node*, uint32_t> node2index;
src.DFSVisit([&node2index, &jgraph](const std::shared_ptr<Node>& n) {
DFSVisit(src.outputs, [&node2index, &jgraph](const std::shared_ptr<Node>& n) {
uint32_t nid = static_cast<uint32_t>(jgraph.nodes.size());
node2index[n.get()] = nid;
if (n->is_variable()) {
......
......@@ -8,7 +8,7 @@
void test_op() {
using namespace nnvm;
auto add = Op::Get("add");
auto nick = Op::GetAttr<std::string>("nick_name");
static auto& nick = Op::GetAttr<std::string>("nick_name");
LOG(INFO) << "nick=" << nick[add];
}
......@@ -35,9 +35,7 @@ void test_tuple() {
void test_graph() {
nnvm::Graph g;
g.DFSVisit([](const std::shared_ptr<const nnvm::Node>& n){
});
nnvm::Symbol s;
}
int main() {
test_tuple();
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <nngraph/op.h>
#include <nnvm/op.h>
#include <utility>
NNGRAPH_REGISTER_OP(add)
NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
.attr("inplace_pair", std::make_pair(0, 0));
NNGRAPH_REGISTER_OP(add)
NNVM_REGISTER_OP(add)
.attr<std::string>("nick_name", "plus");
TEST(Op, GetAttr) {
using namespace nngraph;
using namespace nnvm;
auto add = Op::Get("add");
auto nick = Op::GetAttr<std::string>("nick_name");
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <nngraph/tuple.h>
#include <nnvm/tuple.h>
TEST(Tuple, Basic) {
using nngraph::Tuple;
using nngraph::TShape;
using nnvm::Tuple;
using nnvm::TShape;
Tuple<int> x{1, 2, 3};
Tuple<int> y{1, 2, 3, 5, 6};
x = std::move(y);
......@@ -17,7 +17,7 @@ TEST(Tuple, Basic) {
std::istringstream is(os.str());
is >> y;
CHECK_EQ(x, y);
Tuple<nngraph::index_t> ss{1, 2, 3};
Tuple<nnvm::index_t> ss{1, 2, 3};
TShape s = ss;
s = std::move(ss);
CHECK((s == TShape{1, 2, 3}));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment