Commit a3968975 by Yizhi Liu Committed by Tianqi Chen

[Bugfix] Recover original layout when alter_layout function return None (#2101)

parent 629a293a
...@@ -46,7 +46,7 @@ Graph AlterOpLayout(const Graph& src) { ...@@ -46,7 +46,7 @@ Graph AlterOpLayout(const Graph& src) {
std::vector<std::vector<Layout> > in_layouts_of_node(idx_graph.num_nodes()); std::vector<std::vector<Layout> > in_layouts_of_node(idx_graph.num_nodes());
std::vector<std::vector<Layout> > out_layouts_of_node(idx_graph.num_nodes()); std::vector<std::vector<Layout> > out_layouts_of_node(idx_graph.num_nodes());
std::unordered_map<const Node*, uint32_t> new_nodes; std::unordered_map<const Node*, uint32_t> unchanged_nodes;
if (src.HasAttr("layout")) { if (src.HasAttr("layout")) {
// record layouts so that LayoutTransform pass can fix layouts correctly, // record layouts so that LayoutTransform pass can fix layouts correctly,
...@@ -56,10 +56,8 @@ Graph AlterOpLayout(const Graph& src) { ...@@ -56,10 +56,8 @@ Graph AlterOpLayout(const Graph& src) {
const auto& layouts = src.GetAttr<std::vector<Layout> >("layout"); const auto& layouts = src.GetAttr<std::vector<Layout> >("layout");
for (uint32_t nid = 0; nid < idx_graph.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx_graph.num_nodes(); ++nid) {
const auto &inode = idx_graph[nid]; const auto &inode = idx_graph[nid];
if (falter_op_layout.count(inode.source->op())) { // record input layouts for all nodes,
// do not record input layouts of nodes that will be replaced. // while replaced nodes will ignore the records here and have undefined input layouts.
continue;
}
std::vector<Layout> in_layout; std::vector<Layout> in_layout;
for (const auto& e : inode.inputs) { for (const auto& e : inode.inputs) {
in_layout.emplace_back(layouts[idx_graph.entry_id(e)]); in_layout.emplace_back(layouts[idx_graph.entry_id(e)]);
...@@ -80,7 +78,8 @@ Graph AlterOpLayout(const Graph& src) { ...@@ -80,7 +78,8 @@ Graph AlterOpLayout(const Graph& src) {
nnvm::compiler::FTVMAlterOpLayout fn_alter_op_layout = nnvm::compiler::FTVMAlterOpLayout fn_alter_op_layout =
falter_op_layout.get(n->op(), nullptr); falter_op_layout.get(n->op(), nullptr);
if (fn_alter_op_layout == nullptr) { if (fn_alter_op_layout == nullptr) {
new_nodes[n.get()] = nid; // will restore the original input layouts later.
unchanged_nodes[n.get()] = nid;
return false; return false;
} }
...@@ -106,7 +105,13 @@ Graph AlterOpLayout(const Graph& src) { ...@@ -106,7 +105,13 @@ Graph AlterOpLayout(const Graph& src) {
Symbol op; Symbol op;
bool do_alter = bool do_alter =
fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos, &op); fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos, &op);
if (do_alter) *ret = op.outputs;
if (do_alter) {
*ret = op.outputs;
} else {
// will restore the original input layouts later.
unchanged_nodes[n.get()] = nid;
}
return do_alter; return do_alter;
}; };
...@@ -118,15 +123,15 @@ Graph AlterOpLayout(const Graph& src) { ...@@ -118,15 +123,15 @@ Graph AlterOpLayout(const Graph& src) {
std::vector<Layout> ret_layouts(ret_idx.num_node_entries(), Layout::Undef()); std::vector<Layout> ret_layouts(ret_idx.num_node_entries(), Layout::Undef());
for (uint32_t nid = 0; nid < ret_idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < ret_idx.num_nodes(); ++nid) {
const auto& inode = ret_idx[nid]; const auto& inode = ret_idx[nid];
if (new_nodes.count(inode.source)) { if (unchanged_nodes.count(inode.source)) {
const std::vector<Layout>& in_layouts = const std::vector<Layout>& in_layouts =
in_layouts_of_node[new_nodes[inode.source]]; in_layouts_of_node[unchanged_nodes[inode.source]];
for (uint32_t i = 0; i < inode.inputs.size(); ++i) { for (uint32_t i = 0; i < inode.inputs.size(); ++i) {
const auto& e = inode.inputs[i]; const auto& e = inode.inputs[i];
ret_layouts[ret_idx.entry_id(e)] = in_layouts[i]; ret_layouts[ret_idx.entry_id(e)] = in_layouts[i];
} }
const std::vector<Layout>& out_layouts = const std::vector<Layout>& out_layouts =
out_layouts_of_node[new_nodes[inode.source]]; out_layouts_of_node[unchanged_nodes[inode.source]];
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) { for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
ret_layouts[ret_idx.entry_id(nid, i)] = out_layouts[i]; ret_layouts[ret_idx.entry_id(nid, i)] = out_layouts[i];
} }
......
...@@ -45,9 +45,61 @@ def test_alter_conv2d_layout(): ...@@ -45,9 +45,61 @@ def test_alter_conv2d_layout():
# check copy layouts # check copy layouts
for node in ["data", "relu", "flatten", "softmax", "conv_weight"]: for node in ["data", "relu", "flatten", "softmax", "conv_weight"]:
assert(layouts[node] == layouts_origin[node]) assert layouts[node] == layouts_origin[node]
assert(layouts["conv_alter"] == layouts_origin["conv"]) assert layouts["conv_alter"] == layouts_origin["conv"]
def test_consecutive_alter_layout():
data = sym.Variable("data", shape=(1, 32, 512, 512))
pool1 = sym.global_avg_pool2d(data, name="global_avg_pool2d_1", layout="NCHW")
pool2 = sym.global_avg_pool2d(pool1, name="global_avg_pool2d_2", layout="NCHW")
relu = sym.relu(pool2, name="relu")
g = graph.create(relu)
g = g.apply("CorrectLayout")
g = graph_attr.set_dtype_inputs(g, "float32")
g = g.apply(["InferShape", "InferType"])
assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW']
@reg.register_alter_op_layout("global_avg_pool2d", level=100)
def alter_global_avg_pool2d_layout(attrs, inputs, tinfos):
new_attrs = {k : attrs[k] for k in attrs.keys()}
new_attrs["layout"] = "NCHW16c"
return sym.global_avg_pool2d(inputs[0], **new_attrs)
g = g.apply("AlterOpLayout")
# pool1 get replaced - output layout of pool1 is not recorded
# pool2 get replaced - input layout of pool2 is not recorded
# thus the second entry must be undefined - it can neither recover from pool1's output,
# nor from pool2's input.
assert g.json_attr("layout") == ['NCHW', '__undef__', 'NCHW', 'NCHW']
def test_alter_func_return_none():
data = sym.Variable("data", shape=(1, 32, 512, 512))
pool1 = sym.global_max_pool2d(data, name="pool1", layout="NCHW")
pool2 = sym.global_max_pool2d(pool1, name="pool2", layout="NCHW")
relu = sym.relu(pool2, name="relu")
g = graph.create(relu)
g = g.apply("CorrectLayout")
g = graph_attr.set_dtype_inputs(g, "float32")
g = g.apply(["InferShape", "InferType"])
assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW']
@reg.register_alter_op_layout("global_max_pool2d", level=100)
def alter_global_max_pool2d_layout(attrs, inputs, tinfos):
return None
g = g.apply("AlterOpLayout")
# alter func return none, nothing get replaced,
# the layouts should remain the same
assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW']
if __name__ == "__main__": if __name__ == "__main__":
test_alter_conv2d_layout() test_alter_conv2d_layout()
test_consecutive_alter_layout()
test_alter_func_return_none()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment