Commit 94fa45f8 by Tianqi Chen

Remove backward index, use gradient guessing instead (#85)

* Remove backward index, use gradient guessing instead

* minor fix

* bugfix
parent e4820d34
...@@ -95,32 +95,13 @@ using FInferType = FInferNodeEntryAttr<int>; ...@@ -95,32 +95,13 @@ using FInferType = FInferNodeEntryAttr<int>;
/*! /*!
* \brief Whether this op is an explicit backward operator, * \brief Whether this op is an explicit backward operator,
* and the correspondence of each output to input. * If TIsBackward is true:
*
* If FBackwardOutToInIndex exists:
* - The first control_deps of the node points to the corresponding forward operator.
* - The k-th outputs corresponds to the FBackwardOutputToInputIndex()[k]-th input of forward op.
*
* \note Register under "FBackwardOutToInIndex"
* This enables easier shape/type inference for backward operators for slice and reduction.
*/
using FBackwardOutToInIndex = std::function<
std::vector<uint32_t> (const NodeAttrs& attrs)>;
/*!
* \brief Whether this op is an explicit backward operator,
* Returns list of input index that corresponds to the outputs of the forward operator.
*
* If FBackwardInGradIndex exists:
* - The first control_deps of the node points to the corresponding forward operator. * - The first control_deps of the node points to the corresponding forward operator.
* - The FBackwardInGradIndex[i]-th input of backward op corresponds to the i-th
* output of forward operator.
* *
* \note Register under "FBackwardInGradIndex" * \note Register under "TIsBackward"
* This enables easier shape/type inference for backward operators. * This enables easier shape/type inference for backward operators.
*/ */
using FBackwardInGradIndex = std::function< using TIsBackward = bool;
std::vector<uint32_t> (const NodeAttrs& attrs)>;
/*! /*!
* \brief Get possible inplace options. * \brief Get possible inplace options.
......
...@@ -25,10 +25,11 @@ Graph InferAttr(Graph &&ret, ...@@ -25,10 +25,11 @@ Graph InferAttr(Graph &&ret,
const IndexedGraph& idx = ret.indexed_graph(); const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape = static auto& finfer_shape =
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name); Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
static auto& backward_map = static auto& is_backward =
Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex"); Op::GetAttr<TIsBackward>("TIsBackward");
static auto& backward_in_grad = // gradient function, used to get node correspondence.
Op::GetAttr<FBackwardInGradIndex>("FBackwardInGradIndex"); static auto& fgrad =
Op::GetAttr<FGradient>("FGradient");
// reshape shape vector // reshape shape vector
AttrVector rshape; AttrVector rshape;
if (ret.attrs.count(attr_name) != 0) { if (ret.attrs.count(attr_name) != 0) {
...@@ -74,29 +75,44 @@ Graph InferAttr(Graph &&ret, ...@@ -74,29 +75,44 @@ Graph InferAttr(Graph &&ret,
CHECK(is >> rshape[out_ent_id]) << "Invalid attribute"; CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
} }
} }
} else if (backward_map.count(inode.source->op())) { } else if (is_backward.get(inode.source->op(), false)) {
// Backward operator inference.
CHECK_GE(inode.control_deps.size(), 1) CHECK_GE(inode.control_deps.size(), 1)
<< "BackwardOp need to have control_deps to its forward op"; << "BackwardOp need to have control_deps to its forward op";
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
// Inference the outputs of backward operator (equal to the inputs NodePtr fwd_ptr = inode.source->control_deps[0];
// of its corresponding forward operator). // use gradient function to find out the correspondence.
std::vector<uint32_t> out_map = std::vector<NodeEntry> ograd(fwd_ptr->num_outputs());
backward_map[inode.source->op()](inode.source->attrs); for (size_t i = 0; i < ograd.size(); ++i) {
for (size_t i = 0; i < out_map.size(); ++i) { ograd[i].index = static_cast<uint32_t>(i);
uint32_t in_id = out_map[i];
CHECK_LT(in_id, fnode.inputs.size());
rshape[idx.entry_id(nid, i)] =
rshape[idx.entry_id(fnode.inputs[in_id])];
} }
if (backward_in_grad.count(inode.source->op())) { // input gradient list
std::vector<uint32_t> in_grad = auto igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd);
backward_in_grad[inode.source->op()](inode.source->attrs); const Op* backward_op = inode.source->op();
CHECK_LE(in_grad.size(), fnode.source->num_outputs()); const Node* igrad_node = nullptr;
for (size_t i = 0; i < in_grad.size(); ++i) { // Input gradient assignement
uint32_t eid = idx.entry_id(inode.inputs[in_grad[i]]); for (size_t i = 0; i < igrad.size(); ++i) {
if (igrad[i].node->op() == backward_op) {
uint32_t eid = idx.entry_id(nid, igrad[i].index);
if (fis_none(rshape[eid])) { if (fis_none(rshape[eid])) {
rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], i)]; rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])];
} else {
CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])])
<< "Backward shape inconsistent with the forward shape";
}
if (igrad_node == nullptr) {
igrad_node = igrad[i].node.get();
} else {
CHECK(igrad_node == igrad[i].node.get());
}
}
}
// out grad entries
for (size_t i = 0; i < igrad_node->inputs.size(); ++i) {
const NodeEntry& e = igrad_node->inputs[i];
if (e.node == nullptr) {
uint32_t eid = idx.entry_id(inode.inputs[i]);
if (fis_none(rshape[eid])) {
rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)];
} }
} }
} }
...@@ -153,8 +169,10 @@ Graph InferAttr(Graph &&ret, ...@@ -153,8 +169,10 @@ Graph InferAttr(Graph &&ret,
} }
} }
num_unknown = 0; num_unknown = 0;
for (size_t i = 0; i < idx.num_node_entries(); ++i) { for (size_t j = 0; j < idx.num_node_entries(); ++j) {
if (fis_none(rshape[i])) ++num_unknown; if (fis_none(rshape[j])) {
++num_unknown;
}
} }
if (num_unknown == 0) break; if (num_unknown == 0) break;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment