Commit 6ffeae97 by Tianqi Chen

update (#26)

* updates (#1)

* add scalars

* change format

* change inferattr interface

* remove scalar

* remove warning
parent ebbec6de
......@@ -21,14 +21,14 @@ using nnvm::array_view;
// simply return the shape as same
inline bool SameShape(const NodeAttrs& attrs,
array_view<TShape*> ishape,
array_view<TShape*> oshape) {
if (ishape.size() == 0 || ishape[0]->ndim() == 0) return false;
for (TShape* pshape : oshape) {
*pshape = *ishape[0];
std::vector<TShape> *ishape,
std::vector<TShape> *oshape) {
if (ishape->size() == 0 || (*ishape)[0].ndim() == 0) return false;
for (TShape& pshape : *oshape) {
pshape = (*ishape)[0];
}
for (TShape* pshape : ishape) {
*pshape = *ishape[0];
for (TShape& pshape : *ishape) {
pshape = (*ishape)[0];
}
return true;
}
......@@ -51,13 +51,13 @@ NNVM_REGISTER_OP(reshape)
})
.attr<FInferShape>(
"FInferShape", [] (const NodeAttrs& attrs,
array_view<TShape*> ishape,
array_view<TShape*> oshape) {
std::vector<TShape> *ishape,
std::vector<TShape> *oshape) {
// get parsed attribute
const TShape& target = nnvm::get<TShape>(attrs.parsed);
*oshape[0] = target;
if (ishape[0]->ndim() == 0) return false;
CHECK_EQ(ishape[0]->Size(), target.Size())
(*oshape)[0] = target;
if ((*ishape)[0].ndim() == 0) return false;
CHECK_EQ((*ishape)[0].Size(), target.Size())
<< "Reshape op: source target shape mismatch";
return true;
})
......@@ -78,9 +78,9 @@ NNVM_REGISTER_OP(cast)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInferType>(
"FInferType", [](const NodeAttrs& attrs,
array_view<int*> itype,
array_view<int*> otype) {
*otype[0] = nnvm::get<int>(attrs.parsed);
std::vector<int> *itype,
std::vector<int> *otype) {
(*otype)[0] = nnvm::get<int>(attrs.parsed);
return true;
});
......
......@@ -6,6 +6,7 @@
#ifndef NNVM_OP_H_
#define NNVM_OP_H_
#include <dmlc/parameter.h>
#include <string>
#include <vector>
#include <utility>
......@@ -22,6 +23,7 @@ struct NodeAttrs;
template<typename ValueType>
class OpMap;
class OpRegistryEntry;
using dmlc::ParamFieldInfo;
/*! \brief constant to indicate it take any length of positional inputs */
static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
......@@ -80,6 +82,8 @@ class Op {
* This can be used to generate docstring automatically for the operator.
*/
std::string description;
/* \brief description of inputs and keyword arguments*/
std::vector<ParamFieldInfo> arguments;
/*!
* \brief number of inputs to the operator,
* -1 means it is variable length
......@@ -150,6 +154,22 @@ class Op {
*/
inline Op& describe(const std::string& descr); // NOLINT(*)
/*!
* \brief Add argument information to the function.
* \param name Name of the argument.
* \param type Type of the argument.
* \param description Description of the argument.
* \return reference to self.
*/
inline Op& add_argument(const std::string &name,
const std::string &type,
const std::string &description);
/*!
* \brief Append list if arguments to the end.
* \param args Additional list of arguments.
* \return reference to self.
*/
inline Op& add_arguments(const std::vector<ParamFieldInfo> &args);
/*!
* \brief Set the num_inputs
* \param n The number of inputs to be set.
* \return reference to self.
......@@ -340,6 +360,18 @@ inline Op& Op::describe(const std::string& descr) { // NOLINT(*)
return *this;
}
inline Op& Op::add_argument(const std::string &name,
const std::string &type,
const std::string &description) {
arguments.push_back({name, type, type, description});
return *this;
}
inline Op& Op::add_arguments(const std::vector<ParamFieldInfo> &args) {
this->arguments.insert(arguments.end(), args.begin(), args.end());
return *this;
}
inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
this->num_inputs = n;
return *this;
......
......@@ -57,8 +57,8 @@ using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attr
*/
template<typename AttrType>
using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs,
array_view<AttrType*> in_attrs,
array_view<AttrType*> out_attrs)>;
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs)>;
/*!
* \brief Shape inference function.
* Update the shapes given the input shape information.
......
......@@ -28,11 +28,26 @@ int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char ***arg_descriptions,
const char **return_type) {
const Op *op = static_cast<const Op *>(creator);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
*name = op->name.c_str();
*description = op->description.c_str();
*num_doc_args = 0;
*num_doc_args = static_cast<nn_uint>(op->arguments.size());
if (return_type) *return_type = nullptr;
ret->ret_vec_charp.clear();
for (size_t i = 0; i < op->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(op->arguments[i].name.c_str());
}
for (size_t i = 0; i < op->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(op->arguments[i].type_info_str.c_str());
}
for (size_t i = 0; i < op->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(op->arguments[i].description.c_str());
}
*arg_names = dmlc::BeginPtr(ret->ret_vec_charp);
*arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + op->arguments.size();
*arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (op->arguments.size() * 2);
API_END();
}
......
......@@ -151,7 +151,10 @@ void Symbol::Print(std::ostream &os) const {
}
if (!node->attrs.dict.empty()) {
os << "Attrs:\n";
for (auto &kv : node->attrs.dict) {
// make an ordered copy because unordered_map doesn't guarantee order.
std::map<std::string, std::string> sorted_dict(
node->attrs.dict.begin(), node->attrs.dict.end());
for (auto &kv : sorted_dict) {
os << '\t' << kv.first << '=' << kv.second << '\n';
}
}
......
......@@ -47,44 +47,52 @@ Graph InferAttr(Graph &&ret,
}
// temp space for shape inference.
std::vector<AttrType*> ishape, oshape;
std::vector<AttrType> ishape, oshape;
// number of completed nodes
size_t num_unknown = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
uint32_t num_inputs = inode.inputs.size();
uint32_t num_outputs = inode.source->num_outputs();
if (inode.source->is_variable()) {
if (shape_attr_key.length() != 0 && fis_none(rshape[idx.entry_id(nid, 0)])) {
auto it = inode.source->attrs.dict.find(shape_attr_key);
if (it != inode.source->attrs.dict.end()) {
CHECK_EQ(inode.source->num_outputs(), 1);
CHECK_EQ(num_outputs, 1);
std::istringstream is(it->second);
CHECK(is >> rshape[idx.entry_id(nid, 0)]) << "Invalid attribute";
}
}
continue;
}
ishape.resize(inode.inputs.size());
for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = &rshape[idx.entry_id(inode.inputs[i])];
}
oshape.resize(inode.source->num_outputs());
for (uint32_t i = 0; i < oshape.size(); ++i) {
oshape[i] = &rshape[idx.entry_id(nid, i)];
}
if (finfer_shape.count(inode.source->op)) {
ishape.resize(num_inputs, def_value);
for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
}
oshape.resize(num_outputs, def_value);
for (uint32_t i = 0; i < oshape.size(); ++i) {
oshape[i] = rshape[idx.entry_id(nid, i)];
}
num_unknown +=
!(finfer_shape[inode.source->op](inode.source->attrs, ishape, oshape));
!(finfer_shape[inode.source->op](inode.source->attrs, &ishape, &oshape));
for (uint32_t i = 0; i < num_inputs; ++i) {
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
}
for (uint32_t i = 0; i < num_outputs; ++i) {
rshape[idx.entry_id(nid, i)] = oshape[i];
}
} else if (is_backward.get(inode.source->op, false)) {
// backward operator inference.
CHECK_GE(inode.control_deps.size(), 1)
<< "BackwardOp need to have control_deps to its forward op";
const auto& fnode = idx[inode.control_deps[0]];
CHECK_EQ(fnode.inputs.size(), inode.source->num_outputs())
CHECK_EQ(fnode.inputs.size(), num_outputs)
<< "BackwardOp need to correspond to the forward node";
bool known = true;
for (size_t i = 0; i < fnode.inputs.size(); ++i) {
*oshape[i] = rshape[idx.entry_id(fnode.inputs[i])];
if (fis_none(*oshape[i])) known = false;
rshape[idx.entry_id(nid, i)] = rshape[idx.entry_id(fnode.inputs[i])];
if (fis_none(rshape[idx.entry_id(nid, i)])) known = false;
}
num_unknown += !known;
}
......
......@@ -41,7 +41,6 @@ def test_copy():
z = sym.Variable('z')
y = sym.exp(sym.add(x, x, name='add', gpu=2),
name='exp', gpu=1, attr={"kk": "1"})
assert y.__copy__().debug_str() == y.debug_str()
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment