"""Unittest cases for AlterOpLayout pass""" from nnvm import symbol as sym from nnvm.compiler import graph_attr from nnvm.top import registry as reg import nnvm.graph as graph def get_layouts(g): ldict = {} vlayout = g.json_attr("layout") entry_ptr = g.index.entry_ptr for i, n in enumerate(g.index.nodes): begin, end = entry_ptr[i], entry_ptr[i + 1] ldict[n["name"]] = vlayout[begin:end] return ldict def test_alter_conv2d_layout(): data = sym.Variable("data", shape=(1, 32, 512, 512)) conv = sym.conv2d(data, name="conv", channels=16, kernel_size=(3,3), padding=(1,1), use_bias=False, layout="NCHW") # split here convs = sym.split(conv, indices_or_sections=2) relus = [sym.relu(x, name="relu") for x in convs] relu = sym.concatenate(*relus) flatten = sym.flatten(relu, name="flatten") softmax = sym.softmax(flatten, name="softmax") g = graph.create(softmax) g = g.apply("CorrectLayout") g = graph_attr.set_dtype_inputs(g, "float32") g = g.apply(["InferShape", "InferType"]) layouts_origin = get_layouts(g) @reg.register_alter_op_layout("conv2d", level=100) def alter_conv2d_layout(attrs, inputs, tinfos): new_attrs = {k : attrs[k] for k in attrs.keys()} new_attrs["layout"] = "NCHW16c" new_attrs["kernel_layout"] = "NCHW16c" new_attrs["name"] = "conv_alter" return sym.conv2d(inputs[0], inputs[1], **new_attrs) g = g.apply("AlterOpLayout") layouts = get_layouts(g) # check copy layouts for node in ["data", "relu", "flatten", "softmax", "conv_weight"]: assert(layouts[node] == layouts_origin[node]) assert(layouts["conv_alter"] == layouts_origin["conv"]) if __name__ == "__main__": test_alter_conv2d_layout()