Commit 095ce875 by ziheng Committed by Tianqi Chen

[EXECUTOR] PruneGraph pass (#274)

parent 29d253d0
......@@ -499,5 +499,61 @@ NNVM_REGISTER_OP(layout_transform)
.set_num_outputs(1)
.add_argument("data", "NDArray-or-Symbol", "Input data")
.add_arguments(LayoutTransformParam::__FIELDS__());
nnvm::Graph PruneGraph(nnvm::Graph src) {
const auto& params = src.GetAttr<std::unordered_set<std::string>>("params");
std::unordered_set<nnvm::Node*> pruned;
nnvm::NodeEntryMap<nnvm::NodePtr> entry_var;
DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) {
bool can_be_pruned = true;
if (n->is_variable()) {
if (params.count(n->attrs.name)) {
pruned.emplace(n.get());
}
can_be_pruned = false;
}
for (const auto& e : n->inputs) {
if (!pruned.count(e.node.get())) {
can_be_pruned = false;
}
}
if (can_be_pruned) {
pruned.emplace(n.get());
} else {
// scan again to find edge nodes
for (auto& e : n->inputs) {
if (pruned.count(e.node.get())) {
if (!entry_var.count(e)) {
nnvm::NodePtr var = nnvm::Node::Create();
var->attrs.name = e.node->attrs.name + "_" + std::to_string(e.index);
entry_var.emplace(e, var);
}
e = nnvm::NodeEntry{entry_var.at(e), 0, 0};
}
}
}
});
nnvm::Graph pre_graph;
pre_graph.outputs.reserve(entry_var.size());
std::vector<std::string> output_names;
output_names.reserve(entry_var.size());
for (auto kv : entry_var) {
pre_graph.outputs.emplace_back(kv.first);
output_names.emplace_back(kv.second->attrs.name);
}
pre_graph.attrs["pruned_params"] =
std::make_shared<dmlc::any>(std::move(output_names));
src.attrs["pre_graph"] =
std::make_shared<dmlc::any>(std::move(pre_graph));
return src;
}
NNVM_REGISTER_PASS(PruneGraph)
.set_body(PruneGraph);
} // namespace contrib
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment