Commit b63cb4d1 by Tianqi Chen

Introduce NodePtr (#9)

parent ea8d2292
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <dmlc/base.h> #include <dmlc/base.h>
#include <dmlc/any.h> #include <dmlc/any.h>
#include <dmlc/memory.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/registry.h> #include <dmlc/registry.h>
#include <dmlc/array_view.h> #include <dmlc/array_view.h>
......
...@@ -81,7 +81,7 @@ void PostOrderDFSVisit(const std::vector<GNode>& heads, ...@@ -81,7 +81,7 @@ void PostOrderDFSVisit(const std::vector<GNode>& heads,
template<typename FVisit> template<typename FVisit>
inline void DFSVisit(const std::vector<NodeEntry>& heads, inline void DFSVisit(const std::vector<NodeEntry>& heads,
FVisit fvisit) { FVisit fvisit) {
typedef const std::shared_ptr<Node>* GNode; typedef const NodePtr* GNode;
std::vector<GNode> head_nodes(heads.size()); std::vector<GNode> head_nodes(heads.size());
std::transform(heads.begin(), heads.end(), head_nodes.begin(), std::transform(heads.begin(), heads.end(), head_nodes.begin(),
[](const NodeEntry& e)->GNode { [](const NodeEntry& e)->GNode {
......
...@@ -20,8 +20,7 @@ class Node; ...@@ -20,8 +20,7 @@ class Node;
/*! /*!
* \brief we always used NodePtr for a reference pointer * \brief we always used NodePtr for a reference pointer
* to the node, so this alias can be changed in case we need * to the node, so this alias can be changed in case.
* even faster graph composition than 3M ops/sec.
* *
* By default, NodePtr is a std::shared_ptr of node * By default, NodePtr is a std::shared_ptr of node
*/ */
......
...@@ -14,7 +14,7 @@ IndexedGraph::IndexedGraph(const Graph &g) { ...@@ -14,7 +14,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
std::vector<size_t> inputs_rptr{0}, control_rptr{0}; std::vector<size_t> inputs_rptr{0}, control_rptr{0};
DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr] DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr]
(const std::shared_ptr<nnvm::Node>& n) { (const NodePtr& n) {
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max()); CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
uint32_t nid = static_cast<uint32_t>(nodes_.size()); uint32_t nid = static_cast<uint32_t>(nodes_.size());
// nodes_ // nodes_
......
...@@ -12,7 +12,7 @@ Node::~Node() { ...@@ -12,7 +12,7 @@ Node::~Node() {
// explicit deletion via DFS // explicit deletion via DFS
// this is used to avoid stackoverflow caused by chain of deletions // this is used to avoid stackoverflow caused by chain of deletions
std::vector<Node*> stack{this}; std::vector<Node*> stack{this};
std::vector<std::shared_ptr<Node> > to_delete; std::vector<NodePtr> to_delete;
while (!stack.empty()) { while (!stack.empty()) {
Node* n = stack.back(); Node* n = stack.back();
stack.pop_back(); stack.pop_back();
...@@ -37,8 +37,6 @@ Node::~Node() { ...@@ -37,8 +37,6 @@ Node::~Node() {
} }
NodePtr Node::Create() { NodePtr Node::Create() {
// NOTE: possible change to thread local memory pool
// via std::allocate_shared instead for faster allocation.
return std::make_shared<Node>(); return std::make_shared<Node>();
} }
......
...@@ -35,7 +35,7 @@ void test_node_speed() { ...@@ -35,7 +35,7 @@ void test_node_speed() {
auto add = nnvm::Op::Get("add"); auto add = nnvm::Op::Get("add");
double tstart = dmlc::GetTime(); double tstart = dmlc::GetTime();
size_t rep = 1000; size_t rep = 1000;
size_t n = 100; size_t n = 1000;
for (size_t t = 0; t < rep; ++t) { for (size_t t = 0; t < rep; ++t) {
nnvm::Symbol s = nnvm::Symbol::CreateVariable("x"); nnvm::Symbol s = nnvm::Symbol::CreateVariable("x");
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment