test_simplify_inference.py 2.62 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19
"""Unittest cases for simplify batch_norm"""
import nnvm
from nnvm import symbol as sym
20
from nnvm.compiler import graph_util, graph_attr
21 22 23

def test_simplify_batchnorm():
    def simple_bn(x, gamma, beta, moving_mean, moving_var,
24
                  axis=1, epsilon=1e-5, shape=None):
25 26 27 28 29
        # expect = (x - moving_mean) / sym.sqrt(moving_var + eps) * gamma + beta
        scale = sym.elemwise_mul(1 / sym.sqrt(moving_var + epsilon), gamma)
        shift = sym.elemwise_add(
            sym.elemwise_mul(sym.negative(moving_mean), scale), beta)
        # for 2D
30 31
        num_newaxis=len(shape) - axis - 1
        if num_newaxis:
32 33
            scale = sym.expand_dims(scale, axis=1, num_newaxis=num_newaxis)
            shift = sym.expand_dims(shift, axis=1, num_newaxis=num_newaxis)
34 35 36 37 38 39 40 41 42 43 44
        return x * scale + shift


    # Before simplify
    def check(dim, axis, nstep):
        eps = 0.01
        x = sym.Variable("x") + 1
        beta = sym.Variable("beta")
        gamma = sym.Variable("gamma")
        moving_var = sym.Variable("moving_var")
        moving_mean = sym.Variable("moving_mean")
45
        y1, y2 = x, sym.Variable("xx") + 1
46
        ishape = {"x": tuple(10 for i in range(dim))}
47 48 49
        for i in range(nstep):
            y1 = sym.batch_norm(
                y1 + 1, gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
50
            y1 = sym.dropout(y1)
51
            y2 = simple_bn(y2 + 1, gamma, beta, moving_mean, moving_var,
52
                           epsilon=eps, axis=axis, shape=ishape["x"])
53 54 55
        g = nnvm.graph.create(y1)
        g2 = nnvm.graph.create(y2)
        graph_attr.set_shape_inputs(g, ishape)
56
        g1 = g.apply("InferShape").apply("SimplifyInference")
57
        # assert graph equals as expected
58
        graph_util.check_graph_equal(g1, g2)
59 60 61

    check(2, 1, 1)
    check(4, 0, 3)
62
    check(4, 1, 2)
63 64 65

if __name__ == "__main__":
    test_simplify_batchnorm()