Commit 04edd05d by Eric Junyuan Xie Committed by Tianqi Chen

add symbol::GetChildren (#104)

parent 37c450fe
...@@ -137,6 +137,12 @@ class Symbol { ...@@ -137,6 +137,12 @@ class Symbol {
* including input variables and intermediate outputs. * including input variables and intermediate outputs.
*/ */
Symbol GetInternals() const; Symbol GetInternals() const;
/*
* \brief Get the direct inputs of the head node(s) of this symbol.
* \return symbol A new symbol whose output contains all the inputs of the head
* node(s).
*/
Symbol GetChildren() const;
/*! /*!
* \brief Set additional attributes to current node. * \brief Set additional attributes to current node.
* *
......
...@@ -435,6 +435,19 @@ Symbol Symbol::GetInternals() const { ...@@ -435,6 +435,19 @@ Symbol Symbol::GetInternals() const {
return ret; return ret;
} }
Symbol Symbol::GetChildren() const {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol ret;
std::unordered_set<Node*> visited;
for (const auto& p : this->outputs) {
Node* node = p.node.get();
if (visited.count(node)) continue;
visited.insert(node);
ret.outputs.insert(ret.outputs.end(), node->inputs.begin(), node->inputs.end());
}
return ret;
}
void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs) { void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs) {
Node* node = outputs[0].node.get(); Node* node = outputs[0].node.get();
for (const NodeEntry& e : outputs) { for (const NodeEntry& e : outputs) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment