Commit 98613dfc by Tianqi Chen

Enable control deps in API (#55)

* [SYMBOL] support control deps in API

* enable more generic tuple list

* fix

* fix
parent 45da8718
...@@ -134,6 +134,13 @@ NNVM_DLL int NNSymbolCreateGroup(nn_uint num_symbols, ...@@ -134,6 +134,13 @@ NNVM_DLL int NNSymbolCreateGroup(nn_uint num_symbols,
SymbolHandle *symbols, SymbolHandle *symbols,
SymbolHandle *out); SymbolHandle *out);
/*! /*!
* \brief Add src_dep to the handle as control dep.
* \param handle The symbol to add dependency edges on.
* \param src_dep the source handles.
*/
NNVM_DLL int NNAddControlDeps(SymbolHandle handle,
SymbolHandle src_dep);
/*!
* \brief Free the symbol handle. * \brief Free the symbol handle.
* \param symbol the symbol * \param symbol the symbol
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
......
...@@ -348,8 +348,10 @@ inline Op& Op::set_attr( // NOLINT(*) ...@@ -348,8 +348,10 @@ inline Op& Op::set_attr( // NOLINT(*)
std::vector<std::pair<ValueType, int> >& vec = std::vector<std::pair<ValueType, int> >& vec =
nnvm::get<OpMap<ValueType> >(*pmap).data_; nnvm::get<OpMap<ValueType> >(*pmap).data_;
// resize the value type. // resize the value type.
if (vec.size() <= index_) {
vec.resize(index_ + 1, vec.resize(index_ + 1,
std::make_pair(ValueType(), 0)); std::make_pair(ValueType(), 0));
}
std::pair<ValueType, int>& p = vec[index_]; std::pair<ValueType, int>& p = vec[index_];
CHECK(p.second == 0) CHECK(p.second == 0)
<< "Attribute " << attr_name << "Attribute " << attr_name
......
...@@ -227,7 +227,7 @@ class Tuple { ...@@ -227,7 +227,7 @@ class Tuple {
return is; return is;
} }
is.get(); is.get();
if (ch == '(') break; if (ch == '(' || ch == '[') break;
if (!isspace(ch)) { if (!isspace(ch)) {
is.setstate(std::ios::failbit); is.setstate(std::ios::failbit);
return is; return is;
...@@ -250,13 +250,13 @@ class Tuple { ...@@ -250,13 +250,13 @@ class Tuple {
if (isspace(ch)) { if (isspace(ch)) {
is.get(); continue; is.get(); continue;
} }
if (ch == ')') { if (ch == ')' || ch == ']') {
is.get(); break; is.get(); break;
} }
break; break;
} }
if (ch == ')') break; if (ch == ')' || ch == ']') break;
} else if (ch == ')') { } else if (ch == ')' || ch == ']') {
break; break;
} else { } else {
is.setstate(std::ios::failbit); is.setstate(std::ios::failbit);
......
...@@ -233,6 +233,22 @@ class Symbol(SymbolBase): ...@@ -233,6 +233,22 @@ class Symbol(SymbolBase):
self.handle, _ctypes.byref(debug_str))) self.handle, _ctypes.byref(debug_str)))
return _base.py_str(debug_str.value) return _base.py_str(debug_str.value)
def _add_control_deps(self, deps):
"""Add control flow dependencies.
This makes current op depend on the deps.
Only use when necessary,
this function mutate the current symbol node.
Returns
-------
deps : Symbol for list of symbol
The dependencies
"""
if isinstance(deps, list):
deps = Group(deps)
_check_call(_LIB.NNAddControlDeps(
self.handle, deps.handle))
def Variable(name, **kwargs): def Variable(name, **kwargs):
"""Create a symbolic variable with specified name. """Create a symbolic variable with specified name.
......
...@@ -40,6 +40,14 @@ int NNListUniqueOps(nn_uint *out_size, ...@@ -40,6 +40,14 @@ int NNListUniqueOps(nn_uint *out_size,
API_END(); API_END();
} }
int NNAddControlDeps(SymbolHandle handle,
SymbolHandle src_dep) {
API_BEGIN();
static_cast<Symbol*>(handle)->AddControlDeps(
*static_cast<Symbol*>(src_dep));
API_END();
}
int NNGetOpInfo(OpHandle handle, int NNGetOpInfo(OpHandle handle,
const char **name, const char **name,
const char **description, const char **description,
......
...@@ -76,10 +76,11 @@ def test_infer_shape_known_partial(): ...@@ -76,10 +76,11 @@ def test_infer_shape_known_partial():
def test_infer_type(): def test_infer_type():
x = sym.Variable('x') x = sym.Variable('x', dtype=0)
y = sym.add(x, x, name='add1') y = sym.add(x, x, name='add1')
y = sym.cast(y, dtype=1, name="cast1") y = sym.cast(y, dtype=1, name="cast1")
g = graph.create(y) g = graph.create(y)
g._set_json_attr("dtype_attr_key", "dtype")
g = g.apply('InferType') g = g.apply('InferType')
jgraph = json.loads(g.apply('SaveJSON').json_attr('json')) jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
jnodes = jgraph['nodes'] jnodes = jgraph['nodes']
......
...@@ -43,8 +43,16 @@ def test_copy(): ...@@ -43,8 +43,16 @@ def test_copy():
name='exp', gpu=1, attr={"kk": "1"}) name='exp', gpu=1, attr={"kk": "1"})
assert y.__copy__().debug_str() == y.debug_str() assert y.__copy__().debug_str() == y.debug_str()
def test_control_dep():
x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv')
z = sym.assign(x, y)
t = sym.add(x, x)
t._add_control_deps([z, y])
if __name__ == "__main__": if __name__ == "__main__":
test_copy() test_copy()
test_default_input() test_default_input()
test_compose() test_compose()
test_mutate_input() test_mutate_input()
test_control_dep()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment