test_tir_pass_hoist_if.py 7.35 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
18
from tvm import te
19 20 21 22 23 24 25 26 27


var_list = []

def verify_structure(stmt, expected_struct):
    node_dict = {}
    struct = {}
    def _extract_vars(op):
        global var_list
28
        if isinstance(op, tvm.tir.Var):
29 30 31 32
            var_list.append(op.name)

    def _visit(op):
        key = op
33
        if isinstance(op, tvm.tir.IfThenElse):
34
            global var_list
35
            tvm.tir.ir_pass.PostOrderVisit(op.condition, _extract_vars)
36 37
            val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))]
            var_list.clear()
38
        elif isinstance(op, tvm.tir.For):
39
            val = [(op.body,), ("For", op.loop_var.name)]
40
        elif isinstance(op, tvm.tir.AttrStmt):
41 42 43 44 45
            val = [(op.body,), ("AttrStmt", op.attr_key, int(op.value))]
        else:
            return
        node_dict[key] = val

46
    tvm.tir.ir_pass.PostOrderVisit(stmt, _visit)
47 48 49 50 51 52 53 54 55
    for key, val in node_dict.items():
        struct[val[1]] = tuple(node_dict[child][1] if child in node_dict
                               else None for child in val[0])

    assert struct == expected_struct, "Structure mismatch: expect %s but got %s" \
                                      % (expected_struct, struct)
    var_list.clear()

def test_basic():
56 57 58 59
    ib = tvm.tir.ir_builder.create()
    l = te.var('l')
    m = te.var('m')
    n = te.var('n')
60 61 62 63 64

    with ib.for_range(0, l, "i") as i:
        with ib.for_range(0, m, "j") as j:
            with ib.for_range(0, n, "k") as k:
                with ib.if_scope(ib.likely(i < 2)):
65
                    ib.emit(tvm.tir.Evaluate(m))
66
                with ib.else_scope():
67
                    ib.emit(tvm.tir.Evaluate(n))
68 69

    stmt = ib.get()
70
    new_stmt = tvm.tir.ir_pass.HoistIfThenElse(stmt)
71 72 73 74 75 76
    expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),),
                       ('IfThenElse', ('i',)): (('For', 'j'), ('For', 'j')),
                       ('For', 'i'): (('IfThenElse', ('i',)),)}
    verify_structure(new_stmt, expected_struct)

def test_no_else():
77 78 79 80
    ib = tvm.tir.ir_builder.create()
    l = te.var('l')
    m = te.var('m')
    n = te.var('n')
81 82 83 84 85

    with ib.for_range(0, l, "i") as i:
        with ib.for_range(0, m, "j") as j:
            with ib.for_range(0, n, "k") as k:
                with ib.if_scope(ib.likely(i < 2)):
86
                    ib.emit(tvm.tir.Evaluate(m))
87 88

    stmt = ib.get()
89
    new_stmt = tvm.tir.ir_pass.HoistIfThenElse(stmt)
90 91 92 93 94 95
    expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),),
                       ('IfThenElse', ('i',)): (('For', 'j'), None),
                       ('For', 'i'): (('IfThenElse', ('i',)),)}
    verify_structure(new_stmt, expected_struct)

def test_attr_stmt():
96
    ib = tvm.tir.ir_builder.create()
97 98
    dshape = (32, 64)
    data = ib.pointer("float32", name="data")
99 100 101
    l = te.var('l')
    m = te.var('m')
    n = te.var('n')
102

103 104
    tx = te.thread_axis("threadIdx.x")
    bx = te.thread_axis("blockIdx.x")
105 106 107 108 109
    ib.scope_attr(tx, "thread_extent", dshape[0])
    ib.scope_attr(bx, "thread_extent", dshape[1])
    with ib.for_range(0, l, "i") as i:
        with ib.for_range(0, m, "j") as j:
            with ib.for_range(0, n, "k") as k:
110
                with ib.if_scope(tvm.tir.any(i < 4, j >= 8)):
111 112 113 114 115
                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.5
                with ib.else_scope():
                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.0

    stmt = ib.get()
116
    new_stmt = tvm.tir.ir_pass.HoistIfThenElse(stmt)
117 118 119 120 121 122 123
    expected_struct = {('For', 'k'): (None,), ('IfThenElse', ('i', 'j')): (('For', 'k'), ('For', 'k')),
                       ('For', 'j'): (('IfThenElse', ('i', 'j')),), ('For', 'i'): (('For', 'j'),),
                       ('AttrStmt', 'thread_extent', 64): (('For', 'i'),),
                       ('AttrStmt', 'thread_extent', 32): (('AttrStmt', 'thread_extent', 64),)}
    verify_structure(new_stmt, expected_struct)

def test_nested_for():
124
    ib = tvm.tir.ir_builder.create()
125 126 127 128 129 130 131 132 133
    data = ib.pointer("float32", name="data")


    with ib.for_range(0, 5, "i") as i:
        with ib.for_range(0, 10, "j") as j:
            with ib.if_scope(i >= 3):
                data[i * 3 + j] = data[i * 3 + j] + 0.5
                with ib.for_range(0, 15, "k") as k:
                    with ib.for_range(0, 20, "l") as l:
134
                        with ib.if_scope(tvm.tir.any(i < 4, j >= 8)):
135 136 137 138 139
                            data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2
                        with ib.else_scope():
                            data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5

    stmt = ib.get()
140
    new_stmt = tvm.tir.ir_pass.HoistIfThenElse(stmt)
141 142 143 144 145 146
    expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('For', 'l'): (('IfThenElse', ('i', 'j')),),
                       ('For', 'k'): (('For', 'l'),), ('For', 'j'): (None,), ('IfThenElse', ('i',)): (('For', 'j'), None),
                       ('For', 'i'): (('IfThenElse', ('i',)),)}
    verify_structure(new_stmt, expected_struct)

def test_if_block():
147
    ib = tvm.tir.ir_builder.create()
148
    data = ib.pointer("float32", name="data")
149
    n = te.var("n")
150 151 152 153 154 155 156 157


    with ib.for_range(0, 5, "i") as i:
        with ib.for_range(0, 10, "j") as j:
            with ib.if_scope(i >= 3):
                data[i * 3 + j] = data[i * 3 + j] + 0.5
                with ib.for_range(0, 15, "k") as k:
                    with ib.for_range(0, 20, "l") as l:
158
                        with ib.if_scope(tvm.tir.any(i < 4, j >= 8)):
159 160 161 162 163 164 165 166 167 168 169 170 171 172
                            data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2
                        with ib.else_scope():
                            data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5
                        with ib.if_scope(j <5):
                            data[i * 3 + j + k + l] = data[i * 3 + j + k + l] - 1


    with ib.for_range(0, 5, "i") as i:
        with ib.for_range(0, 10, "j") as j:
                with ib.for_range(0, 15, "k") as k:
                    with ib.if_scope(n >= 3):
                        data[i * 3 + j + k] = data[i * 3 + j + k] + 0.6

    stmt = ib.get()
173
    new_stmt = tvm.tir.ir_pass.HoistIfThenElse(stmt)
174 175 176 177 178 179 180 181 182 183 184 185 186
    expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('IfThenElse', ('j',)): (None, None),
                       ('For', 'l'): (None,), ('For', 'k'): (None,), ('For', 'j'): (('For', 'j'),),
                       ('IfThenElse', ('i',)): (('For', 'j'), None), ('For', 'i'): (('IfThenElse', ('i',)),),
                       ('IfThenElse', ('n',)): (('For', 'j'), None)}
    verify_structure(new_stmt, expected_struct)


if __name__ == "__main__":
    test_basic()
    test_no_else()
    test_attr_stmt()
    test_nested_for()
    test_if_block()