Commit ea8d2292 by Tianqi Chen

Rename shared_ptr<Node> to NodePtr (#8)

parent db9a9a79
...@@ -18,10 +18,19 @@ namespace nnvm { ...@@ -18,10 +18,19 @@ namespace nnvm {
// Forward declare node. // Forward declare node.
class Node; class Node;
/*!
* \brief we always used NodePtr for a reference pointer
* to the node, so this alias can be changed in case we need
* even faster graph composition than 3M ops/sec.
*
* By default, NodePtr is a std::shared_ptr of node
*/
using NodePtr = std::shared_ptr<Node>;
/*! \brief an entry that represents output data from a node */ /*! \brief an entry that represents output data from a node */
struct NodeEntry { struct NodeEntry {
/*! \brief the source node of this data */ /*! \brief the source node of this data */
std::shared_ptr<Node> node; NodePtr node;
/*! \brief index of output from the source. */ /*! \brief index of output from the source. */
uint32_t index; uint32_t index;
/*! /*!
...@@ -66,7 +75,7 @@ class Node { ...@@ -66,7 +75,7 @@ class Node {
* \brief Optional control flow dependencies * \brief Optional control flow dependencies
* Gives operation must be performed before this operation. * Gives operation must be performed before this operation.
*/ */
std::vector<std::shared_ptr<Node> > control_deps; std::vector<NodePtr> control_deps;
/*! \brief The attributes in the node. */ /*! \brief The attributes in the node. */
NodeAttrs attrs; NodeAttrs attrs;
/*! \brief destructor of node */ /*! \brief destructor of node */
...@@ -85,7 +94,7 @@ class Node { ...@@ -85,7 +94,7 @@ class Node {
* \brief create a new empty shared_ptr of Node. * \brief create a new empty shared_ptr of Node.
* \return a created empty node. * \return a created empty node.
*/ */
static std::shared_ptr<Node> Create(); static NodePtr Create();
}; };
// implementation of functions. // implementation of functions.
......
...@@ -24,7 +24,7 @@ Node::~Node() { ...@@ -24,7 +24,7 @@ Node::~Node() {
e.node.reset(); e.node.reset();
} }
} }
for (std::shared_ptr<Node>& sp : n->control_deps) { for (NodePtr& sp : n->control_deps) {
if (sp.unique()) { if (sp.unique()) {
stack.push_back(sp.get()); stack.push_back(sp.get());
} else { } else {
...@@ -36,7 +36,7 @@ Node::~Node() { ...@@ -36,7 +36,7 @@ Node::~Node() {
} }
} }
std::shared_ptr<Node> Node::Create() { NodePtr Node::Create() {
// NOTE: possible change to thread local memory pool // NOTE: possible change to thread local memory pool
// via std::allocate_shared instead for faster allocation. // via std::allocate_shared instead for faster allocation.
return std::make_shared<Node>(); return std::make_shared<Node>();
......
...@@ -18,8 +18,8 @@ struct VariableParam { ...@@ -18,8 +18,8 @@ struct VariableParam {
uint32_t version{0}; uint32_t version{0};
}; };
std::shared_ptr<Node> CreateVariableNode(const std::string& name) { NodePtr CreateVariableNode(const std::string& name) {
std::shared_ptr<Node> n = Node::Create(); NodePtr n = Node::Create();
n->op = nullptr; n->op = nullptr;
n->attrs.name = name; n->attrs.name = name;
n->attrs.parsed = VariableParam(); n->attrs.parsed = VariableParam();
...@@ -95,10 +95,10 @@ inline bool IsAtomic(const std::vector<NodeEntry>& outputs) { ...@@ -95,10 +95,10 @@ inline bool IsAtomic(const std::vector<NodeEntry>& outputs) {
// public functions // public functions
Symbol Symbol::Copy() const { Symbol Symbol::Copy() const {
std::unordered_map<Node*, std::shared_ptr<Node> > old_new; std::unordered_map<Node*, NodePtr> old_new;
// use DFSVisit to copy all the nodes // use DFSVisit to copy all the nodes
DFSVisit(this->outputs, [&old_new](const std::shared_ptr<Node>& node) { DFSVisit(this->outputs, [&old_new](const NodePtr& node) {
std::shared_ptr<Node> np = Node::Create(); NodePtr np = Node::Create();
np->op = node->op; np->op = node->op;
np->attrs = node->attrs; np->attrs = node->attrs;
old_new[node.get()] = std::move(np); old_new[node.get()] = std::move(np);
...@@ -109,7 +109,7 @@ Symbol Symbol::Copy() const { ...@@ -109,7 +109,7 @@ Symbol Symbol::Copy() const {
Node *ptr = e.node.get(); Node *ptr = e.node.get();
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version}); kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version});
} }
for (const std::shared_ptr<Node>& p : kv.first->control_deps) { for (const NodePtr& p : kv.first->control_deps) {
kv.second->control_deps.emplace_back(old_new[p.get()]); kv.second->control_deps.emplace_back(old_new[p.get()]);
} }
} }
...@@ -131,7 +131,7 @@ void Symbol::Print(std::ostream &os) const { ...@@ -131,7 +131,7 @@ void Symbol::Print(std::ostream &os) const {
os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name
<< '(' << outputs[i].index << ")\n"; << '(' << outputs[i].index << ")\n";
} }
DFSVisit(this->outputs, [&os](const std::shared_ptr<Node>& node) { DFSVisit(this->outputs, [&os](const NodePtr& node) {
if (node->is_variable()) { if (node->is_variable()) {
os << "Variable:" << node->attrs.name << '\n'; os << "Variable:" << node->attrs.name << '\n';
} else { } else {
...@@ -179,7 +179,7 @@ Symbol Symbol::operator[] (size_t index) const { ...@@ -179,7 +179,7 @@ Symbol Symbol::operator[] (size_t index) const {
std::vector<std::string> Symbol::ListArguments() const { std::vector<std::string> Symbol::ListArguments() const {
std::vector<std::string> ret; std::vector<std::string> ret;
DFSVisit(this->outputs, [&ret](const std::shared_ptr<Node> &node) { DFSVisit(this->outputs, [&ret](const NodePtr &node) {
if (node->is_variable()) { if (node->is_variable()) {
ret.push_back(node->attrs.name); ret.push_back(node->attrs.name);
} }
...@@ -295,7 +295,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args, ...@@ -295,7 +295,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
std::unordered_map<Node *, const NodeEntry*> replace_map; std::unordered_map<Node *, const NodeEntry*> replace_map;
// replace map stores the existing replacement plan for arguments node // replace map stores the existing replacement plan for arguments node
auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map] auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map]
(const std::shared_ptr<Node> &node) { (const NodePtr &node) {
if (node->is_variable()) { if (node->is_variable()) {
if (arg_counter < args.size()) { if (arg_counter < args.size()) {
replace_map[node.get()] = &(args[arg_counter]->outputs[0]); replace_map[node.get()] = &(args[arg_counter]->outputs[0]);
...@@ -316,7 +316,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args, ...@@ -316,7 +316,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
std::vector<Node*> update_nodes; std::vector<Node*> update_nodes;
std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan; std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan;
auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes] auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes]
(const std::shared_ptr<Node> &node) { (const NodePtr &node) {
// visit all the childs, find possible replacement // visit all the childs, find possible replacement
bool repl = false; bool repl = false;
for (size_t i = 0; i < node->inputs.size(); ++i) { for (size_t i = 0; i < node->inputs.size(); ++i) {
...@@ -368,7 +368,7 @@ void Symbol::AddControlDeps(const Symbol& src) { ...@@ -368,7 +368,7 @@ void Symbol::AddControlDeps(const Symbol& src) {
Symbol Symbol::GetInternals() const { Symbol Symbol::GetInternals() const {
Symbol ret; Symbol ret;
DFSVisit(this->outputs, [&ret](const std::shared_ptr<Node>& node) { DFSVisit(this->outputs, [&ret](const NodePtr& node) {
Node* n = node.get(); Node* n = node.get();
if (n->is_variable()) { if (n->is_variable()) {
// grab version from variable. // grab version from variable.
...@@ -421,7 +421,7 @@ bool Symbol::GetAttr(const std::string& key, std::string* out) const { ...@@ -421,7 +421,7 @@ bool Symbol::GetAttr(const std::string& key, std::string* out) const {
std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption option) const { std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption option) const {
if (option == kRecursive) { if (option == kRecursive) {
std::unordered_map<std::string, std::string> ret; std::unordered_map<std::string, std::string> ret;
DFSVisit(this->outputs, [&ret](const std::shared_ptr<Node>& n) { DFSVisit(this->outputs, [&ret](const NodePtr& n) {
for (const auto& it : n->attrs.dict) { for (const auto& it : n->attrs.dict) {
ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second; ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second;
} }
...@@ -435,7 +435,7 @@ std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption op ...@@ -435,7 +435,7 @@ std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption op
Symbol Symbol::CreateFunctor(const Op* op, Symbol Symbol::CreateFunctor(const Op* op,
std::unordered_map<std::string, std::string>&& attrs) { std::unordered_map<std::string, std::string>&& attrs) {
Symbol s; Symbol s;
std::shared_ptr<Node> n = Node::Create(); NodePtr n = Node::Create();
n->op = op; n->op = op;
n->attrs.dict = std::move(attrs); n->attrs.dict = std::move(attrs);
if (n->op->attr_parser != nullptr) { if (n->op->attr_parser != nullptr) {
......
...@@ -21,7 +21,7 @@ inline T get_with_default(const std::unordered_map<Node*, T> &map, ...@@ -21,7 +21,7 @@ inline T get_with_default(const std::unordered_map<Node*, T> &map,
Graph OrderMutation(const Graph& src) { Graph OrderMutation(const Graph& src) {
std::unordered_map<Node*, std::vector<NodeEntry> > version_hist; std::unordered_map<Node*, std::vector<NodeEntry> > version_hist;
DFSVisit(src.outputs, [&version_hist](const std::shared_ptr<Node>& n) { DFSVisit(src.outputs, [&version_hist](const NodePtr& n) {
for (const NodeEntry& e : n->inputs) { for (const NodeEntry& e : n->inputs) {
if (e.node->is_variable()) { if (e.node->is_variable()) {
if (e.version != 0 && version_hist.count(e.node.get()) == 0) { if (e.version != 0 && version_hist.count(e.node.get()) == 0) {
...@@ -33,8 +33,8 @@ Graph OrderMutation(const Graph& src) { ...@@ -33,8 +33,8 @@ Graph OrderMutation(const Graph& src) {
// no mutation happens, everything if fine. // no mutation happens, everything if fine.
if (version_hist.size() == 0) return src; if (version_hist.size() == 0) return src;
// start preparing for remapping the nodes. // start preparing for remapping the nodes.
std::unordered_map<Node*, std::shared_ptr<Node> > old_new; std::unordered_map<Node*, NodePtr> old_new;
auto prepare = [&version_hist, &old_new] (const std::shared_ptr<Node>& n) { auto prepare = [&version_hist, &old_new] (const NodePtr& n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput"); static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
bool need_repl = false; bool need_repl = false;
for (size_t i = 0; i < n->inputs.size(); ++i) { for (size_t i = 0; i < n->inputs.size(); ++i) {
...@@ -52,11 +52,11 @@ Graph OrderMutation(const Graph& src) { ...@@ -52,11 +52,11 @@ Graph OrderMutation(const Graph& src) {
if (old_new.count(e.node.get()) != 0) need_repl = true; if (old_new.count(e.node.get()) != 0) need_repl = true;
} }
} }
for (const std::shared_ptr<Node>& p : n->control_deps) { for (const NodePtr& p : n->control_deps) {
if (old_new.count(p.get()) != 0) need_repl = true; if (old_new.count(p.get()) != 0) need_repl = true;
} }
if (need_repl) { if (need_repl) {
std::shared_ptr<Node> np = Node::Create(); NodePtr np = Node::Create();
np->op = n->op; np->op = n->op;
np->attrs = n->attrs; np->attrs = n->attrs;
old_new[n.get()] = std::move(np); old_new[n.get()] = std::move(np);
...@@ -84,7 +84,7 @@ Graph OrderMutation(const Graph& src) { ...@@ -84,7 +84,7 @@ Graph OrderMutation(const Graph& src) {
kv.second->inputs.push_back(e); kv.second->inputs.push_back(e);
} }
} }
for (const std::shared_ptr<Node>& p : kv.first->control_deps) { for (const NodePtr& p : kv.first->control_deps) {
kv.second->control_deps.emplace_back( kv.second->control_deps.emplace_back(
get_with_default(old_new, p.get(), p)); get_with_default(old_new, p.get(), p));
} }
......
...@@ -32,7 +32,7 @@ struct JSONNode { ...@@ -32,7 +32,7 @@ struct JSONNode {
// the node entry structure in serialized format // the node entry structure in serialized format
typedef std::pair<uint32_t, uint32_t> Entry; typedef std::pair<uint32_t, uint32_t> Entry;
// pointer to the graph node // pointer to the graph node
std::shared_ptr<Node> node; NodePtr node;
// inputs // inputs
std::vector<Entry> inputs; std::vector<Entry> inputs;
// control flow dependencies // control flow dependencies
...@@ -159,7 +159,7 @@ Graph LoadJSON(const Graph& src) { ...@@ -159,7 +159,7 @@ Graph LoadJSON(const Graph& src) {
Graph SaveJSON(const Graph& src) { Graph SaveJSON(const Graph& src) {
JSONGraph jgraph; JSONGraph jgraph;
std::unordered_map<Node*, uint32_t> node2index; std::unordered_map<Node*, uint32_t> node2index;
DFSVisit(src.outputs, [&node2index, &jgraph](const std::shared_ptr<Node>& n) { DFSVisit(src.outputs, [&node2index, &jgraph](const NodePtr& n) {
uint32_t nid = static_cast<uint32_t>(jgraph.nodes.size()); uint32_t nid = static_cast<uint32_t>(jgraph.nodes.size());
node2index[n.get()] = nid; node2index[n.get()] = nid;
if (n->is_variable()) { if (n->is_variable()) {
...@@ -172,7 +172,7 @@ Graph SaveJSON(const Graph& src) { ...@@ -172,7 +172,7 @@ Graph SaveJSON(const Graph& src) {
jnode.inputs.emplace_back( jnode.inputs.emplace_back(
std::make_pair(node2index.at(e.node.get()), e.index)); std::make_pair(node2index.at(e.node.get()), e.index));
} }
for (const std::shared_ptr<Node>& c : n->control_deps) { for (const NodePtr& c : n->control_deps) {
jnode.control_deps.push_back(node2index.at(c.get())); jnode.control_deps.push_back(node2index.at(c.get()));
} }
jgraph.nodes.emplace_back(std::move(jnode)); jgraph.nodes.emplace_back(std::move(jnode));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment