Commit 1146b816 by Pedro Larroy Committed by Tianqi Chen

Check that the node is not null, add contains to OpMap (#3037)

parent b405f68b
Subproject commit 3ffea8694adf9c0363f9abbf162dc0e4a45b22c5
Subproject commit 82bf4c2e2af312b3d52513aa727483803a2f8734
......@@ -315,12 +315,16 @@ inline void DFSVisit(const std::vector<NodeEntry>& heads,
});
PostOrderDFSVisit<GNode, Node*>(
head_nodes,
[fvisit](GNode n) { fvisit(*n); }, // FVisit
[](GNode n)->Node* { return n->get(); }, // HashFunc
[fvisit](GNode n) {
fvisit(*n);
}, // FVisit
[](GNode n)->Node* {
return n->get();
}, // HashFunc
[](GNode n)->uint32_t { // InDegree
if (!(*n)) return 0;
return (*n)->inputs.size() + (*n)->control_deps.size();
},
},
[](GNode n, uint32_t index)->GNode { // GetInput
if (index < (*n)->inputs.size()) {
return &(*n)->inputs.at(index).node;
......
......@@ -340,6 +340,13 @@ class OpMap {
*/
inline int count(const Op* op) const;
/*!
* \brief Check if the map has op as key.
* \param op The key to the map
* \return true if op is contained in map, false otherwise.
*/
inline bool contains(const Op* op) const;
private:
friend class Op;
// internal attribute name
......@@ -539,9 +546,20 @@ inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { //
// member functions of OpMap
template<typename ValueType>
inline int OpMap<ValueType>::count(const Op* op) const {
if (op == nullptr) return 0;
if (contains(op)) {
return 1;
} else {
return 0;
}
}
template<typename ValueType>
inline bool OpMap<ValueType>::contains(const Op* op) const {
if (op == nullptr) {
return false;
}
const uint32_t idx = op->index_;
return idx < data_.size() ? (data_[idx].second != 0) : 0;
return idx < data_.size() ? (data_[idx].second != 0) : false;
}
template<typename ValueType>
......
......@@ -78,6 +78,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
(const NodePtr& n) {
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
uint32_t nid = static_cast<uint32_t>(nodes_.size());
CHECK(n);
for (const auto &subgraph : n->attrs.subgraphs)
subgraphs.push_back(subgraph);
// nodes_
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -143,13 +143,13 @@ Graph Gradient(Graph src) {
<< "because it is unreachable from the outputs.";
}
// construct mirror reduece memory strategy if needed
// construct mirror as memory reduction strategy if needed
std::unordered_map<Node*, NodePtr> mirror_map;
if (mirror_fun != nullptr) {
for (const NodePtr& n : topo_order) {
if (mirror_fun(*n)) {
for (const NodePtr& node_ptr : topo_order) {
if (mirror_fun(*node_ptr)) {
NodePtr new_node = Node::Create();
*new_node = *n;
*new_node = *node_ptr;
new_node->attrs.name += "_mirror";
for (auto& e : new_node->inputs) {
e.node = mirror_map.at(e.node.get());
......@@ -157,9 +157,9 @@ Graph Gradient(Graph src) {
for (auto& n : new_node->control_deps) {
n = mirror_map.at(n.get());
}
mirror_map[n.get()] = std::move(new_node);
mirror_map[node_ptr.get()] = std::move(new_node);
} else {
mirror_map[n.get()] = n;
mirror_map[node_ptr.get()] = node_ptr;
}
}
}
......@@ -185,7 +185,8 @@ Graph Gradient(Graph src) {
if ((*rit)->inputs.size() != 0) {
NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
std::vector<NodeEntry> input_grads;
if (grad_fun_map.count(ptr->op())) {
// Check for FGradient
if (grad_fun_map.contains(ptr->op())) {
input_grads = grad_fun_map[ptr->op()](fwd_node, out_agg_grads);
CHECK_EQ((*rit)->inputs.size(), input_grads.size())
<< "Gradient function not returning enough gradient";
......@@ -205,20 +206,23 @@ Graph Gradient(Graph src) {
if (p->op()->attr_parser != nullptr) {
p->op()->attr_parser(&(p->attrs));
}
input_grads.emplace_back(nnvm::NodeEntry{p, 0, 0});
input_grads.emplace_back(p, 0, 0);
}
} else {
LOG(FATAL) << "Operator " << fwd_node->op()->name << " is non-differentiable "
<< "because it didn't register FGradient attribute.";
}
for (const auto& nodeEntry : input_grads)
CHECK(nodeEntry.node);
auto git = input_grads.begin();
CHECK((*rit)->inputs.size() <= input_grads.size());
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
auto& ge = output_grads[it->node.get()][it->index];
auto& output_grad_entry = output_grads[it->node.get()][it->index];
// if any of the backward op can do shape inference, the hint is not necessary.
if (finfer_shape.count(git->node->op())) {
ge.need_attr_hint = false;
if (finfer_shape.contains(git->node->op())) {
output_grad_entry.need_attr_hint = false;
}
ge.grads.emplace_back(std::move(*git));
output_grad_entry.grads.emplace_back(std::move(*git));
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment