test_pass_verify_gpu_code.py 7.16 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24 25 26
"""Test gpu code verifier"""
import tvm

def get_verify_pass(valid, **kwargs):
    def verify_pass(stmt):
        valid[0] = tvm.ir_pass.VerifyGPUCode(stmt, kwargs)
        return stmt
    return verify_pass

def test_shared_memory():
27 28 29
    def check_shared_memory(dtype):
        N = 1024
        M = 128
30

31 32
        tvm_type = tvm.datatype._TVMType(dtype)
        type_size = tvm_type.bits // 8 * tvm_type.lanes
33

34 35
        A = tvm.placeholder((N,), name='A', dtype=dtype)
        B = tvm.compute((N, ), lambda i: A[i], name='B')
36

37 38 39 40 41 42
        s = tvm.create_schedule([B.op])
        AA = s.cache_read(A, "shared", [B])
        o, i = s[B].split(s[B].op.axis[0], M)
        s[AA].compute_at(s[B], o)
        s[B].bind(o, tvm.thread_axis("blockIdx.x"))
        s[B].bind(i, tvm.thread_axis("threadIdx.x"))
43

44 45
        # shared memory usage: M * sizeof(dtype) Bytes
        # thread usage: M
46

47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
        for target in ['opencl', 'cuda']:
            if not tvm.context(target).exist:
                continue
            valid = [None]
            with tvm.build_config(**{"add_lower_pass": [
                (2, get_verify_pass(valid,
                                    max_shared_memory_per_block=type_size * M - 1,
                                    max_threads_per_block=M))]}):
                tvm.build(s, [A, B], target)
            assert not valid[0]

            with tvm.build_config(**{"add_lower_pass": [
                (2, get_verify_pass(valid,
                                    max_shared_memory_per_block=type_size * M,
                                    max_threads_per_block=M))]}):
                tvm.build(s, [A, B], target)
            assert valid[0]
    check_shared_memory('float32')
    check_shared_memory('int8x4')
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90

def test_local_memory():
    N = 1024
    M = 128

    A = tvm.placeholder((N,), name='A', dtype='float32')
    B = tvm.compute((N, ), lambda i: A[i], name='B')

    s = tvm.create_schedule([B.op])
    AA = s.cache_read(A, "local", [B])
    o, i = s[B].split(s[B].op.axis[0], M)
    s[AA].compute_at(s[B], o)
    s[B].bind(o, tvm.thread_axis("blockIdx.x"))

    # local memory usage: M * 4B
    # thread usage: M

    for target in ['opencl', 'cuda']:
        if not tvm.context(target).exist:
            continue

        valid = [None]
        with tvm.build_config(**{"add_lower_pass": [
            (2, get_verify_pass(valid,
                                max_local_memory_per_block=4 * M - 1,
91
                                max_threads_per_block=1))]}):
92 93 94 95 96 97
            tvm.build(s, [A, B], target)
        assert not valid[0]

        with tvm.build_config(**{"add_lower_pass": [
            (2, get_verify_pass(valid,
                                max_local_memory_per_block=4 * M,
98
                                max_threads_per_block=1))]}):
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
            tvm.build(s, [A, B], target)
        assert valid[0]

def test_num_thread():
    N = 1024
    M = 128

    A = tvm.placeholder((N,), name='A', dtype='float32')
    B = tvm.compute((N, ), lambda i: A[i], name='B')

    s = tvm.create_schedule([B.op])
    o, i = s[B].split(s[B].op.axis[0], M)

    s[B].bind(o, tvm.thread_axis('threadIdx.x'))
    s[B].bind(i, tvm.thread_axis("threadIdx.y"))

    # shared memory usage: 0
    # thread usage: N

    for target in ['opencl', 'cuda']:
        if not tvm.context(target).exist:
            continue

        valid = [None]
        with tvm.build_config(**{"add_lower_pass": [
            (2, get_verify_pass(valid,
                                max_shared_memory_per_block=0,
126
                                max_threads_per_block=N - 1))]}):
127 128 129 130 131 132
            tvm.build(s, [A, B], target)
        assert not valid[0]

        with tvm.build_config(**{"add_lower_pass": [
            (2, get_verify_pass(valid,
                                max_shared_memory_per_block=0,
133
                                max_threads_per_block=N))]}):
134 135 136 137 138 139
            tvm.build(s, [A, B], target)
        assert valid[0]

        with tvm.build_config(**{"add_lower_pass": [
            (2, get_verify_pass(valid,
                                max_shared_memory_per_block=0,
140
                                max_threads_per_block=N,
141 142 143 144 145 146 147
                                max_thread_y=M-1))]}):
            tvm.build(s, [A, B], target)
        assert not valid[0]

        with tvm.build_config(**{"add_lower_pass": [
            (2, get_verify_pass(valid,
                                max_shared_memory_per_block=0,
148
                                max_threads_per_block=N,
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
                                max_thread_y=M))]}):
            tvm.build(s, [A, B], target)
        assert valid[0]

def test_multiple_kernels():
    N = 1024

    A = tvm.placeholder((N, N), name='A')
    B = tvm.compute((N, N), lambda i, j: A[i, j])
    C = tvm.compute((N, N), lambda i, j: B[i, j])

    s = tvm.create_schedule([C.op])

    s[C].bind(s[C].op.axis[1], tvm.thread_axis("threadIdx.x"))
    s[B].bind(s[B].op.axis[1], tvm.thread_axis("threadIdx.x"))

    # shared memory usage: 0
    # thread usage: N

    for target in ['opencl', 'cuda']:
        if not tvm.context(target).exist:
            continue

        valid = [None]
        with tvm.build_config(**{"add_lower_pass": [
            (2, get_verify_pass(valid,
                                max_shared_memory_per_block=0,
176
                                max_threads_per_block=N - 1))]}):
177 178 179 180 181 182
            tvm.build(s, [A, C], target)
        assert not valid[0]

        with tvm.build_config(**{"add_lower_pass": [
            (2, get_verify_pass(valid,
                                max_shared_memory_per_block=0,
183
                                max_threads_per_block=N))]}):
184 185 186
            tvm.build(s, [A, C], target)
        assert valid[0]

187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
def test_wrong_bind():
    N = 1024

    A = tvm.placeholder((N, N-1), name='A')
    B = tvm.compute((N, N-1), lambda i, j: A[i, j])

    s = tvm.create_schedule([B.op])

    # bind a thread axis to two loop axes with different lengths
    s[B].bind(s[B].op.axis[0], tvm.thread_axis("threadIdx.x"))
    s[B].bind(s[B].op.axis[1], tvm.thread_axis("threadIdx.x"))

    for target in ['opencl', 'cuda']:
        if not tvm.context(target).exist:
            continue

        valid = [None]
        with tvm.build_config(**{"add_lower_pass": [
                (2, get_verify_pass(valid, max_threads_per_block=N*N))]}):
            tvm.build(s, [A, B], target)
        assert not valid[0]


210 211 212 213 214
if __name__ == "__main__":
    test_local_memory()
    test_shared_memory()
    test_num_thread()
    test_multiple_kernels()
215
    test_wrong_bind()