Commit 2da23bd8 by Pedro Larroy Committed by Tianqi Chen

Optimize move semantics of NodeEntry reducing copies of shared_ptr which causes…

Optimize move semantics of NodeEntry reducing copies of shared_ptr which causes atomic contention (#2576)
parent 390b7445
......@@ -30,6 +30,18 @@ using NodePtr = std::shared_ptr<Node>;
/*! \brief an entry that represents output data from a node */
struct NodeEntry {
NodeEntry(NodePtr node, uint32_t index, uint32_t version):
node(std::move(node)),
index(index),
version(version)
{}
NodeEntry():
node(),
index(),
version()
{}
/*! \brief the source node of this data */
NodePtr node;
/*! \brief index of output from the source. */
......@@ -113,6 +125,11 @@ struct NodeAttrs {
*/
class NNVM_DLL Node {
public:
Node() = default;
Node(const Op* op, const std::string& name) {
this->attrs.op = op;
this->attrs.name = name;
}
/*! \brief The attributes in the node. */
NodeAttrs attrs;
/*! \brief inputs to this node */
......@@ -142,7 +159,10 @@ class NNVM_DLL Node {
* \brief create a new empty shared_ptr of Node.
* \return a created empty node.
*/
static NodePtr Create();
template<class ...Args>
static NodePtr Create(Args&&... args) {
return std::make_shared<Node>(std::forward<Args>(args)...);
}
};
/*!
......@@ -167,13 +187,14 @@ inline NodeEntry MakeNode(
p->attrs.op->attr_parser(&(p->attrs));
}
p->inputs = std::move(inputs);
return NodeEntry{p, 0, 0};
return NodeEntry(p, 0, 0);
}
// implementation of functions.
inline const Op* Node::op() const {
return this->attrs.op;
}
inline bool Node::is_variable() const {
return this->op() == nullptr;
}
......
......@@ -37,8 +37,4 @@ Node::~Node() {
}
}
NodePtr Node::Create() {
return std::make_shared<Node>();
}
} // namespace nnvm
......@@ -601,8 +601,8 @@ Symbol Symbol::CreateFunctor(const Op* op,
if (fnum_vis_output.count(n->op())) {
nout = fnum_vis_output[n->op()](n->attrs);
}
for (uint32_t i = 0; i < nout; ++i) {
s.outputs.emplace_back(NodeEntry{n, i, 0});
for (size_t i = 0; i < nout; i++) {
s.outputs.emplace_back(n, i, 0);
}
return s;
}
......@@ -618,7 +618,7 @@ Symbol Symbol::CreateFunctor(const NodeAttrs& attrs) {
nout = fnum_vis_output[n->op()](n->attrs);
}
for (uint32_t i = 0; i < nout; ++i) {
s.outputs.emplace_back(NodeEntry{n, i, 0});
s.outputs.emplace_back(n, i, 0);
}
return s;
}
......@@ -633,7 +633,7 @@ Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {
Symbol Symbol::CreateVariable(const std::string& name) {
Symbol s;
s.outputs.emplace_back(NodeEntry{CreateVariableNode(name), 0, 0});
s.outputs.emplace_back(CreateVariableNode(name), 0, 0);
return s;
}
......
......@@ -114,10 +114,10 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request);
tnode->attrs.name = idx[e.node_id].source->attrs.name + "_" + request.name();
tnode->inputs.emplace_back(new_node->inputs[i]);
nnvm::NodeEntry tnode_output{tnode, 0, 0};
nnvm::NodeEntry tnode_output(std::move(tnode), 0, 0);
new_node->inputs[i] = tnode_output;
// layout produced by LayoutTransformNode
new_layouts[tnode.get()] = {request};
new_layouts[tnode_output.node.get()] = {request};
} else if (!produce.defined()) {
// do reverse infer
new_layouts[in.get()][e.index] = request;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment