test_pass_fold_constant.py 7.04 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import numpy as np
18
import tvm
19
from tvm import te
20
from tvm import relay
Zhi committed
21
from tvm.relay import transform
22 23
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.testing import run_infer_type, create_workload
Zhi committed
24 25 26


def run_opt_pass(expr, opt_pass):
27
    assert isinstance(opt_pass, tvm.transform.Pass)
Zhi committed
28

29
    mod = tvm.IRModule.from_expr(expr)
Zhi committed
30
    mod = opt_pass(mod)
31
    entry = mod["main"]
Zhi committed
32
    return entry if isinstance(expr, relay.Function) else entry.body
33 34 35 36


def test_fold_const():
    c_data = np.array([1, 2, 3]).astype("float32")
Zhi committed
37
    t = relay.TensorType([1, 2, 3], "float32")
38 39
    def before():
        c = relay.const(c_data)
Zhi committed
40
        x = relay.var("x", t)
41 42 43 44 45 46 47
        y = relay.add(c, c)
        y = relay.multiply(y, relay.const(2, "float32"))
        y = relay.add(x, y)
        z = relay.add(y, c)
        return relay.Function([x], z)

    def expected():
Zhi committed
48
        x = relay.var("x", t)
49 50 51 52
        c_folded = (c_data + c_data) * 2
        y = relay.add(x, relay.const(c_folded))
        z = relay.add(y, relay.const(c_data))
        return relay.Function([x], z)
53 54 55

    def fail(x):
        raise RuntimeError()
Zhi committed
56

57
    # the fold constant should work on any context.
58
    with tvm.target.build_config(add_lower_pass=[(0, fail)]):
59
        with tvm.target.create("cuda"):
Zhi committed
60 61
            zz = run_opt_pass(before(), transform.FoldConstant())
    zexpected = run_opt_pass(expected(), transform.InferType())
62
    assert tvm.ir.structural_equal(zz, zexpected)
63 64 65 66


def test_fold_let():
    c_data = np.array(1).astype("float32")
Zhi committed
67
    t = relay.TensorType([1], "float32")
68 69
    def before():
        sb = relay.ScopeBuilder()
Zhi committed
70
        x = relay.var("x", t)
71 72 73 74 75 76 77 78
        t1 = sb.let("t1", relay.const(c_data))
        t2 = sb.let("t2", relay.add(t1, t1))
        t3 = sb.let("t3", relay.add(t2, x))
        sb.ret(t3)
        return relay.Function([x], sb.get())

    def expected():
        sb = relay.ScopeBuilder()
Zhi committed
79
        x = relay.var("x", t)
80 81 82 83 84
        c_folded = (c_data + c_data)
        t3 = sb.let("t3", relay.add(relay.const(c_folded), x))
        sb.ret(t3)
        return relay.Function([x], sb.get())

Zhi committed
85 86
    zz = run_opt_pass(before(), transform.FoldConstant())
    zexpected = run_opt_pass(expected(), transform.InferType())
87
    assert tvm.ir.structural_equal(zz, zexpected)
88 89 90 91


def test_fold_tuple():
    c_data = np.array(1).astype("float32")
Zhi committed
92
    t = relay.TensorType([1], "float32")
93 94
    def before():
        c = relay.const(c_data)
Zhi committed
95
        x = relay.var("x", t)
96 97 98 99 100 101 102
        y = relay.Tuple([x, c])
        z = relay.add(y[1], c)
        z = relay.add(z, y[0])
        return relay.Function([x], z)

    def expected():
        c = relay.const(c_data + c_data)
Zhi committed
103
        x = relay.var("x", t)
104 105 106
        z = relay.add(c, x)
        return relay.Function([x], z)

Zhi committed
107 108
    zz = run_opt_pass(before(), transform.FoldConstant())
    zexpected = run_opt_pass(expected(), transform.InferType())
109
    assert tvm.ir.structural_equal(zz, zexpected)
110 111


112 113 114 115 116 117 118 119 120 121 122 123 124 125
def test_fold_concat():
    c_data = np.array([[1, 2, 3]]).astype("float32")

    def before():
        a = relay.const(c_data)
        b = relay.const(c_data)
        y = relay.concatenate((a, b), axis=0)
        return relay.Function([], y)

    def expected():
        y_data = np.concatenate((c_data, c_data), axis=0)
        y = relay.const(y_data)
        return relay.Function([], y)

Zhi committed
126 127
    zz = run_opt_pass(before(), transform.FoldConstant())
    zexpected = run_opt_pass(expected(), transform.InferType())
128
    assert tvm.ir.structural_equal(zz, zexpected)
129 130


131 132 133 134 135 136 137 138 139 140 141 142
def test_fold_shape_of():
    c_shape = (8, 9, 10)
    def before(dtype):
        x = relay.var("x", shape=c_shape, dtype="float32")
        y = relay.var("y", shape=c_shape, dtype="float32")
        z = relay.shape_of(x + y, dtype)
        return relay.Function([x, y], z)

    def expected(dtype):
        x = relay.var("x", shape=c_shape, dtype="float32")
        y = relay.var("y", shape=c_shape, dtype="float32")
        z = relay.const(np.array(c_shape).astype(dtype), dtype=dtype)
Zhi committed
143 144
        func = relay.Function([x, y], z)
        return func
145 146

    for dtype in ["int32", "float32"]:
Zhi committed
147 148
        zz = run_opt_pass(before(dtype), transform.FoldConstant())
        zexpected = run_opt_pass(expected(dtype), transform.InferType())
149
        assert tvm.ir.structural_equal(zz, zexpected)
150 151


152 153 154 155 156 157 158 159 160 161 162 163
def test_fold_full():
    c_shape = (8, 9, 10)
    def before():
        dtype = 'float32'
        return relay.full(relay.const(1.0, dtype), c_shape, dtype=dtype)

    def expected():
        # expect no changes
        return before()

    zz = run_opt_pass(before(), transform.FoldConstant())
    zexpected = run_opt_pass(expected(), transform.InferType())
164
    assert tvm.ir.structural_equal(zz, zexpected)
165 166


167 168 169 170 171 172 173 174 175 176
def test_fold_batch_norm():
    def expected():
        data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
        weight = relay.const(np.zeros((16, 3, 3, 3)))
        bias = relay.const(np.zeros((16, 1, 1)))
        conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
                               channels=16, padding=(1, 1))
        add = relay.add(conv, bias)
        return relay.Function(relay.analysis.free_vars(add), add)

177
    remove_bn_pass = tvm.transform.Sequential([
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
        relay.transform.InferType(),
        relay.transform.SimplifyInference(),
        relay.transform.FoldConstant(),
        relay.transform.FoldScaleAxis(),
    ])

    data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
    weight = relay.var("weight")
    bn_gamma = relay.var("bn_gamma")
    bn_beta = relay.var("bn_beta")
    bn_mmean = relay.var("bn_mean")
    bn_mvar = relay.var("bn_var")

    conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
                           channels=16, padding=(1, 1))
    bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta,
                                    bn_mmean, bn_mvar)
    def initializer(_, param):
        param = np.zeros(param.shape)

    mod, params = create_workload(bn_output[0], initializer)
    mod["main"] = bind_params_by_name(mod["main"], params)

    with relay.build_config(opt_level=3):
        mod = remove_bn_pass(mod)

    expect = run_infer_type(expected())
205
    assert tvm.ir.structural_equal(mod["main"], expect)
206 207


208 209 210 211
if __name__ == "__main__":
    test_fold_const()
    test_fold_let()
    test_fold_tuple()
212
    test_fold_concat()
213
    test_fold_shape_of()
214
    test_fold_full()
215
    test_fold_batch_norm()