Commit eea05d77 by Tianqi Chen

[SYMBOL] Change list_input->list_input_names, add list_input_variables (#59)

* [SYMBOL] Change list_input->list_input_names, add list_input_variables

* fix
parent e8fee6dc
......@@ -205,8 +205,25 @@ NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol,
int recursive_option,
nn_uint *out_size,
const char*** out);
/*!
* \brief List inputs variables in the symbol.
* \param symbol the symbol
* \param option The option to list the inputs
* option=0 means list all arguments.
* option=1 means list arguments that are readed only by the graph.
* option=2 means list arguments that are mutated by the graph.
* \param out_size output size
* \param out_sym_array the output array.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol,
int option,
nn_uint *out_size,
SymbolHandle** out_sym_array);
/*!
* \brief List inputs in the symbol.
* \brief List input names in the symbol.
* \param symbol the symbol
* \param option The option to list the inputs
* option=0 means list all arguments.
......
......@@ -108,7 +108,7 @@ class Symbol(SymbolBase):
def __getitem__(self, index):
if isinstance(index, _base.string_types):
idx = None
for i, name in enumerate(self.list_outputs()):
for i, name in enumerate(self.list_output_names()):
if name == index:
if idx is not None:
raise ValueError('There are multiple outputs with name \"%s\"' % index)
......@@ -177,7 +177,40 @@ class Symbol(SymbolBase):
self.handle, _ctypes.byref(handle)))
return Symbol(handle=handle)
def list_inputs(self, option='all'):
def _get_list_copt(self, option):
"""internal function to get list option"""
if option == 'all':
return _ctypes.c_int(0)
elif option == 'read_only':
return _ctypes.c_int(1)
elif option == 'aux_state':
return _ctypes.c_int(2)
else:
raise ValueError("option need to be in {'all', 'read_only, 'aux_state'}")
def list_input_variables(self, option='all'):
"""List all the input variables in the symbol.
Parameters
----------
option : {'all', 'read_only', 'aux_state'}, optional
The listing option
- 'all' will list all the arguments.
- 'read_only' lists arguments that are readed by the graph.
- 'aux_state' lists arguments that are mutated by the graph as state.
Returns
-------
vars : list of symbol
List of all the variables
"""
size = _ctypes.c_uint()
sarr = _ctypes.POINTER(_base.SymbolHandle)()
_check_call(_LIB.NNSymbolListInputVariables(
self.handle, self._get_list_copt(option),
_ctypes.byref(size), _ctypes.byref(sarr)))
return [Symbol(_base.SymbolHandle(sarr[i])) for i in range(size.value)]
def list_input_names(self, option='all'):
"""List all the inputs in the symbol.
Parameters
......@@ -194,19 +227,12 @@ class Symbol(SymbolBase):
"""
size = _ctypes.c_uint()
sarr = _ctypes.POINTER(_ctypes.c_char_p)()
if option == 'all':
copt = _ctypes.c_int(0)
elif option == 'read_only':
copt = _ctypes.c_int(1)
elif option == 'aux_state':
copt = _ctypes.c_int(2)
else:
raise ValueError("option need to be in {'all', 'read_only, 'aux_state'}")
_check_call(_LIB.NNSymbolListInputNames(
self.handle, copt, _ctypes.byref(size), _ctypes.byref(sarr)))
self.handle, self._get_list_copt(option),
_ctypes.byref(size), _ctypes.byref(sarr)))
return [_base.py_str(sarr[i]) for i in range(size.value)]
def list_outputs(self):
def list_output_names(self):
"""List all outputs in the symbol.
Returns
......
......@@ -221,6 +221,25 @@ int NNSymbolListAttrs(SymbolHandle symbol,
API_END();
}
int NNSymbolListInputVariables(SymbolHandle symbol,
int option,
nn_uint *out_size,
SymbolHandle** out_sym_array) {
Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
std::vector<NodePtr> vs = s->ListInputs(Symbol::ListInputOption(option));
ret->ret_handles.clear();
for (size_t i = 0; i < vs.size(); ++i) {
nnvm::Symbol* rs = new nnvm::Symbol();
rs->outputs.push_back(NodeEntry{vs[i], 0, 0});
ret->ret_handles.push_back(rs);
}
*out_size = static_cast<nn_uint>(vs.size());
*out_sym_array = dmlc::BeginPtr(ret->ret_handles);
API_END();
}
int NNSymbolListInputNames(SymbolHandle symbol,
int option,
nn_uint *out_size,
......
......@@ -87,7 +87,8 @@ inline std::vector<std::string> GetKeys(
// whether the symbol is atomic functor
inline bool IsAtomic(const std::vector<NodeEntry>& outputs) {
return outputs[0].node->inputs.size() == 0;
return outputs[0].node->inputs.size() == 0 &&
outputs[0].node->control_deps.size() == 0;
}
// public functions
......@@ -118,7 +119,9 @@ Symbol Symbol::Copy() const {
}
void Symbol::Print(std::ostream &os) const {
if (outputs.size() == 1 && outputs[0].node->inputs.size() == 0) {
if (outputs.size() == 1 &&
outputs[0].node->inputs.size() == 0 &&
outputs[0].node->control_deps.size() == 0) {
if (outputs[0].node->is_variable()) {
os << "Variable:" << outputs[0].node->attrs.name << '\n';
} else {
......
......@@ -69,6 +69,7 @@ Graph Gradient(Graph src) {
// topo sort
std::vector<NodePtr> topo_order;
std::unordered_map<Node*, std::vector<GradEntry> > output_grads;
DFSVisit(ys, [&](const NodePtr& node) {
if (output_grads.count(node.get()) == 0) {
output_grads[node.get()].resize(node->num_outputs());
......@@ -113,13 +114,15 @@ Graph Gradient(Graph src) {
e.sum = agg_fun(std::move(e.grads));
out_agg_grads.push_back(e.sum);
}
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()]
(mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads);
CHECK_EQ((*rit)->inputs.size(), input_grads.size())
<< "Gradient function not returning enough gradient";
auto git = input_grads.begin();
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git));
if ((*rit)->inputs.size() != 0) {
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()]
(mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads);
CHECK_EQ((*rit)->inputs.size(), input_grads.size())
<< "Gradient function not returning enough gradient";
auto git = input_grads.begin();
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git));
}
}
}
// take out the xs' grads
......
......@@ -42,8 +42,8 @@ def test_list_args():
y = sym.add(y, z, name='add1')
# write after read
z = sym.assign(x, y, name='assign')
assert z.list_inputs('read_only') == ['conv_weight', 'z']
assert z.list_inputs('aux_state') == ['x']
assert z.list_input_names('read_only') == ['conv_weight', 'z']
assert z.list_input_names('aux_state') == ['x']
def test_infer_shape():
x = sym.Variable('x', shape=(4, 2))
......
......@@ -7,17 +7,19 @@ def test_compose():
y = sym.exp(sym.add(x, x, name='add', gpu=2),
name='exp', gpu=1, attr={"kk": "1"})
assert y.list_inputs() == ['x']
assert y.list_outputs() == ["exp_output"]
assert y.list_input_names() == ['x']
assert y.list_output_names() == ["exp_output"]
assert y.list_attr()['gpu'] == '1'
z = y.get_internals()
assert z['add_output'].list_outputs() == ['add_output']
assert z['add_output'].list_output_names() == ['add_output']
assert y.list_attr(recursive=True)['add_gpu'] == '2'
def test_default_input():
x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv')
assert y.list_inputs() == ['x', 'conv_weight']
assert y.list_input_names() == ['x', 'conv_weight']
tname = [z.list_output_names()[0] for z in y.list_input_variables()]
assert tname == y.list_input_names()
try:
z = sym.add(x)
assert False
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment