Commit c0cde10d by ziheng Committed by Tianqi Chen

[OPT] Improve PreComputePrune When Output Is Pruned (#178)

parent 9a6feca6
......@@ -27,6 +27,25 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
// number of edges that are not variable
int non_var_edge = 0;
auto replace_pruned_entry = [&] (const NodeEntry& e) {
if (!entry_var.count(e)) {
if (!e.node->is_variable()) {
++non_var_edge;
}
nnvm::NodePtr var = nnvm::Node::Create();
var->attrs.name = e.node->attrs.name;
if (e.node->num_outputs() != 1) {
var->attrs.name += "_output" + std::to_string(e.index);
}
entry_var.emplace(e, var);
CHECK(!unique_name.count(var->attrs.name));
unique_name.insert(var->attrs.name);
return nnvm::NodeEntry{var, 0, 0};
} else {
return nnvm::NodeEntry{entry_var.at(e), 0, 0};
}
};
DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) {
bool can_be_pruned = true;
if (n->is_variable()) {
......@@ -47,20 +66,7 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
// scan again to find edge nodes, skip variables
for (auto& e : n->inputs) {
if (pruned.count(e.node.get())) {
if (!entry_var.count(e)) {
if (!e.node->is_variable()) {
++non_var_edge;
}
nnvm::NodePtr var = nnvm::Node::Create();
var->attrs.name = e.node->attrs.name;
if (e.node->num_outputs() != 1) {
var->attrs.name += "_output" + std::to_string(e.index);
}
entry_var.emplace(e, var);
CHECK(!unique_name.count(var->attrs.name));
unique_name.insert(var->attrs.name);
}
e = nnvm::NodeEntry{entry_var.at(e), 0, 0};
e = replace_pruned_entry(e);
}
}
}
......@@ -71,6 +77,12 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
return src;
}
for (auto& e : src.outputs) {
if (pruned.count(e.node.get())) {
e = replace_pruned_entry(e);
}
}
nnvm::Graph pre_graph;
pre_graph.outputs.reserve(entry_var.size());
std::vector<std::string> output_names;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment