Commit 3ae81914 by Eric Junyuan Xie Committed by Tianqi Chen

enhance shape inference. allow in complete shape (#94)

parent 02396a7f
......@@ -387,6 +387,11 @@ class TShape : public Tuple<index_t> {
}
#ifdef MSHADOW_XINLINE
template<int dim>
inline TShape(const mshadow::Shape<dim> &s) {// NOLINT(*)
this->assign(s.shape_, s.shape_ + dim);
}
template<int dim>
inline TShape(mshadow::Shape<dim> &&s) {// NOLINT(*)
this->assign(s.shape_, s.shape_ + dim);
}
......
......@@ -155,27 +155,29 @@ Graph InferAttr(Graph &&ret,
}
};
size_t num_unknown = 0;
const int kMaxStep = 3;
for (int i = 0; i < kMaxStep; ++i) {
size_t last_num_unknown;
size_t num_unknown = rshape.size();
int i = 0;
do {
if (i % 2 == 0) {
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
infer_step(nid, i + 1 == kMaxStep);
infer_step(nid, false);
}
} else {
// backward inference
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
infer_step(i - 1, i + 1 == kMaxStep);
infer_step(i - 1, false);
}
}
last_num_unknown = num_unknown;
num_unknown = 0;
for (size_t j = 0; j < idx.num_node_entries(); ++j) {
if (fis_none(rshape[j])) {
++num_unknown;
}
}
if (num_unknown == 0) break;
}
++i;
} while (num_unknown > 0 && last_num_unknown > num_unknown);
// set the shapes
ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
// number of nodes who knows the shape.
......@@ -190,7 +192,7 @@ NNVM_REGISTER_PASS(InferShape)
std::move(ret), TShape(),
"FInferShape", "shape_inputs", "shape_attr_key",
"shape", "shape_num_unknown_nodes",
[](const TShape& s) { return s.ndim() == 0; },
[](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; },
nullptr);
})
.set_change_graph(false)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment