Commit 7c3d18c6 by Tianqi Chen

Enable copy on write in graph attrs (#31)

* [INFER] Enhance backward op policy

* [SYMBOL] add list inputs

* relax graph attr to enable copy-on-write
parent de076999
......@@ -30,12 +30,17 @@ class Graph {
std::vector<NodeEntry> outputs;
/*!
* \brief attributes of a graph
* Each attribute is immutable,
* and can be shared across multiple Instance of graph
* Note that attribute is shared pointer and can be shared across graphs.
*
* It is highly recommended to keep each attribute immutable.
* It is also safe to implement an copy-on-write semnatics.
*
* Copy when shared_ptr.unique is not true, while reuse original space
* when shared_ptr.unique is true.
*/
std::unordered_map<std::string, std::shared_ptr<const any> > attrs;
std::unordered_map<std::string, std::shared_ptr<any> > attrs;
/*!
* \brief Get the attribute from attrs.
* \brief Get the immutable attribute from attrs.
* \param attr_name the name of the attribute
* \return the reference to corresponding attribute
* \tparam T the type of the attribute.
......@@ -43,6 +48,17 @@ class Graph {
template<typename T>
inline const T& GetAttr(const std::string& attr_name);
/*!
* \brief Get a move copy of the attribute, implement copy on write semantics.
* The content is moved if the reference counter of shared_ptr is 1.
* The attribute is erased from attrs after the call.
*
* \param attr_name the name of the attribute
* \return a new copy of the corresponding attribute.
* \tparam T the type of the attribute.
*/
template<typename T>
inline T MoveCopyAttr(const std::string& attr_name);
/*!
* \brief get a indexed graph of current graph, if not exist, create it on demand
* \return The indexed graph.
* \sa IndexedGraph
......@@ -200,6 +216,20 @@ inline const T& Graph::GetAttr(const std::string& attr_name) {
return nnvm::get<T>(*it->second);
}
template<typename T>
inline T Graph::MoveCopyAttr(const std::string& attr_name) {
auto it = attrs.find(attr_name);
CHECK(it != attrs.end())
<< "Cannot find attribute " << attr_name << " in the graph";
std::shared_ptr<any> sptr = it->second;
attrs.erase(it);
if (sptr.unique()) {
return std::move(nnvm::get<T>(*sptr));
} else {
return nnvm::get<T>(*sptr);
}
}
template <typename GNode, typename HashType,
typename FVisit, typename HashFunc,
typename InDegree, typename GetInput>
......
......@@ -82,17 +82,18 @@ using FInferShape = FInferNodeEntryAttr<TShape>;
using FInferType = FInferNodeEntryAttr<int>;
/*!
* \brief Whether this op is an explicit backward operator
* \brief Whether this op is an explicit backward operator,
* and the correspondence of each output to input.
*
* If TIsBackwardOp is set to be true:
* If FBackwardOutToInIndex exists:
* - The first control_deps of the node points to the corresponding forward operator.
* - The outputs operator corresponds to exactly inputs of forward op one by one.
*
* \note Register under "TIsBackwardOp", default to false.
* - The k-th outputs corresponds to the FBackwardOutputToInputIndex()[k]-th input of forward op.
*
* \note Register under "FBackwardOutToInIndex"
* This enables easier shape/type inference for backward operators for slice and reduction.
*/
using TIsBackwardOp = bool;
using FBackwardOutToInIndex = std::function<
std::vector<uint32_t> (const NodeAttrs& attrs)>;
/*!
* \brief Get possible inplace options.
......
......@@ -63,6 +63,15 @@ class Symbol {
*/
Symbol operator[] (size_t index) const;
/*!
* \brief List the input variable nodes
* \param option The options to list the arguments.
*
* The position of the returned list also corresponds to calling position in operator()
* \return the arguments list of this symbol, they can be either named or unnamed (empty string).
* \sa ListInputOption
*/
std::vector<NodePtr> ListInputs(ListInputOption option) const;
/*!
* \brief List the input names.
* \param option The options to list the arguments.
*
......
......@@ -233,7 +233,7 @@ class Tuple {
return is;
}
}
index_t idx;
ValueType idx;
std::vector<ValueType> tmp;
while (is >> idx) {
tmp.push_back(idx);
......
......@@ -180,37 +180,46 @@ Symbol Symbol::operator[] (size_t index) const {
}
}
std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
std::vector<std::string> ret;
std::vector<NodePtr> Symbol::ListInputs(ListInputOption option) const {
std::vector<NodePtr> ret;
if (option == kAll) {
DFSVisit(this->outputs, [&ret](const NodePtr &node) {
if (node->is_variable()) {
ret.push_back(node->attrs.name);
ret.push_back(node);
}
});
} else {
std::unordered_set<Node*> mutable_set;
std::vector<Node*> vlist;
std::vector<NodePtr> vlist;
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) {
if (node->is_variable()) {
vlist.push_back(node.get());
vlist.push_back(node);
} else if (fmutate_inputs.count(node->op())) {
for (uint32_t i : fmutate_inputs[node->op()](node->attrs)){
mutable_set.insert(node->inputs[i].node.get());
}
}
});
for (Node* node : vlist) {
if ((option == kReadOnlyArgs && mutable_set.count(node) == 0) ||
(option == kAuxiliaryStates && mutable_set.count(node) != 0)) {
ret.push_back(node->attrs.name);
for (const NodePtr& node : vlist) {
if ((option == kReadOnlyArgs && mutable_set.count(node.get()) == 0) ||
(option == kAuxiliaryStates && mutable_set.count(node.get()) != 0)) {
ret.emplace_back(node);
}
}
}
return ret;
}
std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
std::vector<NodePtr> inputs = ListInputs(option);
std::vector<std::string> ret(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
ret[i] = inputs[i]->attrs.name;
}
return ret;
}
std::vector<std::string> Symbol::ListOutputNames() const {
static auto& flist_ouputs = Op::GetAttr<FListOutputNames>("FListOutputNames");
std::vector<std::string> ret;
......
......@@ -24,8 +24,8 @@ Graph InferAttr(Graph &&ret,
const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape =
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
static auto& is_backward =
Op::GetAttr<TIsBackwardOp>("TIsBackwardOp");
static auto& backward_map =
Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex");
// reshape shape vector
AttrVector rshape(idx.num_node_entries(), def_value);
......@@ -82,16 +82,19 @@ Graph InferAttr(Graph &&ret,
for (uint32_t i = 0; i < num_outputs; ++i) {
rshape[idx.entry_id(nid, i)] = oshape[i];
}
} else if (is_backward.get(inode.source->op(), false)) {
} else if (backward_map.count(inode.source->op())) {
// backward operator inference.
CHECK_GE(inode.control_deps.size(), 1)
<< "BackwardOp need to have control_deps to its forward op";
const auto& fnode = idx[inode.control_deps[0]];
CHECK_EQ(fnode.inputs.size(), num_outputs)
<< "BackwardOp need to correspond to the forward node";
std::vector<uint32_t> out_map =
backward_map[inode.source->op()](inode.source->attrs);
bool known = true;
for (size_t i = 0; i < fnode.inputs.size(); ++i) {
rshape[idx.entry_id(nid, i)] = rshape[idx.entry_id(fnode.inputs[i])];
for (size_t i = 0; i < out_map.size(); ++i) {
uint32_t in_id = out_map[i];
CHECK_LT(in_id, fnode.inputs.size());
rshape[idx.entry_id(nid, i)] =
rshape[idx.entry_id(fnode.inputs[in_id])];
if (fis_none(rshape[idx.entry_id(nid, i)])) known = false;
}
num_unknown += !known;
......
......@@ -12,6 +12,7 @@ namespace nnvm {
namespace pass {
namespace {
// simply logic to place device according to device_group hint
// insert copy node when there is
Graph PlaceDevice(Graph src) {
......@@ -21,13 +22,20 @@ Graph PlaceDevice(Graph src) {
<< "Need graph attribute \"device_assign_map\" in PlaceDevice";
CHECK_NE(src.attrs.count("device_copy_op"), 0)
<< "Need graph attribute \"device_copy_op\" in PlaceDevice";
std::string device_group_attr_key = src.GetAttr<std::string>("device_group_attr_key");
const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op"));
auto& device_assign_map = src.GetAttr<DeviceAssignMap>("device_assign_map");
const IndexedGraph& idx = src.indexed_graph();
DeviceVector device(idx.num_nodes(), -1);
DeviceVector device;
// copy on write semanatics
if (src.attrs.count("device") != 0) {
device = src.MoveCopyAttr<DeviceVector>("device");
CHECK_EQ(device.size(), idx.num_nodes());
} else {
device.resize(idx.num_nodes(), -1);
}
// forward pass
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
......
......@@ -12,11 +12,11 @@ namespace dmlc {
namespace json {
// overload handler for shared ptr
template<>
struct Handler<std::shared_ptr<const any> > {
inline static void Write(JSONWriter *writer, const std::shared_ptr<const any> &data) {
struct Handler<std::shared_ptr<any> > {
inline static void Write(JSONWriter *writer, const std::shared_ptr<any> &data) {
writer->Write(*data);
}
inline static void Read(JSONReader *reader, std::shared_ptr<const any> *data) {
inline static void Read(JSONReader *reader, std::shared_ptr<any> *data) {
any v;
reader->Read(&v);
*data = std::make_shared<any>(std::move(v));
......@@ -131,7 +131,7 @@ struct JSONGraph {
std::vector<uint32_t> arg_nodes;
std::vector<uint32_t> node_row_ptr;
std::vector<JSONNode::Entry> heads;
std::unordered_map<std::string, std::shared_ptr<const any> > attrs;
std::unordered_map<std::string, std::shared_ptr<any> > attrs;
void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment