Commit 3ae81914 by Eric Junyuan Xie Committed by Tianqi Chen

enhance shape inference. allow in complete shape (#94)

parent 02396a7f
...@@ -387,6 +387,11 @@ class TShape : public Tuple<index_t> { ...@@ -387,6 +387,11 @@ class TShape : public Tuple<index_t> {
} }
#ifdef MSHADOW_XINLINE #ifdef MSHADOW_XINLINE
template<int dim> template<int dim>
inline TShape(const mshadow::Shape<dim> &s) {// NOLINT(*)
this->assign(s.shape_, s.shape_ + dim);
}
template<int dim>
inline TShape(mshadow::Shape<dim> &&s) {// NOLINT(*) inline TShape(mshadow::Shape<dim> &&s) {// NOLINT(*)
this->assign(s.shape_, s.shape_ + dim); this->assign(s.shape_, s.shape_ + dim);
} }
......
...@@ -155,27 +155,29 @@ Graph InferAttr(Graph &&ret, ...@@ -155,27 +155,29 @@ Graph InferAttr(Graph &&ret,
} }
}; };
size_t num_unknown = 0; size_t last_num_unknown;
const int kMaxStep = 3; size_t num_unknown = rshape.size();
for (int i = 0; i < kMaxStep; ++i) { int i = 0;
do {
if (i % 2 == 0) { if (i % 2 == 0) {
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
infer_step(nid, i + 1 == kMaxStep); infer_step(nid, false);
} }
} else { } else {
// backward inference // backward inference
for (uint32_t i = idx.num_nodes(); i != 0; --i) { for (uint32_t i = idx.num_nodes(); i != 0; --i) {
infer_step(i - 1, i + 1 == kMaxStep); infer_step(i - 1, false);
} }
} }
last_num_unknown = num_unknown;
num_unknown = 0; num_unknown = 0;
for (size_t j = 0; j < idx.num_node_entries(); ++j) { for (size_t j = 0; j < idx.num_node_entries(); ++j) {
if (fis_none(rshape[j])) { if (fis_none(rshape[j])) {
++num_unknown; ++num_unknown;
} }
} }
if (num_unknown == 0) break; ++i;
} } while (num_unknown > 0 && last_num_unknown > num_unknown);
// set the shapes // set the shapes
ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape)); ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
// number of nodes who knows the shape. // number of nodes who knows the shape.
...@@ -190,7 +192,7 @@ NNVM_REGISTER_PASS(InferShape) ...@@ -190,7 +192,7 @@ NNVM_REGISTER_PASS(InferShape)
std::move(ret), TShape(), std::move(ret), TShape(),
"FInferShape", "shape_inputs", "shape_attr_key", "FInferShape", "shape_inputs", "shape_attr_key",
"shape", "shape_num_unknown_nodes", "shape", "shape_num_unknown_nodes",
[](const TShape& s) { return s.ndim() == 0; }, [](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; },
nullptr); nullptr);
}) })
.set_change_graph(false) .set_change_graph(false)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment