Commit 16a6db3a by Tianqi Chen

Update Symbol and C API (#22)

* Update tuple to be compatible with mshadow

* Move set error message to C API

* simplify with using

* updates to shape inference

* Add unnamed namespace to the implementations

* [SYMBOL] Enable inference of Auxiliary data, rename list_arguments to list_inputs
parent 00833e0f
# NNVM: Build deep learning system by parts
NNVM is not a deep learning library. It is a modular, decentralized and lightweight library to
help build deep learning libraries efficiently.
NNVM is not a deep learning library. It is a modular, decentralized and lightweight part to
help build deep learning libraries.
## What is it
......@@ -9,14 +9,14 @@ While most deep learning systems offer end to end solutions,
it is interesting to ask if we can actually assemble a deep learning system by parts.
The goal is to enable hackers can customize optimizations, target platforms and set of operators they care about.
We believe that the decentralized modular system is an interesting direction.
The hope is that effective parts can be assembled together just like you assemble your own desktops.
So the customized deep learning solution can be minimax, minimum in terms of dependencies,
while maxiziming the users' need.
NNVM offers one such part, it provides a generic to do generic
computation graph optimization such as memory reduction, device allocation,
operator fusion while being agnostic to the operator
interface defintion and how operators are executed.
NNVM offers one such part, it provides a generic way to do
computation graph optimization such as memory reduction, device allocation and more
while being agnostic to the operator interface defintion and how operators are executed.
NNVM is inspired by LLVM, aiming to be an intermediate representation library
for neural nets and computation graphs generation and optimizations.
......
......@@ -16,37 +16,13 @@
namespace nnvm {
/*! \brief any type */
using any = dmlc::any;
using dmlc::any;
/*!
* \brief array_veiw type
* \tparam ValueType The value content of array view.
*/
template<typename ValueType>
using array_view = dmlc::array_view<ValueType>;
/*!
* \brief get reference of type T stored in src.
* \param src The source container
* \return the reference to the type.
* \tparam T The type to be fetched.
*/
template<typename T>
inline T& get(any& src) { // NOLINT(*)
return dmlc::get<T>(src);
}
/*!
* \brief get const reference of type T stored in src.
* \param src The source container
* \return the reference to the type.
* \tparam T The type to be fetched.
*/
/*! \brief array_veiw type */
using dmlc::array_view;
template<typename T>
inline const T& get(const any& src) {
return dmlc::get<T>(src);
}
/*!\brief getter function of any type */
using dmlc::get;
} // namespace nnvm
......
......@@ -36,6 +36,12 @@ typedef void *SymbolHandle;
typedef void *GraphHandle;
/*!
* \brief Set the last error message needed by C API
* \param msg The error message to set.
*/
NNVM_DLL void NNAPISetLastError(const char* msg);
/*!
* \brief return str message of the last error
* all function in this file will return 0 when success
* and -1 when an error occured,
......@@ -171,25 +177,30 @@ NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol,
nn_uint *out_size,
const char*** out);
/*!
* \brief List arguments in the symbol.
* \brief List inputs in the symbol.
* \param symbol the symbol
* \param option The option to list the inputs
* option=0 means list all arguments.
* option=1 means list arguments that are readed only by the graph.
* option=2 means list arguments that are mutated by the graph.
* \param out_size output size
* \param out_str_array pointer to hold the output string array
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolListArguments(SymbolHandle symbol,
nn_uint *out_size,
const char ***out_str_array);
NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol,
int option,
nn_uint *out_size,
const char ***out_str_array);
/*!
* \brief List returns in the symbol.
* \brief List returns names in the symbol.
* \param symbol the symbol
* \param out_size output size
* \param out_str_array pointer to hold the output string array
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolListOutputs(SymbolHandle symbol,
nn_uint *out_size,
const char ***out_str_array);
NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol,
nn_uint *out_size,
const char ***out_str_array);
/*!
* \brief Get a symbol that contains all the internals.
* \param symbol The symbol
......
......@@ -289,7 +289,9 @@ template<typename ValueType>
inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
const any* ref = GetAttrMap(key);
if (ref == nullptr) {
// update the attribute map of the key by creating new empty OpMap
UpdateAttrMap(key, [key](any* pmap) {
// use callback so it is in lockscope
if (pmap->empty()) {
OpMap<ValueType> pm;
pm.attr_name_ = key;
......@@ -304,7 +306,9 @@ inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
template<typename ValueType>
inline Op& Op::attr( // NOLINT(*)
const std::string& attr_name, const ValueType& value) {
// update the attribute map of the key by creating new empty if needed.
UpdateAttrMap(attr_name, [this, attr_name, value](any* pmap) {
// the callback is in lockscope so is threadsafe.
if (pmap->empty()) {
OpMap<ValueType> pm;
pm.attr_name_ = attr_name;
......
......@@ -83,10 +83,10 @@ inline Graph InferType(Graph graph,
DTypeVector type_args = {},
std::string type_attr_key = "") {
if (type_args.size() != 0) {
graph.attrs["type_args"] = std::make_shared<any>(std::move(type_args));
graph.attrs["dtype_args"] = std::make_shared<any>(std::move(type_args));
}
if (type_attr_key.length() != 0) {
graph.attrs["type_attr_key"] = std::make_shared<any>(std::move(type_attr_key));
graph.attrs["dtype_attr_key"] = std::make_shared<any>(std::move(type_attr_key));
}
return ApplyPass(std::move(graph), {"InferType"});
}
......
......@@ -30,6 +30,18 @@ class Symbol {
/*! \brief only list attributes in current node */
kShallow = 1
};
/*! \brief option passed to ListInputNames */
enum ListInputOption {
/*! \brief list all the arguments */
kAll = 0,
/*! \brief list only read only arguments */
kReadOnlyArgs = 1,
/*!
* \brief List auxiliary states that can be mutated by the graph.
* This excludes the ReadOnly arguments
*/
kAuxiliaryStates = 2
};
/*! \brief output entries contained in the symbol */
std::vector<NodeEntry> outputs;
......@@ -51,18 +63,20 @@ class Symbol {
*/
Symbol operator[] (size_t index) const;
/*!
* \brief List the arguments names.
* \brief List the input names.
* \param option The options to list the arguments.
*
* The position of the returned list also corresponds to calling position in operator()
* \return the arguments list of this symbol, they can be either named or unnamed (empty string).
* \sa ListInputOption
*/
std::vector<std::string> ListArguments() const;
std::vector<std::string> ListInputNames(ListInputOption option) const;
/*!
* \brief List the names of outputs for this symbol.
* For normal operators, it is usually symbol node name + "_output"
* \return get the descriptions of outputs for this symbol.
*/
std::vector<std::string> ListOutputs() const;
std::vector<std::string> ListOutputNames() const;
/*!
* \brief Compose the symbol with arguments, this changes the current symbol.
* The kwargs passed in can be in-complete,
......
......@@ -58,17 +58,9 @@ class Tuple {
* \brief move constructor from Tuple
* \param src the source shape
*/
inline Tuple(Tuple<ValueType>&& src) {
this->swap(src);
}
/*!
* \param ndim the number of dimension of the Tuple
* \param v The value to fill.
*/
inline Tuple(index_t ndim, ValueType v) {
this->SetDim(ndim);
std::fill_n(begin(), ndim, v);
inline Tuple(Tuple<ValueType>&& src) { // NOLINT(*)
this->swap(src);
}
/*!
* \brief construct the Tuple from content of iterator
......@@ -97,7 +89,7 @@ class Tuple {
* \brief Swap current object with other
* \param other another object to be swapped.
*/
inline void swap(Tuple<ValueType>& other) noexcept { // NOLINT(*)
inline void swap(Tuple<ValueType>& other) { // NOLINT(*)
std::swap(ndim_, other.ndim_);
std::swap(num_heap_allocated_, other.num_heap_allocated_);
std::swap(data_stack_, other.data_stack_);
......@@ -275,7 +267,7 @@ class Tuple {
return is;
}
private:
protected:
// stack cache size
static const uint32_t kStackCache = 4;
/*! \brief number of dimension of the tuple */
......@@ -303,16 +295,30 @@ class Tuple {
*/
class TShape : public Tuple<index_t> {
public:
// inheritate other constructors from Tuple
using Tuple<index_t>::Tuple;
/*! \brief default constructor */
TShape() = default;
/*!
* constructor to construct a shape with all 1.
* \param ndim the number of dimension
*/
inline TShape(index_t ndim) { // NOLINT(*)
this->SetDim(ndim);
std::fill_n(begin(), ndim, 1);
}
/*!
* \brief copy constructor of TShape
* \param s source shape.
*/
inline TShape(const Tuple<index_t>& s) // NOLINT(*)
: Tuple<index_t>(s) {}
inline TShape(const Tuple<index_t>& s) { // NOLINT(*)
this->assign(s.begin(), s.end());
}
/*!
* \brief constructor from initializer list
* \param init the initializer_list
*/
inline TShape(std::initializer_list<index_t> init) {
this->assign(init.begin(), init.end());
}
/*!
* \brief move constructor.
* \param s source shape.
......@@ -321,6 +327,17 @@ class TShape : public Tuple<index_t> {
this->swap(s);
}
/*!
* \brief construct the Tuple from content of iterator
* \param begin the beginning of iterator
* \param end end the end of the iterator
* \tparam RandomAccessIterator iterator type
*/
template<typename RandomAccessIterator>
inline TShape(RandomAccessIterator begin,
RandomAccessIterator end) {
this->assign(begin, end);
}
/*!
* \brief assignment function from tshape
* \param src source shape.
* \return self.
......@@ -347,6 +364,164 @@ class TShape : public Tuple<index_t> {
}
return size;
}
/*!
* \return product shape in [dimstart,dimend)
* \param dimstart start dimension
* \param dimend end dimension
*/
inline index_t ProdShape(int dimstart, int dimend) const {
index_t num = 1;
const index_t *d = this->data();
for (int i = dimstart; i < dimend; ++i) {
num *= d[i];
}
return num;
}
/*! \return the begin data pointer to content of the tuple */
inline const index_t *data() const {
return begin();
}
/*! \return the begin data pointer to content of the tuple */
inline index_t *data() {
return begin();
}
#ifdef MSHADOW_XINLINE
template<int dim>
inline TShape(mshadow::Shape<dim> &&s) {// NOLINT(*)
this->assign(s.shape_, s.shape_ + dim);
}
/*!
* \brief assignment from shape
* \param shape source shape
* \tparam dim shape dimension
* \return reference of self
*/
template<int dim>
inline TShape &operator=(const mshadow::Shape<dim> &shape) {
this->assign(shape.shape_, shape.shape_ + dim);
return *this;
}
/*!
* \brief get the shape of tensor specifying dim
* \return the shape requested
* \tparam dim dimension of the tensor
*/
template<int dim>
inline mshadow::Shape<dim> get() const {
CHECK_EQ(dim, ndim())
<< "dimension do not match target dimension " << dim << " vs " << ndim();
const index_t *d = this->data();
mshadow::Shape<dim> s;
for (int i = 0; i < dim; ++i) {
s[i] = d[i];
}
return s;
}
/*!
* flatten the higher dimension to second dimension, return a 2D shape
* \return the flat 2d shape
*/
inline mshadow::Shape<2> FlatTo2D(void) const {
mshadow::Shape<2> s;
if (ndim() == 0) return mshadow::Shape2(0, 0);
const index_t *d = this->data();
s.shape_[1] = d[ndim() - 1];
index_t ymax = 1;
for (index_t i = 1; i < ndim(); ++i) {
ymax *= d[i - 1];
}
s.shape_[0] = ymax;
return s;
}
/*!
* flatten the shape into three parts: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim)
* \param axis_begin The beginning axis specified.
* \param axis_end The ending axis specified.
* \return the flat 3d shape
*/
inline mshadow::Shape<3> FlatTo3D(index_t axis_begin, index_t axis_end) const {
CHECK(axis_end >= axis_begin);
mshadow::Shape<3> s;
if (ndim() == 0) return mshadow::Shape3(0, 0, 0);
const index_t *d = this->data();
s.shape_[0] = 1;
s.shape_[1] = 1;
s.shape_[2] = 1;
for (index_t i = 0; i < axis_begin; ++i) {
s.shape_[0] *= d[i];
}
for (index_t i = axis_begin; i <= axis_end; ++i) {
s.shape_[1] *= d[i];
}
for (index_t i = axis_end + 1; i < ndim(); ++i) {
s.shape_[2] *= d[i];
}
return s;
}
/*!
* flatten the axis before and after the specified axis, so it becomes 3D tensor
* \param axis The axis specified.
* \return the flat 3d shape
*/
inline mshadow::Shape<3> FlatTo3D(index_t axis) const {
return FlatTo3D(axis, axis);
}
inline bool operator==(const TShape &s) const {
if (ndim() != s.ndim()) return false;
return std::equal(begin(), end(), s.begin());
}
inline bool operator!=(const TShape &s) const {
return !(*this == s);
}
/*!
* \return whether two shape equals
* \param s the shape to compare against
* \tparam dim dimension of the shape
*/
template<int dim>
inline bool operator==(const mshadow::Shape<dim> &s) const {
if (ndim_ != dim) return false;
const index_t *d = dim <= kStackCache ? data_stack_ : data_heap_;
for (index_t i = 0; i < dim; ++i) {
if (d[i] != s.shape_[i]) return false;
}
return true;
}
/*!
* \return whether two shape not equals
* \param s the shape to compare against
* \tparam dim dimension of the shape
*/
template<int dim>
inline bool operator!=(const mshadow::Shape<dim> &s) const {
return !(*this == s);
}
/*!
* \brief save the content into binary stream
* \param strm the output stream
* \tparam TStream any stream type that have write
*/
template<typename TStream>
inline void Save(TStream *strm) const {
strm->Write(&ndim_, sizeof(ndim_));
strm->Write(data(), sizeof(index_t) * ndim_);
}
/*!
* \brief load the content from binary stream
* \param strm the output stream
* \tparam TStream any stream type that have write
* \return whether the load is successful
*/
template<typename TStream>
inline bool Load(TStream *strm) {
if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false;
this->SetDim(ndim_);
size_t nread = sizeof(index_t) * ndim_;
if (strm->Read(data(), nread) != nread) return false;
return true;
}
#endif
};
} // namespace nnvm
......
......@@ -176,9 +176,16 @@ class Symbol(SymbolBase):
self.handle, _ctypes.byref(handle)))
return Symbol(handle=handle)
def list_arguments(self):
"""List all the arguments in the symbol.
def list_inputs(self, option='all'):
"""List all the inputs in the symbol.
Parameters
----------
option : {'all', 'read_only', 'aux_state'}, optional
The listing option
- 'all' will list all the arguments.
- 'read_only' lists arguments that are readed by the graph.
- 'aux_state' lists arguments that are mutated by the graph as state.
Returns
-------
args : list of string
......@@ -186,8 +193,16 @@ class Symbol(SymbolBase):
"""
size = _ctypes.c_uint()
sarr = _ctypes.POINTER(_ctypes.c_char_p)()
_check_call(_LIB.NNSymbolListArguments(
self.handle, _ctypes.byref(size), _ctypes.byref(sarr)))
if option == 'all':
copt = _ctypes.c_int(0)
elif option == 'read_only':
copt = _ctypes.c_int(1)
elif option == 'aux_state':
copt = _ctypes.c_int(2)
else:
raise ValueError("option need to be in {'all', 'read_only, 'aux_state'}")
_check_call(_LIB.NNSymbolListInputNames(
self.handle, copt, _ctypes.byref(size), _ctypes.byref(sarr)))
return [_base.py_str(sarr[i]) for i in range(size.value)]
def list_outputs(self):
......@@ -200,7 +215,7 @@ class Symbol(SymbolBase):
"""
size = _ctypes.c_uint()
sarr = _ctypes.POINTER(_ctypes.c_char_p)()
_check_call(_LIB.NNSymbolListOutputs(
_check_call(_LIB.NNSymbolListOutputNames(
self.handle, _ctypes.byref(size), _ctypes.byref(sarr)))
return [_base.py_str(sarr[i]) for i in range(size.value)]
......
......@@ -45,11 +45,6 @@ struct NNAPIThreadLocalEntry {
typedef dmlc::ThreadLocalStore<NNAPIThreadLocalEntry> NNAPIThreadLocalStore;
/*!
* \brief Set the last error message needed by C API
* \param msg The error message to set.
*/
void NNAPISetLastError(const char* msg);
/*!
* \brief handle exception throwed out
* \param e the exception
* \return the return value of API after exception is handled
......
......@@ -19,7 +19,6 @@ int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
API_END();
}
int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char **name,
const char **description,
......@@ -37,7 +36,6 @@ int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
API_END();
}
int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
nn_uint num_param,
const char **keys,
......@@ -179,13 +177,15 @@ int NNSymbolListAttrs(SymbolHandle symbol,
API_END();
}
int NNSymbolListArguments(SymbolHandle symbol,
nn_uint *out_size,
const char ***out_str_array) {
int NNSymbolListInputNames(SymbolHandle symbol,
int option,
nn_uint *out_size,
const char ***out_str_array) {
Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_vec_str = std::move(s->ListArguments());
ret->ret_vec_str = std::move(
s->ListInputNames(Symbol::ListInputOption(option)));
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
......@@ -195,13 +195,13 @@ int NNSymbolListArguments(SymbolHandle symbol,
API_END();
}
int NNSymbolListOutputs(SymbolHandle symbol,
nn_uint *out_size,
const char ***out_str_array) {
int NNSymbolListOutputNames(SymbolHandle symbol,
nn_uint *out_size,
const char ***out_str_array) {
Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_vec_str = std::move(s->ListOutputs());
ret->ret_vec_str = std::move(s->ListOutputNames());
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
......@@ -221,6 +221,7 @@ int NNSymbolCompose(SymbolHandle sym,
std::string& s_name = ret->ret_str;
std::unordered_map<std::string, const Symbol*>& kwargs
= ret->kwarg_symbol;
kwargs.clear();
if (name != nullptr) {
s_name = name;
} else {
......
......@@ -48,6 +48,7 @@ Graph ApplyPass(Graph g,
}
g = r->body(std::move(g));
}
return g;
}
......
......@@ -181,17 +181,41 @@ Symbol Symbol::operator[] (size_t index) const {
}
}
std::vector<std::string> Symbol::ListArguments() const {
std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
std::vector<std::string> ret;
DFSVisit(this->outputs, [&ret](const NodePtr &node) {
if (node->is_variable()) {
if (option == kAll) {
DFSVisit(this->outputs, [&ret](const NodePtr &node) {
if (node->is_variable()) {
ret.push_back(node->attrs.name);
}
});
} else {
std::unordered_set<Node*> mutable_set;
std::vector<Node*> vlist;
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) {
if (node->is_variable()) {
vlist.push_back(node.get());
} else if (fmutate_inputs.count(node->op)) {
FMutateInput fmutate = fmutate_inputs[node->op];
for (uint32_t i = 0; i < node->inputs.size(); ++i) {
if (fmutate(node->attrs, i)) {
mutable_set.insert(node->inputs[i].node.get());
}
}
}
});
for (Node* node : vlist) {
if ((option == kReadOnlyArgs && mutable_set.count(node) == 0) ||
(option == kAuxiliaryStates && mutable_set.count(node) != 0)) {
ret.push_back(node->attrs.name);
}
});
}
}
return ret;
}
std::vector<std::string> Symbol::ListOutputs() const {
std::vector<std::string> Symbol::ListOutputNames() const {
static auto& flist_ouputs = Op::GetAttr<FListOutputNames>("FListOutputNames");
std::vector<std::string> ret;
for (auto &head : outputs) {
......@@ -345,10 +369,10 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
}
} else {
std::vector<std::string> keys = GetKeys(kwargs);
std::vector<std::string> arg_names = ListArguments();
std::vector<std::string> arg_names = ListInputNames(kAll);
array_view<std::string> view(dmlc::BeginPtr(arg_names) + arg_counter,
dmlc::BeginPtr(arg_names) + arg_names.size());
KeywordArgumentMismatch("Symbol.Compose", keys, ListArguments());
KeywordArgumentMismatch("Symbol.Compose", keys, arg_names);
}
}
}
......
......@@ -9,6 +9,7 @@
namespace nnvm {
namespace pass {
namespace {
template<typename AttrType, typename IsNone>
Graph InferAttr(Graph &&ret,
......@@ -17,7 +18,7 @@ Graph InferAttr(Graph &&ret,
const char* arg_name,
const char* attr_key_name,
const char* attr_name,
const char* known_name,
const char* unknown_name,
IsNone fis_none) {
using AttrVector = std::vector<AttrType>;
const IndexedGraph& idx = ret.indexed_graph();
......@@ -48,11 +49,11 @@ Graph InferAttr(Graph &&ret,
// temp space for shape inference.
std::vector<AttrType*> ishape, oshape;
// number of completed nodes
size_t num_known = 0;
size_t num_unknown = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) {
if (shape_attr_key.length() != 0) {
if (shape_attr_key.length() != 0 && fis_none(rshape[idx.entry_id(nid, 0)])) {
auto it = inode.source->attrs.dict.find(shape_attr_key);
if (it != inode.source->attrs.dict.end()) {
CHECK_EQ(inode.source->num_outputs(), 1);
......@@ -71,8 +72,8 @@ Graph InferAttr(Graph &&ret,
oshape[i] = &rshape[idx.entry_id(nid, i)];
}
if (finfer_shape.count(inode.source->op)) {
num_known +=
finfer_shape[inode.source->op](inode.source->attrs, ishape, oshape);
num_unknown +=
!(finfer_shape[inode.source->op](inode.source->attrs, ishape, oshape));
} else if (is_backward.get(inode.source->op, false)) {
// backward operator inference.
CHECK_GE(inode.control_deps.size(), 1)
......@@ -85,13 +86,13 @@ Graph InferAttr(Graph &&ret,
*oshape[i] = rshape[idx.entry_id(fnode.inputs[i])];
if (fis_none(*oshape[i])) known = false;
}
num_known += known;
num_unknown += !known;
}
}
// set the shapes
ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
// number of nodes who knows the shape.
ret.attrs[known_name] = std::make_shared<any>(num_known);
ret.attrs[unknown_name] = std::make_shared<any>(num_unknown);
return ret;
}
......@@ -101,7 +102,7 @@ NNVM_REGISTER_PASS(InferShape)
return InferAttr<TShape>(
std::move(ret), TShape(),
"FInferShape", "shape_args", "shape_attr_key",
"shape", "shape_num_known_nodes",
"shape", "shape_num_unknown_nodes",
[](const TShape& s) { return s.ndim() == 0; });
})
.set_change_graph(false)
......@@ -113,7 +114,7 @@ NNVM_REGISTER_PASS(InferType)
return InferAttr<int>(
std::move(ret), 0,
"FInferType", "dtype_args", "dtype_attr_key",
"dtype", "dtype_num_known_nodes",
"dtype", "dtype_num_unknown_nodes",
[](const int t) { return t == -1; });
})
.set_change_graph(false)
......@@ -123,5 +124,6 @@ DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape);
DMLC_JSON_ENABLE_ANY(DTypeVector, list_int);
DMLC_JSON_ENABLE_ANY(size_t, size_t);
} // namespace
} // namespace pass
} // namespace nnvm
......@@ -10,6 +10,7 @@
namespace nnvm {
namespace pass {
namespace {
template<typename T>
inline T get_with_default(const std::unordered_map<Node*, T> &map,
......@@ -140,5 +141,6 @@ NNVM_REGISTER_PASS(OrderMutation)
.set_body(OrderMutation)
.set_change_graph(true);
} // namespace
} // namespace pass
} // namespace nnvm
......@@ -10,6 +10,7 @@
namespace nnvm {
namespace pass {
namespace {
// simply logic to place device according to device_group hint
// insert copy node when there is
......@@ -176,5 +177,6 @@ NNVM_REGISTER_PASS(PlaceDevice)
DMLC_JSON_ENABLE_ANY(DeviceAssignMap, dict_str_int);
} // namespace
} // namespace pass
} // namespace nnvm
......@@ -12,6 +12,7 @@
namespace nnvm {
namespace pass {
namespace {
// simple graph based allocator.
class GraphAllocator {
......@@ -91,7 +92,7 @@ class GraphAllocator {
if ((*idx_)[nid].source->is_variable()) continue;
importance[nid] = 1;
}
num_match_color_ = ColorNodeGroup(
num_match_color_ = pass::ColorNodeGroup(
*idx_, importance, num_match_color_, &node_color_);
}
}
......@@ -223,5 +224,6 @@ NNVM_REGISTER_PASS(PlanMemory)
.depend_graph_attr("shape")
.provide_graph_attr("storage_id");
} // namespace
} // namespace pass
} // namespace nnvm
......@@ -4,6 +4,7 @@
* \brief Save and load graph to/from JSON file.
*/
#include <nnvm/pass.h>
#include <nnvm/pass_functions.h>
#include <dmlc/json.h>
#include <algorithm>
......@@ -26,6 +27,7 @@ struct Handler<std::shared_ptr<const any> > {
namespace nnvm {
namespace pass {
namespace {
// auxiliary node structure for serialization.
struct JSONNode {
......@@ -35,7 +37,7 @@ struct JSONNode {
uint32_t index;
uint32_t version;
void Save(dmlc::JSONWriter *writer) const {
writer->BeginArray();
writer->BeginArray(false);
writer->WriteArrayItem(node_id);
writer->WriteArrayItem(index);
writer->WriteArrayItem(version);
......@@ -74,7 +76,10 @@ struct JSONNode {
}
writer->WriteObjectKeyValue("name", node->attrs.name);
if (node->attrs.dict.size() != 0) {
writer->WriteObjectKeyValue("attr", node->attrs.dict);
// write attributes in order;
std::map<std::string, std::string> dict(
node->attrs.dict.begin(), node->attrs.dict.end());
writer->WriteObjectKeyValue("attr", dict);
}
writer->WriteObjectKeyValue("inputs", inputs);
if (control_deps.size() != 0) {
......@@ -247,5 +252,6 @@ NNVM_REGISTER_PASS(SaveJSON)
DMLC_JSON_ENABLE_ANY(std::string, str);
DMLC_JSON_ENABLE_ANY(std::vector<int>, list_int);
} // namespace
} // namespace pass
} // namespace nnvm
......@@ -35,6 +35,16 @@ def test_order_mutation_pass():
assert nindex['add1'] in jnodes[nindex['assign']]['control_deps']
assert jnodes[nindex['assign']]['inputs'][0][2] == 1
def test_list_args():
x = sym.Variable('x')
z = sym.Variable('z')
y = sym.conv2d(data=x, name='conv', dev='gpu')
y = sym.add(y, z, name='add1')
# write after read
z = sym.assign(x, y, name='assign')
assert z.list_inputs('read_only') == ['conv_weight', 'z']
assert z.list_inputs('aux_state') == ['x']
def test_infer_shape():
x = sym.Variable('x', shape=(4, 2))
y = sym.add(x, x, name='add1')
......@@ -109,3 +119,4 @@ if __name__ == "__main__":
test_infer_type()
test_place_device()
test_plan_memory()
test_list_args()
......@@ -7,7 +7,7 @@ def test_compose():
y = sym.exp(sym.add(x, x, name='add', gpu=2),
name='exp', gpu=1, attr={"kk": "1"})
assert y.list_arguments() == ['x']
assert y.list_inputs() == ['x']
assert y.list_outputs() == ["exp_output"]
assert y.list_attr()['gpu'] == '1'
z = y.get_internals()
......@@ -17,7 +17,7 @@ def test_compose():
def test_default_input():
x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv')
assert y.list_arguments() == ['x', 'conv_weight']
assert y.list_inputs() == ['x', 'conv_weight']
try:
z = sym.add(x)
assert False
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment