test_alter_op_layout.py 3.83 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
"""Unittest cases for AlterOpLayout pass"""
from nnvm import symbol as sym
from nnvm.compiler import graph_attr
from nnvm.top import registry as reg
import nnvm.graph as graph

def get_layouts(g):
    ldict = {}
    vlayout = g.json_attr("layout")
    entry_ptr = g.index.entry_ptr
    for i, n in enumerate(g.index.nodes):
        begin, end = entry_ptr[i], entry_ptr[i + 1]
        ldict[n["name"]] = vlayout[begin:end]
    return ldict


def test_alter_conv2d_layout():
    data = sym.Variable("data", shape=(1, 32, 512, 512))
    conv = sym.conv2d(data, name="conv", channels=16,
                      kernel_size=(3,3), padding=(1,1),
                      use_bias=False, layout="NCHW")
22 23 24 25
    # split here
    convs = sym.split(conv, indices_or_sections=2)
    relus = [sym.relu(x, name="relu") for x in convs]
    relu = sym.concatenate(*relus)
26 27 28
    flatten = sym.flatten(relu, name="flatten")
    softmax = sym.softmax(flatten, name="softmax")
    g = graph.create(softmax)
29

30 31 32 33 34
    g = g.apply("CorrectLayout")
    g = graph_attr.set_dtype_inputs(g, "float32")
    g = g.apply(["InferShape", "InferType"])
    layouts_origin = get_layouts(g)

35
    @reg.register_alter_op_layout("conv2d", level=100)
36 37 38 39 40 41 42 43 44 45 46 47
    def alter_conv2d_layout(attrs, inputs, tinfos):
        new_attrs = {k : attrs[k] for k in attrs.keys()}
        new_attrs["layout"] = "NCHW16c"
        new_attrs["kernel_layout"] = "NCHW16c"
        new_attrs["name"] = "conv_alter"
        return sym.conv2d(inputs[0], inputs[1], **new_attrs)

    g = g.apply("AlterOpLayout")
    layouts = get_layouts(g)

    # check copy layouts
    for node in ["data", "relu", "flatten", "softmax", "conv_weight"]:
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
        assert layouts[node] == layouts_origin[node]
    assert layouts["conv_alter"] == layouts_origin["conv"]


def test_consecutive_alter_layout():
    data = sym.Variable("data", shape=(1, 32, 512, 512))
    pool1 = sym.global_avg_pool2d(data, name="global_avg_pool2d_1", layout="NCHW")
    pool2 = sym.global_avg_pool2d(pool1, name="global_avg_pool2d_2", layout="NCHW")
    relu = sym.relu(pool2, name="relu")

    g = graph.create(relu)
    g = g.apply("CorrectLayout")
    g = graph_attr.set_dtype_inputs(g, "float32")
    g = g.apply(["InferShape", "InferType"])
    assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW']

    @reg.register_alter_op_layout("global_avg_pool2d", level=100)
    def alter_global_avg_pool2d_layout(attrs, inputs, tinfos):
        new_attrs = {k : attrs[k] for k in attrs.keys()}
        new_attrs["layout"] = "NCHW16c"
        return sym.global_avg_pool2d(inputs[0], **new_attrs)

    g = g.apply("AlterOpLayout")

    # pool1 get replaced - output layout of pool1 is not recorded
    # pool2 get replaced - input layout of pool2 is not recorded
    # thus the second entry must be undefined - it can neither recover from pool1's output,
    # nor from pool2's input.
    assert g.json_attr("layout") == ['NCHW', '__undef__', 'NCHW', 'NCHW']


def test_alter_func_return_none():
    data = sym.Variable("data", shape=(1, 32, 512, 512))
    pool1 = sym.global_max_pool2d(data, name="pool1", layout="NCHW")
    pool2 = sym.global_max_pool2d(pool1, name="pool2", layout="NCHW")
    relu = sym.relu(pool2, name="relu")

    g = graph.create(relu)
    g = g.apply("CorrectLayout")
    g = graph_attr.set_dtype_inputs(g, "float32")
    g = g.apply(["InferShape", "InferType"])
    assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW']

    @reg.register_alter_op_layout("global_max_pool2d", level=100)
    def alter_global_max_pool2d_layout(attrs, inputs, tinfos):
        return None

    g = g.apply("AlterOpLayout")

    # alter func return none, nothing get replaced,
    # the layouts should remain the same
    assert g.json_attr("layout") == ['NCHW', 'NCHW', 'NCHW', 'NCHW']
100 101 102 103


if __name__ == "__main__":
    test_alter_conv2d_layout()
104 105
    test_consecutive_alter_layout()
    test_alter_func_return_none()