Commit 7c3d18c6 by Tianqi Chen

Enable copy on write in graph attrs (#31)

* [INFER] Enhance backward op policy

* [SYMBOL] add list inputs

* relax graph attr to enable copy-on-write
parent de076999
...@@ -30,12 +30,17 @@ class Graph { ...@@ -30,12 +30,17 @@ class Graph {
std::vector<NodeEntry> outputs; std::vector<NodeEntry> outputs;
/*! /*!
* \brief attributes of a graph * \brief attributes of a graph
* Each attribute is immutable, * Note that attribute is shared pointer and can be shared across graphs.
* and can be shared across multiple Instance of graph *
* It is highly recommended to keep each attribute immutable.
* It is also safe to implement an copy-on-write semnatics.
*
* Copy when shared_ptr.unique is not true, while reuse original space
* when shared_ptr.unique is true.
*/ */
std::unordered_map<std::string, std::shared_ptr<const any> > attrs; std::unordered_map<std::string, std::shared_ptr<any> > attrs;
/*! /*!
* \brief Get the attribute from attrs. * \brief Get the immutable attribute from attrs.
* \param attr_name the name of the attribute * \param attr_name the name of the attribute
* \return the reference to corresponding attribute * \return the reference to corresponding attribute
* \tparam T the type of the attribute. * \tparam T the type of the attribute.
...@@ -43,6 +48,17 @@ class Graph { ...@@ -43,6 +48,17 @@ class Graph {
template<typename T> template<typename T>
inline const T& GetAttr(const std::string& attr_name); inline const T& GetAttr(const std::string& attr_name);
/*! /*!
* \brief Get a move copy of the attribute, implement copy on write semantics.
* The content is moved if the reference counter of shared_ptr is 1.
* The attribute is erased from attrs after the call.
*
* \param attr_name the name of the attribute
* \return a new copy of the corresponding attribute.
* \tparam T the type of the attribute.
*/
template<typename T>
inline T MoveCopyAttr(const std::string& attr_name);
/*!
* \brief get a indexed graph of current graph, if not exist, create it on demand * \brief get a indexed graph of current graph, if not exist, create it on demand
* \return The indexed graph. * \return The indexed graph.
* \sa IndexedGraph * \sa IndexedGraph
...@@ -200,6 +216,20 @@ inline const T& Graph::GetAttr(const std::string& attr_name) { ...@@ -200,6 +216,20 @@ inline const T& Graph::GetAttr(const std::string& attr_name) {
return nnvm::get<T>(*it->second); return nnvm::get<T>(*it->second);
} }
template<typename T>
inline T Graph::MoveCopyAttr(const std::string& attr_name) {
auto it = attrs.find(attr_name);
CHECK(it != attrs.end())
<< "Cannot find attribute " << attr_name << " in the graph";
std::shared_ptr<any> sptr = it->second;
attrs.erase(it);
if (sptr.unique()) {
return std::move(nnvm::get<T>(*sptr));
} else {
return nnvm::get<T>(*sptr);
}
}
template <typename GNode, typename HashType, template <typename GNode, typename HashType,
typename FVisit, typename HashFunc, typename FVisit, typename HashFunc,
typename InDegree, typename GetInput> typename InDegree, typename GetInput>
......
...@@ -82,17 +82,18 @@ using FInferShape = FInferNodeEntryAttr<TShape>; ...@@ -82,17 +82,18 @@ using FInferShape = FInferNodeEntryAttr<TShape>;
using FInferType = FInferNodeEntryAttr<int>; using FInferType = FInferNodeEntryAttr<int>;
/*! /*!
* \brief Whether this op is an explicit backward operator * \brief Whether this op is an explicit backward operator,
* and the correspondence of each output to input.
* *
* If TIsBackwardOp is set to be true: * If FBackwardOutToInIndex exists:
* - The first control_deps of the node points to the corresponding forward operator. * - The first control_deps of the node points to the corresponding forward operator.
* - The outputs operator corresponds to exactly inputs of forward op one by one. * - The k-th outputs corresponds to the FBackwardOutputToInputIndex()[k]-th input of forward op.
*
* \note Register under "TIsBackwardOp", default to false.
* *
* \note Register under "FBackwardOutToInIndex"
* This enables easier shape/type inference for backward operators for slice and reduction. * This enables easier shape/type inference for backward operators for slice and reduction.
*/ */
using TIsBackwardOp = bool; using FBackwardOutToInIndex = std::function<
std::vector<uint32_t> (const NodeAttrs& attrs)>;
/*! /*!
* \brief Get possible inplace options. * \brief Get possible inplace options.
......
...@@ -63,6 +63,15 @@ class Symbol { ...@@ -63,6 +63,15 @@ class Symbol {
*/ */
Symbol operator[] (size_t index) const; Symbol operator[] (size_t index) const;
/*! /*!
* \brief List the input variable nodes
* \param option The options to list the arguments.
*
* The position of the returned list also corresponds to calling position in operator()
* \return the arguments list of this symbol, they can be either named or unnamed (empty string).
* \sa ListInputOption
*/
std::vector<NodePtr> ListInputs(ListInputOption option) const;
/*!
* \brief List the input names. * \brief List the input names.
* \param option The options to list the arguments. * \param option The options to list the arguments.
* *
......
...@@ -233,7 +233,7 @@ class Tuple { ...@@ -233,7 +233,7 @@ class Tuple {
return is; return is;
} }
} }
index_t idx; ValueType idx;
std::vector<ValueType> tmp; std::vector<ValueType> tmp;
while (is >> idx) { while (is >> idx) {
tmp.push_back(idx); tmp.push_back(idx);
......
...@@ -180,37 +180,46 @@ Symbol Symbol::operator[] (size_t index) const { ...@@ -180,37 +180,46 @@ Symbol Symbol::operator[] (size_t index) const {
} }
} }
std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const { std::vector<NodePtr> Symbol::ListInputs(ListInputOption option) const {
std::vector<std::string> ret; std::vector<NodePtr> ret;
if (option == kAll) { if (option == kAll) {
DFSVisit(this->outputs, [&ret](const NodePtr &node) { DFSVisit(this->outputs, [&ret](const NodePtr &node) {
if (node->is_variable()) { if (node->is_variable()) {
ret.push_back(node->attrs.name); ret.push_back(node);
} }
}); });
} else { } else {
std::unordered_set<Node*> mutable_set; std::unordered_set<Node*> mutable_set;
std::vector<Node*> vlist; std::vector<NodePtr> vlist;
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs"); static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) { DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) {
if (node->is_variable()) { if (node->is_variable()) {
vlist.push_back(node.get()); vlist.push_back(node);
} else if (fmutate_inputs.count(node->op())) { } else if (fmutate_inputs.count(node->op())) {
for (uint32_t i : fmutate_inputs[node->op()](node->attrs)){ for (uint32_t i : fmutate_inputs[node->op()](node->attrs)){
mutable_set.insert(node->inputs[i].node.get()); mutable_set.insert(node->inputs[i].node.get());
} }
} }
}); });
for (Node* node : vlist) { for (const NodePtr& node : vlist) {
if ((option == kReadOnlyArgs && mutable_set.count(node) == 0) || if ((option == kReadOnlyArgs && mutable_set.count(node.get()) == 0) ||
(option == kAuxiliaryStates && mutable_set.count(node) != 0)) { (option == kAuxiliaryStates && mutable_set.count(node.get()) != 0)) {
ret.push_back(node->attrs.name); ret.emplace_back(node);
} }
} }
} }
return ret; return ret;
} }
std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
std::vector<NodePtr> inputs = ListInputs(option);
std::vector<std::string> ret(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
ret[i] = inputs[i]->attrs.name;
}
return ret;
}
std::vector<std::string> Symbol::ListOutputNames() const { std::vector<std::string> Symbol::ListOutputNames() const {
static auto& flist_ouputs = Op::GetAttr<FListOutputNames>("FListOutputNames"); static auto& flist_ouputs = Op::GetAttr<FListOutputNames>("FListOutputNames");
std::vector<std::string> ret; std::vector<std::string> ret;
......
...@@ -24,8 +24,8 @@ Graph InferAttr(Graph &&ret, ...@@ -24,8 +24,8 @@ Graph InferAttr(Graph &&ret,
const IndexedGraph& idx = ret.indexed_graph(); const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape = static auto& finfer_shape =
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name); Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
static auto& is_backward = static auto& backward_map =
Op::GetAttr<TIsBackwardOp>("TIsBackwardOp"); Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex");
// reshape shape vector // reshape shape vector
AttrVector rshape(idx.num_node_entries(), def_value); AttrVector rshape(idx.num_node_entries(), def_value);
...@@ -82,16 +82,19 @@ Graph InferAttr(Graph &&ret, ...@@ -82,16 +82,19 @@ Graph InferAttr(Graph &&ret,
for (uint32_t i = 0; i < num_outputs; ++i) { for (uint32_t i = 0; i < num_outputs; ++i) {
rshape[idx.entry_id(nid, i)] = oshape[i]; rshape[idx.entry_id(nid, i)] = oshape[i];
} }
} else if (is_backward.get(inode.source->op(), false)) { } else if (backward_map.count(inode.source->op())) {
// backward operator inference. // backward operator inference.
CHECK_GE(inode.control_deps.size(), 1) CHECK_GE(inode.control_deps.size(), 1)
<< "BackwardOp need to have control_deps to its forward op"; << "BackwardOp need to have control_deps to its forward op";
const auto& fnode = idx[inode.control_deps[0]]; const auto& fnode = idx[inode.control_deps[0]];
CHECK_EQ(fnode.inputs.size(), num_outputs) std::vector<uint32_t> out_map =
<< "BackwardOp need to correspond to the forward node"; backward_map[inode.source->op()](inode.source->attrs);
bool known = true; bool known = true;
for (size_t i = 0; i < fnode.inputs.size(); ++i) { for (size_t i = 0; i < out_map.size(); ++i) {
rshape[idx.entry_id(nid, i)] = rshape[idx.entry_id(fnode.inputs[i])]; uint32_t in_id = out_map[i];
CHECK_LT(in_id, fnode.inputs.size());
rshape[idx.entry_id(nid, i)] =
rshape[idx.entry_id(fnode.inputs[in_id])];
if (fis_none(rshape[idx.entry_id(nid, i)])) known = false; if (fis_none(rshape[idx.entry_id(nid, i)])) known = false;
} }
num_unknown += !known; num_unknown += !known;
......
...@@ -12,6 +12,7 @@ namespace nnvm { ...@@ -12,6 +12,7 @@ namespace nnvm {
namespace pass { namespace pass {
namespace { namespace {
// simply logic to place device according to device_group hint // simply logic to place device according to device_group hint
// insert copy node when there is // insert copy node when there is
Graph PlaceDevice(Graph src) { Graph PlaceDevice(Graph src) {
...@@ -21,13 +22,20 @@ Graph PlaceDevice(Graph src) { ...@@ -21,13 +22,20 @@ Graph PlaceDevice(Graph src) {
<< "Need graph attribute \"device_assign_map\" in PlaceDevice"; << "Need graph attribute \"device_assign_map\" in PlaceDevice";
CHECK_NE(src.attrs.count("device_copy_op"), 0) CHECK_NE(src.attrs.count("device_copy_op"), 0)
<< "Need graph attribute \"device_copy_op\" in PlaceDevice"; << "Need graph attribute \"device_copy_op\" in PlaceDevice";
std::string device_group_attr_key = src.GetAttr<std::string>("device_group_attr_key"); std::string device_group_attr_key = src.GetAttr<std::string>("device_group_attr_key");
const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op")); const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op"));
auto& device_assign_map = src.GetAttr<DeviceAssignMap>("device_assign_map"); auto& device_assign_map = src.GetAttr<DeviceAssignMap>("device_assign_map");
const IndexedGraph& idx = src.indexed_graph(); const IndexedGraph& idx = src.indexed_graph();
DeviceVector device(idx.num_nodes(), -1); DeviceVector device;
// copy on write semanatics
if (src.attrs.count("device") != 0) {
device = src.MoveCopyAttr<DeviceVector>("device");
CHECK_EQ(device.size(), idx.num_nodes());
} else {
device.resize(idx.num_nodes(), -1);
}
// forward pass // forward pass
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid]; const auto& inode = idx[nid];
......
...@@ -12,11 +12,11 @@ namespace dmlc { ...@@ -12,11 +12,11 @@ namespace dmlc {
namespace json { namespace json {
// overload handler for shared ptr // overload handler for shared ptr
template<> template<>
struct Handler<std::shared_ptr<const any> > { struct Handler<std::shared_ptr<any> > {
inline static void Write(JSONWriter *writer, const std::shared_ptr<const any> &data) { inline static void Write(JSONWriter *writer, const std::shared_ptr<any> &data) {
writer->Write(*data); writer->Write(*data);
} }
inline static void Read(JSONReader *reader, std::shared_ptr<const any> *data) { inline static void Read(JSONReader *reader, std::shared_ptr<any> *data) {
any v; any v;
reader->Read(&v); reader->Read(&v);
*data = std::make_shared<any>(std::move(v)); *data = std::make_shared<any>(std::move(v));
...@@ -131,7 +131,7 @@ struct JSONGraph { ...@@ -131,7 +131,7 @@ struct JSONGraph {
std::vector<uint32_t> arg_nodes; std::vector<uint32_t> arg_nodes;
std::vector<uint32_t> node_row_ptr; std::vector<uint32_t> node_row_ptr;
std::vector<JSONNode::Entry> heads; std::vector<JSONNode::Entry> heads;
std::unordered_map<std::string, std::shared_ptr<const any> > attrs; std::unordered_map<std::string, std::shared_ptr<any> > attrs;
void Save(dmlc::JSONWriter *writer) const { void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject(); writer->BeginObject();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment