test_pass_legalize.py 6.37 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test legalize pass"""
18
import numpy as np
19
import tvm
20
from tvm import te
21 22

from tvm import relay
23
from tvm.contrib import graph_runtime
24
from tvm.relay import transform, analysis
25
from tvm.relay.testing.temp_op_attr import TempOpAttr
26 27 28 29


def run_opt_pass(expr, passes):
    passes = passes if isinstance(passes, list) else [passes]
30
    mod = tvm.IRModule.from_expr(expr)
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
    seq = transform.Sequential(passes)
    with transform.PassContext(opt_level=3):
        mod = seq(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body

def test_legalize():
    """Test directly replacing an operator with a new one"""
    def before():
        x = relay.var("x", shape=(1, 64, 56, 56))
        weight = relay.var('weight', shape=(64, 64, 3, 3))
        y = relay.nn.conv2d(x, weight,
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y = relay.nn.relu(y)
        y = relay.Function([x, weight], y)
        return y

50
    def legalize_conv2d(attrs, inputs, types):
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
        data, weight = inputs
        weight = relay.multiply(weight, relay.const(2.0, "float32"))
        return relay.nn.conv2d(data, weight, **attrs)

    def expected():
        x = relay.var("x", shape=(1, 64, 56, 56))
        weight = relay.var('weight', shape=(64, 64, 3, 3))
        y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y = relay.nn.relu(y)
        y = relay.Function([x, weight], y)
        return y

66 67 68 69
    with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d):
        a = before()
        a = run_opt_pass(a, transform.Legalize())
        b = run_opt_pass(expected(), transform.InferType())
70

71
    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
72 73 74 75 76 77 78 79 80 81 82

def test_legalize_none():
    """Test doing nothing by returning 'None' """
    def before():
        x = relay.var("x", shape=(1, 64, 56, 56))
        y = relay.nn.global_max_pool2d(x)
        y = relay.Function([x], y)
        return y

    called = [False]

83
    def legalize_conv2d(attrs, inputs, types):
84 85 86
        called[0] = True
        return None

87 88 89 90
    with TempOpAttr("nn.global_max_pool2d", "FTVMLegalize", legalize_conv2d):
        a = before()
        a = run_opt_pass(a, transform.Legalize())
        b = run_opt_pass(before(), transform.InferType())
91

92
    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
93 94
    assert(called[0])

95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
def test_legalize_multiple_ops():
    """Test directly replacing an operator with a new one"""
    def before():
        x = relay.var("x", shape=(1, 64, 56, 56))
        weight = relay.var('weight', shape=(64, 64, 3, 3))
        y = relay.nn.conv2d(x, weight,
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y = relay.nn.relu(y)
        y = relay.Function([x, weight], y)
        return y

    def legalize_conv2d(attrs, inputs, types):
        data, weight = inputs
        weight = relay.multiply(weight, relay.const(2.0, "float32"))
        return relay.nn.conv2d(data, weight, **attrs)

113
    def legalize_relu(attrs, inputs, types):
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
        data = inputs[0]
        add = relay.add(tvm.relay.const(0, "float32"), data)
        return relay.nn.relu(add)


    def expected():
        x = relay.var("x", shape=(1, 64, 56, 56))
        weight = relay.var('weight', shape=(64, 64, 3, 3))
        y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y = relay.add(tvm.relay.const(0, "float32"), y)
        y = relay.nn.relu(y)
        y = relay.Function([x, weight], y)
        return y

131 132 133 134 135
    with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d):
        with TempOpAttr("nn.relu", "FTVMLegalize", legalize_relu):
            a = before()
            a = run_opt_pass(a, transform.Legalize())
            b = run_opt_pass(expected(), transform.InferType())
136

137
    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
138 139


140 141 142 143 144 145 146 147 148 149
def test_legalize_multi_input():
    """Test directly replacing an operator with a new one"""
    def before():
        x = relay.var("x", shape=(1, 64, 56, 56))
        y = relay.var("y", shape=(1, 64, 56, 20))
        z = relay.var("z", shape=(1, 64, 56, 10))
        func = relay.concatenate([x, y, z], axis=3)
        func = relay.Function([x, y, z], func)
        return func

150
    def legalize_concatenate(attrs, inputs, types):
151 152 153
        # Check that the correct multi-input case is handled.
        assert len(inputs) == 1
        assert isinstance(inputs[0], tvm.relay.expr.Tuple)
154 155 156
        assert len(types) == 2
        assert isinstance(types[0], tvm.relay.ty.TupleType)
        assert isinstance(types[1], tvm.relay.ty.TensorType)
157 158 159 160 161 162 163 164 165 166
        return None

    def expected():
        x = relay.var("x", shape=(1, 64, 56, 56))
        y = relay.var("y", shape=(1, 64, 56, 20))
        z = relay.var("z", shape=(1, 64, 56, 10))
        func = relay.concatenate([x, y, z], axis=3)
        func = relay.Function([x, y, z], func)
        return func

167 168 169 170 171

    with TempOpAttr("concatenate", "FTVMLegalize", legalize_concatenate):
        a = before()
        a = run_opt_pass(a, transform.Legalize())
        b = run_opt_pass(expected(), transform.InferType())
172

173
    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
174 175 176 177 178


if __name__ == "__main__":
    test_legalize()
    test_legalize_none()
179
    test_legalize_multiple_ops()
180
    test_legalize_multi_input()