Commit e8fee6dc by Tianqi Chen

Add shape backward inference (#58)

parent 869a953a
......@@ -108,6 +108,21 @@ using FBackwardOutToInIndex = std::function<
std::vector<uint32_t> (const NodeAttrs& attrs)>;
/*!
* \brief Whether this op is an explicit backward operator,
* Returns list of input index that corresponds to the outputs of the forward operator.
*
* If FBackwardInGradIndex exists:
* - The first control_deps of the node points to the corresponding forward operator.
* - The FBackwardInGradIndex[i]-th input of backward op corresponds to the i-th
* output of forward operator.
*
* \note Register under "FBackwardInGradIndex"
* This enables easier shape/type inference for backward operators.
*/
using FBackwardInGradIndex = std::function<
std::vector<uint32_t> (const NodeAttrs& attrs)>;
/*!
* \brief Get possible inplace options.
* This function enables optimization to reuse memory of inputs in output.
* \param attrs The attributes of the node
......
......@@ -27,6 +27,8 @@ Graph InferAttr(Graph &&ret,
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
static auto& backward_map =
Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex");
static auto& backward_in_grad =
Op::GetAttr<FBackwardInGradIndex>("FBackwardInGradIndex");
// reshape shape vector
AttrVector rshape;
if (ret.attrs.count(attr_name) != 0) {
......@@ -54,7 +56,6 @@ Graph InferAttr(Graph &&ret,
}
// Temp space for shape inference.
std::vector<AttrType> ishape, oshape;
size_t num_unknown;
// inference step function for nid
auto infer_step = [&](uint32_t nid) {
......@@ -76,21 +77,29 @@ Graph InferAttr(Graph &&ret,
} else if (backward_map.count(inode.source->op())) {
// Backward operator inference.
CHECK_GE(inode.control_deps.size(), 1)
<< "BackwardOp need to have control_deps to its forward op";
<< "BackwardOp need to have control_deps to its forward op";
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
// Inference the outputs of backward operator (equal to the inputs
// of its corresponding forward operator).
std::vector<uint32_t> out_map =
backward_map[inode.source->op()](inode.source->attrs);
bool known = true;
for (size_t i = 0; i < out_map.size(); ++i) {
uint32_t in_id = out_map[i];
CHECK_LT(in_id, fnode.inputs.size());
rshape[idx.entry_id(nid, i)] =
rshape[idx.entry_id(fnode.inputs[in_id])];
if (fis_none(rshape[idx.entry_id(nid, i)])) known = false;
}
num_unknown += !known;
if (backward_in_grad.count(inode.source->op())) {
std::vector<uint32_t> in_grad =
backward_in_grad[inode.source->op()](inode.source->attrs);
CHECK_LE(in_grad.size(), fnode.source->num_outputs());
for (size_t i = 0; i < in_grad.size(); ++i) {
uint32_t eid = idx.entry_id(inode.inputs[in_grad[i]]);
if (fis_none(rshape[eid])) {
rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], i)];
}
}
}
} else {
bool forward_known = true;
// Forward operator inference.
......@@ -112,7 +121,6 @@ Graph InferAttr(Graph &&ret,
// Call inference function of the operator.
forward_known = finfer(inode.source->attrs, &ishape, &oshape);
}
num_unknown += !forward_known;
// Save to the result map.
for (uint32_t i = 0; i < num_inputs; ++i) {
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
......@@ -123,16 +131,24 @@ Graph InferAttr(Graph &&ret,
}
};
num_unknown = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
infer_step(nid);
}
if (num_unknown != 0) {
size_t num_unknown = 0;
const int kMaxStep = 3;
for (int i = 0; i < kMaxStep; ++i) {
if (i % 2 == 0) {
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
infer_step(nid);
}
} else {
// backward inference
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
infer_step(i - 1);
}
}
num_unknown = 0;
// backward inference
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
infer_step(i - 1);
for (size_t i = 0; i < idx.num_node_entries(); ++i) {
if (fis_none(rshape[i])) ++num_unknown;
}
if (num_unknown == 0) break;
}
// set the shapes
ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment