Commit 98fd6bd0 by Tianqi Chen

enable shape inference with hint func (#84)

parent c002d80c
...@@ -58,7 +58,7 @@ Graph InferAttr(Graph &&ret, ...@@ -58,7 +58,7 @@ Graph InferAttr(Graph &&ret,
std::vector<AttrType> ishape, oshape; std::vector<AttrType> ishape, oshape;
// inference step function for nid // inference step function for nid
auto infer_step = [&](uint32_t nid) { auto infer_step = [&](uint32_t nid, bool last_iter) {
const auto& inode = idx[nid]; const auto& inode = idx[nid];
const uint32_t num_inputs = inode.inputs.size(); const uint32_t num_inputs = inode.inputs.size();
const uint32_t num_outputs = inode.source->num_outputs(); const uint32_t num_outputs = inode.source->num_outputs();
...@@ -113,16 +113,20 @@ Graph InferAttr(Graph &&ret, ...@@ -113,16 +113,20 @@ Graph InferAttr(Graph &&ret,
oshape[i] = rshape[idx.entry_id(nid, i)]; oshape[i] = rshape[idx.entry_id(nid, i)];
if (fis_none(oshape[i])) forward_known = false; if (fis_none(oshape[i])) forward_known = false;
} }
auto finfer = finfer_shape.get(inode.source->op(), fdefault);
if (!forward_known) { if (!forward_known) {
auto finfer = finfer_shape.get(inode.source->op(), fdefault); if (finfer != nullptr) {
CHECK(finfer != nullptr) // Call inference function of the operator.
<< "Attribute " << infer_name try {
<< " is not registed by op " << inode.source->op()->name; forward_known = finfer(inode.source->attrs, &ishape, &oshape);
// Call inference function of the operator. } catch (const std::exception& e) {
try { throw dmlc::Error(e.what() + std::string(" with ") + inode.source->attrs.name);
forward_known = finfer(inode.source->attrs, &ishape, &oshape); }
} catch (const std::exception& e) { } else {
throw dmlc::Error(e.what() + std::string(" with ") + inode.source->attrs.name); CHECK(!last_iter)
<< "Attribute " << infer_name
<< " is not registed by op " << inode.source->op()->name
<< " we are not able to complete the inference because of this";
} }
} }
// Save to the result map. // Save to the result map.
...@@ -140,12 +144,12 @@ Graph InferAttr(Graph &&ret, ...@@ -140,12 +144,12 @@ Graph InferAttr(Graph &&ret,
for (int i = 0; i < kMaxStep; ++i) { for (int i = 0; i < kMaxStep; ++i) {
if (i % 2 == 0) { if (i % 2 == 0) {
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
infer_step(nid); infer_step(nid, i + 1 == kMaxStep);
} }
} else { } else {
// backward inference // backward inference
for (uint32_t i = idx.num_nodes(); i != 0; --i) { for (uint32_t i = idx.num_nodes(); i != 0; --i) {
infer_step(i - 1); infer_step(i - 1, i + 1 == kMaxStep);
} }
} }
num_unknown = 0; num_unknown = 0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment