Commit 98fd6bd0 by Tianqi Chen

enable shape inference with hint func (#84)

parent c002d80c
...@@ -58,7 +58,7 @@ Graph InferAttr(Graph &&ret, ...@@ -58,7 +58,7 @@ Graph InferAttr(Graph &&ret,
std::vector<AttrType> ishape, oshape; std::vector<AttrType> ishape, oshape;
// inference step function for nid // inference step function for nid
auto infer_step = [&](uint32_t nid) { auto infer_step = [&](uint32_t nid, bool last_iter) {
const auto& inode = idx[nid]; const auto& inode = idx[nid];
const uint32_t num_inputs = inode.inputs.size(); const uint32_t num_inputs = inode.inputs.size();
const uint32_t num_outputs = inode.source->num_outputs(); const uint32_t num_outputs = inode.source->num_outputs();
...@@ -113,17 +113,21 @@ Graph InferAttr(Graph &&ret, ...@@ -113,17 +113,21 @@ Graph InferAttr(Graph &&ret,
oshape[i] = rshape[idx.entry_id(nid, i)]; oshape[i] = rshape[idx.entry_id(nid, i)];
if (fis_none(oshape[i])) forward_known = false; if (fis_none(oshape[i])) forward_known = false;
} }
if (!forward_known) {
auto finfer = finfer_shape.get(inode.source->op(), fdefault); auto finfer = finfer_shape.get(inode.source->op(), fdefault);
CHECK(finfer != nullptr) if (!forward_known) {
<< "Attribute " << infer_name if (finfer != nullptr) {
<< " is not registed by op " << inode.source->op()->name;
// Call inference function of the operator. // Call inference function of the operator.
try { try {
forward_known = finfer(inode.source->attrs, &ishape, &oshape); forward_known = finfer(inode.source->attrs, &ishape, &oshape);
} catch (const std::exception& e) { } catch (const std::exception& e) {
throw dmlc::Error(e.what() + std::string(" with ") + inode.source->attrs.name); throw dmlc::Error(e.what() + std::string(" with ") + inode.source->attrs.name);
} }
} else {
CHECK(!last_iter)
<< "Attribute " << infer_name
<< " is not registed by op " << inode.source->op()->name
<< " we are not able to complete the inference because of this";
}
} }
// Save to the result map. // Save to the result map.
for (uint32_t i = 0; i < num_inputs; ++i) { for (uint32_t i = 0; i < num_inputs; ++i) {
...@@ -140,12 +144,12 @@ Graph InferAttr(Graph &&ret, ...@@ -140,12 +144,12 @@ Graph InferAttr(Graph &&ret,
for (int i = 0; i < kMaxStep; ++i) { for (int i = 0; i < kMaxStep; ++i) {
if (i % 2 == 0) { if (i % 2 == 0) {
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
infer_step(nid); infer_step(nid, i + 1 == kMaxStep);
} }
} else { } else {
// backward inference // backward inference
for (uint32_t i = idx.num_nodes(); i != 0; --i) { for (uint32_t i = idx.num_nodes(); i != 0; --i) {
infer_step(i - 1); infer_step(i - 1, i + 1 == kMaxStep);
} }
} }
num_unknown = 0; num_unknown = 0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment