Commit a3968975 by Yizhi Liu Committed by Tianqi Chen

[Bugfix] Recover original layout when alter_layout function return None (#2101)

parent 629a293a
......@@ -46,7 +46,7 @@ Graph AlterOpLayout(const Graph& src) {
std::vector<std::vector<Layout> > in_layouts_of_node(idx_graph.num_nodes());
std::vector<std::vector<Layout> > out_layouts_of_node(idx_graph.num_nodes());
std::unordered_map<const Node*, uint32_t> new_nodes;
std::unordered_map<const Node*, uint32_t> unchanged_nodes;
if (src.HasAttr("layout")) {
// record layouts so that LayoutTransform pass can fix layouts correctly,
......@@ -56,10 +56,8 @@ Graph AlterOpLayout(const Graph& src) {
const auto& layouts = src.GetAttr<std::vector<Layout> >("layout");
for (uint32_t nid = 0; nid < idx_graph.num_nodes(); ++nid) {
const auto &inode = idx_graph[nid];
if (falter_op_layout.count(inode.source->op())) {
// do not record input layouts of nodes that will be replaced.
continue;
}
// record input layouts for all nodes,
// while replaced nodes will ignore the records here and have undefined input layouts.
std::vector<Layout> in_layout;
for (const auto& e : inode.inputs) {
in_layout.emplace_back(layouts[idx_graph.entry_id(e)]);
......@@ -80,7 +78,8 @@ Graph AlterOpLayout(const Graph& src) {
nnvm::compiler::FTVMAlterOpLayout fn_alter_op_layout =
falter_op_layout.get(n->op(), nullptr);
if (fn_alter_op_layout == nullptr) {
new_nodes[n.get()] = nid;
// will restore the original input layouts later.
unchanged_nodes[n.get()] = nid;
return false;
}
......@@ -106,7 +105,13 @@ Graph AlterOpLayout(const Graph& src) {
Symbol op;
bool do_alter =
fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos, &op);
if (do_alter) *ret = op.outputs;
if (do_alter) {
*ret = op.outputs;
} else {
// will restore the original input layouts later.
unchanged_nodes[n.get()] = nid;
}
return do_alter;
};
......@@ -118,15 +123,15 @@ Graph AlterOpLayout(const Graph& src) {
std::vector<Layout> ret_layouts(ret_idx.num_node_entries(), Layout::Undef());
for (uint32_t nid = 0; nid < ret_idx.num_nodes(); ++nid) {
const auto& inode = ret_idx[nid];
if (new_nodes.count(inode.source)) {
if (unchanged_nodes.count(inode.source)) {
const std::vector<Layout>& in_layouts =
in_layouts_of_node[new_nodes[inode.source]];
in_layouts_of_node[unchanged_nodes[inode.source]];
for (uint32_t i = 0; i < inode.inputs.size(); ++i) {
const auto& e = inode.inputs[i];
ret_layouts[ret_idx.entry_id(e)] = in_layouts[i];
}
const std::vector<Layout>& out_layouts =
out_layouts_of_node[new_nodes[inode.source]];
out_layouts_of_node[unchanged_nodes[inode.source]];
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
ret_layouts[ret_idx.entry_id(nid, i)] = out_layouts[i];
}
......
......@@ -45,9 +45,61 @@ def test_alter_conv2d_layout():
# check copy layouts
for node in ["data", "relu", "flatten", "softmax", "conv_weight"]:
assert(layouts[node] == layouts_origin[node])
assert(layouts["conv_alter"] == layouts_origin["conv"])
assert layouts[node] == layouts_origin[node]
assert layouts["conv_alter"] == layouts_origin["conv"]
def test_consecutive_alter_layout():
data = sym.Variable("data", shape=(1, 32, 512, 512))
pool1 = sym.global_avg_pool2d(data, name="global_avg_pool2d_1", layout="NCHW")
pool2 = sym.global_avg_pool2d(pool1, name="global_avg_pool2d_2", layout="NCHW")
relu = sym.relu(pool2, name="relu")
g = graph.create(relu)
g = g.apply("CorrectLayout")
g = graph_attr.set_dtype_inputs(g, "float32")
g = g.apply(["InferShape", "InferType"])
assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW']
@reg.register_alter_op_layout("global_avg_pool2d", level=100)
def alter_global_avg_pool2d_layout(attrs, inputs, tinfos):
new_attrs = {k : attrs[k] for k in attrs.keys()}
new_attrs["layout"] = "NCHW16c"
return sym.global_avg_pool2d(inputs[0], **new_attrs)
g = g.apply("AlterOpLayout")
# pool1 get replaced - output layout of pool1 is not recorded
# pool2 get replaced - input layout of pool2 is not recorded
# thus the second entry must be undefined - it can neither recover from pool1's output,
# nor from pool2's input.
assert g.json_attr("layout") == ['NCHW', '__undef__', 'NCHW', 'NCHW']
def test_alter_func_return_none():
data = sym.Variable("data", shape=(1, 32, 512, 512))
pool1 = sym.global_max_pool2d(data, name="pool1", layout="NCHW")
pool2 = sym.global_max_pool2d(pool1, name="pool2", layout="NCHW")
relu = sym.relu(pool2, name="relu")
g = graph.create(relu)
g = g.apply("CorrectLayout")
g = graph_attr.set_dtype_inputs(g, "float32")
g = g.apply(["InferShape", "InferType"])
assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW']
@reg.register_alter_op_layout("global_max_pool2d", level=100)
def alter_global_max_pool2d_layout(attrs, inputs, tinfos):
return None
g = g.apply("AlterOpLayout")
# alter func return none, nothing get replaced,
# the layouts should remain the same
assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW']
if __name__ == "__main__":
test_alter_conv2d_layout()
test_consecutive_alter_layout()
test_alter_func_return_none()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment