Commit 94fa45f8 by Tianqi Chen

Remove backward index, use gradient guessing instead (#85)

* Remove backward index, use gradient guessing instead

* minor fix

* bugfix
parent e4820d34
......@@ -95,32 +95,13 @@ using FInferType = FInferNodeEntryAttr<int>;
/*!
* \brief Whether this op is an explicit backward operator,
* and the correspondence of each output to input.
*
* If FBackwardOutToInIndex exists:
* - The first control_deps of the node points to the corresponding forward operator.
* - The k-th outputs corresponds to the FBackwardOutputToInputIndex()[k]-th input of forward op.
*
* \note Register under "FBackwardOutToInIndex"
* This enables easier shape/type inference for backward operators for slice and reduction.
*/
using FBackwardOutToInIndex = std::function<
std::vector<uint32_t> (const NodeAttrs& attrs)>;
/*!
* \brief Whether this op is an explicit backward operator,
* Returns list of input index that corresponds to the outputs of the forward operator.
*
* If FBackwardInGradIndex exists:
* If TIsBackward is true:
* - The first control_deps of the node points to the corresponding forward operator.
* - The FBackwardInGradIndex[i]-th input of backward op corresponds to the i-th
* output of forward operator.
*
* \note Register under "FBackwardInGradIndex"
* \note Register under "TIsBackward"
* This enables easier shape/type inference for backward operators.
*/
using FBackwardInGradIndex = std::function<
std::vector<uint32_t> (const NodeAttrs& attrs)>;
using TIsBackward = bool;
/*!
* \brief Get possible inplace options.
......
......@@ -25,10 +25,11 @@ Graph InferAttr(Graph &&ret,
const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape =
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
static auto& backward_map =
Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex");
static auto& backward_in_grad =
Op::GetAttr<FBackwardInGradIndex>("FBackwardInGradIndex");
static auto& is_backward =
Op::GetAttr<TIsBackward>("TIsBackward");
// gradient function, used to get node correspondence.
static auto& fgrad =
Op::GetAttr<FGradient>("FGradient");
// reshape shape vector
AttrVector rshape;
if (ret.attrs.count(attr_name) != 0) {
......@@ -74,29 +75,44 @@ Graph InferAttr(Graph &&ret,
CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
}
}
} else if (backward_map.count(inode.source->op())) {
// Backward operator inference.
} else if (is_backward.get(inode.source->op(), false)) {
CHECK_GE(inode.control_deps.size(), 1)
<< "BackwardOp need to have control_deps to its forward op";
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
// Inference the outputs of backward operator (equal to the inputs
// of its corresponding forward operator).
std::vector<uint32_t> out_map =
backward_map[inode.source->op()](inode.source->attrs);
for (size_t i = 0; i < out_map.size(); ++i) {
uint32_t in_id = out_map[i];
CHECK_LT(in_id, fnode.inputs.size());
rshape[idx.entry_id(nid, i)] =
rshape[idx.entry_id(fnode.inputs[in_id])];
NodePtr fwd_ptr = inode.source->control_deps[0];
// use gradient function to find out the correspondence.
std::vector<NodeEntry> ograd(fwd_ptr->num_outputs());
for (size_t i = 0; i < ograd.size(); ++i) {
ograd[i].index = static_cast<uint32_t>(i);
}
if (backward_in_grad.count(inode.source->op())) {
std::vector<uint32_t> in_grad =
backward_in_grad[inode.source->op()](inode.source->attrs);
CHECK_LE(in_grad.size(), fnode.source->num_outputs());
for (size_t i = 0; i < in_grad.size(); ++i) {
uint32_t eid = idx.entry_id(inode.inputs[in_grad[i]]);
// input gradient list
auto igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd);
const Op* backward_op = inode.source->op();
const Node* igrad_node = nullptr;
// Input gradient assignement
for (size_t i = 0; i < igrad.size(); ++i) {
if (igrad[i].node->op() == backward_op) {
uint32_t eid = idx.entry_id(nid, igrad[i].index);
if (fis_none(rshape[eid])) {
rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], i)];
rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])];
} else {
CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])])
<< "Backward shape inconsistent with the forward shape";
}
if (igrad_node == nullptr) {
igrad_node = igrad[i].node.get();
} else {
CHECK(igrad_node == igrad[i].node.get());
}
}
}
// out grad entries
for (size_t i = 0; i < igrad_node->inputs.size(); ++i) {
const NodeEntry& e = igrad_node->inputs[i];
if (e.node == nullptr) {
uint32_t eid = idx.entry_id(inode.inputs[i]);
if (fis_none(rshape[eid])) {
rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)];
}
}
}
......@@ -153,8 +169,10 @@ Graph InferAttr(Graph &&ret,
}
}
num_unknown = 0;
for (size_t i = 0; i < idx.num_node_entries(); ++i) {
if (fis_none(rshape[i])) ++num_unknown;
for (size_t j = 0; j < idx.num_node_entries(); ++j) {
if (fis_none(rshape[j])) {
++num_unknown;
}
}
if (num_unknown == 0) break;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment