Commit de076999 by Tianqi Chen

[NODE] Move op inside node attribute (#30)

parent ac070f83
...@@ -46,7 +46,7 @@ inline NodeEntry MakeNode(const char* op_name, ...@@ -46,7 +46,7 @@ inline NodeEntry MakeNode(const char* op_name,
std::string node_name, std::string node_name,
std::vector<NodeEntry> inputs) { std::vector<NodeEntry> inputs) {
NodePtr p = Node::Create(); NodePtr p = Node::Create();
p->op = nnvm::Op::Get(op_name); p->attrs.op = nnvm::Op::Get(op_name);
p->attrs.name = std::move(node_name); p->attrs.name = std::move(node_name);
p->inputs = std::move(inputs); p->inputs = std::move(inputs);
return NodeEntry{p, 0, 0}; return NodeEntry{p, 0, 0};
......
...@@ -46,6 +46,11 @@ struct NodeEntry { ...@@ -46,6 +46,11 @@ struct NodeEntry {
* Usually are additional parameters like axis, * Usually are additional parameters like axis,
*/ */
struct NodeAttrs { struct NodeAttrs {
/*!
* \brief The operator this node uses.
* For place holder variable, op == nullptr.
*/
const Op *op{nullptr};
/*! \brief name of the node */ /*! \brief name of the node */
std::string name; std::string name;
/*! \brief Vector representation of positional attributes */ /*! \brief Vector representation of positional attributes */
...@@ -65,11 +70,8 @@ struct NodeAttrs { ...@@ -65,11 +70,8 @@ struct NodeAttrs {
*/ */
class Node { class Node {
public: public:
/*! /*! \brief The attributes in the node. */
* \brief The operator this node uses. NodeAttrs attrs;
* For place holder variable, op == nullptr.
*/
const Op *op{nullptr};
/*! \brief inputs to this node */ /*! \brief inputs to this node */
std::vector<NodeEntry> inputs; std::vector<NodeEntry> inputs;
/*! /*!
...@@ -77,10 +79,10 @@ class Node { ...@@ -77,10 +79,10 @@ class Node {
* Gives operation must be performed before this operation. * Gives operation must be performed before this operation.
*/ */
std::vector<NodePtr> control_deps; std::vector<NodePtr> control_deps;
/*! \brief The attributes in the node. */
NodeAttrs attrs;
/*! \brief destructor of node */ /*! \brief destructor of node */
~Node(); ~Node();
/*! \return operator in this node */
inline const Op* op() const;
/*! /*!
* \brief return whether node is placeholder variable. * \brief return whether node is placeholder variable.
* This is equivalent to op == nullptr * This is equivalent to op == nullptr
...@@ -99,25 +101,28 @@ class Node { ...@@ -99,25 +101,28 @@ class Node {
}; };
// implementation of functions. // implementation of functions.
inline const Op* Node::op() const {
return this->attrs.op;
}
inline bool Node::is_variable() const { inline bool Node::is_variable() const {
return this->op == nullptr; return this->op() == nullptr;
} }
inline uint32_t Node::num_outputs() const { inline uint32_t Node::num_outputs() const {
if (is_variable()) return 1; if (is_variable()) return 1;
if (this->op->get_num_outputs == nullptr) { if (this->op()->get_num_outputs == nullptr) {
return this->op->num_outputs; return this->op()->num_outputs;
} else { } else {
return this->op->get_num_outputs(this->attrs); return this->op()->get_num_outputs(this->attrs);
} }
} }
inline uint32_t Node::num_inputs() const { inline uint32_t Node::num_inputs() const {
if (is_variable()) return 1; if (is_variable()) return 1;
if (this->op->get_num_inputs == nullptr) { if (this->op()->get_num_inputs == nullptr) {
return this->op->num_inputs; return this->op()->num_inputs;
} else { } else {
return this->op->get_num_inputs(this->attrs); return this->op()->get_num_inputs(this->attrs);
} }
} }
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include <vector>
#include "./base.h" #include "./base.h"
#include "./pass.h" #include "./pass.h"
#include "./graph_attr_types.h" #include "./graph_attr_types.h"
......
...@@ -66,9 +66,9 @@ IndexedGraph::IndexedGraph(const Graph &g) { ...@@ -66,9 +66,9 @@ IndexedGraph::IndexedGraph(const Graph &g) {
for (size_t nid = 0; nid < nodes_.size(); ++nid) { for (size_t nid = 0; nid < nodes_.size(); ++nid) {
nodes_[nid].inputs = array_view<NodeEntry>( nodes_[nid].inputs = array_view<NodeEntry>(
iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]); iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]);
if (nodes_[nid].source->op != nullptr && if (nodes_[nid].source->op() != nullptr &&
fmutate_inputs.count(nodes_[nid].source->op)) { fmutate_inputs.count(nodes_[nid].source->op())) {
for (uint32_t i : fmutate_inputs[nodes_[nid].source->op](nodes_[nid].source->attrs)) { for (uint32_t i : fmutate_inputs[nodes_[nid].source->op()](nodes_[nid].source->attrs)) {
mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id); mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id);
} }
} }
......
...@@ -20,7 +20,7 @@ struct VariableParam { ...@@ -20,7 +20,7 @@ struct VariableParam {
NodePtr CreateVariableNode(const std::string& name) { NodePtr CreateVariableNode(const std::string& name) {
NodePtr n = Node::Create(); NodePtr n = Node::Create();
n->op = nullptr; n->attrs.op = nullptr;
n->attrs.name = name; n->attrs.name = name;
n->attrs.parsed = VariableParam(); n->attrs.parsed = VariableParam();
return n; return n;
...@@ -37,8 +37,8 @@ inline void UpdateNodeVersion(Node *n) { ...@@ -37,8 +37,8 @@ inline void UpdateNodeVersion(Node *n) {
e.version = nnvm::get<VariableParam>(e.node->attrs.parsed).version; e.version = nnvm::get<VariableParam>(e.node->attrs.parsed).version;
} }
} }
if (fmutate_inputs.count(n->op) != 0) { if (fmutate_inputs.count(n->op()) != 0) {
for (uint32_t i : fmutate_inputs[n->op](n->attrs)) { for (uint32_t i : fmutate_inputs[n->op()](n->attrs)) {
NodeEntry& e = n->inputs[i]; NodeEntry& e = n->inputs[i];
CHECK(e.node->is_variable()) CHECK(e.node->is_variable())
<< "Mutation target can only be Variable"; << "Mutation target can only be Variable";
...@@ -96,7 +96,6 @@ Symbol Symbol::Copy() const { ...@@ -96,7 +96,6 @@ Symbol Symbol::Copy() const {
// use DFSVisit to copy all the nodes // use DFSVisit to copy all the nodes
DFSVisit(this->outputs, [&old_new](const NodePtr& node) { DFSVisit(this->outputs, [&old_new](const NodePtr& node) {
NodePtr np = Node::Create(); NodePtr np = Node::Create();
np->op = node->op;
np->attrs = node->attrs; np->attrs = node->attrs;
old_new[node.get()] = std::move(np); old_new[node.get()] = std::move(np);
}); });
...@@ -123,7 +122,7 @@ void Symbol::Print(std::ostream &os) const { ...@@ -123,7 +122,7 @@ void Symbol::Print(std::ostream &os) const {
if (outputs[0].node->is_variable()) { if (outputs[0].node->is_variable()) {
os << "Variable:" << outputs[0].node->attrs.name << '\n'; os << "Variable:" << outputs[0].node->attrs.name << '\n';
} else { } else {
os << "AtomicFunctor "<< " Op:" << outputs[0].node->op->name << '\n'; os << "AtomicFunctor "<< " Op:" << outputs[0].node->op()->name << '\n';
} }
} else { } else {
// use DFSVisit to copy all the nodes // use DFSVisit to copy all the nodes
...@@ -137,7 +136,7 @@ void Symbol::Print(std::ostream &os) const { ...@@ -137,7 +136,7 @@ void Symbol::Print(std::ostream &os) const {
os << "Variable:" << node->attrs.name << '\n'; os << "Variable:" << node->attrs.name << '\n';
} else { } else {
os << "--------------------\n"; os << "--------------------\n";
os << "Op:" << node->op->name << ", Name=" << node->attrs.name << '\n' os << "Op:" << node->op()->name << ", Name=" << node->attrs.name << '\n'
<< "Inputs:\n"; << "Inputs:\n";
for (size_t i = 0; i < node->inputs.size(); ++i) { for (size_t i = 0; i < node->inputs.size(); ++i) {
const NodeEntry& e = node->inputs[i]; const NodeEntry& e = node->inputs[i];
...@@ -196,8 +195,8 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const { ...@@ -196,8 +195,8 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) { DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) {
if (node->is_variable()) { if (node->is_variable()) {
vlist.push_back(node.get()); vlist.push_back(node.get());
} else if (fmutate_inputs.count(node->op)) { } else if (fmutate_inputs.count(node->op())) {
for (uint32_t i : fmutate_inputs[node->op](node->attrs)){ for (uint32_t i : fmutate_inputs[node->op()](node->attrs)){
mutable_set.insert(node->inputs[i].node.get()); mutable_set.insert(node->inputs[i].node.get());
} }
} }
...@@ -221,7 +220,7 @@ std::vector<std::string> Symbol::ListOutputNames() const { ...@@ -221,7 +220,7 @@ std::vector<std::string> Symbol::ListOutputNames() const {
} else { } else {
const std::string& hname = head.node->attrs.name; const std::string& hname = head.node->attrs.name;
std::string rname; std::string rname;
FListOutputNames fn = flist_ouputs.get(head.node->op, nullptr); FListOutputNames fn = flist_ouputs.get(head.node->op(), nullptr);
if (fn != nullptr) { if (fn != nullptr) {
rname = fn(head.node->attrs)[head.index]; rname = fn(head.node->attrs)[head.index];
} else { } else {
...@@ -278,10 +277,10 @@ void Symbol::Compose(const array_view<const Symbol*>& args, ...@@ -278,10 +277,10 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
} }
// switch to keyword argument matching // switch to keyword argument matching
if (args.size() != n_req) { if (args.size() != n_req) {
FListInputNames fn = flist_inputs.get(n->op, nullptr); FListInputNames fn = flist_inputs.get(n->op(), nullptr);
auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs); auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs);
if (arg_names.size() != n_req) { if (arg_names.size() != n_req) {
LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op->name; LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op()->name;
} }
size_t nmatched = 0; size_t nmatched = 0;
for (size_t i = args.size(); i < n_req; ++i) { for (size_t i = args.size(); i < n_req; ++i) {
...@@ -422,8 +421,8 @@ void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& a ...@@ -422,8 +421,8 @@ void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& a
node->attrs.dict[kv.first] = kv.second; node->attrs.dict[kv.first] = kv.second;
} }
} }
if (node->op != nullptr && node->op->attr_parser != nullptr) { if (node->op() != nullptr && node->op()->attr_parser != nullptr) {
node->op->attr_parser(&(node->attrs)); node->op()->attr_parser(&(node->attrs));
} }
} }
...@@ -461,10 +460,10 @@ Symbol Symbol::CreateFunctor(const Op* op, ...@@ -461,10 +460,10 @@ Symbol Symbol::CreateFunctor(const Op* op,
std::unordered_map<std::string, std::string> attrs) { std::unordered_map<std::string, std::string> attrs) {
Symbol s; Symbol s;
NodePtr n = Node::Create(); NodePtr n = Node::Create();
n->op = op; n->attrs.op = op;
n->attrs.dict = std::move(attrs); n->attrs.dict = std::move(attrs);
if (n->op->attr_parser != nullptr) { if (n->op()->attr_parser != nullptr) {
n->op->attr_parser(&(n->attrs)); n->op()->attr_parser(&(n->attrs));
} }
s.outputs.emplace_back(NodeEntry{std::move(n), 0, 0}); s.outputs.emplace_back(NodeEntry{std::move(n), 0, 0});
return s; return s;
......
...@@ -20,11 +20,11 @@ NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) { ...@@ -20,11 +20,11 @@ NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
return std::move(v[0]); return std::move(v[0]);
} else if (v.size() == 0) { } else if (v.size() == 0) {
NodePtr zero_node = Node::Create(); NodePtr zero_node = Node::Create();
zero_node->op = Op::Get("__zero__"); zero_node->attrs.op = Op::Get("__zero__");
return NodeEntry{zero_node, 0, 0}; return NodeEntry{zero_node, 0, 0};
} else { } else {
NodePtr sum_node = Node::Create(); NodePtr sum_node = Node::Create();
sum_node->op = Op::Get("__ewise_sum__"); sum_node->attrs.op = Op::Get("__ewise_sum__");
sum_node->inputs = std::move(v); sum_node->inputs = std::move(v);
return NodeEntry{sum_node, 0, 0}; return NodeEntry{sum_node, 0, 0};
} }
...@@ -109,7 +109,7 @@ Graph Gradient(Graph src) { ...@@ -109,7 +109,7 @@ Graph Gradient(Graph src) {
e.sum = agg_fun(std::move(e.grads)); e.sum = agg_fun(std::move(e.grads));
out_agg_grads.push_back(e.sum); out_agg_grads.push_back(e.sum);
} }
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op] std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()]
(mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads); (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads);
auto git = input_grads.begin(); auto git = input_grads.begin();
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) { for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
......
...@@ -65,7 +65,7 @@ Graph InferAttr(Graph &&ret, ...@@ -65,7 +65,7 @@ Graph InferAttr(Graph &&ret,
} }
continue; continue;
} }
if (finfer_shape.count(inode.source->op)) { if (finfer_shape.count(inode.source->op())) {
ishape.resize(num_inputs, def_value); ishape.resize(num_inputs, def_value);
for (uint32_t i = 0; i < ishape.size(); ++i) { for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = rshape[idx.entry_id(inode.inputs[i])]; ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
...@@ -75,14 +75,14 @@ Graph InferAttr(Graph &&ret, ...@@ -75,14 +75,14 @@ Graph InferAttr(Graph &&ret,
oshape[i] = rshape[idx.entry_id(nid, i)]; oshape[i] = rshape[idx.entry_id(nid, i)];
} }
num_unknown += num_unknown +=
!(finfer_shape[inode.source->op](inode.source->attrs, &ishape, &oshape)); !(finfer_shape[inode.source->op()](inode.source->attrs, &ishape, &oshape));
for (uint32_t i = 0; i < num_inputs; ++i) { for (uint32_t i = 0; i < num_inputs; ++i) {
rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
} }
for (uint32_t i = 0; i < num_outputs; ++i) { for (uint32_t i = 0; i < num_outputs; ++i) {
rshape[idx.entry_id(nid, i)] = oshape[i]; rshape[idx.entry_id(nid, i)] = oshape[i];
} }
} else if (is_backward.get(inode.source->op, false)) { } else if (is_backward.get(inode.source->op(), false)) {
// backward operator inference. // backward operator inference.
CHECK_GE(inode.control_deps.size(), 1) CHECK_GE(inode.control_deps.size(), 1)
<< "BackwardOp need to have control_deps to its forward op"; << "BackwardOp need to have control_deps to its forward op";
......
...@@ -43,8 +43,8 @@ Graph OrderMutation(const Graph& src) { ...@@ -43,8 +43,8 @@ Graph OrderMutation(const Graph& src) {
auto prepare = [&version_hist, &old_new] (const NodePtr& n) { auto prepare = [&version_hist, &old_new] (const NodePtr& n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs"); static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::vector<uint32_t> mutate_inputs; std::vector<uint32_t> mutate_inputs;
if (!n->is_variable() && fmutate_inputs.count(n->op)) { if (!n->is_variable() && fmutate_inputs.count(n->op())) {
mutate_inputs = fmutate_inputs[n->op](n->attrs); mutate_inputs = fmutate_inputs[n->op()](n->attrs);
} }
std::sort(mutate_inputs.begin(), mutate_inputs.end()); std::sort(mutate_inputs.begin(), mutate_inputs.end());
...@@ -67,7 +67,6 @@ Graph OrderMutation(const Graph& src) { ...@@ -67,7 +67,6 @@ Graph OrderMutation(const Graph& src) {
} }
if (need_repl) { if (need_repl) {
NodePtr np = Node::Create(); NodePtr np = Node::Create();
np->op = n->op;
np->attrs = n->attrs; np->attrs = n->attrs;
old_new[n.get()] = std::move(np); old_new[n.get()] = std::move(np);
} }
...@@ -101,8 +100,8 @@ Graph OrderMutation(const Graph& src) { ...@@ -101,8 +100,8 @@ Graph OrderMutation(const Graph& src) {
// add control deps // add control deps
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs"); static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::vector<uint32_t> mutate_inputs; std::vector<uint32_t> mutate_inputs;
if (fmutate_inputs.count(kv.first->op)) { if (fmutate_inputs.count(kv.first->op())) {
mutate_inputs = fmutate_inputs[kv.first->op](kv.first->attrs); mutate_inputs = fmutate_inputs[kv.first->op()](kv.first->attrs);
} }
std::sort(mutate_inputs.begin(), mutate_inputs.end()); std::sort(mutate_inputs.begin(), mutate_inputs.end());
......
...@@ -109,9 +109,9 @@ Graph PlaceDevice(Graph src) { ...@@ -109,9 +109,9 @@ Graph PlaceDevice(Graph src) {
NodeEntry{it->second, 0, 0}); NodeEntry{it->second, 0, 0});
} else { } else {
NodePtr copy_node = Node::Create(); NodePtr copy_node = Node::Create();
copy_node->op = copy_op;
std::ostringstream os; std::ostringstream os;
os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy"; os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy";
copy_node->attrs.op = copy_op;
copy_node->attrs.name = os.str(); copy_node->attrs.name = os.str();
copy_node->inputs.push_back(inode.source->inputs[i]); copy_node->inputs.push_back(inode.source->inputs[i]);
copy_map[copy_key] = copy_node; copy_map[copy_key] = copy_node;
......
...@@ -168,8 +168,8 @@ Graph PlanMemory(Graph ret) { ...@@ -168,8 +168,8 @@ Graph PlanMemory(Graph ret) {
const auto& inode = idx[nid]; const auto& inode = idx[nid];
if (inode.source->is_variable()) continue; if (inode.source->is_variable()) continue;
// check inplace option // check inplace option
if (finplace_option.count(inode.source->op) != 0) { if (finplace_option.count(inode.source->op()) != 0) {
auto inplace_pairs = finplace_option[inode.source->op](inode.source->attrs); auto inplace_pairs = finplace_option[inode.source->op()](inode.source->attrs);
for (auto& kv : inplace_pairs) { for (auto& kv : inplace_pairs) {
uint32_t eid_out = idx.entry_id(nid, kv.second); uint32_t eid_out = idx.entry_id(nid, kv.second);
uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]); uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]);
......
...@@ -68,8 +68,8 @@ struct JSONNode { ...@@ -68,8 +68,8 @@ struct JSONNode {
// function to save JSON node. // function to save JSON node.
void Save(dmlc::JSONWriter *writer) const { void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject(); writer->BeginObject();
if (node->op != nullptr) { if (node->op() != nullptr) {
writer->WriteObjectKeyValue("op", node->op->name); writer->WriteObjectKeyValue("op", node->op()->name);
} else { } else {
std::string json_null = "null"; std::string json_null = "null";
writer->WriteObjectKeyValue("op", json_null); writer->WriteObjectKeyValue("op", json_null);
...@@ -108,10 +108,10 @@ struct JSONNode { ...@@ -108,10 +108,10 @@ struct JSONNode {
if (op_type_str != "null") { if (op_type_str != "null") {
try { try {
node->op = Op::Get(op_type_str); node->attrs.op = Op::Get(op_type_str);
// rebuild attribute parser // rebuild attribute parser
if (node->op->attr_parser != nullptr) { if (node->op()->attr_parser != nullptr) {
node->op->attr_parser(&(node->attrs)); node->op()->attr_parser(&(node->attrs));
} }
} catch (const dmlc::Error &err) { } catch (const dmlc::Error &err) {
std::ostringstream os; std::ostringstream os;
...@@ -120,7 +120,7 @@ struct JSONNode { ...@@ -120,7 +120,7 @@ struct JSONNode {
throw dmlc::Error(os.str()); throw dmlc::Error(os.str());
} }
} else { } else {
node->op = nullptr; node->attrs.op = nullptr;
} }
} }
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment