Commit 3ad05398 by Przemyslaw Tredak Committed by Tianqi Chen

Handling duplicate NodeEntries on the edge of the gradient graph (#122)

* Handling duplicate NodeEntries on the edge of the graph

* Fix docs and segfault

* Suggestions from review

* Added attr_parser check
parent 1feabb0d
......@@ -132,6 +132,8 @@ inline Graph PlaceDevice(Graph graph,
* \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same as like.
* \param zero_ops Optional, list of operators that outputs a single zero array. The first one
* must be zeros_like.
* \param copy_op_str Optional, name of the copy operation required to handle duplicates
* on the edge of the graph
* \return A new graph, whose outputs correspond to inputs of xs.
*/
inline Graph Gradient(
......@@ -143,7 +145,8 @@ inline Graph Gradient(
std::function<int(const Node& node)> mirror_fun = nullptr,
std::function<NodeEntry(const NodeEntry& src, const NodeEntry &like)>
attr_hint_fun = nullptr,
std::vector<const Op*> zero_ops = std::vector<const Op*>()) {
std::vector<const Op*> zero_ops = std::vector<const Op*>(),
std::string copy_op_str = std::string()) {
graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys));
graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs));
......@@ -164,6 +167,10 @@ inline Graph Gradient(
graph.attrs["zero_ops"] = std::make_shared<any>(std::move(zero_ops));
}
if (copy_op_str != std::string()) {
graph.attrs["copy_op"] = std::make_shared<any>(std::move(copy_op_str));
}
return ApplyPass(std::move(graph), "Gradient");
}
......
......@@ -91,6 +91,9 @@ Graph Gradient(Graph src) {
if (src.attrs.count("zero_ops") != 0) {
zero_ops = src.GetAttr<std::vector<const Op*> >("zero_ops");
}
const Op* copy_op = (src.attrs.count("copy_op") != 0) ?
Op::Get(src.GetAttr<std::string>("copy_op")) :
nullptr;
// topo sort
std::vector<NodePtr> topo_order;
......@@ -190,7 +193,9 @@ Graph Gradient(Graph src) {
}
// take out the xs' grads
Graph ret;
ret.outputs.reserve(xs.size());
ret.outputs.resize(xs.size());
NodeEntryMap<std::pair<size_t, size_t> > unique_grads;
size_t counter = 0;
for (const NodeEntry& e : xs) {
GradEntry& entry = output_grads[e.node.get()][e.index];
// aggregate sum if there haven't been
......@@ -200,7 +205,32 @@ Graph Gradient(Graph src) {
entry.sum = attr_hint_fun(entry.sum, e);
}
}
ret.outputs.emplace_back(std::move(entry.sum));
if (copy_op != nullptr) {
auto kv = unique_grads.find(entry.sum);
if (kv == unique_grads.end()) {
unique_grads.emplace(std::move(entry.sum), std::make_pair(1, counter));
} else {
NodePtr copy_node = Node::Create();
std::ostringstream os;
os << entry.sum.node->attrs.name << "_" << kv->second.first << "_copy";
kv->second.first++;
copy_node->attrs.op = copy_op;
copy_node->attrs.name = os.str();
copy_node->inputs.emplace_back(entry.sum);
if (copy_node->attrs.op->attr_parser != nullptr) {
copy_node->attrs.op->attr_parser(&(copy_node->attrs));
}
unique_grads.emplace(NodeEntry{std::move(copy_node), 0, 0}, std::make_pair(1, counter));
}
} else {
ret.outputs[counter] = entry.sum;
}
++counter;
}
if (copy_op != nullptr) {
for (const auto& kv : unique_grads) {
ret.outputs[kv.second.second] = kv.first;
}
}
return ret;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment