Commit 16a6db3a by Tianqi Chen

Update Symbol and C API (#22)

* Update tuple to be compatible with mshadow

* Move set error message to C API

* simplify with using

* updates to shape inference

* Add unnamed namespace to the implementations

* [SYMBOL] Enable inference of Auxiliary data, rename list_arguments to list_inputs
parent 00833e0f
# NNVM: Build deep learning system by parts # NNVM: Build deep learning system by parts
NNVM is not a deep learning library. It is a modular, decentralized and lightweight library to NNVM is not a deep learning library. It is a modular, decentralized and lightweight part to
help build deep learning libraries efficiently. help build deep learning libraries.
## What is it ## What is it
...@@ -9,14 +9,14 @@ While most deep learning systems offer end to end solutions, ...@@ -9,14 +9,14 @@ While most deep learning systems offer end to end solutions,
it is interesting to ask if we can actually assemble a deep learning system by parts. it is interesting to ask if we can actually assemble a deep learning system by parts.
The goal is to enable hackers can customize optimizations, target platforms and set of operators they care about. The goal is to enable hackers can customize optimizations, target platforms and set of operators they care about.
We believe that the decentralized modular system is an interesting direction. We believe that the decentralized modular system is an interesting direction.
The hope is that effective parts can be assembled together just like you assemble your own desktops. The hope is that effective parts can be assembled together just like you assemble your own desktops.
So the customized deep learning solution can be minimax, minimum in terms of dependencies, So the customized deep learning solution can be minimax, minimum in terms of dependencies,
while maxiziming the users' need. while maxiziming the users' need.
NNVM offers one such part, it provides a generic to do generic NNVM offers one such part, it provides a generic way to do
computation graph optimization such as memory reduction, device allocation, computation graph optimization such as memory reduction, device allocation and more
operator fusion while being agnostic to the operator while being agnostic to the operator interface defintion and how operators are executed.
interface defintion and how operators are executed.
NNVM is inspired by LLVM, aiming to be an intermediate representation library NNVM is inspired by LLVM, aiming to be an intermediate representation library
for neural nets and computation graphs generation and optimizations. for neural nets and computation graphs generation and optimizations.
......
...@@ -16,37 +16,13 @@ ...@@ -16,37 +16,13 @@
namespace nnvm { namespace nnvm {
/*! \brief any type */ /*! \brief any type */
using any = dmlc::any; using dmlc::any;
/*! /*! \brief array_veiw type */
* \brief array_veiw type using dmlc::array_view;
* \tparam ValueType The value content of array view.
*/
template<typename ValueType>
using array_view = dmlc::array_view<ValueType>;
/*!
* \brief get reference of type T stored in src.
* \param src The source container
* \return the reference to the type.
* \tparam T The type to be fetched.
*/
template<typename T>
inline T& get(any& src) { // NOLINT(*)
return dmlc::get<T>(src);
}
/*!
* \brief get const reference of type T stored in src.
* \param src The source container
* \return the reference to the type.
* \tparam T The type to be fetched.
*/
template<typename T> /*!\brief getter function of any type */
inline const T& get(const any& src) { using dmlc::get;
return dmlc::get<T>(src);
}
} // namespace nnvm } // namespace nnvm
......
...@@ -36,6 +36,12 @@ typedef void *SymbolHandle; ...@@ -36,6 +36,12 @@ typedef void *SymbolHandle;
typedef void *GraphHandle; typedef void *GraphHandle;
/*! /*!
* \brief Set the last error message needed by C API
* \param msg The error message to set.
*/
NNVM_DLL void NNAPISetLastError(const char* msg);
/*!
* \brief return str message of the last error * \brief return str message of the last error
* all function in this file will return 0 when success * all function in this file will return 0 when success
* and -1 when an error occured, * and -1 when an error occured,
...@@ -171,25 +177,30 @@ NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, ...@@ -171,25 +177,30 @@ NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol,
nn_uint *out_size, nn_uint *out_size,
const char*** out); const char*** out);
/*! /*!
* \brief List arguments in the symbol. * \brief List inputs in the symbol.
* \param symbol the symbol * \param symbol the symbol
* \param option The option to list the inputs
* option=0 means list all arguments.
* option=1 means list arguments that are readed only by the graph.
* option=2 means list arguments that are mutated by the graph.
* \param out_size output size * \param out_size output size
* \param out_str_array pointer to hold the output string array * \param out_str_array pointer to hold the output string array
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
NNVM_DLL int NNSymbolListArguments(SymbolHandle symbol, NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol,
nn_uint *out_size, int option,
const char ***out_str_array); nn_uint *out_size,
const char ***out_str_array);
/*! /*!
* \brief List returns in the symbol. * \brief List returns names in the symbol.
* \param symbol the symbol * \param symbol the symbol
* \param out_size output size * \param out_size output size
* \param out_str_array pointer to hold the output string array * \param out_str_array pointer to hold the output string array
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
NNVM_DLL int NNSymbolListOutputs(SymbolHandle symbol, NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol,
nn_uint *out_size, nn_uint *out_size,
const char ***out_str_array); const char ***out_str_array);
/*! /*!
* \brief Get a symbol that contains all the internals. * \brief Get a symbol that contains all the internals.
* \param symbol The symbol * \param symbol The symbol
......
...@@ -289,7 +289,9 @@ template<typename ValueType> ...@@ -289,7 +289,9 @@ template<typename ValueType>
inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) { inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
const any* ref = GetAttrMap(key); const any* ref = GetAttrMap(key);
if (ref == nullptr) { if (ref == nullptr) {
// update the attribute map of the key by creating new empty OpMap
UpdateAttrMap(key, [key](any* pmap) { UpdateAttrMap(key, [key](any* pmap) {
// use callback so it is in lockscope
if (pmap->empty()) { if (pmap->empty()) {
OpMap<ValueType> pm; OpMap<ValueType> pm;
pm.attr_name_ = key; pm.attr_name_ = key;
...@@ -304,7 +306,9 @@ inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) { ...@@ -304,7 +306,9 @@ inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
template<typename ValueType> template<typename ValueType>
inline Op& Op::attr( // NOLINT(*) inline Op& Op::attr( // NOLINT(*)
const std::string& attr_name, const ValueType& value) { const std::string& attr_name, const ValueType& value) {
// update the attribute map of the key by creating new empty if needed.
UpdateAttrMap(attr_name, [this, attr_name, value](any* pmap) { UpdateAttrMap(attr_name, [this, attr_name, value](any* pmap) {
// the callback is in lockscope so is threadsafe.
if (pmap->empty()) { if (pmap->empty()) {
OpMap<ValueType> pm; OpMap<ValueType> pm;
pm.attr_name_ = attr_name; pm.attr_name_ = attr_name;
......
...@@ -83,10 +83,10 @@ inline Graph InferType(Graph graph, ...@@ -83,10 +83,10 @@ inline Graph InferType(Graph graph,
DTypeVector type_args = {}, DTypeVector type_args = {},
std::string type_attr_key = "") { std::string type_attr_key = "") {
if (type_args.size() != 0) { if (type_args.size() != 0) {
graph.attrs["type_args"] = std::make_shared<any>(std::move(type_args)); graph.attrs["dtype_args"] = std::make_shared<any>(std::move(type_args));
} }
if (type_attr_key.length() != 0) { if (type_attr_key.length() != 0) {
graph.attrs["type_attr_key"] = std::make_shared<any>(std::move(type_attr_key)); graph.attrs["dtype_attr_key"] = std::make_shared<any>(std::move(type_attr_key));
} }
return ApplyPass(std::move(graph), {"InferType"}); return ApplyPass(std::move(graph), {"InferType"});
} }
......
...@@ -30,6 +30,18 @@ class Symbol { ...@@ -30,6 +30,18 @@ class Symbol {
/*! \brief only list attributes in current node */ /*! \brief only list attributes in current node */
kShallow = 1 kShallow = 1
}; };
/*! \brief option passed to ListInputNames */
enum ListInputOption {
/*! \brief list all the arguments */
kAll = 0,
/*! \brief list only read only arguments */
kReadOnlyArgs = 1,
/*!
* \brief List auxiliary states that can be mutated by the graph.
* This excludes the ReadOnly arguments
*/
kAuxiliaryStates = 2
};
/*! \brief output entries contained in the symbol */ /*! \brief output entries contained in the symbol */
std::vector<NodeEntry> outputs; std::vector<NodeEntry> outputs;
...@@ -51,18 +63,20 @@ class Symbol { ...@@ -51,18 +63,20 @@ class Symbol {
*/ */
Symbol operator[] (size_t index) const; Symbol operator[] (size_t index) const;
/*! /*!
* \brief List the arguments names. * \brief List the input names.
* \param option The options to list the arguments.
* *
* The position of the returned list also corresponds to calling position in operator() * The position of the returned list also corresponds to calling position in operator()
* \return the arguments list of this symbol, they can be either named or unnamed (empty string). * \return the arguments list of this symbol, they can be either named or unnamed (empty string).
* \sa ListInputOption
*/ */
std::vector<std::string> ListArguments() const; std::vector<std::string> ListInputNames(ListInputOption option) const;
/*! /*!
* \brief List the names of outputs for this symbol. * \brief List the names of outputs for this symbol.
* For normal operators, it is usually symbol node name + "_output" * For normal operators, it is usually symbol node name + "_output"
* \return get the descriptions of outputs for this symbol. * \return get the descriptions of outputs for this symbol.
*/ */
std::vector<std::string> ListOutputs() const; std::vector<std::string> ListOutputNames() const;
/*! /*!
* \brief Compose the symbol with arguments, this changes the current symbol. * \brief Compose the symbol with arguments, this changes the current symbol.
* The kwargs passed in can be in-complete, * The kwargs passed in can be in-complete,
......
...@@ -58,17 +58,9 @@ class Tuple { ...@@ -58,17 +58,9 @@ class Tuple {
* \brief move constructor from Tuple * \brief move constructor from Tuple
* \param src the source shape * \param src the source shape
*/ */
inline Tuple(Tuple<ValueType>&& src) {
this->swap(src);
}
/*!
* \param ndim the number of dimension of the Tuple inline Tuple(Tuple<ValueType>&& src) { // NOLINT(*)
* \param v The value to fill. this->swap(src);
*/
inline Tuple(index_t ndim, ValueType v) {
this->SetDim(ndim);
std::fill_n(begin(), ndim, v);
} }
/*! /*!
* \brief construct the Tuple from content of iterator * \brief construct the Tuple from content of iterator
...@@ -97,7 +89,7 @@ class Tuple { ...@@ -97,7 +89,7 @@ class Tuple {
* \brief Swap current object with other * \brief Swap current object with other
* \param other another object to be swapped. * \param other another object to be swapped.
*/ */
inline void swap(Tuple<ValueType>& other) noexcept { // NOLINT(*) inline void swap(Tuple<ValueType>& other) { // NOLINT(*)
std::swap(ndim_, other.ndim_); std::swap(ndim_, other.ndim_);
std::swap(num_heap_allocated_, other.num_heap_allocated_); std::swap(num_heap_allocated_, other.num_heap_allocated_);
std::swap(data_stack_, other.data_stack_); std::swap(data_stack_, other.data_stack_);
...@@ -275,7 +267,7 @@ class Tuple { ...@@ -275,7 +267,7 @@ class Tuple {
return is; return is;
} }
private: protected:
// stack cache size // stack cache size
static const uint32_t kStackCache = 4; static const uint32_t kStackCache = 4;
/*! \brief number of dimension of the tuple */ /*! \brief number of dimension of the tuple */
...@@ -303,16 +295,30 @@ class Tuple { ...@@ -303,16 +295,30 @@ class Tuple {
*/ */
class TShape : public Tuple<index_t> { class TShape : public Tuple<index_t> {
public: public:
// inheritate other constructors from Tuple
using Tuple<index_t>::Tuple;
/*! \brief default constructor */ /*! \brief default constructor */
TShape() = default; TShape() = default;
/*! /*!
* constructor to construct a shape with all 1.
* \param ndim the number of dimension
*/
inline TShape(index_t ndim) { // NOLINT(*)
this->SetDim(ndim);
std::fill_n(begin(), ndim, 1);
}
/*!
* \brief copy constructor of TShape * \brief copy constructor of TShape
* \param s source shape. * \param s source shape.
*/ */
inline TShape(const Tuple<index_t>& s) // NOLINT(*) inline TShape(const Tuple<index_t>& s) { // NOLINT(*)
: Tuple<index_t>(s) {} this->assign(s.begin(), s.end());
}
/*!
* \brief constructor from initializer list
* \param init the initializer_list
*/
inline TShape(std::initializer_list<index_t> init) {
this->assign(init.begin(), init.end());
}
/*! /*!
* \brief move constructor. * \brief move constructor.
* \param s source shape. * \param s source shape.
...@@ -321,6 +327,17 @@ class TShape : public Tuple<index_t> { ...@@ -321,6 +327,17 @@ class TShape : public Tuple<index_t> {
this->swap(s); this->swap(s);
} }
/*! /*!
* \brief construct the Tuple from content of iterator
* \param begin the beginning of iterator
* \param end end the end of the iterator
* \tparam RandomAccessIterator iterator type
*/
template<typename RandomAccessIterator>
inline TShape(RandomAccessIterator begin,
RandomAccessIterator end) {
this->assign(begin, end);
}
/*!
* \brief assignment function from tshape * \brief assignment function from tshape
* \param src source shape. * \param src source shape.
* \return self. * \return self.
...@@ -347,6 +364,164 @@ class TShape : public Tuple<index_t> { ...@@ -347,6 +364,164 @@ class TShape : public Tuple<index_t> {
} }
return size; return size;
} }
/*!
* \return product shape in [dimstart,dimend)
* \param dimstart start dimension
* \param dimend end dimension
*/
inline index_t ProdShape(int dimstart, int dimend) const {
index_t num = 1;
const index_t *d = this->data();
for (int i = dimstart; i < dimend; ++i) {
num *= d[i];
}
return num;
}
/*! \return the begin data pointer to content of the tuple */
inline const index_t *data() const {
return begin();
}
/*! \return the begin data pointer to content of the tuple */
inline index_t *data() {
return begin();
}
#ifdef MSHADOW_XINLINE
template<int dim>
inline TShape(mshadow::Shape<dim> &&s) {// NOLINT(*)
this->assign(s.shape_, s.shape_ + dim);
}
/*!
* \brief assignment from shape
* \param shape source shape
* \tparam dim shape dimension
* \return reference of self
*/
template<int dim>
inline TShape &operator=(const mshadow::Shape<dim> &shape) {
this->assign(shape.shape_, shape.shape_ + dim);
return *this;
}
/*!
* \brief get the shape of tensor specifying dim
* \return the shape requested
* \tparam dim dimension of the tensor
*/
template<int dim>
inline mshadow::Shape<dim> get() const {
CHECK_EQ(dim, ndim())
<< "dimension do not match target dimension " << dim << " vs " << ndim();
const index_t *d = this->data();
mshadow::Shape<dim> s;
for (int i = 0; i < dim; ++i) {
s[i] = d[i];
}
return s;
}
/*!
* flatten the higher dimension to second dimension, return a 2D shape
* \return the flat 2d shape
*/
inline mshadow::Shape<2> FlatTo2D(void) const {
mshadow::Shape<2> s;
if (ndim() == 0) return mshadow::Shape2(0, 0);
const index_t *d = this->data();
s.shape_[1] = d[ndim() - 1];
index_t ymax = 1;
for (index_t i = 1; i < ndim(); ++i) {
ymax *= d[i - 1];
}
s.shape_[0] = ymax;
return s;
}
/*!
* flatten the shape into three parts: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim)
* \param axis_begin The beginning axis specified.
* \param axis_end The ending axis specified.
* \return the flat 3d shape
*/
inline mshadow::Shape<3> FlatTo3D(index_t axis_begin, index_t axis_end) const {
CHECK(axis_end >= axis_begin);
mshadow::Shape<3> s;
if (ndim() == 0) return mshadow::Shape3(0, 0, 0);
const index_t *d = this->data();
s.shape_[0] = 1;
s.shape_[1] = 1;
s.shape_[2] = 1;
for (index_t i = 0; i < axis_begin; ++i) {
s.shape_[0] *= d[i];
}
for (index_t i = axis_begin; i <= axis_end; ++i) {
s.shape_[1] *= d[i];
}
for (index_t i = axis_end + 1; i < ndim(); ++i) {
s.shape_[2] *= d[i];
}
return s;
}
/*!
* flatten the axis before and after the specified axis, so it becomes 3D tensor
* \param axis The axis specified.
* \return the flat 3d shape
*/
inline mshadow::Shape<3> FlatTo3D(index_t axis) const {
return FlatTo3D(axis, axis);
}
inline bool operator==(const TShape &s) const {
if (ndim() != s.ndim()) return false;
return std::equal(begin(), end(), s.begin());
}
inline bool operator!=(const TShape &s) const {
return !(*this == s);
}
/*!
* \return whether two shape equals
* \param s the shape to compare against
* \tparam dim dimension of the shape
*/
template<int dim>
inline bool operator==(const mshadow::Shape<dim> &s) const {
if (ndim_ != dim) return false;
const index_t *d = dim <= kStackCache ? data_stack_ : data_heap_;
for (index_t i = 0; i < dim; ++i) {
if (d[i] != s.shape_[i]) return false;
}
return true;
}
/*!
* \return whether two shape not equals
* \param s the shape to compare against
* \tparam dim dimension of the shape
*/
template<int dim>
inline bool operator!=(const mshadow::Shape<dim> &s) const {
return !(*this == s);
}
/*!
* \brief save the content into binary stream
* \param strm the output stream
* \tparam TStream any stream type that have write
*/
template<typename TStream>
inline void Save(TStream *strm) const {
strm->Write(&ndim_, sizeof(ndim_));
strm->Write(data(), sizeof(index_t) * ndim_);
}
/*!
* \brief load the content from binary stream
* \param strm the output stream
* \tparam TStream any stream type that have write
* \return whether the load is successful
*/
template<typename TStream>
inline bool Load(TStream *strm) {
if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false;
this->SetDim(ndim_);
size_t nread = sizeof(index_t) * ndim_;
if (strm->Read(data(), nread) != nread) return false;
return true;
}
#endif
}; };
} // namespace nnvm } // namespace nnvm
......
...@@ -176,9 +176,16 @@ class Symbol(SymbolBase): ...@@ -176,9 +176,16 @@ class Symbol(SymbolBase):
self.handle, _ctypes.byref(handle))) self.handle, _ctypes.byref(handle)))
return Symbol(handle=handle) return Symbol(handle=handle)
def list_arguments(self): def list_inputs(self, option='all'):
"""List all the arguments in the symbol. """List all the inputs in the symbol.
Parameters
----------
option : {'all', 'read_only', 'aux_state'}, optional
The listing option
- 'all' will list all the arguments.
- 'read_only' lists arguments that are readed by the graph.
- 'aux_state' lists arguments that are mutated by the graph as state.
Returns Returns
------- -------
args : list of string args : list of string
...@@ -186,8 +193,16 @@ class Symbol(SymbolBase): ...@@ -186,8 +193,16 @@ class Symbol(SymbolBase):
""" """
size = _ctypes.c_uint() size = _ctypes.c_uint()
sarr = _ctypes.POINTER(_ctypes.c_char_p)() sarr = _ctypes.POINTER(_ctypes.c_char_p)()
_check_call(_LIB.NNSymbolListArguments( if option == 'all':
self.handle, _ctypes.byref(size), _ctypes.byref(sarr))) copt = _ctypes.c_int(0)
elif option == 'read_only':
copt = _ctypes.c_int(1)
elif option == 'aux_state':
copt = _ctypes.c_int(2)
else:
raise ValueError("option need to be in {'all', 'read_only, 'aux_state'}")
_check_call(_LIB.NNSymbolListInputNames(
self.handle, copt, _ctypes.byref(size), _ctypes.byref(sarr)))
return [_base.py_str(sarr[i]) for i in range(size.value)] return [_base.py_str(sarr[i]) for i in range(size.value)]
def list_outputs(self): def list_outputs(self):
...@@ -200,7 +215,7 @@ class Symbol(SymbolBase): ...@@ -200,7 +215,7 @@ class Symbol(SymbolBase):
""" """
size = _ctypes.c_uint() size = _ctypes.c_uint()
sarr = _ctypes.POINTER(_ctypes.c_char_p)() sarr = _ctypes.POINTER(_ctypes.c_char_p)()
_check_call(_LIB.NNSymbolListOutputs( _check_call(_LIB.NNSymbolListOutputNames(
self.handle, _ctypes.byref(size), _ctypes.byref(sarr))) self.handle, _ctypes.byref(size), _ctypes.byref(sarr)))
return [_base.py_str(sarr[i]) for i in range(size.value)] return [_base.py_str(sarr[i]) for i in range(size.value)]
......
...@@ -45,11 +45,6 @@ struct NNAPIThreadLocalEntry { ...@@ -45,11 +45,6 @@ struct NNAPIThreadLocalEntry {
typedef dmlc::ThreadLocalStore<NNAPIThreadLocalEntry> NNAPIThreadLocalStore; typedef dmlc::ThreadLocalStore<NNAPIThreadLocalEntry> NNAPIThreadLocalStore;
/*! /*!
* \brief Set the last error message needed by C API
* \param msg The error message to set.
*/
void NNAPISetLastError(const char* msg);
/*!
* \brief handle exception throwed out * \brief handle exception throwed out
* \param e the exception * \param e the exception
* \return the return value of API after exception is handled * \return the return value of API after exception is handled
......
...@@ -19,7 +19,6 @@ int NNSymbolListAtomicSymbolCreators(nn_uint *out_size, ...@@ -19,7 +19,6 @@ int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
API_END(); API_END();
} }
int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char **name, const char **name,
const char **description, const char **description,
...@@ -37,7 +36,6 @@ int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, ...@@ -37,7 +36,6 @@ int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
API_END(); API_END();
} }
int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
nn_uint num_param, nn_uint num_param,
const char **keys, const char **keys,
...@@ -179,13 +177,15 @@ int NNSymbolListAttrs(SymbolHandle symbol, ...@@ -179,13 +177,15 @@ int NNSymbolListAttrs(SymbolHandle symbol,
API_END(); API_END();
} }
int NNSymbolListArguments(SymbolHandle symbol, int NNSymbolListInputNames(SymbolHandle symbol,
nn_uint *out_size, int option,
const char ***out_str_array) { nn_uint *out_size,
const char ***out_str_array) {
Symbol *s = static_cast<Symbol*>(symbol); Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN(); API_BEGIN();
ret->ret_vec_str = std::move(s->ListArguments()); ret->ret_vec_str = std::move(
s->ListInputNames(Symbol::ListInputOption(option)));
ret->ret_vec_charp.clear(); ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
...@@ -195,13 +195,13 @@ int NNSymbolListArguments(SymbolHandle symbol, ...@@ -195,13 +195,13 @@ int NNSymbolListArguments(SymbolHandle symbol,
API_END(); API_END();
} }
int NNSymbolListOutputs(SymbolHandle symbol, int NNSymbolListOutputNames(SymbolHandle symbol,
nn_uint *out_size, nn_uint *out_size,
const char ***out_str_array) { const char ***out_str_array) {
Symbol *s = static_cast<Symbol*>(symbol); Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN(); API_BEGIN();
ret->ret_vec_str = std::move(s->ListOutputs()); ret->ret_vec_str = std::move(s->ListOutputNames());
ret->ret_vec_charp.clear(); ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
...@@ -221,6 +221,7 @@ int NNSymbolCompose(SymbolHandle sym, ...@@ -221,6 +221,7 @@ int NNSymbolCompose(SymbolHandle sym,
std::string& s_name = ret->ret_str; std::string& s_name = ret->ret_str;
std::unordered_map<std::string, const Symbol*>& kwargs std::unordered_map<std::string, const Symbol*>& kwargs
= ret->kwarg_symbol; = ret->kwarg_symbol;
kwargs.clear();
if (name != nullptr) { if (name != nullptr) {
s_name = name; s_name = name;
} else { } else {
......
...@@ -48,6 +48,7 @@ Graph ApplyPass(Graph g, ...@@ -48,6 +48,7 @@ Graph ApplyPass(Graph g,
} }
g = r->body(std::move(g)); g = r->body(std::move(g));
} }
return g; return g;
} }
......
...@@ -181,17 +181,41 @@ Symbol Symbol::operator[] (size_t index) const { ...@@ -181,17 +181,41 @@ Symbol Symbol::operator[] (size_t index) const {
} }
} }
std::vector<std::string> Symbol::ListArguments() const { std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
std::vector<std::string> ret; std::vector<std::string> ret;
DFSVisit(this->outputs, [&ret](const NodePtr &node) { if (option == kAll) {
if (node->is_variable()) { DFSVisit(this->outputs, [&ret](const NodePtr &node) {
if (node->is_variable()) {
ret.push_back(node->attrs.name);
}
});
} else {
std::unordered_set<Node*> mutable_set;
std::vector<Node*> vlist;
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) {
if (node->is_variable()) {
vlist.push_back(node.get());
} else if (fmutate_inputs.count(node->op)) {
FMutateInput fmutate = fmutate_inputs[node->op];
for (uint32_t i = 0; i < node->inputs.size(); ++i) {
if (fmutate(node->attrs, i)) {
mutable_set.insert(node->inputs[i].node.get());
}
}
}
});
for (Node* node : vlist) {
if ((option == kReadOnlyArgs && mutable_set.count(node) == 0) ||
(option == kAuxiliaryStates && mutable_set.count(node) != 0)) {
ret.push_back(node->attrs.name); ret.push_back(node->attrs.name);
} }
}); }
}
return ret; return ret;
} }
std::vector<std::string> Symbol::ListOutputs() const { std::vector<std::string> Symbol::ListOutputNames() const {
static auto& flist_ouputs = Op::GetAttr<FListOutputNames>("FListOutputNames"); static auto& flist_ouputs = Op::GetAttr<FListOutputNames>("FListOutputNames");
std::vector<std::string> ret; std::vector<std::string> ret;
for (auto &head : outputs) { for (auto &head : outputs) {
...@@ -345,10 +369,10 @@ void Symbol::Compose(const array_view<const Symbol*>& args, ...@@ -345,10 +369,10 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
} }
} else { } else {
std::vector<std::string> keys = GetKeys(kwargs); std::vector<std::string> keys = GetKeys(kwargs);
std::vector<std::string> arg_names = ListArguments(); std::vector<std::string> arg_names = ListInputNames(kAll);
array_view<std::string> view(dmlc::BeginPtr(arg_names) + arg_counter, array_view<std::string> view(dmlc::BeginPtr(arg_names) + arg_counter,
dmlc::BeginPtr(arg_names) + arg_names.size()); dmlc::BeginPtr(arg_names) + arg_names.size());
KeywordArgumentMismatch("Symbol.Compose", keys, ListArguments()); KeywordArgumentMismatch("Symbol.Compose", keys, arg_names);
} }
} }
} }
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
namespace nnvm { namespace nnvm {
namespace pass { namespace pass {
namespace {
template<typename AttrType, typename IsNone> template<typename AttrType, typename IsNone>
Graph InferAttr(Graph &&ret, Graph InferAttr(Graph &&ret,
...@@ -17,7 +18,7 @@ Graph InferAttr(Graph &&ret, ...@@ -17,7 +18,7 @@ Graph InferAttr(Graph &&ret,
const char* arg_name, const char* arg_name,
const char* attr_key_name, const char* attr_key_name,
const char* attr_name, const char* attr_name,
const char* known_name, const char* unknown_name,
IsNone fis_none) { IsNone fis_none) {
using AttrVector = std::vector<AttrType>; using AttrVector = std::vector<AttrType>;
const IndexedGraph& idx = ret.indexed_graph(); const IndexedGraph& idx = ret.indexed_graph();
...@@ -48,11 +49,11 @@ Graph InferAttr(Graph &&ret, ...@@ -48,11 +49,11 @@ Graph InferAttr(Graph &&ret,
// temp space for shape inference. // temp space for shape inference.
std::vector<AttrType*> ishape, oshape; std::vector<AttrType*> ishape, oshape;
// number of completed nodes // number of completed nodes
size_t num_known = 0; size_t num_unknown = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid]; const auto& inode = idx[nid];
if (inode.source->is_variable()) { if (inode.source->is_variable()) {
if (shape_attr_key.length() != 0) { if (shape_attr_key.length() != 0 && fis_none(rshape[idx.entry_id(nid, 0)])) {
auto it = inode.source->attrs.dict.find(shape_attr_key); auto it = inode.source->attrs.dict.find(shape_attr_key);
if (it != inode.source->attrs.dict.end()) { if (it != inode.source->attrs.dict.end()) {
CHECK_EQ(inode.source->num_outputs(), 1); CHECK_EQ(inode.source->num_outputs(), 1);
...@@ -71,8 +72,8 @@ Graph InferAttr(Graph &&ret, ...@@ -71,8 +72,8 @@ Graph InferAttr(Graph &&ret,
oshape[i] = &rshape[idx.entry_id(nid, i)]; oshape[i] = &rshape[idx.entry_id(nid, i)];
} }
if (finfer_shape.count(inode.source->op)) { if (finfer_shape.count(inode.source->op)) {
num_known += num_unknown +=
finfer_shape[inode.source->op](inode.source->attrs, ishape, oshape); !(finfer_shape[inode.source->op](inode.source->attrs, ishape, oshape));
} else if (is_backward.get(inode.source->op, false)) { } else if (is_backward.get(inode.source->op, false)) {
// backward operator inference. // backward operator inference.
CHECK_GE(inode.control_deps.size(), 1) CHECK_GE(inode.control_deps.size(), 1)
...@@ -85,13 +86,13 @@ Graph InferAttr(Graph &&ret, ...@@ -85,13 +86,13 @@ Graph InferAttr(Graph &&ret,
*oshape[i] = rshape[idx.entry_id(fnode.inputs[i])]; *oshape[i] = rshape[idx.entry_id(fnode.inputs[i])];
if (fis_none(*oshape[i])) known = false; if (fis_none(*oshape[i])) known = false;
} }
num_known += known; num_unknown += !known;
} }
} }
// set the shapes // set the shapes
ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape)); ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
// number of nodes who knows the shape. // number of nodes who knows the shape.
ret.attrs[known_name] = std::make_shared<any>(num_known); ret.attrs[unknown_name] = std::make_shared<any>(num_unknown);
return ret; return ret;
} }
...@@ -101,7 +102,7 @@ NNVM_REGISTER_PASS(InferShape) ...@@ -101,7 +102,7 @@ NNVM_REGISTER_PASS(InferShape)
return InferAttr<TShape>( return InferAttr<TShape>(
std::move(ret), TShape(), std::move(ret), TShape(),
"FInferShape", "shape_args", "shape_attr_key", "FInferShape", "shape_args", "shape_attr_key",
"shape", "shape_num_known_nodes", "shape", "shape_num_unknown_nodes",
[](const TShape& s) { return s.ndim() == 0; }); [](const TShape& s) { return s.ndim() == 0; });
}) })
.set_change_graph(false) .set_change_graph(false)
...@@ -113,7 +114,7 @@ NNVM_REGISTER_PASS(InferType) ...@@ -113,7 +114,7 @@ NNVM_REGISTER_PASS(InferType)
return InferAttr<int>( return InferAttr<int>(
std::move(ret), 0, std::move(ret), 0,
"FInferType", "dtype_args", "dtype_attr_key", "FInferType", "dtype_args", "dtype_attr_key",
"dtype", "dtype_num_known_nodes", "dtype", "dtype_num_unknown_nodes",
[](const int t) { return t == -1; }); [](const int t) { return t == -1; });
}) })
.set_change_graph(false) .set_change_graph(false)
...@@ -123,5 +124,6 @@ DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape); ...@@ -123,5 +124,6 @@ DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape);
DMLC_JSON_ENABLE_ANY(DTypeVector, list_int); DMLC_JSON_ENABLE_ANY(DTypeVector, list_int);
DMLC_JSON_ENABLE_ANY(size_t, size_t); DMLC_JSON_ENABLE_ANY(size_t, size_t);
} // namespace
} // namespace pass } // namespace pass
} // namespace nnvm } // namespace nnvm
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
namespace nnvm { namespace nnvm {
namespace pass { namespace pass {
namespace {
template<typename T> template<typename T>
inline T get_with_default(const std::unordered_map<Node*, T> &map, inline T get_with_default(const std::unordered_map<Node*, T> &map,
...@@ -140,5 +141,6 @@ NNVM_REGISTER_PASS(OrderMutation) ...@@ -140,5 +141,6 @@ NNVM_REGISTER_PASS(OrderMutation)
.set_body(OrderMutation) .set_body(OrderMutation)
.set_change_graph(true); .set_change_graph(true);
} // namespace
} // namespace pass } // namespace pass
} // namespace nnvm } // namespace nnvm
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
namespace nnvm { namespace nnvm {
namespace pass { namespace pass {
namespace {
// simply logic to place device according to device_group hint // simply logic to place device according to device_group hint
// insert copy node when there is // insert copy node when there is
...@@ -176,5 +177,6 @@ NNVM_REGISTER_PASS(PlaceDevice) ...@@ -176,5 +177,6 @@ NNVM_REGISTER_PASS(PlaceDevice)
DMLC_JSON_ENABLE_ANY(DeviceAssignMap, dict_str_int); DMLC_JSON_ENABLE_ANY(DeviceAssignMap, dict_str_int);
} // namespace
} // namespace pass } // namespace pass
} // namespace nnvm } // namespace nnvm
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
namespace nnvm { namespace nnvm {
namespace pass { namespace pass {
namespace {
// simple graph based allocator. // simple graph based allocator.
class GraphAllocator { class GraphAllocator {
...@@ -91,7 +92,7 @@ class GraphAllocator { ...@@ -91,7 +92,7 @@ class GraphAllocator {
if ((*idx_)[nid].source->is_variable()) continue; if ((*idx_)[nid].source->is_variable()) continue;
importance[nid] = 1; importance[nid] = 1;
} }
num_match_color_ = ColorNodeGroup( num_match_color_ = pass::ColorNodeGroup(
*idx_, importance, num_match_color_, &node_color_); *idx_, importance, num_match_color_, &node_color_);
} }
} }
...@@ -223,5 +224,6 @@ NNVM_REGISTER_PASS(PlanMemory) ...@@ -223,5 +224,6 @@ NNVM_REGISTER_PASS(PlanMemory)
.depend_graph_attr("shape") .depend_graph_attr("shape")
.provide_graph_attr("storage_id"); .provide_graph_attr("storage_id");
} // namespace
} // namespace pass } // namespace pass
} // namespace nnvm } // namespace nnvm
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief Save and load graph to/from JSON file. * \brief Save and load graph to/from JSON file.
*/ */
#include <nnvm/pass.h> #include <nnvm/pass.h>
#include <nnvm/pass_functions.h>
#include <dmlc/json.h> #include <dmlc/json.h>
#include <algorithm> #include <algorithm>
...@@ -26,6 +27,7 @@ struct Handler<std::shared_ptr<const any> > { ...@@ -26,6 +27,7 @@ struct Handler<std::shared_ptr<const any> > {
namespace nnvm { namespace nnvm {
namespace pass { namespace pass {
namespace {
// auxiliary node structure for serialization. // auxiliary node structure for serialization.
struct JSONNode { struct JSONNode {
...@@ -35,7 +37,7 @@ struct JSONNode { ...@@ -35,7 +37,7 @@ struct JSONNode {
uint32_t index; uint32_t index;
uint32_t version; uint32_t version;
void Save(dmlc::JSONWriter *writer) const { void Save(dmlc::JSONWriter *writer) const {
writer->BeginArray(); writer->BeginArray(false);
writer->WriteArrayItem(node_id); writer->WriteArrayItem(node_id);
writer->WriteArrayItem(index); writer->WriteArrayItem(index);
writer->WriteArrayItem(version); writer->WriteArrayItem(version);
...@@ -74,7 +76,10 @@ struct JSONNode { ...@@ -74,7 +76,10 @@ struct JSONNode {
} }
writer->WriteObjectKeyValue("name", node->attrs.name); writer->WriteObjectKeyValue("name", node->attrs.name);
if (node->attrs.dict.size() != 0) { if (node->attrs.dict.size() != 0) {
writer->WriteObjectKeyValue("attr", node->attrs.dict); // write attributes in order;
std::map<std::string, std::string> dict(
node->attrs.dict.begin(), node->attrs.dict.end());
writer->WriteObjectKeyValue("attr", dict);
} }
writer->WriteObjectKeyValue("inputs", inputs); writer->WriteObjectKeyValue("inputs", inputs);
if (control_deps.size() != 0) { if (control_deps.size() != 0) {
...@@ -247,5 +252,6 @@ NNVM_REGISTER_PASS(SaveJSON) ...@@ -247,5 +252,6 @@ NNVM_REGISTER_PASS(SaveJSON)
DMLC_JSON_ENABLE_ANY(std::string, str); DMLC_JSON_ENABLE_ANY(std::string, str);
DMLC_JSON_ENABLE_ANY(std::vector<int>, list_int); DMLC_JSON_ENABLE_ANY(std::vector<int>, list_int);
} // namespace
} // namespace pass } // namespace pass
} // namespace nnvm } // namespace nnvm
...@@ -35,6 +35,16 @@ def test_order_mutation_pass(): ...@@ -35,6 +35,16 @@ def test_order_mutation_pass():
assert nindex['add1'] in jnodes[nindex['assign']]['control_deps'] assert nindex['add1'] in jnodes[nindex['assign']]['control_deps']
assert jnodes[nindex['assign']]['inputs'][0][2] == 1 assert jnodes[nindex['assign']]['inputs'][0][2] == 1
def test_list_args():
x = sym.Variable('x')
z = sym.Variable('z')
y = sym.conv2d(data=x, name='conv', dev='gpu')
y = sym.add(y, z, name='add1')
# write after read
z = sym.assign(x, y, name='assign')
assert z.list_inputs('read_only') == ['conv_weight', 'z']
assert z.list_inputs('aux_state') == ['x']
def test_infer_shape(): def test_infer_shape():
x = sym.Variable('x', shape=(4, 2)) x = sym.Variable('x', shape=(4, 2))
y = sym.add(x, x, name='add1') y = sym.add(x, x, name='add1')
...@@ -109,3 +119,4 @@ if __name__ == "__main__": ...@@ -109,3 +119,4 @@ if __name__ == "__main__":
test_infer_type() test_infer_type()
test_place_device() test_place_device()
test_plan_memory() test_plan_memory()
test_list_args()
...@@ -7,7 +7,7 @@ def test_compose(): ...@@ -7,7 +7,7 @@ def test_compose():
y = sym.exp(sym.add(x, x, name='add', gpu=2), y = sym.exp(sym.add(x, x, name='add', gpu=2),
name='exp', gpu=1, attr={"kk": "1"}) name='exp', gpu=1, attr={"kk": "1"})
assert y.list_arguments() == ['x'] assert y.list_inputs() == ['x']
assert y.list_outputs() == ["exp_output"] assert y.list_outputs() == ["exp_output"]
assert y.list_attr()['gpu'] == '1' assert y.list_attr()['gpu'] == '1'
z = y.get_internals() z = y.get_internals()
...@@ -17,7 +17,7 @@ def test_compose(): ...@@ -17,7 +17,7 @@ def test_compose():
def test_default_input(): def test_default_input():
x = sym.Variable('x') x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv') y = sym.conv2d(data=x, name='conv')
assert y.list_arguments() == ['x', 'conv_weight'] assert y.list_inputs() == ['x', 'conv_weight']
try: try:
z = sym.add(x) z = sym.add(x)
assert False assert False
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment