test_simplify_inference.py 1.92 KB
Newer Older
1 2 3
"""Unittest cases for simplify batch_norm"""
import nnvm
from nnvm import symbol as sym
4
from nnvm.compiler import graph_util, graph_attr
5 6 7

def test_simplify_batchnorm():
    def simple_bn(x, gamma, beta, moving_mean, moving_var,
8
                  axis=1, epsilon=1e-5, shape=None):
9 10 11 12
        # expect = (x - moving_mean) / sym.sqrt(moving_var + eps) * gamma + beta
        scale = sym.elemwise_mul(1 / sym.sqrt(moving_var + epsilon), gamma)
        shift = sym.elemwise_add(
            sym.elemwise_mul(sym.negative(moving_mean), scale), beta)
13
        shape = [-1 if i == axis else 1 for i in range(len(shape))]
14
        # for 2D
15 16
        num_newaxis=len(shape) - axis - 1
        if num_newaxis:
17 18
            scale = sym.expand_dims(scale, axis=1, num_newaxis=num_newaxis)
            shift = sym.expand_dims(shift, axis=1, num_newaxis=num_newaxis)
19 20 21 22 23 24 25 26 27 28 29
        return x * scale + shift


    # Before simplify
    def check(dim, axis, nstep):
        eps = 0.01
        x = sym.Variable("x") + 1
        beta = sym.Variable("beta")
        gamma = sym.Variable("gamma")
        moving_var = sym.Variable("moving_var")
        moving_mean = sym.Variable("moving_mean")
30
        y1, y2 = x, sym.Variable("xx") + 1
31
        ishape = {"x": tuple(10 for i in range(dim))}
32 33 34
        for i in range(nstep):
            y1 = sym.batch_norm(
                y1 + 1, gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
35
            y1 = sym.dropout(y1)
36
            y2 = simple_bn(y2 + 1, gamma, beta, moving_mean, moving_var,
37
                           epsilon=eps, axis=axis, shape=ishape["x"])
38 39 40
        g = nnvm.graph.create(y1)
        g2 = nnvm.graph.create(y2)
        graph_attr.set_shape_inputs(g, ishape)
41
        g1 = g.apply("InferShape").apply("SimplifyInference")
42
        # assert graph equals as expected
43
        graph_util.check_graph_equal(g1, g2)
44 45 46

    check(2, 1, 1)
    check(4, 0, 3)
47
    check(4, 1, 2)
48 49 50

if __name__ == "__main__":
    test_simplify_batchnorm()