Commit 343eb82c by Yizhi Liu Committed by Tianqi Chen

fix restore layout in AlterOpLayout (#460)

* fix restore layout in AlterOpLayout

* lint test case
parent 1e4bb2f8
...@@ -119,8 +119,9 @@ Graph AlterOpLayout(const Graph& src) { ...@@ -119,8 +119,9 @@ Graph AlterOpLayout(const Graph& src) {
if (new_nodes.count(inode.source)) { if (new_nodes.count(inode.source)) {
const std::vector<Layout>& in_layouts = const std::vector<Layout>& in_layouts =
in_layouts_of_node[new_nodes[inode.source]]; in_layouts_of_node[new_nodes[inode.source]];
for (const auto& e : inode.inputs) { for (uint32_t i = 0; i < inode.inputs.size(); ++i) {
ret_layouts[ret_idx.entry_id(e)] = in_layouts[e.index]; const auto& e = inode.inputs[i];
ret_layouts[ret_idx.entry_id(e)] = in_layouts[i];
} }
const std::vector<Layout>& out_layouts = const std::vector<Layout>& out_layouts =
out_layouts_of_node[new_nodes[inode.source]]; out_layouts_of_node[new_nodes[inode.source]];
......
...@@ -19,10 +19,14 @@ def test_alter_conv2d_layout(): ...@@ -19,10 +19,14 @@ def test_alter_conv2d_layout():
conv = sym.conv2d(data, name="conv", channels=16, conv = sym.conv2d(data, name="conv", channels=16,
kernel_size=(3,3), padding=(1,1), kernel_size=(3,3), padding=(1,1),
use_bias=False, layout="NCHW") use_bias=False, layout="NCHW")
relu = sym.relu(conv, name="relu") # split here
convs = sym.split(conv, indices_or_sections=2)
relus = [sym.relu(x, name="relu") for x in convs]
relu = sym.concatenate(*relus)
flatten = sym.flatten(relu, name="flatten") flatten = sym.flatten(relu, name="flatten")
softmax = sym.softmax(flatten, name="softmax") softmax = sym.softmax(flatten, name="softmax")
g = graph.create(softmax) g = graph.create(softmax)
g = g.apply("CorrectLayout") g = g.apply("CorrectLayout")
g = graph_attr.set_dtype_inputs(g, "float32") g = graph_attr.set_dtype_inputs(g, "float32")
g = g.apply(["InferShape", "InferType"]) g = g.apply(["InferShape", "InferType"])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment