Commit efc87f1d by Tianqi Chen

Enable aux data (#24)

parent 486249e8
...@@ -147,6 +147,10 @@ class IndexedGraph { ...@@ -147,6 +147,10 @@ class IndexedGraph {
inline const std::vector<uint32_t>& input_nodes() const { inline const std::vector<uint32_t>& input_nodes() const {
return input_nodes_; return input_nodes_;
} }
/*! \return list of mutable nodes */
inline const std::unordered_set<uint32_t>& mutable_input_nodes() const {
return mutable_input_nodes_;
}
/*! \return list of output entries */ /*! \return list of output entries */
inline const std::vector<NodeEntry>& outputs() const { inline const std::vector<NodeEntry>& outputs() const {
return outputs_; return outputs_;
...@@ -161,8 +165,10 @@ class IndexedGraph { ...@@ -161,8 +165,10 @@ class IndexedGraph {
explicit IndexedGraph(const Graph& other); explicit IndexedGraph(const Graph& other);
// node pointers in CSR structure. // node pointers in CSR structure.
std::vector<Node> nodes_; std::vector<Node> nodes_;
// index to input nodes // index all to input nodes
std::vector<uint32_t> input_nodes_; std::vector<uint32_t> input_nodes_;
// index to mutable input nodes
std::unordered_set<uint32_t> mutable_input_nodes_;
// space to store the outputs entries // space to store the outputs entries
std::vector<NodeEntry> outputs_; std::vector<NodeEntry> outputs_;
// mapping from node to index. // mapping from node to index.
......
...@@ -368,12 +368,14 @@ inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { // ...@@ -368,12 +368,14 @@ inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { //
// member functions of OpMap // member functions of OpMap
template<typename ValueType> template<typename ValueType>
inline int OpMap<ValueType>::count(const Op* op) const { inline int OpMap<ValueType>::count(const Op* op) const {
if (op == nullptr) return 0;
const uint32_t idx = op->index_; const uint32_t idx = op->index_;
return idx < data_.size() ? data_[idx].second : 0; return idx < data_.size() ? data_[idx].second : 0;
} }
template<typename ValueType> template<typename ValueType>
inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const { inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {
CHECK(op != nullptr);
const uint32_t idx = op->index_; const uint32_t idx = op->index_;
CHECK(idx < data_.size() && data_[idx].second) CHECK(idx < data_.size() && data_[idx].second)
<< "Attribute " << attr_name_ << "Attribute " << attr_name_
...@@ -383,6 +385,7 @@ inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const { ...@@ -383,6 +385,7 @@ inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {
template<typename ValueType> template<typename ValueType>
inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def_value) const { inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def_value) const {
if (op == nullptr) return def_value;
const uint32_t idx = op->index_; const uint32_t idx = op->index_;
if (idx < data_.size() && data_[idx].second) { if (idx < data_.size() && data_[idx].second) {
return data_[idx].first; return data_[idx].first;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief Graph node data structure. * \brief Graph node data structure.
*/ */
#include <nnvm/graph.h> #include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <limits> #include <limits>
namespace nnvm { namespace nnvm {
...@@ -57,12 +58,20 @@ IndexedGraph::IndexedGraph(const Graph &g) { ...@@ -57,12 +58,20 @@ IndexedGraph::IndexedGraph(const Graph &g) {
node2index_.at(e.node.get()), e.index, e.version}); node2index_.at(e.node.get()), e.index, e.version});
} }
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::unordered_set<uint32_t> mutable_inputs;
// setup array view // setup array view
// input_entries_ and control_rptr must not change after this step. // input_entries_ and control_rptr must not change after this step.
const NodeEntry* iptr = dmlc::BeginPtr(input_entries_); const NodeEntry* iptr = dmlc::BeginPtr(input_entries_);
for (size_t nid = 0; nid < nodes_.size(); ++nid) { for (size_t nid = 0; nid < nodes_.size(); ++nid) {
nodes_[nid].inputs = array_view<NodeEntry>( nodes_[nid].inputs = array_view<NodeEntry>(
iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]); iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]);
if (nodes_[nid].source->op != nullptr &&
fmutate_inputs.count(nodes_[nid].source->op)) {
for (uint32_t i : fmutate_inputs[nodes_[nid].source->op](nodes_[nid].source->attrs)) {
mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id);
}
}
} }
const uint32_t* cptr = dmlc::BeginPtr(control_deps_); const uint32_t* cptr = dmlc::BeginPtr(control_deps_);
for (size_t nid = 0; nid < nodes_.size(); ++nid) { for (size_t nid = 0; nid < nodes_.size(); ++nid) {
......
...@@ -101,7 +101,7 @@ NNVM_REGISTER_PASS(InferShape) ...@@ -101,7 +101,7 @@ NNVM_REGISTER_PASS(InferShape)
.set_body([](Graph ret) { .set_body([](Graph ret) {
return InferAttr<TShape>( return InferAttr<TShape>(
std::move(ret), TShape(), std::move(ret), TShape(),
"FInferShape", "shape_args", "shape_attr_key", "FInferShape", "shape_inputs", "shape_attr_key",
"shape", "shape_num_unknown_nodes", "shape", "shape_num_unknown_nodes",
[](const TShape& s) { return s.ndim() == 0; }); [](const TShape& s) { return s.ndim() == 0; });
}) })
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment