Commit de076999 by Tianqi Chen

[NODE] Move op inside node attribute (#30)

parent ac070f83
......@@ -46,7 +46,7 @@ inline NodeEntry MakeNode(const char* op_name,
std::string node_name,
std::vector<NodeEntry> inputs) {
NodePtr p = Node::Create();
p->op = nnvm::Op::Get(op_name);
p->attrs.op = nnvm::Op::Get(op_name);
p->attrs.name = std::move(node_name);
p->inputs = std::move(inputs);
return NodeEntry{p, 0, 0};
......
......@@ -46,6 +46,11 @@ struct NodeEntry {
* Usually are additional parameters like axis,
*/
struct NodeAttrs {
/*!
* \brief The operator this node uses.
* For place holder variable, op == nullptr.
*/
const Op *op{nullptr};
/*! \brief name of the node */
std::string name;
/*! \brief Vector representation of positional attributes */
......@@ -65,11 +70,8 @@ struct NodeAttrs {
*/
class Node {
public:
/*!
* \brief The operator this node uses.
* For place holder variable, op == nullptr.
*/
const Op *op{nullptr};
/*! \brief The attributes in the node. */
NodeAttrs attrs;
/*! \brief inputs to this node */
std::vector<NodeEntry> inputs;
/*!
......@@ -77,10 +79,10 @@ class Node {
* Gives operation must be performed before this operation.
*/
std::vector<NodePtr> control_deps;
/*! \brief The attributes in the node. */
NodeAttrs attrs;
/*! \brief destructor of node */
~Node();
/*! \return operator in this node */
inline const Op* op() const;
/*!
* \brief return whether node is placeholder variable.
* This is equivalent to op == nullptr
......@@ -99,25 +101,28 @@ class Node {
};
// implementation of functions.
inline const Op* Node::op() const {
return this->attrs.op;
}
inline bool Node::is_variable() const {
return this->op == nullptr;
return this->op() == nullptr;
}
inline uint32_t Node::num_outputs() const {
if (is_variable()) return 1;
if (this->op->get_num_outputs == nullptr) {
return this->op->num_outputs;
if (this->op()->get_num_outputs == nullptr) {
return this->op()->num_outputs;
} else {
return this->op->get_num_outputs(this->attrs);
return this->op()->get_num_outputs(this->attrs);
}
}
inline uint32_t Node::num_inputs() const {
if (is_variable()) return 1;
if (this->op->get_num_inputs == nullptr) {
return this->op->num_inputs;
if (this->op()->get_num_inputs == nullptr) {
return this->op()->num_inputs;
} else {
return this->op->get_num_inputs(this->attrs);
return this->op()->get_num_inputs(this->attrs);
}
}
......
......@@ -12,6 +12,7 @@
#include <string>
#include <memory>
#include <vector>
#include "./base.h"
#include "./pass.h"
#include "./graph_attr_types.h"
......
......@@ -66,9 +66,9 @@ IndexedGraph::IndexedGraph(const Graph &g) {
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
nodes_[nid].inputs = array_view<NodeEntry>(
iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]);
if (nodes_[nid].source->op != nullptr &&
fmutate_inputs.count(nodes_[nid].source->op)) {
for (uint32_t i : fmutate_inputs[nodes_[nid].source->op](nodes_[nid].source->attrs)) {
if (nodes_[nid].source->op() != nullptr &&
fmutate_inputs.count(nodes_[nid].source->op())) {
for (uint32_t i : fmutate_inputs[nodes_[nid].source->op()](nodes_[nid].source->attrs)) {
mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id);
}
}
......
......@@ -20,7 +20,7 @@ struct VariableParam {
NodePtr CreateVariableNode(const std::string& name) {
NodePtr n = Node::Create();
n->op = nullptr;
n->attrs.op = nullptr;
n->attrs.name = name;
n->attrs.parsed = VariableParam();
return n;
......@@ -37,8 +37,8 @@ inline void UpdateNodeVersion(Node *n) {
e.version = nnvm::get<VariableParam>(e.node->attrs.parsed).version;
}
}
if (fmutate_inputs.count(n->op) != 0) {
for (uint32_t i : fmutate_inputs[n->op](n->attrs)) {
if (fmutate_inputs.count(n->op()) != 0) {
for (uint32_t i : fmutate_inputs[n->op()](n->attrs)) {
NodeEntry& e = n->inputs[i];
CHECK(e.node->is_variable())
<< "Mutation target can only be Variable";
......@@ -96,7 +96,6 @@ Symbol Symbol::Copy() const {
// use DFSVisit to copy all the nodes
DFSVisit(this->outputs, [&old_new](const NodePtr& node) {
NodePtr np = Node::Create();
np->op = node->op;
np->attrs = node->attrs;
old_new[node.get()] = std::move(np);
});
......@@ -123,7 +122,7 @@ void Symbol::Print(std::ostream &os) const {
if (outputs[0].node->is_variable()) {
os << "Variable:" << outputs[0].node->attrs.name << '\n';
} else {
os << "AtomicFunctor "<< " Op:" << outputs[0].node->op->name << '\n';
os << "AtomicFunctor "<< " Op:" << outputs[0].node->op()->name << '\n';
}
} else {
// use DFSVisit to copy all the nodes
......@@ -137,7 +136,7 @@ void Symbol::Print(std::ostream &os) const {
os << "Variable:" << node->attrs.name << '\n';
} else {
os << "--------------------\n";
os << "Op:" << node->op->name << ", Name=" << node->attrs.name << '\n'
os << "Op:" << node->op()->name << ", Name=" << node->attrs.name << '\n'
<< "Inputs:\n";
for (size_t i = 0; i < node->inputs.size(); ++i) {
const NodeEntry& e = node->inputs[i];
......@@ -196,8 +195,8 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) {
if (node->is_variable()) {
vlist.push_back(node.get());
} else if (fmutate_inputs.count(node->op)) {
for (uint32_t i : fmutate_inputs[node->op](node->attrs)){
} else if (fmutate_inputs.count(node->op())) {
for (uint32_t i : fmutate_inputs[node->op()](node->attrs)){
mutable_set.insert(node->inputs[i].node.get());
}
}
......@@ -221,7 +220,7 @@ std::vector<std::string> Symbol::ListOutputNames() const {
} else {
const std::string& hname = head.node->attrs.name;
std::string rname;
FListOutputNames fn = flist_ouputs.get(head.node->op, nullptr);
FListOutputNames fn = flist_ouputs.get(head.node->op(), nullptr);
if (fn != nullptr) {
rname = fn(head.node->attrs)[head.index];
} else {
......@@ -278,10 +277,10 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
}
// switch to keyword argument matching
if (args.size() != n_req) {
FListInputNames fn = flist_inputs.get(n->op, nullptr);
FListInputNames fn = flist_inputs.get(n->op(), nullptr);
auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs);
if (arg_names.size() != n_req) {
LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op->name;
LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op()->name;
}
size_t nmatched = 0;
for (size_t i = args.size(); i < n_req; ++i) {
......@@ -422,8 +421,8 @@ void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& a
node->attrs.dict[kv.first] = kv.second;
}
}
if (node->op != nullptr && node->op->attr_parser != nullptr) {
node->op->attr_parser(&(node->attrs));
if (node->op() != nullptr && node->op()->attr_parser != nullptr) {
node->op()->attr_parser(&(node->attrs));
}
}
......@@ -461,10 +460,10 @@ Symbol Symbol::CreateFunctor(const Op* op,
std::unordered_map<std::string, std::string> attrs) {
Symbol s;
NodePtr n = Node::Create();
n->op = op;
n->attrs.op = op;
n->attrs.dict = std::move(attrs);
if (n->op->attr_parser != nullptr) {
n->op->attr_parser(&(n->attrs));
if (n->op()->attr_parser != nullptr) {
n->op()->attr_parser(&(n->attrs));
}
s.outputs.emplace_back(NodeEntry{std::move(n), 0, 0});
return s;
......
......@@ -20,11 +20,11 @@ NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
return std::move(v[0]);
} else if (v.size() == 0) {
NodePtr zero_node = Node::Create();
zero_node->op = Op::Get("__zero__");
zero_node->attrs.op = Op::Get("__zero__");
return NodeEntry{zero_node, 0, 0};
} else {
NodePtr sum_node = Node::Create();
sum_node->op = Op::Get("__ewise_sum__");
sum_node->attrs.op = Op::Get("__ewise_sum__");
sum_node->inputs = std::move(v);
return NodeEntry{sum_node, 0, 0};
}
......@@ -109,7 +109,7 @@ Graph Gradient(Graph src) {
e.sum = agg_fun(std::move(e.grads));
out_agg_grads.push_back(e.sum);
}
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op]
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()]
(mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads);
auto git = input_grads.begin();
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
......
......@@ -65,7 +65,7 @@ Graph InferAttr(Graph &&ret,
}
continue;
}
if (finfer_shape.count(inode.source->op)) {
if (finfer_shape.count(inode.source->op())) {
ishape.resize(num_inputs, def_value);
for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
......@@ -75,14 +75,14 @@ Graph InferAttr(Graph &&ret,
oshape[i] = rshape[idx.entry_id(nid, i)];
}
num_unknown +=
!(finfer_shape[inode.source->op](inode.source->attrs, &ishape, &oshape));
!(finfer_shape[inode.source->op()](inode.source->attrs, &ishape, &oshape));
for (uint32_t i = 0; i < num_inputs; ++i) {
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
}
for (uint32_t i = 0; i < num_outputs; ++i) {
rshape[idx.entry_id(nid, i)] = oshape[i];
}
} else if (is_backward.get(inode.source->op, false)) {
} else if (is_backward.get(inode.source->op(), false)) {
// backward operator inference.
CHECK_GE(inode.control_deps.size(), 1)
<< "BackwardOp need to have control_deps to its forward op";
......
......@@ -43,8 +43,8 @@ Graph OrderMutation(const Graph& src) {
auto prepare = [&version_hist, &old_new] (const NodePtr& n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::vector<uint32_t> mutate_inputs;
if (!n->is_variable() && fmutate_inputs.count(n->op)) {
mutate_inputs = fmutate_inputs[n->op](n->attrs);
if (!n->is_variable() && fmutate_inputs.count(n->op())) {
mutate_inputs = fmutate_inputs[n->op()](n->attrs);
}
std::sort(mutate_inputs.begin(), mutate_inputs.end());
......@@ -67,7 +67,6 @@ Graph OrderMutation(const Graph& src) {
}
if (need_repl) {
NodePtr np = Node::Create();
np->op = n->op;
np->attrs = n->attrs;
old_new[n.get()] = std::move(np);
}
......@@ -101,8 +100,8 @@ Graph OrderMutation(const Graph& src) {
// add control deps
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::vector<uint32_t> mutate_inputs;
if (fmutate_inputs.count(kv.first->op)) {
mutate_inputs = fmutate_inputs[kv.first->op](kv.first->attrs);
if (fmutate_inputs.count(kv.first->op())) {
mutate_inputs = fmutate_inputs[kv.first->op()](kv.first->attrs);
}
std::sort(mutate_inputs.begin(), mutate_inputs.end());
......
......@@ -109,9 +109,9 @@ Graph PlaceDevice(Graph src) {
NodeEntry{it->second, 0, 0});
} else {
NodePtr copy_node = Node::Create();
copy_node->op = copy_op;
std::ostringstream os;
os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy";
copy_node->attrs.op = copy_op;
copy_node->attrs.name = os.str();
copy_node->inputs.push_back(inode.source->inputs[i]);
copy_map[copy_key] = copy_node;
......
......@@ -168,8 +168,8 @@ Graph PlanMemory(Graph ret) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
// check inplace option
if (finplace_option.count(inode.source->op) != 0) {
auto inplace_pairs = finplace_option[inode.source->op](inode.source->attrs);
if (finplace_option.count(inode.source->op()) != 0) {
auto inplace_pairs = finplace_option[inode.source->op()](inode.source->attrs);
for (auto& kv : inplace_pairs) {
uint32_t eid_out = idx.entry_id(nid, kv.second);
uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]);
......
......@@ -68,8 +68,8 @@ struct JSONNode {
// function to save JSON node.
void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject();
if (node->op != nullptr) {
writer->WriteObjectKeyValue("op", node->op->name);
if (node->op() != nullptr) {
writer->WriteObjectKeyValue("op", node->op()->name);
} else {
std::string json_null = "null";
writer->WriteObjectKeyValue("op", json_null);
......@@ -108,10 +108,10 @@ struct JSONNode {
if (op_type_str != "null") {
try {
node->op = Op::Get(op_type_str);
node->attrs.op = Op::Get(op_type_str);
// rebuild attribute parser
if (node->op->attr_parser != nullptr) {
node->op->attr_parser(&(node->attrs));
if (node->op()->attr_parser != nullptr) {
node->op()->attr_parser(&(node->attrs));
}
} catch (const dmlc::Error &err) {
std::ostringstream os;
......@@ -120,7 +120,7 @@ struct JSONNode {
throw dmlc::Error(os.str());
}
} else {
node->op = nullptr;
node->attrs.op = nullptr;
}
}
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment