Commit 45da8718 by Tianqi Chen

[Infer] More robust inference, support backward inference (#54)

parent 647267da
...@@ -11,15 +11,16 @@ namespace nnvm { ...@@ -11,15 +11,16 @@ namespace nnvm {
namespace pass { namespace pass {
namespace { namespace {
template<typename AttrType, typename IsNone> template<typename AttrType, typename IsNone, typename FDefault>
Graph InferAttr(Graph &&ret, Graph InferAttr(Graph &&ret,
const AttrType default_val, const AttrType empty_val,
const char* infer_name, const char* infer_name,
const char* input_name, const char* input_name,
const char* attr_key_name, const char* attr_key_name,
const char* attr_name, const char* attr_name,
const char* unknown_name, const char* unknown_name,
IsNone fis_none) { IsNone fis_none,
FDefault fdefault) {
using AttrVector = std::vector<AttrType>; using AttrVector = std::vector<AttrType>;
const IndexedGraph& idx = ret.indexed_graph(); const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape = static auto& finfer_shape =
...@@ -31,7 +32,7 @@ Graph InferAttr(Graph &&ret, ...@@ -31,7 +32,7 @@ Graph InferAttr(Graph &&ret,
if (ret.attrs.count(attr_name) != 0) { if (ret.attrs.count(attr_name) != 0) {
rshape = ret.MoveCopyAttr<AttrVector>(attr_name); rshape = ret.MoveCopyAttr<AttrVector>(attr_name);
} else { } else {
rshape.resize(idx.num_node_entries(), default_val); rshape.resize(idx.num_node_entries(), empty_val);
} }
if (ret.attrs.count(input_name) != 0) { if (ret.attrs.count(input_name) != 0) {
...@@ -51,12 +52,12 @@ Graph InferAttr(Graph &&ret, ...@@ -51,12 +52,12 @@ Graph InferAttr(Graph &&ret,
// erase the provided arguments // erase the provided arguments
ret.attrs.erase(attr_key_name); ret.attrs.erase(attr_key_name);
} }
// Temp space for shape inference. // Temp space for shape inference.
std::vector<AttrType> ishape, oshape; std::vector<AttrType> ishape, oshape;
// number of completed nodes size_t num_unknown;
size_t num_unknown = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { // inference step function for nid
auto infer_step = [&](uint32_t nid) {
const auto& inode = idx[nid]; const auto& inode = idx[nid];
const uint32_t num_inputs = inode.inputs.size(); const uint32_t num_inputs = inode.inputs.size();
const uint32_t num_outputs = inode.source->num_outputs(); const uint32_t num_outputs = inode.source->num_outputs();
...@@ -72,27 +73,6 @@ Graph InferAttr(Graph &&ret, ...@@ -72,27 +73,6 @@ Graph InferAttr(Graph &&ret,
CHECK(is >> rshape[out_ent_id]) << "Invalid attribute"; CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
} }
} }
} else if (finfer_shape.count(inode.source->op())) {
// Forward operator inference.
ishape.resize(num_inputs, default_val);
for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
}
oshape.resize(num_outputs, default_val);
for (uint32_t i = 0; i < oshape.size(); ++i) {
oshape[i] = rshape[idx.entry_id(nid, i)];
}
// Call inference function of the operator.
bool forward_known = finfer_shape[inode.source->op()](
inode.source->attrs, &ishape, &oshape);
num_unknown += !forward_known;
// Save to the result map.
for (uint32_t i = 0; i < num_inputs; ++i) {
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
}
for (uint32_t i = 0; i < num_outputs; ++i) {
rshape[idx.entry_id(nid, i)] = oshape[i];
}
} else if (backward_map.count(inode.source->op())) { } else if (backward_map.count(inode.source->op())) {
// Backward operator inference. // Backward operator inference.
CHECK_GE(inode.control_deps.size(), 1) CHECK_GE(inode.control_deps.size(), 1)
...@@ -111,6 +91,47 @@ Graph InferAttr(Graph &&ret, ...@@ -111,6 +91,47 @@ Graph InferAttr(Graph &&ret,
if (fis_none(rshape[idx.entry_id(nid, i)])) known = false; if (fis_none(rshape[idx.entry_id(nid, i)])) known = false;
} }
num_unknown += !known; num_unknown += !known;
} else {
bool forward_known = true;
// Forward operator inference.
ishape.resize(num_inputs, empty_val);
for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
if (fis_none(ishape[i])) forward_known = false;
}
oshape.resize(num_outputs, empty_val);
for (uint32_t i = 0; i < oshape.size(); ++i) {
oshape[i] = rshape[idx.entry_id(nid, i)];
if (fis_none(oshape[i])) forward_known = false;
}
if (!forward_known) {
auto finfer = finfer_shape.get(inode.source->op(), fdefault);
CHECK(finfer != nullptr)
<< "Attribute " << infer_name
<< " is not registed by op " << inode.source->op()->name;
// Call inference function of the operator.
forward_known = finfer(inode.source->attrs, &ishape, &oshape);
}
num_unknown += !forward_known;
// Save to the result map.
for (uint32_t i = 0; i < num_inputs; ++i) {
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
}
for (uint32_t i = 0; i < num_outputs; ++i) {
rshape[idx.entry_id(nid, i)] = oshape[i];
}
}
};
num_unknown = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
infer_step(nid);
}
if (num_unknown != 0) {
num_unknown = 0;
// backward inference
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
infer_step(i - 1);
} }
} }
// set the shapes // set the shapes
...@@ -127,19 +148,48 @@ NNVM_REGISTER_PASS(InferShape) ...@@ -127,19 +148,48 @@ NNVM_REGISTER_PASS(InferShape)
std::move(ret), TShape(), std::move(ret), TShape(),
"FInferShape", "shape_inputs", "shape_attr_key", "FInferShape", "shape_inputs", "shape_attr_key",
"shape", "shape_num_unknown_nodes", "shape", "shape_num_unknown_nodes",
[](const TShape& s) { return s.ndim() == 0; }); [](const TShape& s) { return s.ndim() == 0; },
nullptr);
}) })
.set_change_graph(false) .set_change_graph(false)
.provide_graph_attr("shape"); .provide_graph_attr("shape");
// inference fucntion for same type
inline bool SameType(const NodeAttrs& attrs,
std::vector<int> *iattr,
std::vector<int> *oattr) {
int def_v = -1;
for (int v : *oattr) {
if (v != -1) {
def_v = v; break;
}
}
if (def_v == -1) {
for (int v : *iattr) {
if (v != -1) {
def_v = v; break;
}
}
}
if (def_v == -1) return false;
for (int& v : *oattr) {
v = def_v;
}
for (int& v : *iattr) {
v = def_v;
}
return true;
}
NNVM_REGISTER_PASS(InferType) NNVM_REGISTER_PASS(InferType)
.describe("Infer the dtype of each node entries.") .describe("Infer the dtype of each node entries.")
.set_body([](Graph ret) { .set_body([](Graph ret) {
return InferAttr<int>( return InferAttr<int>(
std::move(ret), 0, std::move(ret), -1,
"FInferType", "dtype_inputs", "dtype_attr_key", "FInferType", "dtype_inputs", "dtype_attr_key",
"dtype", "dtype_num_unknown_nodes", "dtype", "dtype_num_unknown_nodes",
[](const int t) { return t == -1; }); [](const int t) { return t == -1; },
SameType);
}) })
.set_change_graph(false) .set_change_graph(false)
.provide_graph_attr("dtype"); .provide_graph_attr("dtype");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment