Commit 3b6c66d1 by Ziheng Jiang Committed by Tianqi Chen

Enable shape hints during infer_shape pass (#107)

* enable shape hints during infer_shape pass

* fix comment
parent 96db41db
......@@ -171,6 +171,12 @@ class IndexedGraph {
inline const std::vector<NodeEntry>& outputs() const {
return outputs_;
}
/*! \return whether a node is existed in the indexed graph */
inline bool exist(const nnvm::Node* node) const {
return node2index_.count(node);
}
// disalllow copy assign
IndexedGraph(const IndexedGraph&) = delete;
......
......@@ -42,6 +42,34 @@ struct NodeEntry {
};
/*!
* \brief This lets you use a NodeEntry as a key in a unordered_map of the form
* unordered_map<NodeEntry, ValueType, NodeEntryHash, NodeEntryEqual>
*/
struct NodeEntryHash {
size_t operator()(const NodeEntry& e) const {
return std::hash<Node*>()(e.node.get()) ^
(std::hash<size_t>()(e.index) << 1 >> 1) ^
(std::hash<size_t>()(e.version) << 1);
}
};
/*!
* \brief This lets you use a NodeEntry as a key in a unordered_map of the form
* unordered_map<NodeEntry, ValueType, NodeEntryHash, NodeEntryEqual>
*/
struct NodeEntryEqual {
size_t operator()(const NodeEntry& a, const NodeEntry& b) const {
return (a.node.get() == b.node.get()) &&
(a.index == b.index) &&
(a.version == b.version);
}
};
/*! use NodeEntry as key in unordered_map */
template<typename ValueType>
using NodeEntryMap = std::unordered_map<NodeEntry, ValueType, NodeEntryHash, NodeEntryEqual>;
/*!
* \brief The attributes of the current operation node.
* Usually are additional parameters like axis,
*/
......
......@@ -49,6 +49,19 @@ Graph InferAttr(Graph &&ret,
ret.attrs.erase(input_name);
}
// get the shape hints
std::string shape_hints_key = std::string(attr_name) + "_hints";
if (ret.attrs.count(shape_hints_key)) {
NodeEntryMap<AttrType> shape_hints =
ret.GetAttr<NodeEntryMap<AttrType>>(shape_hints_key);
for (const auto& kv : shape_hints) {
NodeEntry e = kv.first;
if (idx.exist(e.node.get())) {
rshape[idx.entry_id(kv.first)] = kv.second;
}
}
}
std::string shape_attr_key;
if (ret.attrs.count(attr_key_name) != 0) {
shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
......@@ -75,7 +88,7 @@ Graph InferAttr(Graph &&ret,
CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
}
}
} else if (is_backward.get(inode.source->op(), false)) {
} else if (is_backward.get(inode.source->op(), false) && inode.control_deps.size()) {
CHECK_GE(inode.control_deps.size(), 1U)
<< "BackwardOp need to have control_deps to its forward op";
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment