Commit 1146b816 by Pedro Larroy Committed by Tianqi Chen

Check that the node is not null, add contains to OpMap (#3037)

parent b405f68b
Subproject commit 3ffea8694adf9c0363f9abbf162dc0e4a45b22c5 Subproject commit 82bf4c2e2af312b3d52513aa727483803a2f8734
...@@ -315,8 +315,12 @@ inline void DFSVisit(const std::vector<NodeEntry>& heads, ...@@ -315,8 +315,12 @@ inline void DFSVisit(const std::vector<NodeEntry>& heads,
}); });
PostOrderDFSVisit<GNode, Node*>( PostOrderDFSVisit<GNode, Node*>(
head_nodes, head_nodes,
[fvisit](GNode n) { fvisit(*n); }, // FVisit [fvisit](GNode n) {
[](GNode n)->Node* { return n->get(); }, // HashFunc fvisit(*n);
}, // FVisit
[](GNode n)->Node* {
return n->get();
}, // HashFunc
[](GNode n)->uint32_t { // InDegree [](GNode n)->uint32_t { // InDegree
if (!(*n)) return 0; if (!(*n)) return 0;
return (*n)->inputs.size() + (*n)->control_deps.size(); return (*n)->inputs.size() + (*n)->control_deps.size();
......
...@@ -340,6 +340,13 @@ class OpMap { ...@@ -340,6 +340,13 @@ class OpMap {
*/ */
inline int count(const Op* op) const; inline int count(const Op* op) const;
/*!
* \brief Check if the map has op as key.
* \param op The key to the map
* \return true if op is contained in map, false otherwise.
*/
inline bool contains(const Op* op) const;
private: private:
friend class Op; friend class Op;
// internal attribute name // internal attribute name
...@@ -539,9 +546,20 @@ inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { // ...@@ -539,9 +546,20 @@ inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { //
// member functions of OpMap // member functions of OpMap
template<typename ValueType> template<typename ValueType>
inline int OpMap<ValueType>::count(const Op* op) const { inline int OpMap<ValueType>::count(const Op* op) const {
if (op == nullptr) return 0; if (contains(op)) {
return 1;
} else {
return 0;
}
}
template<typename ValueType>
inline bool OpMap<ValueType>::contains(const Op* op) const {
if (op == nullptr) {
return false;
}
const uint32_t idx = op->index_; const uint32_t idx = op->index_;
return idx < data_.size() ? (data_[idx].second != 0) : 0; return idx < data_.size() ? (data_[idx].second != 0) : false;
} }
template<typename ValueType> template<typename ValueType>
......
...@@ -78,6 +78,7 @@ IndexedGraph::IndexedGraph(const Graph &g) { ...@@ -78,6 +78,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
(const NodePtr& n) { (const NodePtr& n) {
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max()); CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
uint32_t nid = static_cast<uint32_t>(nodes_.size()); uint32_t nid = static_cast<uint32_t>(nodes_.size());
CHECK(n);
for (const auto &subgraph : n->attrs.subgraphs) for (const auto &subgraph : n->attrs.subgraphs)
subgraphs.push_back(subgraph); subgraphs.push_back(subgraph);
// nodes_ // nodes_
......
...@@ -143,13 +143,13 @@ Graph Gradient(Graph src) { ...@@ -143,13 +143,13 @@ Graph Gradient(Graph src) {
<< "because it is unreachable from the outputs."; << "because it is unreachable from the outputs.";
} }
// construct mirror reduece memory strategy if needed // construct mirror as memory reduction strategy if needed
std::unordered_map<Node*, NodePtr> mirror_map; std::unordered_map<Node*, NodePtr> mirror_map;
if (mirror_fun != nullptr) { if (mirror_fun != nullptr) {
for (const NodePtr& n : topo_order) { for (const NodePtr& node_ptr : topo_order) {
if (mirror_fun(*n)) { if (mirror_fun(*node_ptr)) {
NodePtr new_node = Node::Create(); NodePtr new_node = Node::Create();
*new_node = *n; *new_node = *node_ptr;
new_node->attrs.name += "_mirror"; new_node->attrs.name += "_mirror";
for (auto& e : new_node->inputs) { for (auto& e : new_node->inputs) {
e.node = mirror_map.at(e.node.get()); e.node = mirror_map.at(e.node.get());
...@@ -157,9 +157,9 @@ Graph Gradient(Graph src) { ...@@ -157,9 +157,9 @@ Graph Gradient(Graph src) {
for (auto& n : new_node->control_deps) { for (auto& n : new_node->control_deps) {
n = mirror_map.at(n.get()); n = mirror_map.at(n.get());
} }
mirror_map[n.get()] = std::move(new_node); mirror_map[node_ptr.get()] = std::move(new_node);
} else { } else {
mirror_map[n.get()] = n; mirror_map[node_ptr.get()] = node_ptr;
} }
} }
} }
...@@ -185,7 +185,8 @@ Graph Gradient(Graph src) { ...@@ -185,7 +185,8 @@ Graph Gradient(Graph src) {
if ((*rit)->inputs.size() != 0) { if ((*rit)->inputs.size() != 0) {
NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get())); NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
std::vector<NodeEntry> input_grads; std::vector<NodeEntry> input_grads;
if (grad_fun_map.count(ptr->op())) { // Check for FGradient
if (grad_fun_map.contains(ptr->op())) {
input_grads = grad_fun_map[ptr->op()](fwd_node, out_agg_grads); input_grads = grad_fun_map[ptr->op()](fwd_node, out_agg_grads);
CHECK_EQ((*rit)->inputs.size(), input_grads.size()) CHECK_EQ((*rit)->inputs.size(), input_grads.size())
<< "Gradient function not returning enough gradient"; << "Gradient function not returning enough gradient";
...@@ -205,20 +206,23 @@ Graph Gradient(Graph src) { ...@@ -205,20 +206,23 @@ Graph Gradient(Graph src) {
if (p->op()->attr_parser != nullptr) { if (p->op()->attr_parser != nullptr) {
p->op()->attr_parser(&(p->attrs)); p->op()->attr_parser(&(p->attrs));
} }
input_grads.emplace_back(nnvm::NodeEntry{p, 0, 0}); input_grads.emplace_back(p, 0, 0);
} }
} else { } else {
LOG(FATAL) << "Operator " << fwd_node->op()->name << " is non-differentiable " LOG(FATAL) << "Operator " << fwd_node->op()->name << " is non-differentiable "
<< "because it didn't register FGradient attribute."; << "because it didn't register FGradient attribute.";
} }
for (const auto& nodeEntry : input_grads)
CHECK(nodeEntry.node);
auto git = input_grads.begin(); auto git = input_grads.begin();
CHECK((*rit)->inputs.size() <= input_grads.size());
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) { for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
auto& ge = output_grads[it->node.get()][it->index]; auto& output_grad_entry = output_grads[it->node.get()][it->index];
// if any of the backward op can do shape inference, the hint is not necessary. // if any of the backward op can do shape inference, the hint is not necessary.
if (finfer_shape.count(git->node->op())) { if (finfer_shape.contains(git->node->op())) {
ge.need_attr_hint = false; output_grad_entry.need_attr_hint = false;
} }
ge.grads.emplace_back(std::move(*git)); output_grad_entry.grads.emplace_back(std::move(*git));
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment