Commit ea8d2292 by Tianqi Chen

Rename shared_ptr<Node> to NodePtr (#8)

parent db9a9a79
......@@ -18,10 +18,19 @@ namespace nnvm {
// Forward declare node.
class Node;
/*!
* \brief we always used NodePtr for a reference pointer
* to the node, so this alias can be changed in case we need
* even faster graph composition than 3M ops/sec.
*
* By default, NodePtr is a std::shared_ptr of node
*/
using NodePtr = std::shared_ptr<Node>;
/*! \brief an entry that represents output data from a node */
struct NodeEntry {
/*! \brief the source node of this data */
std::shared_ptr<Node> node;
NodePtr node;
/*! \brief index of output from the source. */
uint32_t index;
/*!
......@@ -66,7 +75,7 @@ class Node {
* \brief Optional control flow dependencies
* Gives operation must be performed before this operation.
*/
std::vector<std::shared_ptr<Node> > control_deps;
std::vector<NodePtr> control_deps;
/*! \brief The attributes in the node. */
NodeAttrs attrs;
/*! \brief destructor of node */
......@@ -85,7 +94,7 @@ class Node {
* \brief create a new empty shared_ptr of Node.
* \return a created empty node.
*/
static std::shared_ptr<Node> Create();
static NodePtr Create();
};
// implementation of functions.
......
......@@ -24,7 +24,7 @@ Node::~Node() {
e.node.reset();
}
}
for (std::shared_ptr<Node>& sp : n->control_deps) {
for (NodePtr& sp : n->control_deps) {
if (sp.unique()) {
stack.push_back(sp.get());
} else {
......@@ -36,7 +36,7 @@ Node::~Node() {
}
}
std::shared_ptr<Node> Node::Create() {
NodePtr Node::Create() {
// NOTE: possible change to thread local memory pool
// via std::allocate_shared instead for faster allocation.
return std::make_shared<Node>();
......
......@@ -18,8 +18,8 @@ struct VariableParam {
uint32_t version{0};
};
std::shared_ptr<Node> CreateVariableNode(const std::string& name) {
std::shared_ptr<Node> n = Node::Create();
NodePtr CreateVariableNode(const std::string& name) {
NodePtr n = Node::Create();
n->op = nullptr;
n->attrs.name = name;
n->attrs.parsed = VariableParam();
......@@ -95,10 +95,10 @@ inline bool IsAtomic(const std::vector<NodeEntry>& outputs) {
// public functions
Symbol Symbol::Copy() const {
std::unordered_map<Node*, std::shared_ptr<Node> > old_new;
std::unordered_map<Node*, NodePtr> old_new;
// use DFSVisit to copy all the nodes
DFSVisit(this->outputs, [&old_new](const std::shared_ptr<Node>& node) {
std::shared_ptr<Node> np = Node::Create();
DFSVisit(this->outputs, [&old_new](const NodePtr& node) {
NodePtr np = Node::Create();
np->op = node->op;
np->attrs = node->attrs;
old_new[node.get()] = std::move(np);
......@@ -109,7 +109,7 @@ Symbol Symbol::Copy() const {
Node *ptr = e.node.get();
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version});
}
for (const std::shared_ptr<Node>& p : kv.first->control_deps) {
for (const NodePtr& p : kv.first->control_deps) {
kv.second->control_deps.emplace_back(old_new[p.get()]);
}
}
......@@ -131,7 +131,7 @@ void Symbol::Print(std::ostream &os) const {
os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name
<< '(' << outputs[i].index << ")\n";
}
DFSVisit(this->outputs, [&os](const std::shared_ptr<Node>& node) {
DFSVisit(this->outputs, [&os](const NodePtr& node) {
if (node->is_variable()) {
os << "Variable:" << node->attrs.name << '\n';
} else {
......@@ -179,7 +179,7 @@ Symbol Symbol::operator[] (size_t index) const {
std::vector<std::string> Symbol::ListArguments() const {
std::vector<std::string> ret;
DFSVisit(this->outputs, [&ret](const std::shared_ptr<Node> &node) {
DFSVisit(this->outputs, [&ret](const NodePtr &node) {
if (node->is_variable()) {
ret.push_back(node->attrs.name);
}
......@@ -295,7 +295,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
std::unordered_map<Node *, const NodeEntry*> replace_map;
// replace map stores the existing replacement plan for arguments node
auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map]
(const std::shared_ptr<Node> &node) {
(const NodePtr &node) {
if (node->is_variable()) {
if (arg_counter < args.size()) {
replace_map[node.get()] = &(args[arg_counter]->outputs[0]);
......@@ -316,7 +316,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
std::vector<Node*> update_nodes;
std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan;
auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes]
(const std::shared_ptr<Node> &node) {
(const NodePtr &node) {
// visit all the childs, find possible replacement
bool repl = false;
for (size_t i = 0; i < node->inputs.size(); ++i) {
......@@ -368,7 +368,7 @@ void Symbol::AddControlDeps(const Symbol& src) {
Symbol Symbol::GetInternals() const {
Symbol ret;
DFSVisit(this->outputs, [&ret](const std::shared_ptr<Node>& node) {
DFSVisit(this->outputs, [&ret](const NodePtr& node) {
Node* n = node.get();
if (n->is_variable()) {
// grab version from variable.
......@@ -421,7 +421,7 @@ bool Symbol::GetAttr(const std::string& key, std::string* out) const {
std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption option) const {
if (option == kRecursive) {
std::unordered_map<std::string, std::string> ret;
DFSVisit(this->outputs, [&ret](const std::shared_ptr<Node>& n) {
DFSVisit(this->outputs, [&ret](const NodePtr& n) {
for (const auto& it : n->attrs.dict) {
ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second;
}
......@@ -435,7 +435,7 @@ std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption op
Symbol Symbol::CreateFunctor(const Op* op,
std::unordered_map<std::string, std::string>&& attrs) {
Symbol s;
std::shared_ptr<Node> n = Node::Create();
NodePtr n = Node::Create();
n->op = op;
n->attrs.dict = std::move(attrs);
if (n->op->attr_parser != nullptr) {
......
......@@ -21,7 +21,7 @@ inline T get_with_default(const std::unordered_map<Node*, T> &map,
Graph OrderMutation(const Graph& src) {
std::unordered_map<Node*, std::vector<NodeEntry> > version_hist;
DFSVisit(src.outputs, [&version_hist](const std::shared_ptr<Node>& n) {
DFSVisit(src.outputs, [&version_hist](const NodePtr& n) {
for (const NodeEntry& e : n->inputs) {
if (e.node->is_variable()) {
if (e.version != 0 && version_hist.count(e.node.get()) == 0) {
......@@ -33,8 +33,8 @@ Graph OrderMutation(const Graph& src) {
// no mutation happens, everything if fine.
if (version_hist.size() == 0) return src;
// start preparing for remapping the nodes.
std::unordered_map<Node*, std::shared_ptr<Node> > old_new;
auto prepare = [&version_hist, &old_new] (const std::shared_ptr<Node>& n) {
std::unordered_map<Node*, NodePtr> old_new;
auto prepare = [&version_hist, &old_new] (const NodePtr& n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
bool need_repl = false;
for (size_t i = 0; i < n->inputs.size(); ++i) {
......@@ -52,11 +52,11 @@ Graph OrderMutation(const Graph& src) {
if (old_new.count(e.node.get()) != 0) need_repl = true;
}
}
for (const std::shared_ptr<Node>& p : n->control_deps) {
for (const NodePtr& p : n->control_deps) {
if (old_new.count(p.get()) != 0) need_repl = true;
}
if (need_repl) {
std::shared_ptr<Node> np = Node::Create();
NodePtr np = Node::Create();
np->op = n->op;
np->attrs = n->attrs;
old_new[n.get()] = std::move(np);
......@@ -84,7 +84,7 @@ Graph OrderMutation(const Graph& src) {
kv.second->inputs.push_back(e);
}
}
for (const std::shared_ptr<Node>& p : kv.first->control_deps) {
for (const NodePtr& p : kv.first->control_deps) {
kv.second->control_deps.emplace_back(
get_with_default(old_new, p.get(), p));
}
......
......@@ -32,7 +32,7 @@ struct JSONNode {
// the node entry structure in serialized format
typedef std::pair<uint32_t, uint32_t> Entry;
// pointer to the graph node
std::shared_ptr<Node> node;
NodePtr node;
// inputs
std::vector<Entry> inputs;
// control flow dependencies
......@@ -159,7 +159,7 @@ Graph LoadJSON(const Graph& src) {
Graph SaveJSON(const Graph& src) {
JSONGraph jgraph;
std::unordered_map<Node*, uint32_t> node2index;
DFSVisit(src.outputs, [&node2index, &jgraph](const std::shared_ptr<Node>& n) {
DFSVisit(src.outputs, [&node2index, &jgraph](const NodePtr& n) {
uint32_t nid = static_cast<uint32_t>(jgraph.nodes.size());
node2index[n.get()] = nid;
if (n->is_variable()) {
......@@ -172,7 +172,7 @@ Graph SaveJSON(const Graph& src) {
jnode.inputs.emplace_back(
std::make_pair(node2index.at(e.node.get()), e.index));
}
for (const std::shared_ptr<Node>& c : n->control_deps) {
for (const NodePtr& c : n->control_deps) {
jnode.control_deps.push_back(node2index.at(c.get()));
}
jgraph.nodes.emplace_back(std::move(jnode));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment