Commit eea05d77 by Tianqi Chen

[SYMBOL] Change list_input->list_input_names, add list_input_variables (#59)

* [SYMBOL] Change list_input->list_input_names, add list_input_variables

* fix
parent e8fee6dc
...@@ -205,8 +205,25 @@ NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, ...@@ -205,8 +205,25 @@ NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol,
int recursive_option, int recursive_option,
nn_uint *out_size, nn_uint *out_size,
const char*** out); const char*** out);
/*!
* \brief List inputs variables in the symbol.
* \param symbol the symbol
* \param option The option to list the inputs
* option=0 means list all arguments.
* option=1 means list arguments that are readed only by the graph.
* option=2 means list arguments that are mutated by the graph.
* \param out_size output size
* \param out_sym_array the output array.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol,
int option,
nn_uint *out_size,
SymbolHandle** out_sym_array);
/*! /*!
* \brief List inputs in the symbol. * \brief List input names in the symbol.
* \param symbol the symbol * \param symbol the symbol
* \param option The option to list the inputs * \param option The option to list the inputs
* option=0 means list all arguments. * option=0 means list all arguments.
......
...@@ -108,7 +108,7 @@ class Symbol(SymbolBase): ...@@ -108,7 +108,7 @@ class Symbol(SymbolBase):
def __getitem__(self, index): def __getitem__(self, index):
if isinstance(index, _base.string_types): if isinstance(index, _base.string_types):
idx = None idx = None
for i, name in enumerate(self.list_outputs()): for i, name in enumerate(self.list_output_names()):
if name == index: if name == index:
if idx is not None: if idx is not None:
raise ValueError('There are multiple outputs with name \"%s\"' % index) raise ValueError('There are multiple outputs with name \"%s\"' % index)
...@@ -177,7 +177,40 @@ class Symbol(SymbolBase): ...@@ -177,7 +177,40 @@ class Symbol(SymbolBase):
self.handle, _ctypes.byref(handle))) self.handle, _ctypes.byref(handle)))
return Symbol(handle=handle) return Symbol(handle=handle)
def list_inputs(self, option='all'): def _get_list_copt(self, option):
"""internal function to get list option"""
if option == 'all':
return _ctypes.c_int(0)
elif option == 'read_only':
return _ctypes.c_int(1)
elif option == 'aux_state':
return _ctypes.c_int(2)
else:
raise ValueError("option need to be in {'all', 'read_only, 'aux_state'}")
def list_input_variables(self, option='all'):
"""List all the input variables in the symbol.
Parameters
----------
option : {'all', 'read_only', 'aux_state'}, optional
The listing option
- 'all' will list all the arguments.
- 'read_only' lists arguments that are readed by the graph.
- 'aux_state' lists arguments that are mutated by the graph as state.
Returns
-------
vars : list of symbol
List of all the variables
"""
size = _ctypes.c_uint()
sarr = _ctypes.POINTER(_base.SymbolHandle)()
_check_call(_LIB.NNSymbolListInputVariables(
self.handle, self._get_list_copt(option),
_ctypes.byref(size), _ctypes.byref(sarr)))
return [Symbol(_base.SymbolHandle(sarr[i])) for i in range(size.value)]
def list_input_names(self, option='all'):
"""List all the inputs in the symbol. """List all the inputs in the symbol.
Parameters Parameters
...@@ -194,19 +227,12 @@ class Symbol(SymbolBase): ...@@ -194,19 +227,12 @@ class Symbol(SymbolBase):
""" """
size = _ctypes.c_uint() size = _ctypes.c_uint()
sarr = _ctypes.POINTER(_ctypes.c_char_p)() sarr = _ctypes.POINTER(_ctypes.c_char_p)()
if option == 'all':
copt = _ctypes.c_int(0)
elif option == 'read_only':
copt = _ctypes.c_int(1)
elif option == 'aux_state':
copt = _ctypes.c_int(2)
else:
raise ValueError("option need to be in {'all', 'read_only, 'aux_state'}")
_check_call(_LIB.NNSymbolListInputNames( _check_call(_LIB.NNSymbolListInputNames(
self.handle, copt, _ctypes.byref(size), _ctypes.byref(sarr))) self.handle, self._get_list_copt(option),
_ctypes.byref(size), _ctypes.byref(sarr)))
return [_base.py_str(sarr[i]) for i in range(size.value)] return [_base.py_str(sarr[i]) for i in range(size.value)]
def list_outputs(self): def list_output_names(self):
"""List all outputs in the symbol. """List all outputs in the symbol.
Returns Returns
......
...@@ -221,6 +221,25 @@ int NNSymbolListAttrs(SymbolHandle symbol, ...@@ -221,6 +221,25 @@ int NNSymbolListAttrs(SymbolHandle symbol,
API_END(); API_END();
} }
int NNSymbolListInputVariables(SymbolHandle symbol,
int option,
nn_uint *out_size,
SymbolHandle** out_sym_array) {
Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
std::vector<NodePtr> vs = s->ListInputs(Symbol::ListInputOption(option));
ret->ret_handles.clear();
for (size_t i = 0; i < vs.size(); ++i) {
nnvm::Symbol* rs = new nnvm::Symbol();
rs->outputs.push_back(NodeEntry{vs[i], 0, 0});
ret->ret_handles.push_back(rs);
}
*out_size = static_cast<nn_uint>(vs.size());
*out_sym_array = dmlc::BeginPtr(ret->ret_handles);
API_END();
}
int NNSymbolListInputNames(SymbolHandle symbol, int NNSymbolListInputNames(SymbolHandle symbol,
int option, int option,
nn_uint *out_size, nn_uint *out_size,
......
...@@ -87,7 +87,8 @@ inline std::vector<std::string> GetKeys( ...@@ -87,7 +87,8 @@ inline std::vector<std::string> GetKeys(
// whether the symbol is atomic functor // whether the symbol is atomic functor
inline bool IsAtomic(const std::vector<NodeEntry>& outputs) { inline bool IsAtomic(const std::vector<NodeEntry>& outputs) {
return outputs[0].node->inputs.size() == 0; return outputs[0].node->inputs.size() == 0 &&
outputs[0].node->control_deps.size() == 0;
} }
// public functions // public functions
...@@ -118,7 +119,9 @@ Symbol Symbol::Copy() const { ...@@ -118,7 +119,9 @@ Symbol Symbol::Copy() const {
} }
void Symbol::Print(std::ostream &os) const { void Symbol::Print(std::ostream &os) const {
if (outputs.size() == 1 && outputs[0].node->inputs.size() == 0) { if (outputs.size() == 1 &&
outputs[0].node->inputs.size() == 0 &&
outputs[0].node->control_deps.size() == 0) {
if (outputs[0].node->is_variable()) { if (outputs[0].node->is_variable()) {
os << "Variable:" << outputs[0].node->attrs.name << '\n'; os << "Variable:" << outputs[0].node->attrs.name << '\n';
} else { } else {
......
...@@ -69,6 +69,7 @@ Graph Gradient(Graph src) { ...@@ -69,6 +69,7 @@ Graph Gradient(Graph src) {
// topo sort // topo sort
std::vector<NodePtr> topo_order; std::vector<NodePtr> topo_order;
std::unordered_map<Node*, std::vector<GradEntry> > output_grads; std::unordered_map<Node*, std::vector<GradEntry> > output_grads;
DFSVisit(ys, [&](const NodePtr& node) { DFSVisit(ys, [&](const NodePtr& node) {
if (output_grads.count(node.get()) == 0) { if (output_grads.count(node.get()) == 0) {
output_grads[node.get()].resize(node->num_outputs()); output_grads[node.get()].resize(node->num_outputs());
...@@ -113,6 +114,7 @@ Graph Gradient(Graph src) { ...@@ -113,6 +114,7 @@ Graph Gradient(Graph src) {
e.sum = agg_fun(std::move(e.grads)); e.sum = agg_fun(std::move(e.grads));
out_agg_grads.push_back(e.sum); out_agg_grads.push_back(e.sum);
} }
if ((*rit)->inputs.size() != 0) {
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()] std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()]
(mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads); (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads);
CHECK_EQ((*rit)->inputs.size(), input_grads.size()) CHECK_EQ((*rit)->inputs.size(), input_grads.size())
...@@ -122,6 +124,7 @@ Graph Gradient(Graph src) { ...@@ -122,6 +124,7 @@ Graph Gradient(Graph src) {
output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git)); output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git));
} }
} }
}
// take out the xs' grads // take out the xs' grads
Graph ret; Graph ret;
ret.outputs.reserve(xs.size()); ret.outputs.reserve(xs.size());
......
...@@ -42,8 +42,8 @@ def test_list_args(): ...@@ -42,8 +42,8 @@ def test_list_args():
y = sym.add(y, z, name='add1') y = sym.add(y, z, name='add1')
# write after read # write after read
z = sym.assign(x, y, name='assign') z = sym.assign(x, y, name='assign')
assert z.list_inputs('read_only') == ['conv_weight', 'z'] assert z.list_input_names('read_only') == ['conv_weight', 'z']
assert z.list_inputs('aux_state') == ['x'] assert z.list_input_names('aux_state') == ['x']
def test_infer_shape(): def test_infer_shape():
x = sym.Variable('x', shape=(4, 2)) x = sym.Variable('x', shape=(4, 2))
......
...@@ -7,17 +7,19 @@ def test_compose(): ...@@ -7,17 +7,19 @@ def test_compose():
y = sym.exp(sym.add(x, x, name='add', gpu=2), y = sym.exp(sym.add(x, x, name='add', gpu=2),
name='exp', gpu=1, attr={"kk": "1"}) name='exp', gpu=1, attr={"kk": "1"})
assert y.list_inputs() == ['x'] assert y.list_input_names() == ['x']
assert y.list_outputs() == ["exp_output"] assert y.list_output_names() == ["exp_output"]
assert y.list_attr()['gpu'] == '1' assert y.list_attr()['gpu'] == '1'
z = y.get_internals() z = y.get_internals()
assert z['add_output'].list_outputs() == ['add_output'] assert z['add_output'].list_output_names() == ['add_output']
assert y.list_attr(recursive=True)['add_gpu'] == '2' assert y.list_attr(recursive=True)['add_gpu'] == '2'
def test_default_input(): def test_default_input():
x = sym.Variable('x') x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv') y = sym.conv2d(data=x, name='conv')
assert y.list_inputs() == ['x', 'conv_weight'] assert y.list_input_names() == ['x', 'conv_weight']
tname = [z.list_output_names()[0] for z in y.list_input_variables()]
assert tname == y.list_input_names()
try: try:
z = sym.add(x) z = sym.add(x)
assert False assert False
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment