Commit 66b9ef23 by tqchen Committed by Tianqi Chen

[OP] Finalize Op registry

parent 871529b1
......@@ -3,13 +3,13 @@ export CFLAGS= -std=c++11 -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loops
-Iinclude -Idmlc-core/include -fPIC
# specify tensor path
.PHONY: clean all test
.PHONY: clean all test lint doc
all: lib/libnngraph.so test
SRC = $(wildcard src/*.cc src/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) $(PLUGIN_OBJS)
SRC = $(wildcard src/*.cc src/*/*.cc example/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(ALL_OBJ)
build/%.o: src/%.cc
......@@ -24,8 +24,14 @@ lib/libnngraph.so: $(ALL_DEP)
test: $(ALL_DEP)
$(CXX) $(CFLAGS) -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
lint:
python2 dmlc-core/scripts/lint.py nngraph cpp include src
doc:
doxygen docs/Doxyfile
clean:
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o xgboost
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o test
-include build/*.d
-include build/*/*.d
This source diff could not be displayed because it is too large. You can view the blob instead.
/*!
* Copyright (c) 2016 by Contributors
* \file attr_frame.h
* \brief Attribute frame data structure for properties in the graph.
* This data structure is inspired by data_frame for general.
*/
#include "./base.h"
namespace nngraph {
struct AttrFrame {
std::unique_ptr<std::unordered_map<std::string, any> > info;
};
} // namespace nngraph
......@@ -9,6 +9,7 @@
#include <dmlc/base.h>
#include <dmlc/any.h>
#include <dmlc/logging.h>
#include <dmlc/registry.h>
namespace nngraph {
......@@ -39,4 +40,5 @@ inline const T& get(const any& src) {
}
} // namespace nngraph
#endif // NNGRAPH_BASE_H_
......@@ -8,9 +8,9 @@
#include <vector>
#include "./node.h"
#include "./attr_frame.h"
namespace nngraph {
/*!
* \brief Symbolic computation graph.
*/
......@@ -33,8 +33,6 @@ class Graph {
private:
/*! \brief outputs of the graph. */
std::vector<NodeEntry> outputs_;
/*! \brief additional internal attribute */
AttrFrame attr_frame_;
};
} // namespace nngraph
......
/*!
* Copyright (c) 2016 by Contributors
* \file base.h
* \brief Configuation of nngraph as well as basic data structure.
* \brief Graph node data structure.
*/
#ifndef NNGRAPH_NODE_H_
#define NNGRAPH_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include <unordered_map>
#include "./op_prop.h"
#include "./op.h"
namespace nngraph {
// Forward declare node.
struct Node;
class Node;
/*! \brief an entry that represents output data from a node */
struct NodeEntry {
......@@ -23,43 +26,40 @@ struct NodeEntry {
};
/*!
* \brief The attributes of the current operation node.
* Usually are additional parameters like axis,
*/
struct NodeAttrs {
/*! \brief The dictionary representation of attributes */
std::unordered_map<std::string, std::string> dict;
/*!
* \brief A parsed version of attributes,
* This is generated if OpProperty.attr_parser is registered.
* The object can be used to quickly access attributes.
*/
any parsed;
};
/*!
* \brief Node represents an operation in a computation graph.
*/
struct Node {
class Node {
public:
/*! \brief name of the node */
std::string name;
/*! \brief the operator this node is pointing at */
const OpProperty *op;
const Op *op;
/*! \brief inputs to this node */
std::vector<NodeEntry> inputs;
/*! \brief The attributes in the node. */
NodeAttrs attrs;
/*! \brief destructor of node */
~Node();
/*!
* \brief additional attributes about the node,
* Use pointer to save space, as attr can be accessed in a slow way,
* not every node will have attributes.
* \brief create a new empty shared_ptr of Node.
* \return a created empty node.
*/
std::unordered_map<std::string, std::string> attr;
~Node() {
if (inputs.size() != 0) {
// explicit deletion via DFS
// this is used to avoid stackoverflow caused by chain of deletions
std::vector<Node*> stack{this};
std::vector<std::shared_ptr<Node> > to_delete;
while (!stack.empty()) {
Node* n = stack.back();
stack.pop_back();
for (NodeEntry& e: n->inputs) {
if (e.node.unique()) {
stack.push_back(e.node.get());
to_delete.emplace_back(std::move(e.node));
} else {
e.node.reset();
}
}
n->inputs.clear();
}
}
}
static std::shared_ptr<Node> Create();
};
} // namespace nngraph
......
/*!
* Copyright (c) 2016 by Contributors
* \file op.h
* \brief Operator information structor.
*/
#ifndef NNGRAPH_OP_H_
#define NNGRAPH_OP_H_
#include <string>
#include <vector>
#include <utility>
#include <typeinfo>
#include <functional>
#include "./base.h"
namespace nngraph {
// forward declarations
class Node;
struct NodeAttrs;
template<typename ValueType>
class OpMap;
class OpRegistryEntry;
/*! \brief constant to indicate variable length inout and output */
static const int kVarg = -1;
/*!
* \brief Operator structure.
*
* Besides the fields in the structure,
* arbitary additional information can be associated with each op.
* See function GetAttr for details.
*
* \code
* // Example usage of Op
*
* // registeration of oeprators
* // NOTE that the attr function can register any
* // additional attributes to the operator
* NNGRAPH_REGISTER_OP(add)
* .describe("add two inputs together")
* .set_num_inputs(2)
* .attr<OpKernel>("gpu_kernel", AddKernel);
*
* NNGRAPH_REGISTER_OP(sub)
* .describe("substract one tensor from another")
* .set_num_inputs(2);
*
* // Can call regster multiple times in different files
* // to register different part of information
* NNGRAPH_REGISTER_OP(sub)
* .attr<OpKernel>("gpu_kernel", SubKernel);
*
* // get operators from registry.
* void my_function() {
* const Op* add = Op::Get("add");
* const Op* sub = Op::Get("sub");
* // query basic information about each operator.
* assert(op->name == "plus");
* assert(op->num_inputs == 2);
*
* // get additional registered information,
* // Assume user registered a OpKernel type attribute as gpu_kernel on each operator.
* const OpMap<OpKernel>& kernel = Op::GetAttr<OpKernel>("gpu_kernel");
* // we can get the kernel functions by using operator as key.
* auto add_kernel = kernel[add];
* auto sub_kernel = kernel[sub];
* // subsequent code can make use of the queried kernel functions.
* }
* \endcode
*/
class Op {
public:
/*! \brief name of the operator */
std::string name;
/*! \brief detailed description of the operator */
std::string description;
/*!
* \brief number of inputs to the operator,
* -1 means it is variable length
*/
int num_inputs = 0;
/*!
* \brief number of outputs of the operator
* -1 means it is variable length
* The number of outputs will be decided by
* get_num_outputs function
* \sa get_num_outputs
*/
int num_outputs = 1;
/*!
* \brief get number of outputs given information about the node.
* This is only valid when num_outputs == -1.
* \param node The constructed node.
* \return number of outputs.
*/
int (*get_num_outputs)(const Node& node) = nullptr;
/*!
* \brief Attribute parser to parse the NodeAttrs information.
*
* This can help to get quick access to a parsed attribute
* object
*
* \code
* // Example usage of attr_parser.
*
* // Suppose we want to register operator sum.
* // The parameters about sum operator
* struct SumParam {
* int axis;
* };
* // The parser function
* void SumAttrParser(NodeAttrs* attrs) {
* // This will be invoked during node construction.
* SumParam param;
* // parse axis string to integer
* param.axis = atoi(attrs->dict["axis"].c_str());
* // set the parsed parameter
* attrs->parsed = std::move(param);
* }
* // The other function that can utilize the parsed result.
* TShape SumInferShape(const NodeAttrs& attrs,
* const std::vector<TShape>& ishapes) {
* // we can use the parsed version of param
* // without repeatively parsing the parameter
* const SumParam& param = nngraph::get<SumParam>(attrs.parsed);
* }
* \endcode
*/
void (*attr_parser)(NodeAttrs* attrs) = nullptr;
// function fields.
/*!
* \brief setter function during registration
* Set the description of operator
* \param descr the description string.
* \return reference to self.
*/
inline Op& describe(const std::string& descr); // NOLINT(*)
/*!
* \brief Set the num_inputs
* \param n The number of inputs to be set.
* \return reference to self.
*/
inline Op& set_num_inputs(int n); // NOLINT(*)
/*!
* \brief Set the num_outputs
* \param n The number of outputs to be set.
* \return reference to self.
*/
inline Op& set_num_outputs(int n); // NOLINT(*)
/*!
* \brief Set the get_num_outputs function.
* \param fn The function to be set.
* \return reference to self.
*/
inline Op& set_num_outputs(int (*fn)(const Node& node)); // NOLINT(*)
/*!
* \brief Set the attr_parser function.
* \param fn The number of outputs to be set.
* \return reference to self.
*/
inline Op& set_attr_parser(void (*fn)(NodeAttrs* attrs)); // NOLINT(*)
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \tparam ValueType The type of the value to be set.
*/
template<typename ValueType>
inline Op& attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value);
/*!
* \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered.
* \param op_name Name of the operator.
* \return Pointer to a Op, valid throughout program lifetime.
*/
static const Op* Get(const std::string& op_name);
/*!
* \brief Get additional registered attribute about operators.
* \param attr_name The name of the attribute.
* \return An OpMap of specified attr_name.
* \tparam ValueType The type of the attribute.
*/
template<typename ValueType>
static const OpMap<ValueType>& GetAttr(const std::string& attr_name);
private:
template<typename ValueType>
friend class OpMap;
friend class dmlc::Registry<Op>;
// Program internal unique index of operator.
// Used to help index the program.
uint32_t index_{0};
// internal constructor
Op();
// get const reference to certain attribute
static const any& GetAttrMap(const std::string& key);
// update the attribute OpMap
static void UpdateAttrMap(const std::string& key,
std::function<void(any*)> updater);
};
/*!
* \brief A map data structure that takes Op* as key
* and returns ValueType
* \tparam ValueType The type of the value stored in map.
*/
template<typename ValueType>
class OpMap {
public:
/*!
* \brief get the corresponding value element at op
* \param op The key to the map
* \return the const reference to the content value.
*/
inline const ValueType& operator[](const Op* op) const;
/*!
* \brief Check if the map has op as key.
* \param op The key to the map
* \return 1 if op is contained in map, 0 otherwise.
*/
inline int count(const Op* op) const;
private:
friend class Op;
// internal attribute name
std::string attr_name_;
// internal data
std::vector<std::pair<ValueType, int> > data_;
OpMap() = default;
};
// internal macros to make
#define NNGRAPH_STR_CONCAT_(__x, __y) __x##__y
#define NNGRAPH_STR_CONCAT(__x, __y) NNGRAPH_STR_CONCAT_(__x, __y)
#define NNGRAPH_REGISTER_VAR_DEF(OpName) \
static ::nngraph::Op & __make_ ## NNGraphOp ## _ ## OpName
/*!
* \def NNGRAPH_REGISTER_OP
* \brief Register
* This macro must be used under namespace dmlc, and only used once in cc file.
* \param OpName The name of registry
*
* \code
*
* NNGRAPH_REGISTER_OP(add)
* .describe("add two inputs together")
* .set_num_inputs(2)
* .attr<OpKernel>("gpu_kernel", AddKernel);
*
* \endcode
*/
#define NNGRAPH_REGISTER_OP(OpName) \
NNGRAPH_STR_CONCAT(NNGRAPH_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
::dmlc::Registry<::nngraph::Op>::Get()->__REGISTER_OR_GET__(#OpName)
// implementations of template functions after this.
// member function of Op
template<typename ValueType>
inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
const any& ref = GetAttrMap(key);
return nngraph::get<OpMap<ValueType> >(ref);
}
template<typename ValueType>
inline Op& Op::attr( // NOLINT(*)
const std::string& attr_name, const ValueType& value) {
UpdateAttrMap(attr_name, [this, attr_name, value](any* pmap) {
if (pmap->empty()) {
OpMap<ValueType> pm;
pm.attr_name_ = attr_name;
*pmap = pm;
}
CHECK_EQ(pmap->type(), typeid(OpMap<ValueType>))
<< "Attribute " << attr_name
<< " of operator " << this->name
<< " is registered as inconsistent types"
<< " previously " << pmap->type().name()
<< " current " << typeid(OpMap<ValueType>).name();
std::vector<std::pair<ValueType, int> >& vec =
nngraph::get<OpMap<ValueType> >(*pmap).data_;
// resize the value type.
vec.resize(index_ + 1,
std::make_pair(ValueType(), 0));
std::pair<ValueType, int>& p = vec[index_];
CHECK(p.second == 0 || p.first == value)
<< "Attribute " << attr_name
<< " of operator " << this->name
<< " is already registered to a different value";
vec[index_] = std::make_pair(value, 1);
});
return *this;
}
inline Op& Op::describe(const std::string& descr) { // NOLINT(*)
this->description = descr;
return *this;
}
inline Op& Op::set_num_inputs(int n) { // NOLINT(*)
this->num_inputs = n;
return *this;
}
inline Op& Op::set_num_outputs(int n) { // NOLINT(*)
this->num_outputs = n;
return *this;
}
inline Op& Op::set_num_outputs(int (*fn)(const Node& node)) { // NOLINT(*)
this->num_outputs = kVarg;
this->get_num_outputs = fn;
return *this;
}
inline Op& Op::set_attr_parser(void (*fn)(NodeAttrs* attrs)) { // NOLINT(*)
this->attr_parser = fn;
return *this;
}
// member functions of OpMap
template<typename ValueType>
inline int OpMap<ValueType>::count(const Op* op) const {
const uint32_t idx = op->index_;
return idx < data_.size() ? data_[idx].second : 0;
}
template<typename ValueType>
inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {
const uint32_t idx = op->index_;
CHECK(idx < data_.size() && data_[idx].second)
<< "Attribute " << attr_name_
<< " has not been registered for Operator " << op->name;
return data_[idx].first;
}
} // namespace nngraph
#endif // NNGRAPH_OP_H_
/*!
* Copyright (c) 2016 by Contributors
* \file op_prop.h
* \brief Data structure about property of operators
*/
#ifndef NNGRAPH_OP_PROP_H_
#define NNGRAPH_OP_PROP_H_
namespace nngraph {
/*!
* \brief operator specific data structure
*/
struct OpProperty {
/*! \brief name of the operator */
std::string name;
/*! \brief number of inputs to the operator */
int num_inputs;
/*! \brief number of outputs to the operator */
int num_outputs;
};
} // namespace nngraph
#endif // NNGRAPH_OP_PROP_H_
/*!
* Copyright (c) 2016 by Contributors
* \file node.cc
* \brief Graph node data structure.
*/
#include <nngraph/node.h>
namespace nngraph {
Node::~Node() {
if (inputs.size() != 0) {
// explicit deletion via DFS
// this is used to avoid stackoverflow caused by chain of deletions
std::vector<Node*> stack{this};
std::vector<std::shared_ptr<Node> > to_delete;
while (!stack.empty()) {
Node* n = stack.back();
stack.pop_back();
for (NodeEntry& e : n->inputs) {
if (e.node.unique()) {
stack.push_back(e.node.get());
to_delete.emplace_back(std::move(e.node));
} else {
e.node.reset();
}
}
n->inputs.clear();
}
}
}
std::shared_ptr<Node> Node::Create() {
// NOTE: possible change to thread local memory pool
// via std::allocate_shared instead for faster allocation.
return std::make_shared<Node>();
}
} // namespace nngraph
/*!
* Copyright (c) 2016 by Contributors
* \file op.cc
* \brief Support for operator registry.
*/
#include <nngraph/base.h>
#include <nngraph/op.h>
#include <atomic>
#include <mutex>
namespace dmlc {
// enable registry
DMLC_REGISTRY_ENABLE(nngraph::Op);
} // namespace dmlc
namespace nngraph {
// single manager of operator information.
struct OpManager {
// mutex to avoid registration from multiple threads.
std::mutex mutex;
// global operator counter
std::atomic<int> op_counter{0};
// storage of additional attribute table.
std::unordered_map<std::string, any> attr;
// get singleton of the
static OpManager* Global() {
static OpManager inst;
return &inst;
}
};
// constructor
Op::Op() {
OpManager* mgr = OpManager::Global();
index_ = mgr->op_counter++;
}
// find operator by name
const Op* Op::Get(const std::string& name) {
const Op* op = dmlc::Registry<Op>::Find(name);
CHECK(op != nullptr)
<< "Operator " << name << " is not registered";
return op;
}
// Get attribute map by key
const any& Op::GetAttrMap(const std::string& key) {
// assume no operator registration during
// the execution phase.
const auto& dict = OpManager::Global()->attr;
auto it = dict.find(key);
CHECK(it != dict.end() && it->first == key)
<< "Cannot find Operator attribute " << key
<< " for any operator";
return it->second;
}
// update attribute map by updater function.
void Op::UpdateAttrMap(const std::string& key,
std::function<void(any*)> updater) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex>(mgr->mutex);
any& value = mgr->attr[key];
updater(&value);
}
} // namespace nngraph
// Copyright (c) 2016 by Contributors
// This is an example on how we can register operator information to NNGRAPH
#include <nngraph/op.h>
#include <utility>
NNGRAPH_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
.attr("inplace_pair", std::make_pair(0, 0));
NNGRAPH_REGISTER_OP(add)
.attr<std::string>("nick_name", "plus");
// Copyright (c) 2016 by Contributors
#include <nngraph/op.h>
#include <nngraph/graph.h>
int main() {
nngraph::any a = 1;
LOG(INFO) << nngraph::get<int>(a);
using namespace nngraph;
auto add = Op::Get("add");
auto nick = Op::GetAttr<std::string>("nick_name");
LOG(INFO) << "nick=" << nick[add];
return 0;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment