Commit b63cb4d1 by Tianqi Chen

Introduce NodePtr (#9)

parent ea8d2292
......@@ -8,6 +8,7 @@
#include <dmlc/base.h>
#include <dmlc/any.h>
#include <dmlc/memory.h>
#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <dmlc/array_view.h>
......
......@@ -81,7 +81,7 @@ void PostOrderDFSVisit(const std::vector<GNode>& heads,
template<typename FVisit>
inline void DFSVisit(const std::vector<NodeEntry>& heads,
FVisit fvisit) {
typedef const std::shared_ptr<Node>* GNode;
typedef const NodePtr* GNode;
std::vector<GNode> head_nodes(heads.size());
std::transform(heads.begin(), heads.end(), head_nodes.begin(),
[](const NodeEntry& e)->GNode {
......
......@@ -20,8 +20,7 @@ class Node;
/*!
* \brief we always used NodePtr for a reference pointer
* to the node, so this alias can be changed in case we need
* even faster graph composition than 3M ops/sec.
* to the node, so this alias can be changed in case.
*
* By default, NodePtr is a std::shared_ptr of node
*/
......
......@@ -14,7 +14,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
std::vector<size_t> inputs_rptr{0}, control_rptr{0};
DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr]
(const std::shared_ptr<nnvm::Node>& n) {
(const NodePtr& n) {
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
uint32_t nid = static_cast<uint32_t>(nodes_.size());
// nodes_
......
......@@ -12,7 +12,7 @@ Node::~Node() {
// explicit deletion via DFS
// this is used to avoid stackoverflow caused by chain of deletions
std::vector<Node*> stack{this};
std::vector<std::shared_ptr<Node> > to_delete;
std::vector<NodePtr> to_delete;
while (!stack.empty()) {
Node* n = stack.back();
stack.pop_back();
......@@ -37,8 +37,6 @@ Node::~Node() {
}
NodePtr Node::Create() {
// NOTE: possible change to thread local memory pool
// via std::allocate_shared instead for faster allocation.
return std::make_shared<Node>();
}
......
......@@ -35,7 +35,7 @@ void test_node_speed() {
auto add = nnvm::Op::Get("add");
double tstart = dmlc::GetTime();
size_t rep = 1000;
size_t n = 100;
size_t n = 1000;
for (size_t t = 0; t < rep; ++t) {
nnvm::Symbol s = nnvm::Symbol::CreateVariable("x");
for (size_t i = 0; i < n; ++i) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment