test_fold_axis.py 5.85 KB
Newer Older
1
"""Unittest cases for fold_axis"""
2
import tvm
3
import nnvm
4 5
import nnvm.testing.resnet
import numpy as np
6 7 8 9
from nnvm import symbol as sym
from nnvm.compiler import graph_util, graph_attr

def test_fold_axis_conv():
10
    # Before simplify
11 12
    def before(x, conv_weight, conv_bias, in_scale, out_scale, channels):
        x = x * sym.expand_dims(in_scale, axis=1, num_newaxis=2)
13 14 15 16 17 18
        y = sym.conv2d(x, conv_weight, conv_bias,
                       channels=channels,
                       kernel_size=(3, 3),
                       padding=(1, 1),
                       name="conv")
        y = sym.relu(y)
19
        y = y * sym.expand_dims(out_scale, axis=1, num_newaxis=2)
20 21
        return y

22 23 24 25
    def expected(x, conv_weight, conv_bias, in_scale, out_scale, channels):
        conv_weight = conv_weight * sym.expand_dims(out_scale, axis=1, num_newaxis=3)
        conv_weight = conv_weight * sym.expand_dims(in_scale, axis=1, num_newaxis=2)
        conv_bias = conv_bias * out_scale
26 27 28 29 30 31 32 33 34 35 36 37 38 39
        y = sym.conv2d(x,
                       conv_weight,
                       conv_bias,
                       channels=channels,
                       kernel_size=(3, 3),
                       padding=(1, 1),
                       name="conv")
        y = sym.relu(y)
        return y

    def check(shape, channels):
        x = sym.Variable("x") + 1
        weight = sym.Variable("weight")
        bias = sym.Variable("bias")
40 41 42 43 44
        in_scale = sym.Variable("in_scale")
        out_scale = sym.Variable("out_scale")
        y1 = before(x, weight, bias, in_scale, out_scale, channels)
        y2 = expected(x, weight, bias, in_scale, out_scale, channels)
        ishape = {"x": shape, "out_scale": (channels,), "in_scale": (shape[1],)}
45 46 47 48 49 50 51 52 53
        g1 = nnvm.graph.create(y1)
        g2 = nnvm.graph.create(y2)
        graph_attr.set_shape_inputs(g1, ishape)
        g1 = g1.apply("InferShape").apply("FoldScaleAxis")
        # assert graph equals as expected
        graph_util.check_graph_equal(g1, g2)

    check((2, 4, 10, 10), 2)

54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
def test_fold_axis_depthwise_conv():
    # Before simplify
    def before(x, conv_weight, conv_bias, in_scale, out_scale, channels):
        x = x * sym.expand_dims(in_scale, axis=1, num_newaxis=2)
        y = sym.conv2d(x, conv_weight, conv_bias,
                       channels=channels,
                       kernel_size=(3, 3),
                       padding=(1, 1),
                       groups=54,
                       name="depthiwise_conv")
        y = sym.relu(y)
        y = y * sym.expand_dims(out_scale, axis=1, num_newaxis=2)
        return y

    def expected(x, conv_weight, conv_bias, in_scale, out_scale, channels):
        conv_weight = conv_weight * sym.expand_dims(out_scale, axis=1, num_newaxis=3)
        conv_weight = conv_weight * sym.expand_dims(in_scale, axis=1, num_newaxis=3)
        conv_bias = conv_bias * out_scale
        y = sym.conv2d(x,
                       conv_weight,
                       conv_bias,
                       channels=channels,
                       kernel_size=(3, 3),
                       padding=(1, 1),
                       groups=54,
                       name="depthiwise_conv")
        y = sym.relu(y)
        return y

    def check(shape, channels):
        x = sym.Variable("x") + 1
        weight = sym.Variable("weight")
        bias = sym.Variable("bias")
        in_scale = sym.Variable("in_scale")
        out_scale = sym.Variable("out_scale")
        y1 = before(x, weight, bias, in_scale, out_scale, channels)
        y2 = expected(x, weight, bias, in_scale, out_scale, channels)
        ishape = {"x": shape, "out_scale": (channels,), "in_scale": (shape[1],)}
        g1 = nnvm.graph.create(y1)
        g2 = nnvm.graph.create(y2)
        graph_attr.set_shape_inputs(g1, ishape)
        g1 = g1.apply("InferShape").apply("FoldScaleAxis")
        # assert graph equals as expected
        graph_util.check_graph_equal(g1, g2)

    check((1, 54, 63, 127), 54)
100 101

def test_fold_fail():
102
    # Before simplify
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
    def before(x, scale, channels):
        y = sym.conv2d(x,
                       channels=channels,
                       kernel_size=(3, 3),
                       padding=(1, 1),
                       name="conv")
        y = y * sym.expand_dims(scale, axis=1, num_newaxis=1)
        return y

    def check(shape, channels):
        x = sym.Variable("x")
        bias = sym.Variable("bias")
        scale = sym.Variable("scale")
        y1 = before(x, scale, channels)
        ishape = {"x": shape, "scale": (channels,), "bias": (channels,)}
        g1 = nnvm.graph.create(y1)
        graph_attr.set_shape_inputs(g1, ishape)
        g2 = g1.apply("InferShape").apply("FoldScaleAxis")
        # assert graph equals as expected
        graph_util.check_graph_equal(g1, g2)

    check((2, 10, 10, 10), 10)


def test_fold_resnet():
    batch_size = 1
    num_classes = 1000
    image_shape = (3, 224, 224)
    data_shape = (batch_size,) +image_shape
    net, params = nnvm.testing.resnet.get_workload(
        batch_size=1, image_shape=image_shape)
    ishape = {"data" : data_shape}
    graph = nnvm.graph.create(net)
    data = np.random.uniform(size=data_shape).astype("float32")
    # Initial pass do shape type inference
    shape, _ = graph_util.infer_shape(graph, **ishape)
    ishape.update(zip(graph.index.input_names, shape))

    def run_prune(graph, params, opt_level):
        # Apply optimization
        with nnvm.compiler.build_config(opt_level=0):
            graph = nnvm.compiler.optimize(graph, ishape)
        graph, params = nnvm.compiler.build_module.precompute_prune(graph, params)
        params["data"] = data
        return nnvm.compiler.build_module._run_graph(graph, params)

    x = run_prune(graph, params, 0)
    y = run_prune(graph, params, 3)
151
    tvm.testing.assert_allclose(y[0].asnumpy(), x[0].asnumpy())
152 153


154
if __name__ == "__main__":
155
    test_fold_resnet()
156
    test_fold_axis_conv()
157
    test_fold_fail()
158
    test_fold_axis_depthwise_conv()