Commit 66b9ef23 by tqchen Committed by Tianqi Chen

[OP] Finalize Op registry

parent 871529b1
...@@ -3,13 +3,13 @@ export CFLAGS= -std=c++11 -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loops ...@@ -3,13 +3,13 @@ export CFLAGS= -std=c++11 -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loops
-Iinclude -Idmlc-core/include -fPIC -Iinclude -Idmlc-core/include -fPIC
# specify tensor path # specify tensor path
.PHONY: clean all test .PHONY: clean all test lint doc
all: lib/libnngraph.so test all: lib/libnngraph.so test
SRC = $(wildcard src/*.cc src/*/*.cc) SRC = $(wildcard src/*.cc src/*/*.cc example/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) $(PLUGIN_OBJS) ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(ALL_OBJ) ALL_DEP = $(ALL_OBJ)
build/%.o: src/%.cc build/%.o: src/%.cc
...@@ -24,8 +24,14 @@ lib/libnngraph.so: $(ALL_DEP) ...@@ -24,8 +24,14 @@ lib/libnngraph.so: $(ALL_DEP)
test: $(ALL_DEP) test: $(ALL_DEP)
$(CXX) $(CFLAGS) -o $@ $(filter %.o %.a, $^) $(LDFLAGS) $(CXX) $(CFLAGS) -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
lint:
python2 dmlc-core/scripts/lint.py nngraph cpp include src
doc:
doxygen docs/Doxyfile
clean: clean:
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o xgboost $(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o test
-include build/*.d -include build/*.d
-include build/*/*.d -include build/*/*.d
This source diff could not be displayed because it is too large. You can view the blob instead.
/*!
* Copyright (c) 2016 by Contributors
* \file attr_frame.h
* \brief Attribute frame data structure for properties in the graph.
* This data structure is inspired by data_frame for general.
*/
#include "./base.h"
namespace nngraph {
struct AttrFrame {
std::unique_ptr<std::unordered_map<std::string, any> > info;
};
} // namespace nngraph
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <dmlc/base.h> #include <dmlc/base.h>
#include <dmlc/any.h> #include <dmlc/any.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/registry.h>
namespace nngraph { namespace nngraph {
...@@ -39,4 +40,5 @@ inline const T& get(const any& src) { ...@@ -39,4 +40,5 @@ inline const T& get(const any& src) {
} }
} // namespace nngraph } // namespace nngraph
#endif // NNGRAPH_BASE_H_ #endif // NNGRAPH_BASE_H_
...@@ -8,9 +8,9 @@ ...@@ -8,9 +8,9 @@
#include <vector> #include <vector>
#include "./node.h" #include "./node.h"
#include "./attr_frame.h"
namespace nngraph { namespace nngraph {
/*! /*!
* \brief Symbolic computation graph. * \brief Symbolic computation graph.
*/ */
...@@ -33,8 +33,6 @@ class Graph { ...@@ -33,8 +33,6 @@ class Graph {
private: private:
/*! \brief outputs of the graph. */ /*! \brief outputs of the graph. */
std::vector<NodeEntry> outputs_; std::vector<NodeEntry> outputs_;
/*! \brief additional internal attribute */
AttrFrame attr_frame_;
}; };
} // namespace nngraph } // namespace nngraph
......
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file base.h * \file base.h
* \brief Configuation of nngraph as well as basic data structure. * \brief Graph node data structure.
*/ */
#ifndef NNGRAPH_NODE_H_ #ifndef NNGRAPH_NODE_H_
#define NNGRAPH_NODE_H_ #define NNGRAPH_NODE_H_
#include <memory> #include <memory>
#include <string>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include "./op_prop.h" #include "./op.h"
namespace nngraph { namespace nngraph {
// Forward declare node. // Forward declare node.
struct Node; class Node;
/*! \brief an entry that represents output data from a node */ /*! \brief an entry that represents output data from a node */
struct NodeEntry { struct NodeEntry {
...@@ -23,43 +26,40 @@ struct NodeEntry { ...@@ -23,43 +26,40 @@ struct NodeEntry {
}; };
/*! /*!
* \brief The attributes of the current operation node.
* Usually are additional parameters like axis,
*/
struct NodeAttrs {
/*! \brief The dictionary representation of attributes */
std::unordered_map<std::string, std::string> dict;
/*!
* \brief A parsed version of attributes,
* This is generated if OpProperty.attr_parser is registered.
* The object can be used to quickly access attributes.
*/
any parsed;
};
/*!
* \brief Node represents an operation in a computation graph. * \brief Node represents an operation in a computation graph.
*/ */
struct Node { class Node {
public:
/*! \brief name of the node */ /*! \brief name of the node */
std::string name; std::string name;
/*! \brief the operator this node is pointing at */ /*! \brief the operator this node is pointing at */
const OpProperty *op; const Op *op;
/*! \brief inputs to this node */ /*! \brief inputs to this node */
std::vector<NodeEntry> inputs; std::vector<NodeEntry> inputs;
/*! \brief The attributes in the node. */
NodeAttrs attrs;
/*! \brief destructor of node */
~Node();
/*! /*!
* \brief additional attributes about the node, * \brief create a new empty shared_ptr of Node.
* Use pointer to save space, as attr can be accessed in a slow way, * \return a created empty node.
* not every node will have attributes.
*/ */
std::unordered_map<std::string, std::string> attr; static std::shared_ptr<Node> Create();
~Node() {
if (inputs.size() != 0) {
// explicit deletion via DFS
// this is used to avoid stackoverflow caused by chain of deletions
std::vector<Node*> stack{this};
std::vector<std::shared_ptr<Node> > to_delete;
while (!stack.empty()) {
Node* n = stack.back();
stack.pop_back();
for (NodeEntry& e: n->inputs) {
if (e.node.unique()) {
stack.push_back(e.node.get());
to_delete.emplace_back(std::move(e.node));
} else {
e.node.reset();
}
}
n->inputs.clear();
}
}
}
}; };
} // namespace nngraph } // namespace nngraph
......
/*!
* Copyright (c) 2016 by Contributors
* \file op_prop.h
* \brief Data structure about property of operators
*/
#ifndef NNGRAPH_OP_PROP_H_
#define NNGRAPH_OP_PROP_H_
namespace nngraph {
/*!
* \brief operator specific data structure
*/
struct OpProperty {
/*! \brief name of the operator */
std::string name;
/*! \brief number of inputs to the operator */
int num_inputs;
/*! \brief number of outputs to the operator */
int num_outputs;
};
} // namespace nngraph
#endif // NNGRAPH_OP_PROP_H_
/*!
* Copyright (c) 2016 by Contributors
* \file node.cc
* \brief Graph node data structure.
*/
#include <nngraph/node.h>
namespace nngraph {
Node::~Node() {
if (inputs.size() != 0) {
// explicit deletion via DFS
// this is used to avoid stackoverflow caused by chain of deletions
std::vector<Node*> stack{this};
std::vector<std::shared_ptr<Node> > to_delete;
while (!stack.empty()) {
Node* n = stack.back();
stack.pop_back();
for (NodeEntry& e : n->inputs) {
if (e.node.unique()) {
stack.push_back(e.node.get());
to_delete.emplace_back(std::move(e.node));
} else {
e.node.reset();
}
}
n->inputs.clear();
}
}
}
std::shared_ptr<Node> Node::Create() {
// NOTE: possible change to thread local memory pool
// via std::allocate_shared instead for faster allocation.
return std::make_shared<Node>();
}
} // namespace nngraph
/*!
* Copyright (c) 2016 by Contributors
* \file op.cc
* \brief Support for operator registry.
*/
#include <nngraph/base.h>
#include <nngraph/op.h>
#include <atomic>
#include <mutex>
namespace dmlc {
// enable registry
DMLC_REGISTRY_ENABLE(nngraph::Op);
} // namespace dmlc
namespace nngraph {
// single manager of operator information.
struct OpManager {
// mutex to avoid registration from multiple threads.
std::mutex mutex;
// global operator counter
std::atomic<int> op_counter{0};
// storage of additional attribute table.
std::unordered_map<std::string, any> attr;
// get singleton of the
static OpManager* Global() {
static OpManager inst;
return &inst;
}
};
// constructor
Op::Op() {
OpManager* mgr = OpManager::Global();
index_ = mgr->op_counter++;
}
// find operator by name
const Op* Op::Get(const std::string& name) {
const Op* op = dmlc::Registry<Op>::Find(name);
CHECK(op != nullptr)
<< "Operator " << name << " is not registered";
return op;
}
// Get attribute map by key
const any& Op::GetAttrMap(const std::string& key) {
// assume no operator registration during
// the execution phase.
const auto& dict = OpManager::Global()->attr;
auto it = dict.find(key);
CHECK(it != dict.end() && it->first == key)
<< "Cannot find Operator attribute " << key
<< " for any operator";
return it->second;
}
// update attribute map by updater function.
void Op::UpdateAttrMap(const std::string& key,
std::function<void(any*)> updater) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex>(mgr->mutex);
any& value = mgr->attr[key];
updater(&value);
}
} // namespace nngraph
// Copyright (c) 2016 by Contributors
// This is an example on how we can register operator information to NNGRAPH
#include <nngraph/op.h>
#include <utility>
NNGRAPH_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
.attr("inplace_pair", std::make_pair(0, 0));
NNGRAPH_REGISTER_OP(add)
.attr<std::string>("nick_name", "plus");
// Copyright (c) 2016 by Contributors
#include <nngraph/op.h>
#include <nngraph/graph.h> #include <nngraph/graph.h>
int main() { int main() {
nngraph::any a = 1; using namespace nngraph;
LOG(INFO) << nngraph::get<int>(a); auto add = Op::Get("add");
auto nick = Op::GetAttr<std::string>("nick_name");
LOG(INFO) << "nick=" << nick[add];
return 0; return 0;
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment