Commit da78c4c5 by Tianqi Chen

do hint insertion after aggregation (#81)

parent 11883db5
...@@ -20,15 +20,13 @@ ...@@ -20,15 +20,13 @@
# choice of compiler # choice of compiler
#-------------------- #--------------------
export CC = gcc
export CXX = g++
export NVCC = nvcc export NVCC = nvcc
# the additional link flags you want to add # the additional link flags you want to add
ADD_LDFLAGS = ADD_LDFLAGS=
# the additional compile flags you want to add # the additional compile flags you want to add
ADD_CFLAGS = ADD_CFLAGS=
#---------------------------- #----------------------------
# plugins # plugins
......
...@@ -38,6 +38,7 @@ struct GradEntry { ...@@ -38,6 +38,7 @@ struct GradEntry {
NodeEntry sum{nullptr, 0, 0}; NodeEntry sum{nullptr, 0, 0};
#endif #endif
std::vector<NodeEntry> grads; std::vector<NodeEntry> grads;
bool need_attr_hint{true};
}; };
Graph Gradient(Graph src) { Graph Gradient(Graph src) {
...@@ -85,9 +86,6 @@ Graph Gradient(Graph src) { ...@@ -85,9 +86,6 @@ Graph Gradient(Graph src) {
CHECK_EQ(ys.size(), ys_out_grad.size()); CHECK_EQ(ys.size(), ys_out_grad.size());
for (size_t i = 0; i < ys.size(); ++i) { for (size_t i = 0; i < ys.size(); ++i) {
NodeEntry ograd = ys_out_grad[i]; NodeEntry ograd = ys_out_grad[i];
if (attr_hint_fun != nullptr) {
ograd = attr_hint_fun(ograd, ys[i]);
}
output_grads[ys[i].node.get()][ys[i].index].grads = { ograd }; output_grads[ys[i].node.get()][ys[i].index].grads = { ograd };
} }
...@@ -121,27 +119,29 @@ Graph Gradient(Graph src) { ...@@ -121,27 +119,29 @@ Graph Gradient(Graph src) {
const NodePtr& ptr = *rit; const NodePtr& ptr = *rit;
if (ptr->is_variable()) continue; if (ptr->is_variable()) continue;
out_agg_grads.clear(); out_agg_grads.clear();
for (GradEntry& e : output_grads.at(ptr.get())) { auto& out_grad_vec = output_grads.at(ptr.get());
for (uint32_t i = 0; i < out_grad_vec.size(); ++i) {
GradEntry& e = out_grad_vec[i];
e.sum = agg_fun(std::move(e.grads)); e.sum = agg_fun(std::move(e.grads));
if (e.need_attr_hint && attr_hint_fun != nullptr) {
e.sum = attr_hint_fun(e.sum, NodeEntry{ptr, 0, i});
}
out_agg_grads.push_back(e.sum); out_agg_grads.push_back(e.sum);
} }
if ((*rit)->inputs.size() != 0) { if ((*rit)->inputs.size() != 0) {
NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get())); NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()]( std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()](
fwd_node, out_agg_grads); fwd_node, out_agg_grads);
if (attr_hint_fun != nullptr) {
// only insert hint when shape inference function is not available.
for (size_t i = 0; i < input_grads.size(); ++i) {
if (finfer_shape.count(input_grads[i].node->op())) continue;
input_grads[i] = attr_hint_fun(input_grads[i], fwd_node->inputs[i]);
}
}
CHECK_EQ((*rit)->inputs.size(), input_grads.size()) CHECK_EQ((*rit)->inputs.size(), input_grads.size())
<< "Gradient function not returning enough gradient"; << "Gradient function not returning enough gradient";
auto git = input_grads.begin(); auto git = input_grads.begin();
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) { for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git)); auto& ge = output_grads[it->node.get()][it->index];
// if any of the backward op can do shape inference, the hint is not necessary.
if (finfer_shape.count(git->node->op())) {
ge.need_attr_hint = false;
}
ge.grads.emplace_back(std::move(*git));
} }
} }
} }
...@@ -153,6 +153,9 @@ Graph Gradient(Graph src) { ...@@ -153,6 +153,9 @@ Graph Gradient(Graph src) {
// aggregate sum if there haven't been // aggregate sum if there haven't been
if (entry.sum.node.get() == nullptr) { if (entry.sum.node.get() == nullptr) {
entry.sum = agg_fun(std::move(entry.grads)); entry.sum = agg_fun(std::move(entry.grads));
if (entry.need_attr_hint && attr_hint_fun != nullptr) {
entry.sum = attr_hint_fun(entry.sum, e);
}
} }
ret.outputs.emplace_back(std::move(entry.sum)); ret.outputs.emplace_back(std::move(entry.sum));
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment