Commit 343eb82c by Yizhi Liu Committed by Tianqi Chen

fix restore layout in AlterOpLayout (#460)

* fix restore layout in AlterOpLayout

* lint test case
parent 1e4bb2f8
......@@ -119,8 +119,9 @@ Graph AlterOpLayout(const Graph& src) {
if (new_nodes.count(inode.source)) {
const std::vector<Layout>& in_layouts =
in_layouts_of_node[new_nodes[inode.source]];
for (const auto& e : inode.inputs) {
ret_layouts[ret_idx.entry_id(e)] = in_layouts[e.index];
for (uint32_t i = 0; i < inode.inputs.size(); ++i) {
const auto& e = inode.inputs[i];
ret_layouts[ret_idx.entry_id(e)] = in_layouts[i];
}
const std::vector<Layout>& out_layouts =
out_layouts_of_node[new_nodes[inode.source]];
......
......@@ -19,10 +19,14 @@ def test_alter_conv2d_layout():
conv = sym.conv2d(data, name="conv", channels=16,
kernel_size=(3,3), padding=(1,1),
use_bias=False, layout="NCHW")
relu = sym.relu(conv, name="relu")
# split here
convs = sym.split(conv, indices_or_sections=2)
relus = [sym.relu(x, name="relu") for x in convs]
relu = sym.concatenate(*relus)
flatten = sym.flatten(relu, name="flatten")
softmax = sym.softmax(flatten, name="softmax")
g = graph.create(softmax)
g = g.apply("CorrectLayout")
g = graph_attr.set_dtype_inputs(g, "float32")
g = g.apply(["InferShape", "InferType"])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment