Commit a0dc8655 by Chris Olivier Committed by Tianqi Chen

API call to get symbol output count (#270)

* Symbol __getitem__ using list_outputs() is too expensive, when it only cares about the output count in most cases

* Add cython cmake

* GetNumOutputs() and __len__ changes per PR comments

* set commit for tvm
parent 8d960c4b
...@@ -247,6 +247,17 @@ NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol, ...@@ -247,6 +247,17 @@ NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol,
NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol, NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol,
nn_uint *out_size, nn_uint *out_size,
const char ***out_str_array); const char ***out_str_array);
/*!
* \brief Supply number of outputs of the symbol.
* \param symbol the symbol
* \param output_count number of outputs
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol,
nn_uint *output_count);
/*! /*!
* \brief Get a symbol that contains all the internals. * \brief Get a symbol that contains all the internals.
* \param symbol The symbol * \param symbol The symbol
......
...@@ -281,6 +281,14 @@ int NNSymbolListOutputNames(SymbolHandle symbol, ...@@ -281,6 +281,14 @@ int NNSymbolListOutputNames(SymbolHandle symbol,
API_END(); API_END();
} }
int NNSymbolGetNumOutputs(SymbolHandle symbol,
nn_uint *output_count) {
Symbol *s = static_cast<Symbol*>(symbol);
API_BEGIN();
*output_count = static_cast<nn_uint>(s->outputs.size());
API_END();
}
int NNSymbolCompose(SymbolHandle sym, int NNSymbolCompose(SymbolHandle sym,
const char *name, const char *name,
nn_uint num_args, nn_uint num_args,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment