Commit e8fee6dc by Tianqi Chen

Add shape backward inference (#58)

parent 869a953a
...@@ -108,6 +108,21 @@ using FBackwardOutToInIndex = std::function< ...@@ -108,6 +108,21 @@ using FBackwardOutToInIndex = std::function<
std::vector<uint32_t> (const NodeAttrs& attrs)>; std::vector<uint32_t> (const NodeAttrs& attrs)>;
/*! /*!
* \brief Whether this op is an explicit backward operator,
* Returns list of input index that corresponds to the outputs of the forward operator.
*
* If FBackwardInGradIndex exists:
* - The first control_deps of the node points to the corresponding forward operator.
* - The FBackwardInGradIndex[i]-th input of backward op corresponds to the i-th
* output of forward operator.
*
* \note Register under "FBackwardInGradIndex"
* This enables easier shape/type inference for backward operators.
*/
using FBackwardInGradIndex = std::function<
std::vector<uint32_t> (const NodeAttrs& attrs)>;
/*!
* \brief Get possible inplace options. * \brief Get possible inplace options.
* This function enables optimization to reuse memory of inputs in output. * This function enables optimization to reuse memory of inputs in output.
* \param attrs The attributes of the node * \param attrs The attributes of the node
......
...@@ -27,6 +27,8 @@ Graph InferAttr(Graph &&ret, ...@@ -27,6 +27,8 @@ Graph InferAttr(Graph &&ret,
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name); Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
static auto& backward_map = static auto& backward_map =
Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex"); Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex");
static auto& backward_in_grad =
Op::GetAttr<FBackwardInGradIndex>("FBackwardInGradIndex");
// reshape shape vector // reshape shape vector
AttrVector rshape; AttrVector rshape;
if (ret.attrs.count(attr_name) != 0) { if (ret.attrs.count(attr_name) != 0) {
...@@ -54,7 +56,6 @@ Graph InferAttr(Graph &&ret, ...@@ -54,7 +56,6 @@ Graph InferAttr(Graph &&ret,
} }
// Temp space for shape inference. // Temp space for shape inference.
std::vector<AttrType> ishape, oshape; std::vector<AttrType> ishape, oshape;
size_t num_unknown;
// inference step function for nid // inference step function for nid
auto infer_step = [&](uint32_t nid) { auto infer_step = [&](uint32_t nid) {
...@@ -82,15 +83,23 @@ Graph InferAttr(Graph &&ret, ...@@ -82,15 +83,23 @@ Graph InferAttr(Graph &&ret,
// of its corresponding forward operator). // of its corresponding forward operator).
std::vector<uint32_t> out_map = std::vector<uint32_t> out_map =
backward_map[inode.source->op()](inode.source->attrs); backward_map[inode.source->op()](inode.source->attrs);
bool known = true;
for (size_t i = 0; i < out_map.size(); ++i) { for (size_t i = 0; i < out_map.size(); ++i) {
uint32_t in_id = out_map[i]; uint32_t in_id = out_map[i];
CHECK_LT(in_id, fnode.inputs.size()); CHECK_LT(in_id, fnode.inputs.size());
rshape[idx.entry_id(nid, i)] = rshape[idx.entry_id(nid, i)] =
rshape[idx.entry_id(fnode.inputs[in_id])]; rshape[idx.entry_id(fnode.inputs[in_id])];
if (fis_none(rshape[idx.entry_id(nid, i)])) known = false;
} }
num_unknown += !known; if (backward_in_grad.count(inode.source->op())) {
std::vector<uint32_t> in_grad =
backward_in_grad[inode.source->op()](inode.source->attrs);
CHECK_LE(in_grad.size(), fnode.source->num_outputs());
for (size_t i = 0; i < in_grad.size(); ++i) {
uint32_t eid = idx.entry_id(inode.inputs[in_grad[i]]);
if (fis_none(rshape[eid])) {
rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], i)];
}
}
}
} else { } else {
bool forward_known = true; bool forward_known = true;
// Forward operator inference. // Forward operator inference.
...@@ -112,7 +121,6 @@ Graph InferAttr(Graph &&ret, ...@@ -112,7 +121,6 @@ Graph InferAttr(Graph &&ret,
// Call inference function of the operator. // Call inference function of the operator.
forward_known = finfer(inode.source->attrs, &ishape, &oshape); forward_known = finfer(inode.source->attrs, &ishape, &oshape);
} }
num_unknown += !forward_known;
// Save to the result map. // Save to the result map.
for (uint32_t i = 0; i < num_inputs; ++i) { for (uint32_t i = 0; i < num_inputs; ++i) {
rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
...@@ -123,17 +131,25 @@ Graph InferAttr(Graph &&ret, ...@@ -123,17 +131,25 @@ Graph InferAttr(Graph &&ret,
} }
}; };
num_unknown = 0; size_t num_unknown = 0;
const int kMaxStep = 3;
for (int i = 0; i < kMaxStep; ++i) {
if (i % 2 == 0) {
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
infer_step(nid); infer_step(nid);
} }
if (num_unknown != 0) { } else {
num_unknown = 0;
// backward inference // backward inference
for (uint32_t i = idx.num_nodes(); i != 0; --i) { for (uint32_t i = idx.num_nodes(); i != 0; --i) {
infer_step(i - 1); infer_step(i - 1);
} }
} }
num_unknown = 0;
for (size_t i = 0; i < idx.num_node_entries(); ++i) {
if (fis_none(rshape[i])) ++num_unknown;
}
if (num_unknown == 0) break;
}
// set the shapes // set the shapes
ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape)); ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
// number of nodes who knows the shape. // number of nodes who knows the shape.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment