Commit 486249e8 by Tianqi Chen

Update mutate function (#23)

parent 16a6db3a
......@@ -32,7 +32,7 @@ DOXYFILE_ENCODING = UTF-8
# title of most generated pages and in a few other places.
# The default value is: My Project.
PROJECT_NAME = "mxnngraph"
PROJECT_NAME = "nnvm"
# The PROJECT_NUMBER tag can be used to enter a project or revision number. This
# could be handy for archiving the generated documentation or if some version
......@@ -753,7 +753,7 @@ WARN_LOGFILE =
# spaces.
# Note: If this tag is empty the current directory is searched.
INPUT = include
INPUT = include/nnvm
# This tag can be used to specify the character encoding of the source files
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
......
......@@ -11,7 +11,7 @@
namespace myproject {
using nnvm::FListInputNames;
using nnvm::FMutateInput;
using nnvm::FMutateInputs;
using nnvm::FInferShape;
using nnvm::FInferType;
using nnvm::FInplaceOption;
......@@ -119,8 +119,8 @@ NNVM_REGISTER_OP(add)
NNVM_REGISTER_OP(assign)
.set_num_inputs(2)
.set_num_outputs(1)
.attr<FMutateInput>("FMutateInput", [](const NodeAttrs& attrs, uint32_t index) {
return index == 0;
.attr<FMutateInputs>("FMutateInputs", [](const NodeAttrs& attrs) {
return std::vector<uint32_t>{0};
});
} // namespace myproject
......@@ -144,8 +144,8 @@ class IndexedGraph {
return nodes_[node_id(node)];
}
/*! \return list of argument nodes */
inline const std::vector<uint32_t>& arg_nodes() const {
return arg_nodes_;
inline const std::vector<uint32_t>& input_nodes() const {
return input_nodes_;
}
/*! \return list of output entries */
inline const std::vector<NodeEntry>& outputs() const {
......@@ -161,8 +161,8 @@ class IndexedGraph {
explicit IndexedGraph(const Graph& other);
// node pointers in CSR structure.
std::vector<Node> nodes_;
// index to argument nodes
std::vector<uint32_t> arg_nodes_;
// index to input nodes
std::vector<uint32_t> input_nodes_;
// space to store the outputs entries
std::vector<NodeEntry> outputs_;
// mapping from node to index.
......
......@@ -43,13 +43,12 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs
/*!
* \brief Check whether operator will mutate k-th input.
* \param attrs The attributes of the node.
* \param index The input index
* \return Whether this operator will mutate index-th input.
* \return list of input indices it mutates.
*
* \note Register under "FMutateInput", default return false
* \note Register under "FMutateInputs", default return false
* FMutateInputs enables mutation order handling correctly.
*/
using FMutateInput = std::function<bool (const NodeAttrs& attrs, uint32_t index)>;
using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attrs)>;
/*!
* \brief Inference function of certain type.
......
......@@ -54,16 +54,16 @@ inline Graph OrderMutation(Graph src) {
/*!
* \brief Infer shapes in the graph given the information.
* \param graph source graph
* \param shape_args The shapes of aruguments to the graph.
* \param shape_inputs The shapes of aruguments to the graph.
* \param shape_attr_key The key to the node attribute that can indicate shape.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id
*/
inline Graph InferShape(Graph graph,
ShapeVector shape_args = {},
ShapeVector shape_inputs = {},
std::string shape_attr_key = "") {
if (shape_args.size() != 0) {
graph.attrs["shape_args"] = std::make_shared<any>(std::move(shape_args));
if (shape_inputs.size() != 0) {
graph.attrs["shape_inputs"] = std::make_shared<any>(std::move(shape_inputs));
}
if (shape_attr_key.length() != 0) {
graph.attrs["shape_attr_key"] = std::make_shared<any>(std::move(shape_attr_key));
......@@ -74,19 +74,19 @@ inline Graph InferShape(Graph graph,
/*!
* \brief Infer types in the graph given the information.
* \param graph source graph
* \param shape_args The shapes of aruguments to the graph.
* \param shape_attr_key The key to the node attribute that can indicate shape.
* \param dtype_inputs The shapes of inputs to the graph.
* \param dtype_attr_key The key to the node attribute that can indicate shape.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id
*/
inline Graph InferType(Graph graph,
DTypeVector type_args = {},
std::string type_attr_key = "") {
if (type_args.size() != 0) {
graph.attrs["dtype_args"] = std::make_shared<any>(std::move(type_args));
DTypeVector dtype_inputs = {},
std::string dtype_attr_key = "") {
if (dtype_inputs.size() != 0) {
graph.attrs["dtype_inputs"] = std::make_shared<any>(std::move(dtype_inputs));
}
if (type_attr_key.length() != 0) {
graph.attrs["dtype_attr_key"] = std::make_shared<any>(std::move(type_attr_key));
if (dtype_attr_key.length() != 0) {
graph.attrs["dtype_attr_key"] = std::make_shared<any>(std::move(dtype_attr_key));
}
return ApplyPass(std::move(graph), {"InferType"});
}
......
......@@ -30,7 +30,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
nodes_.emplace_back(std::move(new_node));
// arg_nodes_
if (n->is_variable()) {
arg_nodes_.push_back(nid);
input_nodes_.push_back(nid);
}
// node2index_
node2index_[n.get()] = nid;
......
......@@ -31,16 +31,14 @@ NodePtr CreateVariableNode(const std::string& name) {
// The version of that varaible will increase
// version is used to implicitly order the mutation sequences
inline void UpdateNodeVersion(Node *n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
for (NodeEntry& e : n->inputs) {
if (e.node->is_variable()) {
e.version = nnvm::get<VariableParam>(e.node->attrs.parsed).version;
}
}
if (fmutate_inputs.count(n->op) != 0) {
FMutateInput fmutate = fmutate_inputs[n->op];
for (uint32_t i = 0; i < n->inputs.size(); ++i) {
if (fmutate(n->attrs, i)) {
for (uint32_t i : fmutate_inputs[n->op](n->attrs)) {
NodeEntry& e = n->inputs[i];
CHECK(e.node->is_variable())
<< "Mutation target can only be Variable";
......@@ -48,7 +46,6 @@ inline void UpdateNodeVersion(Node *n) {
e.version = ++nnvm::get<VariableParam>(e.node->attrs.parsed).version;
}
}
}
}
inline std::string DefaultVarName(const std::string &op_name,
......@@ -192,18 +189,15 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
} else {
std::unordered_set<Node*> mutable_set;
std::vector<Node*> vlist;
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) {
if (node->is_variable()) {
vlist.push_back(node.get());
} else if (fmutate_inputs.count(node->op)) {
FMutateInput fmutate = fmutate_inputs[node->op];
for (uint32_t i = 0; i < node->inputs.size(); ++i) {
if (fmutate(node->attrs, i)) {
for (uint32_t i : fmutate_inputs[node->op](node->attrs)){
mutable_set.insert(node->inputs[i].node.get());
}
}
}
});
for (Node* node : vlist) {
if ((option == kReadOnlyArgs && mutable_set.count(node) == 0) ||
......
......@@ -15,7 +15,7 @@ template<typename AttrType, typename IsNone>
Graph InferAttr(Graph &&ret,
const AttrType def_value,
const char* infer_name,
const char* arg_name,
const char* input_name,
const char* attr_key_name,
const char* attr_name,
const char* unknown_name,
......@@ -29,15 +29,15 @@ Graph InferAttr(Graph &&ret,
// reshape shape vector
AttrVector rshape(idx.num_node_entries(), def_value);
if (ret.attrs.count(arg_name) != 0) {
const AttrVector& shape_args = ret.GetAttr<AttrVector>(arg_name);
CHECK_LE(shape_args.size(), idx.arg_nodes().size())
if (ret.attrs.count(input_name) != 0) {
const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
CHECK_LE(shape_args.size(), idx.input_nodes().size())
<< "shape args is more than number of arguments";
for (size_t i = 0; i < shape_args.size(); ++i) {
rshape[idx.entry_id(idx.arg_nodes()[i], 0)] = shape_args[i];
rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i];
}
// erase the provided arguments
ret.attrs.erase(arg_name);
ret.attrs.erase(input_name);
}
std::string shape_attr_key;
if (ret.attrs.count(attr_key_name) != 0) {
......@@ -113,7 +113,7 @@ NNVM_REGISTER_PASS(InferType)
.set_body([](Graph ret) {
return InferAttr<int>(
std::move(ret), 0,
"FInferType", "dtype_args", "dtype_attr_key",
"FInferType", "dtype_inputs", "dtype_attr_key",
"dtype", "dtype_num_unknown_nodes",
[](const int t) { return t == -1; });
})
......
......@@ -21,6 +21,13 @@ inline T get_with_default(const std::unordered_map<Node*, T> &map,
return def;
}
inline bool IsMutate(const std::vector<uint32_t>& mutate_inputs, uint32_t i) {
if (mutate_inputs.size() == 0) return false;
auto it = std::lower_bound(
mutate_inputs.begin(), mutate_inputs.end(), i);
return (it != mutate_inputs.end()) && (*it == i);
}
Graph OrderMutation(const Graph& src) {
std::unordered_map<Node*, std::vector<NodeEntry> > version_hist;
DFSVisit(src.outputs, [&version_hist](const NodePtr& n) {
......@@ -37,7 +44,13 @@ Graph OrderMutation(const Graph& src) {
// start preparing for remapping the nodes.
std::unordered_map<Node*, NodePtr> old_new;
auto prepare = [&version_hist, &old_new] (const NodePtr& n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::vector<uint32_t> mutate_inputs;
if (!n->is_variable() && fmutate_inputs.count(n->op)) {
mutate_inputs = fmutate_inputs[n->op](n->attrs);
}
std::sort(mutate_inputs.begin(), mutate_inputs.end());
bool need_repl = false;
for (size_t i = 0; i < n->inputs.size(); ++i) {
const NodeEntry& e = n->inputs[i];
......@@ -46,9 +59,7 @@ Graph OrderMutation(const Graph& src) {
auto it = version_hist.find(e.node.get());
if (it != version_hist.end()) {
std::vector<NodeEntry>& vec = it->second;
uint32_t is_mutate =
fmutate_inputs.count(n->op) ? fmutate_inputs[n->op](n->attrs, i) : 0;
vec.emplace_back(NodeEntry{n, is_mutate, e.version});
vec.emplace_back(NodeEntry{n, IsMutate(mutate_inputs, i), e.version});
}
} else {
if (old_new.count(e.node.get()) != 0) need_repl = true;
......@@ -91,18 +102,21 @@ Graph OrderMutation(const Graph& src) {
get_with_default(old_new, p.get(), p));
}
// add control deps
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::vector<uint32_t> mutate_inputs;
if (fmutate_inputs.count(kv.first->op)) {
mutate_inputs = fmutate_inputs[kv.first->op](kv.first->attrs);
}
std::sort(mutate_inputs.begin(), mutate_inputs.end());
for (size_t i = 0; i < kv.first->inputs.size(); ++i) {
const NodeEntry& e = kv.first->inputs[i];
if (e.node->is_variable() && version_hist.count(e.node.get()) != 0) {
FMutateInput fmutate = fmutate_inputs.get(kv.first->op, nullptr);
uint32_t is_mutate = (fmutate == nullptr) ? 0 : fmutate(kv.first->attrs, i);
std::vector<NodeEntry>& vec = version_hist.at(e.node.get());
auto it = std::lower_bound(vec.begin(), vec.end(),
NodeEntry{nullptr, 1, e.version},
comparator);
if (is_mutate != 0) {
if (IsMutate(mutate_inputs, i)) {
int read_dep = 0;
while (it != vec.begin()) {
--it;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment