Commit 02141d4a by ziheng Committed by Tianqi Chen

[SYMBOL] Add __iter__ and GetChildren for symbol (#268)

* [SYMBOL] Add __iter__ and GetChildren for symbol

* [SYMBOL] Fix lint
parent a0dc8655
...@@ -267,6 +267,14 @@ NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol, ...@@ -267,6 +267,14 @@ NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol,
NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol, NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol,
SymbolHandle *out); SymbolHandle *out);
/*! /*!
* \brief Get a symbol that contains only direct children.
* \param symbol The symbol
* \param out The output symbol whose outputs are the direct children.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolGetChildren(SymbolHandle symbol,
SymbolHandle *out);
/*!
* \brief Get index-th outputs of the symbol. * \brief Get index-th outputs of the symbol.
* \param symbol The symbol * \param symbol The symbol
* \param index the Index of the output. * \param index the Index of the output.
......
...@@ -143,6 +143,9 @@ class Symbol(SymbolBase): ...@@ -143,6 +143,9 @@ class Symbol(SymbolBase):
self.handle, _base.nn_uint(index), _ctypes.byref(handle))) self.handle, _base.nn_uint(index), _ctypes.byref(handle)))
return Symbol(handle=handle) return Symbol(handle=handle)
def __iter__(self):
return (self[i] for i in self.list_output_names())
def attr(self, key): def attr(self, key):
"""Get attribute string from the symbol, this function only works for non-grouped symbol. """Get attribute string from the symbol, this function only works for non-grouped symbol.
...@@ -196,6 +199,17 @@ class Symbol(SymbolBase): ...@@ -196,6 +199,17 @@ class Symbol(SymbolBase):
self.handle, _ctypes.byref(handle))) self.handle, _ctypes.byref(handle)))
return Symbol(handle=handle) return Symbol(handle=handle)
def get_children(self):
"""Gets a new grouped symbol whose output contains
inputs to output nodes of the original symbol."""
handle = _base.SymbolHandle()
_check_call(_LIB.NNSymbolGetChildren(
self.handle, _ctypes.byref(handle)))
ret = Symbol(handle=handle)
if not ret.list_output_names():
return None
return ret
def _get_list_copt(self, option): def _get_list_copt(self, option):
"""internal function to get list option""" """internal function to get list option"""
if option == 'all': if option == 'all':
......
...@@ -141,6 +141,15 @@ int NNSymbolGetInternals(SymbolHandle symbol, ...@@ -141,6 +141,15 @@ int NNSymbolGetInternals(SymbolHandle symbol,
API_END_HANDLE_ERROR(delete s); API_END_HANDLE_ERROR(delete s);
} }
int NNSymbolGetChildren(SymbolHandle symbol,
SymbolHandle *out) {
Symbol *s = new Symbol();
API_BEGIN();
*s = static_cast<Symbol*>(symbol)->GetChildren();
*out = s;
API_END_HANDLE_ERROR(delete s);
}
int NNSymbolFree(SymbolHandle symbol) { int NNSymbolFree(SymbolHandle symbol) {
API_BEGIN(); API_BEGIN();
delete static_cast<Symbol*>(symbol); delete static_cast<Symbol*>(symbol);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment