Commit 486249e8 by Tianqi Chen

Update mutate function (#23)

parent 16a6db3a
...@@ -32,7 +32,7 @@ DOXYFILE_ENCODING = UTF-8 ...@@ -32,7 +32,7 @@ DOXYFILE_ENCODING = UTF-8
# title of most generated pages and in a few other places. # title of most generated pages and in a few other places.
# The default value is: My Project. # The default value is: My Project.
PROJECT_NAME = "mxnngraph" PROJECT_NAME = "nnvm"
# The PROJECT_NUMBER tag can be used to enter a project or revision number. This # The PROJECT_NUMBER tag can be used to enter a project or revision number. This
# could be handy for archiving the generated documentation or if some version # could be handy for archiving the generated documentation or if some version
...@@ -753,7 +753,7 @@ WARN_LOGFILE = ...@@ -753,7 +753,7 @@ WARN_LOGFILE =
# spaces. # spaces.
# Note: If this tag is empty the current directory is searched. # Note: If this tag is empty the current directory is searched.
INPUT = include INPUT = include/nnvm
# This tag can be used to specify the character encoding of the source files # This tag can be used to specify the character encoding of the source files
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
namespace myproject { namespace myproject {
using nnvm::FListInputNames; using nnvm::FListInputNames;
using nnvm::FMutateInput; using nnvm::FMutateInputs;
using nnvm::FInferShape; using nnvm::FInferShape;
using nnvm::FInferType; using nnvm::FInferType;
using nnvm::FInplaceOption; using nnvm::FInplaceOption;
...@@ -119,8 +119,8 @@ NNVM_REGISTER_OP(add) ...@@ -119,8 +119,8 @@ NNVM_REGISTER_OP(add)
NNVM_REGISTER_OP(assign) NNVM_REGISTER_OP(assign)
.set_num_inputs(2) .set_num_inputs(2)
.set_num_outputs(1) .set_num_outputs(1)
.attr<FMutateInput>("FMutateInput", [](const NodeAttrs& attrs, uint32_t index) { .attr<FMutateInputs>("FMutateInputs", [](const NodeAttrs& attrs) {
return index == 0; return std::vector<uint32_t>{0};
}); });
} // namespace myproject } // namespace myproject
...@@ -144,8 +144,8 @@ class IndexedGraph { ...@@ -144,8 +144,8 @@ class IndexedGraph {
return nodes_[node_id(node)]; return nodes_[node_id(node)];
} }
/*! \return list of argument nodes */ /*! \return list of argument nodes */
inline const std::vector<uint32_t>& arg_nodes() const { inline const std::vector<uint32_t>& input_nodes() const {
return arg_nodes_; return input_nodes_;
} }
/*! \return list of output entries */ /*! \return list of output entries */
inline const std::vector<NodeEntry>& outputs() const { inline const std::vector<NodeEntry>& outputs() const {
...@@ -161,8 +161,8 @@ class IndexedGraph { ...@@ -161,8 +161,8 @@ class IndexedGraph {
explicit IndexedGraph(const Graph& other); explicit IndexedGraph(const Graph& other);
// node pointers in CSR structure. // node pointers in CSR structure.
std::vector<Node> nodes_; std::vector<Node> nodes_;
// index to argument nodes // index to input nodes
std::vector<uint32_t> arg_nodes_; std::vector<uint32_t> input_nodes_;
// space to store the outputs entries // space to store the outputs entries
std::vector<NodeEntry> outputs_; std::vector<NodeEntry> outputs_;
// mapping from node to index. // mapping from node to index.
......
...@@ -43,13 +43,12 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs ...@@ -43,13 +43,12 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs
/*! /*!
* \brief Check whether operator will mutate k-th input. * \brief Check whether operator will mutate k-th input.
* \param attrs The attributes of the node. * \param attrs The attributes of the node.
* \param index The input index * \return list of input indices it mutates.
* \return Whether this operator will mutate index-th input.
* *
* \note Register under "FMutateInput", default return false * \note Register under "FMutateInputs", default return false
* FMutateInputs enables mutation order handling correctly. * FMutateInputs enables mutation order handling correctly.
*/ */
using FMutateInput = std::function<bool (const NodeAttrs& attrs, uint32_t index)>; using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attrs)>;
/*! /*!
* \brief Inference function of certain type. * \brief Inference function of certain type.
......
...@@ -54,16 +54,16 @@ inline Graph OrderMutation(Graph src) { ...@@ -54,16 +54,16 @@ inline Graph OrderMutation(Graph src) {
/*! /*!
* \brief Infer shapes in the graph given the information. * \brief Infer shapes in the graph given the information.
* \param graph source graph * \param graph source graph
* \param shape_args The shapes of aruguments to the graph. * \param shape_inputs The shapes of aruguments to the graph.
* \param shape_attr_key The key to the node attribute that can indicate shape. * \param shape_attr_key The key to the node attribute that can indicate shape.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry. * \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id * The index of ShapeVector is given by graph.indexed_graph().entry_id
*/ */
inline Graph InferShape(Graph graph, inline Graph InferShape(Graph graph,
ShapeVector shape_args = {}, ShapeVector shape_inputs = {},
std::string shape_attr_key = "") { std::string shape_attr_key = "") {
if (shape_args.size() != 0) { if (shape_inputs.size() != 0) {
graph.attrs["shape_args"] = std::make_shared<any>(std::move(shape_args)); graph.attrs["shape_inputs"] = std::make_shared<any>(std::move(shape_inputs));
} }
if (shape_attr_key.length() != 0) { if (shape_attr_key.length() != 0) {
graph.attrs["shape_attr_key"] = std::make_shared<any>(std::move(shape_attr_key)); graph.attrs["shape_attr_key"] = std::make_shared<any>(std::move(shape_attr_key));
...@@ -74,19 +74,19 @@ inline Graph InferShape(Graph graph, ...@@ -74,19 +74,19 @@ inline Graph InferShape(Graph graph,
/*! /*!
* \brief Infer types in the graph given the information. * \brief Infer types in the graph given the information.
* \param graph source graph * \param graph source graph
* \param shape_args The shapes of aruguments to the graph. * \param dtype_inputs The shapes of inputs to the graph.
* \param shape_attr_key The key to the node attribute that can indicate shape. * \param dtype_attr_key The key to the node attribute that can indicate shape.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry. * \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id * The index of ShapeVector is given by graph.indexed_graph().entry_id
*/ */
inline Graph InferType(Graph graph, inline Graph InferType(Graph graph,
DTypeVector type_args = {}, DTypeVector dtype_inputs = {},
std::string type_attr_key = "") { std::string dtype_attr_key = "") {
if (type_args.size() != 0) { if (dtype_inputs.size() != 0) {
graph.attrs["dtype_args"] = std::make_shared<any>(std::move(type_args)); graph.attrs["dtype_inputs"] = std::make_shared<any>(std::move(dtype_inputs));
} }
if (type_attr_key.length() != 0) { if (dtype_attr_key.length() != 0) {
graph.attrs["dtype_attr_key"] = std::make_shared<any>(std::move(type_attr_key)); graph.attrs["dtype_attr_key"] = std::make_shared<any>(std::move(dtype_attr_key));
} }
return ApplyPass(std::move(graph), {"InferType"}); return ApplyPass(std::move(graph), {"InferType"});
} }
......
...@@ -30,7 +30,7 @@ IndexedGraph::IndexedGraph(const Graph &g) { ...@@ -30,7 +30,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
nodes_.emplace_back(std::move(new_node)); nodes_.emplace_back(std::move(new_node));
// arg_nodes_ // arg_nodes_
if (n->is_variable()) { if (n->is_variable()) {
arg_nodes_.push_back(nid); input_nodes_.push_back(nid);
} }
// node2index_ // node2index_
node2index_[n.get()] = nid; node2index_[n.get()] = nid;
......
...@@ -31,22 +31,19 @@ NodePtr CreateVariableNode(const std::string& name) { ...@@ -31,22 +31,19 @@ NodePtr CreateVariableNode(const std::string& name) {
// The version of that varaible will increase // The version of that varaible will increase
// version is used to implicitly order the mutation sequences // version is used to implicitly order the mutation sequences
inline void UpdateNodeVersion(Node *n) { inline void UpdateNodeVersion(Node *n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput"); static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
for (NodeEntry& e : n->inputs) { for (NodeEntry& e : n->inputs) {
if (e.node->is_variable()) { if (e.node->is_variable()) {
e.version = nnvm::get<VariableParam>(e.node->attrs.parsed).version; e.version = nnvm::get<VariableParam>(e.node->attrs.parsed).version;
} }
} }
if (fmutate_inputs.count(n->op) != 0) { if (fmutate_inputs.count(n->op) != 0) {
FMutateInput fmutate = fmutate_inputs[n->op]; for (uint32_t i : fmutate_inputs[n->op](n->attrs)) {
for (uint32_t i = 0; i < n->inputs.size(); ++i) { NodeEntry& e = n->inputs[i];
if (fmutate(n->attrs, i)) { CHECK(e.node->is_variable())
NodeEntry& e = n->inputs[i]; << "Mutation target can only be Variable";
CHECK(e.node->is_variable()) // increase the version of the variable.
<< "Mutation target can only be Variable"; e.version = ++nnvm::get<VariableParam>(e.node->attrs.parsed).version;
// increase the version of the variable.
e.version = ++nnvm::get<VariableParam>(e.node->attrs.parsed).version;
}
} }
} }
} }
...@@ -192,16 +189,13 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const { ...@@ -192,16 +189,13 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
} else { } else {
std::unordered_set<Node*> mutable_set; std::unordered_set<Node*> mutable_set;
std::vector<Node*> vlist; std::vector<Node*> vlist;
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput"); static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) { DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) {
if (node->is_variable()) { if (node->is_variable()) {
vlist.push_back(node.get()); vlist.push_back(node.get());
} else if (fmutate_inputs.count(node->op)) { } else if (fmutate_inputs.count(node->op)) {
FMutateInput fmutate = fmutate_inputs[node->op]; for (uint32_t i : fmutate_inputs[node->op](node->attrs)){
for (uint32_t i = 0; i < node->inputs.size(); ++i) { mutable_set.insert(node->inputs[i].node.get());
if (fmutate(node->attrs, i)) {
mutable_set.insert(node->inputs[i].node.get());
}
} }
} }
}); });
......
...@@ -15,7 +15,7 @@ template<typename AttrType, typename IsNone> ...@@ -15,7 +15,7 @@ template<typename AttrType, typename IsNone>
Graph InferAttr(Graph &&ret, Graph InferAttr(Graph &&ret,
const AttrType def_value, const AttrType def_value,
const char* infer_name, const char* infer_name,
const char* arg_name, const char* input_name,
const char* attr_key_name, const char* attr_key_name,
const char* attr_name, const char* attr_name,
const char* unknown_name, const char* unknown_name,
...@@ -29,15 +29,15 @@ Graph InferAttr(Graph &&ret, ...@@ -29,15 +29,15 @@ Graph InferAttr(Graph &&ret,
// reshape shape vector // reshape shape vector
AttrVector rshape(idx.num_node_entries(), def_value); AttrVector rshape(idx.num_node_entries(), def_value);
if (ret.attrs.count(arg_name) != 0) { if (ret.attrs.count(input_name) != 0) {
const AttrVector& shape_args = ret.GetAttr<AttrVector>(arg_name); const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
CHECK_LE(shape_args.size(), idx.arg_nodes().size()) CHECK_LE(shape_args.size(), idx.input_nodes().size())
<< "shape args is more than number of arguments"; << "shape args is more than number of arguments";
for (size_t i = 0; i < shape_args.size(); ++i) { for (size_t i = 0; i < shape_args.size(); ++i) {
rshape[idx.entry_id(idx.arg_nodes()[i], 0)] = shape_args[i]; rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i];
} }
// erase the provided arguments // erase the provided arguments
ret.attrs.erase(arg_name); ret.attrs.erase(input_name);
} }
std::string shape_attr_key; std::string shape_attr_key;
if (ret.attrs.count(attr_key_name) != 0) { if (ret.attrs.count(attr_key_name) != 0) {
...@@ -113,7 +113,7 @@ NNVM_REGISTER_PASS(InferType) ...@@ -113,7 +113,7 @@ NNVM_REGISTER_PASS(InferType)
.set_body([](Graph ret) { .set_body([](Graph ret) {
return InferAttr<int>( return InferAttr<int>(
std::move(ret), 0, std::move(ret), 0,
"FInferType", "dtype_args", "dtype_attr_key", "FInferType", "dtype_inputs", "dtype_attr_key",
"dtype", "dtype_num_unknown_nodes", "dtype", "dtype_num_unknown_nodes",
[](const int t) { return t == -1; }); [](const int t) { return t == -1; });
}) })
......
...@@ -21,6 +21,13 @@ inline T get_with_default(const std::unordered_map<Node*, T> &map, ...@@ -21,6 +21,13 @@ inline T get_with_default(const std::unordered_map<Node*, T> &map,
return def; return def;
} }
inline bool IsMutate(const std::vector<uint32_t>& mutate_inputs, uint32_t i) {
if (mutate_inputs.size() == 0) return false;
auto it = std::lower_bound(
mutate_inputs.begin(), mutate_inputs.end(), i);
return (it != mutate_inputs.end()) && (*it == i);
}
Graph OrderMutation(const Graph& src) { Graph OrderMutation(const Graph& src) {
std::unordered_map<Node*, std::vector<NodeEntry> > version_hist; std::unordered_map<Node*, std::vector<NodeEntry> > version_hist;
DFSVisit(src.outputs, [&version_hist](const NodePtr& n) { DFSVisit(src.outputs, [&version_hist](const NodePtr& n) {
...@@ -37,7 +44,13 @@ Graph OrderMutation(const Graph& src) { ...@@ -37,7 +44,13 @@ Graph OrderMutation(const Graph& src) {
// start preparing for remapping the nodes. // start preparing for remapping the nodes.
std::unordered_map<Node*, NodePtr> old_new; std::unordered_map<Node*, NodePtr> old_new;
auto prepare = [&version_hist, &old_new] (const NodePtr& n) { auto prepare = [&version_hist, &old_new] (const NodePtr& n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput"); static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::vector<uint32_t> mutate_inputs;
if (!n->is_variable() && fmutate_inputs.count(n->op)) {
mutate_inputs = fmutate_inputs[n->op](n->attrs);
}
std::sort(mutate_inputs.begin(), mutate_inputs.end());
bool need_repl = false; bool need_repl = false;
for (size_t i = 0; i < n->inputs.size(); ++i) { for (size_t i = 0; i < n->inputs.size(); ++i) {
const NodeEntry& e = n->inputs[i]; const NodeEntry& e = n->inputs[i];
...@@ -46,9 +59,7 @@ Graph OrderMutation(const Graph& src) { ...@@ -46,9 +59,7 @@ Graph OrderMutation(const Graph& src) {
auto it = version_hist.find(e.node.get()); auto it = version_hist.find(e.node.get());
if (it != version_hist.end()) { if (it != version_hist.end()) {
std::vector<NodeEntry>& vec = it->second; std::vector<NodeEntry>& vec = it->second;
uint32_t is_mutate = vec.emplace_back(NodeEntry{n, IsMutate(mutate_inputs, i), e.version});
fmutate_inputs.count(n->op) ? fmutate_inputs[n->op](n->attrs, i) : 0;
vec.emplace_back(NodeEntry{n, is_mutate, e.version});
} }
} else { } else {
if (old_new.count(e.node.get()) != 0) need_repl = true; if (old_new.count(e.node.get()) != 0) need_repl = true;
...@@ -91,18 +102,21 @@ Graph OrderMutation(const Graph& src) { ...@@ -91,18 +102,21 @@ Graph OrderMutation(const Graph& src) {
get_with_default(old_new, p.get(), p)); get_with_default(old_new, p.get(), p));
} }
// add control deps // add control deps
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput"); static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::vector<uint32_t> mutate_inputs;
if (fmutate_inputs.count(kv.first->op)) {
mutate_inputs = fmutate_inputs[kv.first->op](kv.first->attrs);
}
std::sort(mutate_inputs.begin(), mutate_inputs.end());
for (size_t i = 0; i < kv.first->inputs.size(); ++i) { for (size_t i = 0; i < kv.first->inputs.size(); ++i) {
const NodeEntry& e = kv.first->inputs[i]; const NodeEntry& e = kv.first->inputs[i];
if (e.node->is_variable() && version_hist.count(e.node.get()) != 0) { if (e.node->is_variable() && version_hist.count(e.node.get()) != 0) {
FMutateInput fmutate = fmutate_inputs.get(kv.first->op, nullptr);
uint32_t is_mutate = (fmutate == nullptr) ? 0 : fmutate(kv.first->attrs, i);
std::vector<NodeEntry>& vec = version_hist.at(e.node.get()); std::vector<NodeEntry>& vec = version_hist.at(e.node.get());
auto it = std::lower_bound(vec.begin(), vec.end(), auto it = std::lower_bound(vec.begin(), vec.end(),
NodeEntry{nullptr, 1, e.version}, NodeEntry{nullptr, 1, e.version},
comparator); comparator);
if (is_mutate != 0) { if (IsMutate(mutate_inputs, i)) {
int read_dep = 0; int read_dep = 0;
while (it != vec.begin()) { while (it != vec.begin()) {
--it; --it;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment