Commit 02141d4a by ziheng Committed by Tianqi Chen

[SYMBOL] Add __iter__ and GetChildren for symbol (#268)

* [SYMBOL] Add __iter__ and GetChildren for symbol

* [SYMBOL] Fix lint
parent a0dc8655
......@@ -267,6 +267,14 @@ NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol,
NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol,
SymbolHandle *out);
/*!
* \brief Get a symbol that contains only direct children.
* \param symbol The symbol
* \param out The output symbol whose outputs are the direct children.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolGetChildren(SymbolHandle symbol,
SymbolHandle *out);
/*!
* \brief Get index-th outputs of the symbol.
* \param symbol The symbol
* \param index the Index of the output.
......
......@@ -143,6 +143,9 @@ class Symbol(SymbolBase):
self.handle, _base.nn_uint(index), _ctypes.byref(handle)))
return Symbol(handle=handle)
def __iter__(self):
return (self[i] for i in self.list_output_names())
def attr(self, key):
"""Get attribute string from the symbol, this function only works for non-grouped symbol.
......@@ -196,6 +199,17 @@ class Symbol(SymbolBase):
self.handle, _ctypes.byref(handle)))
return Symbol(handle=handle)
def get_children(self):
"""Gets a new grouped symbol whose output contains
inputs to output nodes of the original symbol."""
handle = _base.SymbolHandle()
_check_call(_LIB.NNSymbolGetChildren(
self.handle, _ctypes.byref(handle)))
ret = Symbol(handle=handle)
if not ret.list_output_names():
return None
return ret
def _get_list_copt(self, option):
"""internal function to get list option"""
if option == 'all':
......
......@@ -141,6 +141,15 @@ int NNSymbolGetInternals(SymbolHandle symbol,
API_END_HANDLE_ERROR(delete s);
}
int NNSymbolGetChildren(SymbolHandle symbol,
SymbolHandle *out) {
Symbol *s = new Symbol();
API_BEGIN();
*s = static_cast<Symbol*>(symbol)->GetChildren();
*out = s;
API_END_HANDLE_ERROR(delete s);
}
int NNSymbolFree(SymbolHandle symbol) {
API_BEGIN();
delete static_cast<Symbol*>(symbol);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment