Commit 6ffeae97 by Tianqi Chen

update (#26)

* updates (#1)

* add scalars

* change format

* change inferattr interface

* remove scalar

* remove warning
parent ebbec6de
...@@ -21,14 +21,14 @@ using nnvm::array_view; ...@@ -21,14 +21,14 @@ using nnvm::array_view;
// simply return the shape as same // simply return the shape as same
inline bool SameShape(const NodeAttrs& attrs, inline bool SameShape(const NodeAttrs& attrs,
array_view<TShape*> ishape, std::vector<TShape> *ishape,
array_view<TShape*> oshape) { std::vector<TShape> *oshape) {
if (ishape.size() == 0 || ishape[0]->ndim() == 0) return false; if (ishape->size() == 0 || (*ishape)[0].ndim() == 0) return false;
for (TShape* pshape : oshape) { for (TShape& pshape : *oshape) {
*pshape = *ishape[0]; pshape = (*ishape)[0];
} }
for (TShape* pshape : ishape) { for (TShape& pshape : *ishape) {
*pshape = *ishape[0]; pshape = (*ishape)[0];
} }
return true; return true;
} }
...@@ -51,13 +51,13 @@ NNVM_REGISTER_OP(reshape) ...@@ -51,13 +51,13 @@ NNVM_REGISTER_OP(reshape)
}) })
.attr<FInferShape>( .attr<FInferShape>(
"FInferShape", [] (const NodeAttrs& attrs, "FInferShape", [] (const NodeAttrs& attrs,
array_view<TShape*> ishape, std::vector<TShape> *ishape,
array_view<TShape*> oshape) { std::vector<TShape> *oshape) {
// get parsed attribute // get parsed attribute
const TShape& target = nnvm::get<TShape>(attrs.parsed); const TShape& target = nnvm::get<TShape>(attrs.parsed);
*oshape[0] = target; (*oshape)[0] = target;
if (ishape[0]->ndim() == 0) return false; if ((*ishape)[0].ndim() == 0) return false;
CHECK_EQ(ishape[0]->Size(), target.Size()) CHECK_EQ((*ishape)[0].Size(), target.Size())
<< "Reshape op: source target shape mismatch"; << "Reshape op: source target shape mismatch";
return true; return true;
}) })
...@@ -78,9 +78,9 @@ NNVM_REGISTER_OP(cast) ...@@ -78,9 +78,9 @@ NNVM_REGISTER_OP(cast)
.attr<FInferShape>("FInferShape", SameShape) .attr<FInferShape>("FInferShape", SameShape)
.attr<FInferType>( .attr<FInferType>(
"FInferType", [](const NodeAttrs& attrs, "FInferType", [](const NodeAttrs& attrs,
array_view<int*> itype, std::vector<int> *itype,
array_view<int*> otype) { std::vector<int> *otype) {
*otype[0] = nnvm::get<int>(attrs.parsed); (*otype)[0] = nnvm::get<int>(attrs.parsed);
return true; return true;
}); });
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#ifndef NNVM_OP_H_ #ifndef NNVM_OP_H_
#define NNVM_OP_H_ #define NNVM_OP_H_
#include <dmlc/parameter.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility> #include <utility>
...@@ -22,6 +23,7 @@ struct NodeAttrs; ...@@ -22,6 +23,7 @@ struct NodeAttrs;
template<typename ValueType> template<typename ValueType>
class OpMap; class OpMap;
class OpRegistryEntry; class OpRegistryEntry;
using dmlc::ParamFieldInfo;
/*! \brief constant to indicate it take any length of positional inputs */ /*! \brief constant to indicate it take any length of positional inputs */
static const uint32_t kVarg = std::numeric_limits<uint32_t>::max(); static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
...@@ -80,6 +82,8 @@ class Op { ...@@ -80,6 +82,8 @@ class Op {
* This can be used to generate docstring automatically for the operator. * This can be used to generate docstring automatically for the operator.
*/ */
std::string description; std::string description;
/* \brief description of inputs and keyword arguments*/
std::vector<ParamFieldInfo> arguments;
/*! /*!
* \brief number of inputs to the operator, * \brief number of inputs to the operator,
* -1 means it is variable length * -1 means it is variable length
...@@ -150,6 +154,22 @@ class Op { ...@@ -150,6 +154,22 @@ class Op {
*/ */
inline Op& describe(const std::string& descr); // NOLINT(*) inline Op& describe(const std::string& descr); // NOLINT(*)
/*! /*!
* \brief Add argument information to the function.
* \param name Name of the argument.
* \param type Type of the argument.
* \param description Description of the argument.
* \return reference to self.
*/
inline Op& add_argument(const std::string &name,
const std::string &type,
const std::string &description);
/*!
* \brief Append list if arguments to the end.
* \param args Additional list of arguments.
* \return reference to self.
*/
inline Op& add_arguments(const std::vector<ParamFieldInfo> &args);
/*!
* \brief Set the num_inputs * \brief Set the num_inputs
* \param n The number of inputs to be set. * \param n The number of inputs to be set.
* \return reference to self. * \return reference to self.
...@@ -340,6 +360,18 @@ inline Op& Op::describe(const std::string& descr) { // NOLINT(*) ...@@ -340,6 +360,18 @@ inline Op& Op::describe(const std::string& descr) { // NOLINT(*)
return *this; return *this;
} }
inline Op& Op::add_argument(const std::string &name,
const std::string &type,
const std::string &description) {
arguments.push_back({name, type, type, description});
return *this;
}
inline Op& Op::add_arguments(const std::vector<ParamFieldInfo> &args) {
this->arguments.insert(arguments.end(), args.begin(), args.end());
return *this;
}
inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*) inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
this->num_inputs = n; this->num_inputs = n;
return *this; return *this;
......
...@@ -57,8 +57,8 @@ using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attr ...@@ -57,8 +57,8 @@ using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attr
*/ */
template<typename AttrType> template<typename AttrType>
using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs, using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs,
array_view<AttrType*> in_attrs, std::vector<AttrType> *in_attrs,
array_view<AttrType*> out_attrs)>; std::vector<AttrType> *out_attrs)>;
/*! /*!
* \brief Shape inference function. * \brief Shape inference function.
* Update the shapes given the input shape information. * Update the shapes given the input shape information.
......
...@@ -28,11 +28,26 @@ int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, ...@@ -28,11 +28,26 @@ int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char ***arg_descriptions, const char ***arg_descriptions,
const char **return_type) { const char **return_type) {
const Op *op = static_cast<const Op *>(creator); const Op *op = static_cast<const Op *>(creator);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN(); API_BEGIN();
*name = op->name.c_str(); *name = op->name.c_str();
*description = op->description.c_str(); *description = op->description.c_str();
*num_doc_args = 0; *num_doc_args = static_cast<nn_uint>(op->arguments.size());
if (return_type) *return_type = nullptr;
ret->ret_vec_charp.clear();
for (size_t i = 0; i < op->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(op->arguments[i].name.c_str());
}
for (size_t i = 0; i < op->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(op->arguments[i].type_info_str.c_str());
}
for (size_t i = 0; i < op->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(op->arguments[i].description.c_str());
}
*arg_names = dmlc::BeginPtr(ret->ret_vec_charp);
*arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + op->arguments.size();
*arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (op->arguments.size() * 2);
API_END(); API_END();
} }
......
...@@ -151,7 +151,10 @@ void Symbol::Print(std::ostream &os) const { ...@@ -151,7 +151,10 @@ void Symbol::Print(std::ostream &os) const {
} }
if (!node->attrs.dict.empty()) { if (!node->attrs.dict.empty()) {
os << "Attrs:\n"; os << "Attrs:\n";
for (auto &kv : node->attrs.dict) { // make an ordered copy because unordered_map doesn't guarantee order.
std::map<std::string, std::string> sorted_dict(
node->attrs.dict.begin(), node->attrs.dict.end());
for (auto &kv : sorted_dict) {
os << '\t' << kv.first << '=' << kv.second << '\n'; os << '\t' << kv.first << '=' << kv.second << '\n';
} }
} }
......
...@@ -47,44 +47,52 @@ Graph InferAttr(Graph &&ret, ...@@ -47,44 +47,52 @@ Graph InferAttr(Graph &&ret,
} }
// temp space for shape inference. // temp space for shape inference.
std::vector<AttrType*> ishape, oshape; std::vector<AttrType> ishape, oshape;
// number of completed nodes // number of completed nodes
size_t num_unknown = 0; size_t num_unknown = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid]; const auto& inode = idx[nid];
uint32_t num_inputs = inode.inputs.size();
uint32_t num_outputs = inode.source->num_outputs();
if (inode.source->is_variable()) { if (inode.source->is_variable()) {
if (shape_attr_key.length() != 0 && fis_none(rshape[idx.entry_id(nid, 0)])) { if (shape_attr_key.length() != 0 && fis_none(rshape[idx.entry_id(nid, 0)])) {
auto it = inode.source->attrs.dict.find(shape_attr_key); auto it = inode.source->attrs.dict.find(shape_attr_key);
if (it != inode.source->attrs.dict.end()) { if (it != inode.source->attrs.dict.end()) {
CHECK_EQ(inode.source->num_outputs(), 1); CHECK_EQ(num_outputs, 1);
std::istringstream is(it->second); std::istringstream is(it->second);
CHECK(is >> rshape[idx.entry_id(nid, 0)]) << "Invalid attribute"; CHECK(is >> rshape[idx.entry_id(nid, 0)]) << "Invalid attribute";
} }
} }
continue; continue;
} }
ishape.resize(inode.inputs.size());
for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = &rshape[idx.entry_id(inode.inputs[i])];
}
oshape.resize(inode.source->num_outputs());
for (uint32_t i = 0; i < oshape.size(); ++i) {
oshape[i] = &rshape[idx.entry_id(nid, i)];
}
if (finfer_shape.count(inode.source->op)) { if (finfer_shape.count(inode.source->op)) {
ishape.resize(num_inputs, def_value);
for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
}
oshape.resize(num_outputs, def_value);
for (uint32_t i = 0; i < oshape.size(); ++i) {
oshape[i] = rshape[idx.entry_id(nid, i)];
}
num_unknown += num_unknown +=
!(finfer_shape[inode.source->op](inode.source->attrs, ishape, oshape)); !(finfer_shape[inode.source->op](inode.source->attrs, &ishape, &oshape));
for (uint32_t i = 0; i < num_inputs; ++i) {
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
}
for (uint32_t i = 0; i < num_outputs; ++i) {
rshape[idx.entry_id(nid, i)] = oshape[i];
}
} else if (is_backward.get(inode.source->op, false)) { } else if (is_backward.get(inode.source->op, false)) {
// backward operator inference. // backward operator inference.
CHECK_GE(inode.control_deps.size(), 1) CHECK_GE(inode.control_deps.size(), 1)
<< "BackwardOp need to have control_deps to its forward op"; << "BackwardOp need to have control_deps to its forward op";
const auto& fnode = idx[inode.control_deps[0]]; const auto& fnode = idx[inode.control_deps[0]];
CHECK_EQ(fnode.inputs.size(), inode.source->num_outputs()) CHECK_EQ(fnode.inputs.size(), num_outputs)
<< "BackwardOp need to correspond to the forward node"; << "BackwardOp need to correspond to the forward node";
bool known = true; bool known = true;
for (size_t i = 0; i < fnode.inputs.size(); ++i) { for (size_t i = 0; i < fnode.inputs.size(); ++i) {
*oshape[i] = rshape[idx.entry_id(fnode.inputs[i])]; rshape[idx.entry_id(nid, i)] = rshape[idx.entry_id(fnode.inputs[i])];
if (fis_none(*oshape[i])) known = false; if (fis_none(rshape[idx.entry_id(nid, i)])) known = false;
} }
num_unknown += !known; num_unknown += !known;
} }
......
...@@ -41,7 +41,6 @@ def test_copy(): ...@@ -41,7 +41,6 @@ def test_copy():
z = sym.Variable('z') z = sym.Variable('z')
y = sym.exp(sym.add(x, x, name='add', gpu=2), y = sym.exp(sym.add(x, x, name='add', gpu=2),
name='exp', gpu=1, attr={"kk": "1"}) name='exp', gpu=1, attr={"kk": "1"})
assert y.__copy__().debug_str() == y.debug_str() assert y.__copy__().debug_str() == y.debug_str()
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment