Commit efc87f1d by Tianqi Chen

Enable aux data (#24)

parent 486249e8
......@@ -147,6 +147,10 @@ class IndexedGraph {
inline const std::vector<uint32_t>& input_nodes() const {
return input_nodes_;
}
/*! \return list of mutable nodes */
inline const std::unordered_set<uint32_t>& mutable_input_nodes() const {
return mutable_input_nodes_;
}
/*! \return list of output entries */
inline const std::vector<NodeEntry>& outputs() const {
return outputs_;
......@@ -161,8 +165,10 @@ class IndexedGraph {
explicit IndexedGraph(const Graph& other);
// node pointers in CSR structure.
std::vector<Node> nodes_;
// index to input nodes
// index all to input nodes
std::vector<uint32_t> input_nodes_;
// index to mutable input nodes
std::unordered_set<uint32_t> mutable_input_nodes_;
// space to store the outputs entries
std::vector<NodeEntry> outputs_;
// mapping from node to index.
......
......@@ -368,12 +368,14 @@ inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { //
// member functions of OpMap
template<typename ValueType>
inline int OpMap<ValueType>::count(const Op* op) const {
if (op == nullptr) return 0;
const uint32_t idx = op->index_;
return idx < data_.size() ? data_[idx].second : 0;
}
template<typename ValueType>
inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {
CHECK(op != nullptr);
const uint32_t idx = op->index_;
CHECK(idx < data_.size() && data_[idx].second)
<< "Attribute " << attr_name_
......@@ -383,6 +385,7 @@ inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {
template<typename ValueType>
inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def_value) const {
if (op == nullptr) return def_value;
const uint32_t idx = op->index_;
if (idx < data_.size() && data_[idx].second) {
return data_[idx].first;
......
......@@ -4,6 +4,7 @@
* \brief Graph node data structure.
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <limits>
namespace nnvm {
......@@ -57,12 +58,20 @@ IndexedGraph::IndexedGraph(const Graph &g) {
node2index_.at(e.node.get()), e.index, e.version});
}
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::unordered_set<uint32_t> mutable_inputs;
// setup array view
// input_entries_ and control_rptr must not change after this step.
const NodeEntry* iptr = dmlc::BeginPtr(input_entries_);
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
nodes_[nid].inputs = array_view<NodeEntry>(
iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]);
if (nodes_[nid].source->op != nullptr &&
fmutate_inputs.count(nodes_[nid].source->op)) {
for (uint32_t i : fmutate_inputs[nodes_[nid].source->op](nodes_[nid].source->attrs)) {
mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id);
}
}
}
const uint32_t* cptr = dmlc::BeginPtr(control_deps_);
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
......
......@@ -101,7 +101,7 @@ NNVM_REGISTER_PASS(InferShape)
.set_body([](Graph ret) {
return InferAttr<TShape>(
std::move(ret), TShape(),
"FInferShape", "shape_args", "shape_attr_key",
"FInferShape", "shape_inputs", "shape_attr_key",
"shape", "shape_num_unknown_nodes",
[](const TShape& s) { return s.ndim() == 0; });
})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment