Commit 2da23bd8 by Pedro Larroy Committed by Tianqi Chen

Optimize move semantics of NodeEntry reducing copies of shared_ptr which causes…

Optimize move semantics of NodeEntry reducing copies of shared_ptr which causes atomic contention (#2576)
parent 390b7445
...@@ -30,6 +30,18 @@ using NodePtr = std::shared_ptr<Node>; ...@@ -30,6 +30,18 @@ using NodePtr = std::shared_ptr<Node>;
/*! \brief an entry that represents output data from a node */ /*! \brief an entry that represents output data from a node */
struct NodeEntry { struct NodeEntry {
NodeEntry(NodePtr node, uint32_t index, uint32_t version):
node(std::move(node)),
index(index),
version(version)
{}
NodeEntry():
node(),
index(),
version()
{}
/*! \brief the source node of this data */ /*! \brief the source node of this data */
NodePtr node; NodePtr node;
/*! \brief index of output from the source. */ /*! \brief index of output from the source. */
...@@ -113,6 +125,11 @@ struct NodeAttrs { ...@@ -113,6 +125,11 @@ struct NodeAttrs {
*/ */
class NNVM_DLL Node { class NNVM_DLL Node {
public: public:
Node() = default;
Node(const Op* op, const std::string& name) {
this->attrs.op = op;
this->attrs.name = name;
}
/*! \brief The attributes in the node. */ /*! \brief The attributes in the node. */
NodeAttrs attrs; NodeAttrs attrs;
/*! \brief inputs to this node */ /*! \brief inputs to this node */
...@@ -142,7 +159,10 @@ class NNVM_DLL Node { ...@@ -142,7 +159,10 @@ class NNVM_DLL Node {
* \brief create a new empty shared_ptr of Node. * \brief create a new empty shared_ptr of Node.
* \return a created empty node. * \return a created empty node.
*/ */
static NodePtr Create(); template<class ...Args>
static NodePtr Create(Args&&... args) {
return std::make_shared<Node>(std::forward<Args>(args)...);
}
}; };
/*! /*!
...@@ -167,13 +187,14 @@ inline NodeEntry MakeNode( ...@@ -167,13 +187,14 @@ inline NodeEntry MakeNode(
p->attrs.op->attr_parser(&(p->attrs)); p->attrs.op->attr_parser(&(p->attrs));
} }
p->inputs = std::move(inputs); p->inputs = std::move(inputs);
return NodeEntry{p, 0, 0}; return NodeEntry(p, 0, 0);
} }
// implementation of functions. // implementation of functions.
inline const Op* Node::op() const { inline const Op* Node::op() const {
return this->attrs.op; return this->attrs.op;
} }
inline bool Node::is_variable() const { inline bool Node::is_variable() const {
return this->op() == nullptr; return this->op() == nullptr;
} }
......
...@@ -37,8 +37,4 @@ Node::~Node() { ...@@ -37,8 +37,4 @@ Node::~Node() {
} }
} }
NodePtr Node::Create() {
return std::make_shared<Node>();
}
} // namespace nnvm } // namespace nnvm
...@@ -601,8 +601,8 @@ Symbol Symbol::CreateFunctor(const Op* op, ...@@ -601,8 +601,8 @@ Symbol Symbol::CreateFunctor(const Op* op,
if (fnum_vis_output.count(n->op())) { if (fnum_vis_output.count(n->op())) {
nout = fnum_vis_output[n->op()](n->attrs); nout = fnum_vis_output[n->op()](n->attrs);
} }
for (uint32_t i = 0; i < nout; ++i) { for (size_t i = 0; i < nout; i++) {
s.outputs.emplace_back(NodeEntry{n, i, 0}); s.outputs.emplace_back(n, i, 0);
} }
return s; return s;
} }
...@@ -618,7 +618,7 @@ Symbol Symbol::CreateFunctor(const NodeAttrs& attrs) { ...@@ -618,7 +618,7 @@ Symbol Symbol::CreateFunctor(const NodeAttrs& attrs) {
nout = fnum_vis_output[n->op()](n->attrs); nout = fnum_vis_output[n->op()](n->attrs);
} }
for (uint32_t i = 0; i < nout; ++i) { for (uint32_t i = 0; i < nout; ++i) {
s.outputs.emplace_back(NodeEntry{n, i, 0}); s.outputs.emplace_back(n, i, 0);
} }
return s; return s;
} }
...@@ -633,7 +633,7 @@ Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) { ...@@ -633,7 +633,7 @@ Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {
Symbol Symbol::CreateVariable(const std::string& name) { Symbol Symbol::CreateVariable(const std::string& name) {
Symbol s; Symbol s;
s.outputs.emplace_back(NodeEntry{CreateVariableNode(name), 0, 0}); s.outputs.emplace_back(CreateVariableNode(name), 0, 0);
return s; return s;
} }
......
...@@ -114,10 +114,10 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { ...@@ -114,10 +114,10 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request); nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request);
tnode->attrs.name = idx[e.node_id].source->attrs.name + "_" + request.name(); tnode->attrs.name = idx[e.node_id].source->attrs.name + "_" + request.name();
tnode->inputs.emplace_back(new_node->inputs[i]); tnode->inputs.emplace_back(new_node->inputs[i]);
nnvm::NodeEntry tnode_output{tnode, 0, 0}; nnvm::NodeEntry tnode_output(std::move(tnode), 0, 0);
new_node->inputs[i] = tnode_output; new_node->inputs[i] = tnode_output;
// layout produced by LayoutTransformNode // layout produced by LayoutTransformNode
new_layouts[tnode.get()] = {request}; new_layouts[tnode_output.node.get()] = {request};
} else if (!produce.defined()) { } else if (!produce.defined()) {
// do reverse infer // do reverse infer
new_layouts[in.get()][e.index] = request; new_layouts[in.get()][e.index] = request;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment