Commit 98613dfc by Tianqi Chen

Enable control deps in API (#55)

* [SYMBOL] support control deps in API

* enable more generic tuple list

* fix

* fix
parent 45da8718
......@@ -134,6 +134,13 @@ NNVM_DLL int NNSymbolCreateGroup(nn_uint num_symbols,
SymbolHandle *symbols,
SymbolHandle *out);
/*!
* \brief Add src_dep to the handle as control dep.
* \param handle The symbol to add dependency edges on.
* \param src_dep the source handles.
*/
NNVM_DLL int NNAddControlDeps(SymbolHandle handle,
SymbolHandle src_dep);
/*!
* \brief Free the symbol handle.
* \param symbol the symbol
* \return 0 when success, -1 when failure happens
......
......@@ -348,8 +348,10 @@ inline Op& Op::set_attr( // NOLINT(*)
std::vector<std::pair<ValueType, int> >& vec =
nnvm::get<OpMap<ValueType> >(*pmap).data_;
// resize the value type.
if (vec.size() <= index_) {
vec.resize(index_ + 1,
std::make_pair(ValueType(), 0));
}
std::pair<ValueType, int>& p = vec[index_];
CHECK(p.second == 0)
<< "Attribute " << attr_name
......
......@@ -227,7 +227,7 @@ class Tuple {
return is;
}
is.get();
if (ch == '(') break;
if (ch == '(' || ch == '[') break;
if (!isspace(ch)) {
is.setstate(std::ios::failbit);
return is;
......@@ -250,13 +250,13 @@ class Tuple {
if (isspace(ch)) {
is.get(); continue;
}
if (ch == ')') {
if (ch == ')' || ch == ']') {
is.get(); break;
}
break;
}
if (ch == ')') break;
} else if (ch == ')') {
if (ch == ')' || ch == ']') break;
} else if (ch == ')' || ch == ']') {
break;
} else {
is.setstate(std::ios::failbit);
......
......@@ -233,6 +233,22 @@ class Symbol(SymbolBase):
self.handle, _ctypes.byref(debug_str)))
return _base.py_str(debug_str.value)
def _add_control_deps(self, deps):
"""Add control flow dependencies.
This makes current op depend on the deps.
Only use when necessary,
this function mutate the current symbol node.
Returns
-------
deps : Symbol for list of symbol
The dependencies
"""
if isinstance(deps, list):
deps = Group(deps)
_check_call(_LIB.NNAddControlDeps(
self.handle, deps.handle))
def Variable(name, **kwargs):
"""Create a symbolic variable with specified name.
......
......@@ -40,6 +40,14 @@ int NNListUniqueOps(nn_uint *out_size,
API_END();
}
int NNAddControlDeps(SymbolHandle handle,
SymbolHandle src_dep) {
API_BEGIN();
static_cast<Symbol*>(handle)->AddControlDeps(
*static_cast<Symbol*>(src_dep));
API_END();
}
int NNGetOpInfo(OpHandle handle,
const char **name,
const char **description,
......
......@@ -76,10 +76,11 @@ def test_infer_shape_known_partial():
def test_infer_type():
x = sym.Variable('x')
x = sym.Variable('x', dtype=0)
y = sym.add(x, x, name='add1')
y = sym.cast(y, dtype=1, name="cast1")
g = graph.create(y)
g._set_json_attr("dtype_attr_key", "dtype")
g = g.apply('InferType')
jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
jnodes = jgraph['nodes']
......
......@@ -43,8 +43,16 @@ def test_copy():
name='exp', gpu=1, attr={"kk": "1"})
assert y.__copy__().debug_str() == y.debug_str()
def test_control_dep():
x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv')
z = sym.assign(x, y)
t = sym.add(x, x)
t._add_control_deps([z, y])
if __name__ == "__main__":
test_copy()
test_default_input()
test_compose()
test_mutate_input()
test_control_dep()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment