Commit 98fd6bd0 by Tianqi Chen

enable shape inference with hint func (#84)

parent c002d80c
......@@ -58,7 +58,7 @@ Graph InferAttr(Graph &&ret,
std::vector<AttrType> ishape, oshape;
// inference step function for nid
auto infer_step = [&](uint32_t nid) {
auto infer_step = [&](uint32_t nid, bool last_iter) {
const auto& inode = idx[nid];
const uint32_t num_inputs = inode.inputs.size();
const uint32_t num_outputs = inode.source->num_outputs();
......@@ -113,17 +113,21 @@ Graph InferAttr(Graph &&ret,
oshape[i] = rshape[idx.entry_id(nid, i)];
if (fis_none(oshape[i])) forward_known = false;
}
if (!forward_known) {
auto finfer = finfer_shape.get(inode.source->op(), fdefault);
CHECK(finfer != nullptr)
<< "Attribute " << infer_name
<< " is not registed by op " << inode.source->op()->name;
if (!forward_known) {
if (finfer != nullptr) {
// Call inference function of the operator.
try {
forward_known = finfer(inode.source->attrs, &ishape, &oshape);
} catch (const std::exception& e) {
throw dmlc::Error(e.what() + std::string(" with ") + inode.source->attrs.name);
}
} else {
CHECK(!last_iter)
<< "Attribute " << infer_name
<< " is not registed by op " << inode.source->op()->name
<< " we are not able to complete the inference because of this";
}
}
// Save to the result map.
for (uint32_t i = 0; i < num_inputs; ++i) {
......@@ -140,12 +144,12 @@ Graph InferAttr(Graph &&ret,
for (int i = 0; i < kMaxStep; ++i) {
if (i % 2 == 0) {
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
infer_step(nid);
infer_step(nid, i + 1 == kMaxStep);
}
} else {
// backward inference
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
infer_step(i - 1);
infer_step(i - 1, i + 1 == kMaxStep);
}
}
num_unknown = 0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment