Commit badcdfff by Tianqi Chen

Change op function pointer to std::function, enable mutation (#6)

parent c92d63c7
...@@ -24,6 +24,13 @@ struct NodeEntry { ...@@ -24,6 +24,13 @@ struct NodeEntry {
std::shared_ptr<Node> node; std::shared_ptr<Node> node;
/*! \brief index of output from the source. */ /*! \brief index of output from the source. */
uint32_t index; uint32_t index;
/*!
* \brief version of input Variable.
* This field can only be nonzero when this->node is a Variable node.
* version is increased by one each time a Variable get composed to a mutation Op.
* This information can be helpful to decide order of operations when sequence of mutation happens.
*/
uint32_t version;
}; };
/*! /*!
......
...@@ -101,13 +101,13 @@ class Op { ...@@ -101,13 +101,13 @@ class Op {
* \param attrs The attribute of the node * \param attrs The attribute of the node
* \return number of outputs. * \return number of outputs.
*/ */
uint32_t (*get_num_outputs)(const NodeAttrs& attrs) = nullptr; std::function<uint32_t(const NodeAttrs& attrs)> get_num_outputs = nullptr;
/*! /*!
* \brief get number of inputs given information about the node. * \brief get number of inputs given information about the node.
* \param attrs The attribute of the node * \param attrs The attribute of the node
* \return number of inputs * \return number of inputs
*/ */
uint32_t (*get_num_inputs)(const NodeAttrs& attrs) = nullptr; std::function<uint32_t(const NodeAttrs& attrs)> get_num_inputs = nullptr;
/*! /*!
* \brief Attribute parser to parse the NodeAttrs information. * \brief Attribute parser to parse the NodeAttrs information.
* *
...@@ -140,8 +140,7 @@ class Op { ...@@ -140,8 +140,7 @@ class Op {
* } * }
* \endcode * \endcode
*/ */
void (*attr_parser)(NodeAttrs* attrs) = nullptr; std::function<void(NodeAttrs* attrs)> attr_parser = nullptr;
// function fields. // function fields.
/*! /*!
* \brief setter function during registration * \brief setter function during registration
...@@ -161,7 +160,7 @@ class Op { ...@@ -161,7 +160,7 @@ class Op {
* \param fn The function to be set. * \param fn The function to be set.
* \return reference to self. * \return reference to self.
*/ */
inline Op& set_num_inputs(uint32_t (*fn)(const NodeAttrs& attr)); // NOLINT(*) inline Op& set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
/*! /*!
* \brief Set the num_outputs * \brief Set the num_outputs
* \param n The number of outputs to be set. * \param n The number of outputs to be set.
...@@ -173,13 +172,13 @@ class Op { ...@@ -173,13 +172,13 @@ class Op {
* \param fn The function to be set. * \param fn The function to be set.
* \return reference to self. * \return reference to self.
*/ */
inline Op& set_num_outputs(uint32_t (*fn)(const NodeAttrs& attr)); // NOLINT(*) inline Op& set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
/*! /*!
* \brief Set the attr_parser function. * \brief Set the attr_parser function.
* \param fn The number of outputs to be set. * \param fn The number of outputs to be set.
* \return reference to self. * \return reference to self.
*/ */
inline Op& set_attr_parser(void (*fn)(NodeAttrs* attrs)); // NOLINT(*) inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*)
/*! /*!
* \brief Register additional attributes to operator. * \brief Register additional attributes to operator.
* \param attr_name The name of the attribute. * \param attr_name The name of the attribute.
...@@ -342,7 +341,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*) ...@@ -342,7 +341,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
return *this; return *this;
} }
inline Op& Op::set_num_inputs(uint32_t (*fn)(const NodeAttrs& attr)) { // NOLINT(*) inline Op& Op::set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
this->get_num_inputs = fn; this->get_num_inputs = fn;
return *this; return *this;
} }
...@@ -352,12 +351,12 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*) ...@@ -352,12 +351,12 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
return *this; return *this;
} }
inline Op& Op::set_num_outputs(uint32_t (*fn)(const NodeAttrs& attr)) { // NOLINT(*) inline Op& Op::set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
this->get_num_outputs = fn; this->get_num_outputs = fn;
return *this; return *this;
} }
inline Op& Op::set_attr_parser(void (*fn)(NodeAttrs* attrs)) { // NOLINT(*) inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { // NOLINT(*)
this->attr_parser = fn; this->attr_parser = fn;
return *this; return *this;
} }
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
namespace nnvm { namespace nnvm {
// These types are optional attributes in each op // These types are optional attributes in each operator.
// Some of them are needed for certain pass. // Each attribute can be required by some passes.
/*! /*!
* \brief Return list of input arguments names of each operator. * \brief Return list of input arguments names of each operator.
...@@ -37,6 +37,16 @@ using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs& ...@@ -37,6 +37,16 @@ using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs&
*/ */
using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>; using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
/*!
* \brief Check whether operator will mutate k-th input.
* \param index The input index
* \return Whether this operator will mutate index-th input.
*
* \note Register under "FMutateInput", default return false
* FMutateInputs enables mutation order handling correctly.
*/
using FMutateInput = std::function<bool (const NodeAttrs& attrs, uint32_t index)>;
} // namespace nnvm } // namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_ #endif // NNVM_OP_ATTR_TYPES_H_
...@@ -13,6 +13,43 @@ namespace symbol_constants { ...@@ -13,6 +13,43 @@ namespace symbol_constants {
const char *kNamespaceSeparator = "_"; const char *kNamespaceSeparator = "_";
} // namespace symbol_constants } // namespace symbol_constants
// auxililary version attribute in variable.
struct VariableParam {
uint32_t version{0};
};
std::shared_ptr<Node> CreateVariableNode(const std::string& name) {
std::shared_ptr<Node> n = Node::Create();
n->op = nullptr;
n->attrs.name = name;
n->attrs.parsed = VariableParam();
return n;
}
// scan over a node's input, update the version to latest
// If the node's op mutates a certain input variable,
// The version of that varaible will increase
// version is used to implicitly order the mutation sequences
inline void UpdateNodeVersion(Node *n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
for (NodeEntry& e : n->inputs) {
if (e.node->is_variable()) {
e.version = nnvm::get<VariableParam>(e.node->attrs.parsed).version;
}
}
if (fmutate_inputs.count(n->op) != 0) {
FMutateInput fmutate = fmutate_inputs[n->op];
for (uint32_t i = 0; i < n->inputs.size(); ++i) {
if (fmutate(n->attrs, i)) {
NodeEntry& e = n->inputs[i];
CHECK(e.node->is_variable())
<< "Mutation target can only be Variable";
// increase the version of the variable.
++nnvm::get<VariableParam>(e.node->attrs.parsed).version;
}
}
}
}
inline std::string DefaultVarName(const std::string &op_name, inline std::string DefaultVarName(const std::string &op_name,
const std::string &arg_name) { const std::string &arg_name) {
...@@ -67,13 +104,13 @@ Symbol Symbol::Copy() const { ...@@ -67,13 +104,13 @@ Symbol Symbol::Copy() const {
for (const auto &kv : old_new) { for (const auto &kv : old_new) {
for (const NodeEntry& e : kv.first->inputs) { for (const NodeEntry& e : kv.first->inputs) {
Node *ptr = e.node.get(); Node *ptr = e.node.get();
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index}); kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version});
} }
} }
// set the head // set the head
Symbol ret; Symbol ret;
for (const NodeEntry &e : outputs) { for (const NodeEntry &e : outputs) {
ret.outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index}); ret.outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index, e.version});
} }
return ret; return ret;
} }
...@@ -95,8 +132,14 @@ void Symbol::Print(std::ostream &os) const { ...@@ -95,8 +132,14 @@ void Symbol::Print(std::ostream &os) const {
os << "Name: " << node->attrs.name << " Op:" << node->op->name << '\n' os << "Name: " << node->attrs.name << " Op:" << node->op->name << '\n'
<< "Inputs:\n"; << "Inputs:\n";
for (size_t i = 0; i < node->inputs.size(); ++i) { for (size_t i = 0; i < node->inputs.size(); ++i) {
os << "\targ[" << i << "]=" << node->inputs[i].node->attrs.name const NodeEntry& e = node->inputs[i];
<< '(' << node->inputs[i].index << ")\n"; os << "\targ[" << i << "]=" << e.node->attrs.name
<< '(' << e.index << ")";
if (e.node->is_variable()) {
os << " version=" << e.version << '\n';
} else {
os << '\n';
}
} }
os << "Attrs:\n"; os << "Attrs:\n";
for (auto &kv : node->attrs.dict) { for (auto &kv : node->attrs.dict) {
...@@ -163,6 +206,8 @@ std::vector<std::string> Symbol::ListOutputs() const { ...@@ -163,6 +206,8 @@ std::vector<std::string> Symbol::ListOutputs() const {
void Symbol::Compose(const std::vector<Symbol>& args, void Symbol::Compose(const std::vector<Symbol>& args,
const std::unordered_map<std::string, Symbol>& kwargs, const std::unordered_map<std::string, Symbol>& kwargs,
const std::string& name) { const std::string& name) {
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
CHECK_EQ(outputs.size(), 1) CHECK_EQ(outputs.size(), 1)
<< "Only composition of value function is supported currently"; << "Only composition of value function is supported currently";
CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed"; CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed";
...@@ -193,7 +238,6 @@ void Symbol::Compose(const std::vector<Symbol>& args, ...@@ -193,7 +238,6 @@ void Symbol::Compose(const std::vector<Symbol>& args,
} }
// switch to keyword argument matching // switch to keyword argument matching
if (args.size() != n_req) { if (args.size() != n_req) {
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
FListInputNames fn = flist_inputs.get(n->op, nullptr); FListInputNames fn = flist_inputs.get(n->op, nullptr);
auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs); auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs);
if (arg_names.size() != n_req) { if (arg_names.size() != n_req) {
...@@ -206,8 +250,8 @@ void Symbol::Compose(const std::vector<Symbol>& args, ...@@ -206,8 +250,8 @@ void Symbol::Compose(const std::vector<Symbol>& args,
n->inputs[i] = it->second.outputs[0]; n->inputs[i] = it->second.outputs[0];
++nmatched; ++nmatched;
} else { } else {
n->inputs[i] = NodeEntry{Node::Create(), 0}; n->inputs[i] = NodeEntry{
n->inputs[i].node->attrs.name = DefaultVarName(name, arg_names[i]); CreateVariableNode(DefaultVarName(name, arg_names[i])), 0, 0};
} }
} }
...@@ -226,6 +270,7 @@ void Symbol::Compose(const std::vector<Symbol>& args, ...@@ -226,6 +270,7 @@ void Symbol::Compose(const std::vector<Symbol>& args,
n->inputs.push_back(s.outputs[0]); n->inputs.push_back(s.outputs[0]);
} }
} }
UpdateNodeVersion(n);
} else { } else {
// general composition // general composition
CHECK_EQ(args.size(), 0) CHECK_EQ(args.size(), 0)
...@@ -253,25 +298,32 @@ void Symbol::Compose(const std::vector<Symbol>& args, ...@@ -253,25 +298,32 @@ void Symbol::Compose(const std::vector<Symbol>& args,
DFSVisit(this->outputs, find_replace_map); DFSVisit(this->outputs, find_replace_map);
if (nmatched == kwargs.size() && arg_counter < args.size()) { if (nmatched == kwargs.size() && arg_counter < args.size()) {
std::vector<Node*> update_nodes;
std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan; std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan;
auto find_replace_plan = [&replace_map, &replace_plan] auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes]
(const std::shared_ptr<Node> &node) { (const std::shared_ptr<Node> &node) {
// visit all the childs, find possible replacement // visit all the childs, find possible replacement
bool repl = false;
for (size_t i = 0; i < node->inputs.size(); ++i) { for (size_t i = 0; i < node->inputs.size(); ++i) {
NodeEntry *e = &(node->inputs[i]); NodeEntry *e = &(node->inputs[i]);
if (e->node->is_variable()) { if (e->node->is_variable()) {
auto iter = replace_map.find(e->node.get()); auto iter = replace_map.find(e->node.get());
if (iter != replace_map.end()) { if (iter != replace_map.end()) {
replace_plan.push_back(std::make_pair(e, iter->second)); replace_plan.push_back(std::make_pair(e, iter->second));
repl = true;
} }
} }
} }
if (repl) update_nodes.push_back(node.get());
}; };
DFSVisit(this->outputs, find_replace_plan); DFSVisit(this->outputs, find_replace_plan);
for (const auto& kv : replace_plan) { for (const auto& kv : replace_plan) {
*(kv.first) = *(kv.second); *(kv.first) = *(kv.second);
} }
for (Node* n : update_nodes) {
UpdateNodeVersion(n);
}
} else { } else {
std::vector<std::string> keys = GetKeys(kwargs); std::vector<std::string> keys = GetKeys(kwargs);
std::vector<std::string> arg_names = ListArguments(); std::vector<std::string> arg_names = ListArguments();
...@@ -303,9 +355,15 @@ Symbol Symbol::GetInternals() const { ...@@ -303,9 +355,15 @@ Symbol Symbol::GetInternals() const {
Symbol ret; Symbol ret;
DFSVisit(this->outputs, [&ret](const std::shared_ptr<Node>& node) { DFSVisit(this->outputs, [&ret](const std::shared_ptr<Node>& node) {
Node* n = node.get(); Node* n = node.get();
uint32_t nout = n->num_outputs(); if (n->is_variable()) {
for (uint32_t i = 0; i < nout; ++i) { // grab version from variable.
ret.outputs.emplace_back(NodeEntry{node, i}); VariableParam& param = nnvm::get<VariableParam>(n->attrs.parsed);
ret.outputs.emplace_back(NodeEntry{node, 0, param.version});
} else {
uint32_t nout = n->num_outputs();
for (uint32_t i = 0; i < nout; ++i) {
ret.outputs.emplace_back(NodeEntry{node, i, 0});
}
} }
}); });
return ret; return ret;
...@@ -325,7 +383,7 @@ void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& a ...@@ -325,7 +383,7 @@ void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& a
} }
} }
if (node->op != nullptr && node->op->attr_parser != nullptr) { if (node->op != nullptr && node->op->attr_parser != nullptr) {
(*node->op->attr_parser)(&(node->attrs)); node->op->attr_parser(&(node->attrs));
} }
} }
...@@ -366,9 +424,9 @@ Symbol Symbol::CreateFunctor(const Op* op, ...@@ -366,9 +424,9 @@ Symbol Symbol::CreateFunctor(const Op* op,
n->op = op; n->op = op;
n->attrs.dict = std::move(attrs); n->attrs.dict = std::move(attrs);
if (n->op->attr_parser != nullptr) { if (n->op->attr_parser != nullptr) {
(*n->op->attr_parser)(&(n->attrs)); n->op->attr_parser(&(n->attrs));
} }
s.outputs.emplace_back(NodeEntry{std::move(n), 0}); s.outputs.emplace_back(NodeEntry{std::move(n), 0, 0});
return s; return s;
} }
...@@ -382,10 +440,7 @@ Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) { ...@@ -382,10 +440,7 @@ Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {
Symbol Symbol::CreateVariable(const std::string& name) { Symbol Symbol::CreateVariable(const std::string& name) {
Symbol s; Symbol s;
std::shared_ptr<Node> n = Node::Create(); s.outputs.emplace_back(NodeEntry{CreateVariableNode(name), 0, 0});
n->op = nullptr;
n->attrs.name = name;
s.outputs.emplace_back(NodeEntry{std::move(n), 0});
return s; return s;
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <utility> #include <utility>
using nnvm::FListInputNames; using nnvm::FListInputNames;
using nnvm::FMutateInput;
using nnvm::NodeAttrs; using nnvm::NodeAttrs;
NNVM_REGISTER_OP(add) NNVM_REGISTER_OP(add)
...@@ -29,3 +30,10 @@ NNVM_REGISTER_OP(conv2d) ...@@ -29,3 +30,10 @@ NNVM_REGISTER_OP(conv2d)
NNVM_REGISTER_OP(add) NNVM_REGISTER_OP(add)
.attr<std::string>("nick_name", "plus"); .attr<std::string>("nick_name", "plus");
NNVM_REGISTER_OP(assign)
.set_num_inputs(2)
.set_num_outputs(1)
.attr<FMutateInput>("FMutateInput", [](const NodeAttrs& attrs, uint32_t index) {
return index == 0;
});
...@@ -24,6 +24,20 @@ def test_default_input(): ...@@ -24,6 +24,20 @@ def test_default_input():
except NNVMError: except NNVMError:
pass pass
def test_mutate_input():
x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv')
z = sym.assign(x, y)
t = sym.add(z, x)
try:
z = sym.assign(z, z)
assert False
except NNVMError:
pass
if __name__ == "__main__": if __name__ == "__main__":
test_default_input() test_default_input()
test_compose() test_compose()
test_mutate_input()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment