Commit 88163ec1 by Przemyslaw Tredak Committed by Tianqi Chen

Ghost nodes in NNVM graph (#3290)

parent 165aa0db
...@@ -137,6 +137,17 @@ using FInferType = FInferNodeEntryAttr<int>; ...@@ -137,6 +137,17 @@ using FInferType = FInferNodeEntryAttr<int>;
using TIsBackward = bool; using TIsBackward = bool;
/*! /*!
* \brief Whether this op is a ghost node.
* If TIsGhost is true:
* - The node with this op will not be visible in the indexed graph.
*
* \note Register under "TIsGhost"
* This enables shape/type inference for backward nodes when
* fusion is present.
*/
using TIsGhost = bool;
/*!
* \brief Get possible inplace options. * \brief Get possible inplace options.
* This function enables optimization to reuse memory of inputs in output. * This function enables optimization to reuse memory of inputs in output.
* \param attrs The attributes of the node * \param attrs The attributes of the node
......
...@@ -76,6 +76,8 @@ IndexedGraph::IndexedGraph(const Graph &g) { ...@@ -76,6 +76,8 @@ IndexedGraph::IndexedGraph(const Graph &g) {
DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs] DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs]
(const NodePtr& n) { (const NodePtr& n) {
const auto& is_ghost = Op::GetAttr<TIsGhost>("TIsGhost");
if (!n->is_variable() && is_ghost.get(n->op(), false)) return;
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max()); CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
uint32_t nid = static_cast<uint32_t>(nodes_.size()); uint32_t nid = static_cast<uint32_t>(nodes_.size());
CHECK(n); CHECK(n);
...@@ -103,6 +105,7 @@ IndexedGraph::IndexedGraph(const Graph &g) { ...@@ -103,6 +105,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
inputs_rptr.push_back(input_entries_.size()); inputs_rptr.push_back(input_entries_.size());
// control deps // control deps
for (const auto& nptr : n->control_deps) { for (const auto& nptr : n->control_deps) {
if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue;
auto it = node2index_.find(nptr.get()); auto it = node2index_.find(nptr.get());
CHECK(it != node2index_.end() && it->first == nptr.get()); CHECK(it != node2index_.end() && it->first == nptr.get());
control_deps_.push_back(it->second); control_deps_.push_back(it->second);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment