Commit 86379cee by Tianqi Chen

Enable optional dependency memory and attr hint (#79)

* Enable optional dependency memory and attr hint

* fix travis
parent a1f59908
...@@ -137,6 +137,18 @@ using FInplaceOption = std::function< ...@@ -137,6 +137,18 @@ using FInplaceOption = std::function<
std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>; std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;
/*! /*!
* \brief Get list of inputs in the op whose content are actually not used by the operator
* These are dummy input that can be used for example in zeros_like, ones_like.
*
* \param attrs The attributes of the node
* \return list input index that are not used by the operator.
*
* \note Register under "FIgnoreInputs".
*/
using FIgnoreInputs = std::function<
std::vector<uint32_t> (const NodeAttrs& attrs)>;
/*!
* \brief Get the gradient node of the op node * \brief Get the gradient node of the op node
* This function generates the backward graph of the node * This function generates the backward graph of the node
* \param nodeptr The node to take gradient * \param nodeptr The node to take gradient
......
...@@ -129,6 +129,7 @@ inline Graph PlaceDevice(Graph graph, ...@@ -129,6 +129,7 @@ inline Graph PlaceDevice(Graph graph,
* \param ys_out_grad The symbol for additional gradient to be propagate back to y. * \param ys_out_grad The symbol for additional gradient to be propagate back to y.
* \param aggregate_fun Aggregation function applied to aggregate the inputs. * \param aggregate_fun Aggregation function applied to aggregate the inputs.
* \param mirror_fun Optional mirror function to do mirror optimization and save memory. * \param mirror_fun Optional mirror function to do mirror optimization and save memory.
* \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same as like.
* \return A new graph, whose outputs correspond to inputs of xs. * \return A new graph, whose outputs correspond to inputs of xs.
*/ */
inline Graph Gradient( inline Graph Gradient(
...@@ -137,7 +138,9 @@ inline Graph Gradient( ...@@ -137,7 +138,9 @@ inline Graph Gradient(
std::vector<NodeEntry> xs, std::vector<NodeEntry> xs,
std::vector<NodeEntry> ys_out_grad, std::vector<NodeEntry> ys_out_grad,
std::function<NodeEntry(std::vector<NodeEntry>&& inputs)> aggregate_fun = nullptr, std::function<NodeEntry(std::vector<NodeEntry>&& inputs)> aggregate_fun = nullptr,
std::function<int(const Node& node)> mirror_fun = nullptr) { std::function<int(const Node& node)> mirror_fun = nullptr,
std::function<NodeEntry(const NodeEntry& src, const NodeEntry &like)>
attr_hint_fun = nullptr) {
graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys)); graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys));
graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs)); graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs));
...@@ -145,10 +148,15 @@ inline Graph Gradient( ...@@ -145,10 +148,15 @@ inline Graph Gradient(
if (aggregate_fun != nullptr) { if (aggregate_fun != nullptr) {
graph.attrs["grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun); graph.attrs["grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun);
} }
if (mirror_fun != nullptr) { if (mirror_fun != nullptr) {
graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun); graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun);
} }
if (attr_hint_fun != nullptr) {
graph.attrs["attr_hint_fun"] = std::make_shared<any>(attr_hint_fun);
}
return ApplyPass(std::move(graph), "Gradient"); return ApplyPass(std::move(graph), "Gradient");
} }
......
...@@ -43,6 +43,7 @@ struct GradEntry { ...@@ -43,6 +43,7 @@ struct GradEntry {
Graph Gradient(Graph src) { Graph Gradient(Graph src) {
using nnvm::FGradient; using nnvm::FGradient;
using MirrorFun = std::function<int (const Node& node)>; using MirrorFun = std::function<int (const Node& node)>;
using AttrHintFun = std::function<NodeEntry (const NodeEntry& src, const NodeEntry &like)>;
CHECK_NE(src.attrs.count("grad_ys"), 0) CHECK_NE(src.attrs.count("grad_ys"), 0)
<< "Gradient require grad_ys to be presented."; << "Gradient require grad_ys to be presented.";
...@@ -65,6 +66,10 @@ Graph Gradient(Graph src) { ...@@ -65,6 +66,10 @@ Graph Gradient(Graph src) {
if (src.attrs.count("grad_mirror_fun") != 0) { if (src.attrs.count("grad_mirror_fun") != 0) {
mirror_fun = src.GetAttr<MirrorFun>("grad_mirror_fun"); mirror_fun = src.GetAttr<MirrorFun>("grad_mirror_fun");
} }
AttrHintFun attr_hint_fun = nullptr;
if (src.attrs.count("attr_hint_fun") != 0) {
attr_hint_fun = src.GetAttr<AttrHintFun>("attr_hint_fun");
}
// topo sort // topo sort
std::vector<NodePtr> topo_order; std::vector<NodePtr> topo_order;
...@@ -79,7 +84,11 @@ Graph Gradient(Graph src) { ...@@ -79,7 +84,11 @@ Graph Gradient(Graph src) {
CHECK_EQ(ys.size(), ys_out_grad.size()); CHECK_EQ(ys.size(), ys_out_grad.size());
for (size_t i = 0; i < ys.size(); ++i) { for (size_t i = 0; i < ys.size(); ++i) {
output_grads[ys[i].node.get()][ys[i].index].grads = { ys_out_grad[i] }; NodeEntry ograd = ys_out_grad[i];
if (attr_hint_fun != nullptr) {
ograd = attr_hint_fun(ograd, ys[i]);
}
output_grads[ys[i].node.get()][ys[i].index].grads = { ograd };
} }
// construct mirror reduece memory strategy if needed // construct mirror reduece memory strategy if needed
...@@ -105,6 +114,8 @@ Graph Gradient(Graph src) { ...@@ -105,6 +114,8 @@ Graph Gradient(Graph src) {
// traverse backward // traverse backward
static auto& grad_fun_map = Op::GetAttr<FGradient>("FGradient"); static auto& grad_fun_map = Op::GetAttr<FGradient>("FGradient");
static auto& finfer_shape = Op::GetAttr<FInferShape>("FInferShape");
std::vector<NodeEntry> out_agg_grads; std::vector<NodeEntry> out_agg_grads;
for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) { for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) {
const NodePtr& ptr = *rit; const NodePtr& ptr = *rit;
...@@ -115,8 +126,17 @@ Graph Gradient(Graph src) { ...@@ -115,8 +126,17 @@ Graph Gradient(Graph src) {
out_agg_grads.push_back(e.sum); out_agg_grads.push_back(e.sum);
} }
if ((*rit)->inputs.size() != 0) { if ((*rit)->inputs.size() != 0) {
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()] NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
(mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads); std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()](
fwd_node, out_agg_grads);
if (attr_hint_fun != nullptr) {
// only insert hint when shape inference function is not available.
for (size_t i = 0; i < input_grads.size(); ++i) {
if (finfer_shape.count(input_grads[i].node->op())) continue;
input_grads[i] = attr_hint_fun(input_grads[i], fwd_node->inputs[i]);
}
}
CHECK_EQ((*rit)->inputs.size(), input_grads.size()) CHECK_EQ((*rit)->inputs.size(), input_grads.size())
<< "Gradient function not returning enough gradient"; << "Gradient function not returning enough gradient";
auto git = input_grads.begin(); auto git = input_grads.begin();
......
...@@ -137,13 +137,25 @@ class GraphAllocator { ...@@ -137,13 +137,25 @@ class GraphAllocator {
Graph PlanMemory(Graph ret) { Graph PlanMemory(Graph ret) {
// setup ref counter // setup ref counter
const IndexedGraph& idx = ret.indexed_graph(); const IndexedGraph& idx = ret.indexed_graph();
static auto& fignore_inputs = Op::GetAttr<FIgnoreInputs>("FIgnoreInputs");
// reference counter of each node // reference counter of each node
std::vector<uint32_t> ref_count(idx.num_node_entries(), 0); std::vector<uint32_t> ref_count(idx.num_node_entries(), 0);
// step 1: initialize reference count // step 1: initialize reference count
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
for (const auto& e : idx[nid].inputs) { const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
for (const auto& e : inode.inputs) {
++ref_count[idx.entry_id(e)]; ++ref_count[idx.entry_id(e)];
} }
// no dataflow dependency is needed for those are ignored.
// revoke the dependency counter.
if (fignore_inputs.count(inode.source->op()) != 0) {
auto ignore_inputs = fignore_inputs[inode.source->op()](inode.source->attrs);
for (uint32_t i : ignore_inputs) {
--ref_count[idx.entry_id(inode.inputs[i])];
}
}
} }
for (const auto& e : idx.outputs()) { for (const auto& e : idx.outputs()) {
++ref_count[idx.entry_id(e)]; ++ref_count[idx.entry_id(e)];
...@@ -195,8 +207,18 @@ Graph PlanMemory(Graph ret) { ...@@ -195,8 +207,18 @@ Graph PlanMemory(Graph ret) {
storage[eid] = allocator.Request(dev_id, dtype_vec[eid], shape_vec[eid], nid); storage[eid] = allocator.Request(dev_id, dtype_vec[eid], shape_vec[eid], nid);
} }
} }
// check if certain inputs is ignored.
std::vector<uint32_t> ignore_inputs;
if (fignore_inputs.count(inode.source->op()) != 0) {
ignore_inputs = fignore_inputs[inode.source->op()](inode.source->attrs);
std::sort(ignore_inputs.begin(), ignore_inputs.end());
}
// then free inputs // then free inputs
for (const auto& e : inode.inputs) { for (size_t i = 0; i < inode.inputs.size(); ++i) {
// ref counter of ignored input is already decreased.
if (std::binary_search(ignore_inputs.begin(), ignore_inputs.end(), i)) continue;
const auto& e = inode.inputs[i];
uint32_t eid = idx.entry_id(e); uint32_t eid = idx.entry_id(e);
// temp_ref_count == 0 means it is taken by inplace op // temp_ref_count == 0 means it is taken by inplace op
if (ref_count[eid] == 0) continue; if (ref_count[eid] == 0) continue;
......
...@@ -5,11 +5,11 @@ if [ ${TRAVIS_OS_NAME} == "osx" ]; then ...@@ -5,11 +5,11 @@ if [ ${TRAVIS_OS_NAME} == "osx" ]; then
brew update brew update
brew install python3 brew install python3
if [ ${TASK} == "python_test" ]; then if [ ${TASK} == "python_test" ]; then
python -m pip install nose numpy cython --user `whoami` python -m pip install --user nose numpy cython
python3 -m pip install nose numpy cython --user `whoami` python3 -m pip install --user nose numpy cython
fi fi
fi fi
if [ ${TASK} == "lint" ]; then if [ ${TASK} == "lint" ]; then
pip install cpplint 'pylint==1.4.4' 'astroid==1.3.6' --user `whoami` pip install --user cpplint 'pylint==1.4.4' 'astroid==1.3.6'
fi fi
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment