test_tir_buffer.py 8.81 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import tvm
18
from tvm import te
19
from tvm.tir import Buffer
20
import numpy as np
21 22

def test_buffer():
23 24 25 26 27
    m = te.size_var('m')
    n = te.size_var('n')
    l = te.size_var('l')
    Ab = tvm.tir.decl_buffer((m, n), "float32")
    Bb = tvm.tir.decl_buffer((n, l), "float32")
28

29
    assert isinstance(Ab, tvm.tir.Buffer)
30
    assert Ab.dtype == "float32"
31 32
    assert tuple(Ab.shape) == (m, n)

33

34
def test_buffer_access_ptr():
35 36 37
    m = te.size_var('m')
    n = te.size_var('n')
    Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1 , 1])
38
    aptr = Ab.access_ptr("rw")
39
    assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m)
40 41 42 43
    assert aptr.args[0].dtype == Ab.dtype
    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
    aptr = Ab.access_ptr("w")
    assert aptr.args[4].value == Buffer.WRITE
44

45

46
def test_buffer_access_ptr_offset():
47 48 49
    m = te.size_var('m')
    n = te.size_var('n')
    Ab = tvm.tir.decl_buffer((m, n), "float32")
50
    aptr = Ab.access_ptr("rw", offset=100)
51
    offset = tvm.tir.ir_pass.Simplify(aptr.args[2])
52
    assert tvm.ir.structural_equal(offset, 100)
53
    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
54
    v = te.size_var('int32')
55
    aptr = Ab.access_ptr("rw", offset=100 + 100 + v)
56
    offset = tvm.tir.ir_pass.Simplify(aptr.args[2])
57
    assert tvm.ir.structural_equal(offset, 200 + v)
58
    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
59 60
    aptr = Ab.access_ptr("rw", offset=tvm.tir.call_extern('int32', "test_call", 100 + 100 + v))
    offset = tvm.tir.ir_pass.Simplify(aptr.args[2])
61
    assert tvm.ir.structural_equal(offset, tvm.tir.call_extern('int32', "test_call", 200 + v))
62
    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
63

64

65
def test_buffer_access_ptr_extent():
66 67 68
    m = te.size_var('m')
    n = te.size_var('n')
    Ab = tvm.tir.decl_buffer((m, n), "float32")
69
    aptr = Ab.access_ptr("rw")
70
    assert tvm.ir.structural_equal(aptr.args[3], m * n)
71
    aptr = Ab.access_ptr("rw", offset=100)
72
    assert tvm.ir.structural_equal(aptr.args[3], m * n - 100)
73
    Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1 , 1])
74
    aptr = Ab.access_ptr("rw", offset=100)
75
    assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m - 100)
76

77

78
def test_buffer_vload():
79 80 81
    m = te.size_var('m')
    n = te.size_var('n')
    Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100)
82
    load = Ab.vload([2, 3])
83
    offset = tvm.tir.ir_pass.Simplify(load.index)
84
    assert tvm.ir.structural_equal(offset, n * 2 + 103)
85

86

87
def test_buffer_index_merge_mult_mod():
88 89 90 91 92 93 94
    m = te.size_var('m')
    n = te.size_var('n')
    s = te.size_var('s')
    k0 = te.size_var('k0')
    k1 = te.size_var('k1')
    A = tvm.tir.decl_buffer((m, n), "float32")
    A_stride = tvm.tir.decl_buffer((m, n), "float32", strides=(s, 1))
95
    def assert_simplified_equal(index_simplified, index_direct):
96
        assert tvm.ir.structural_equal(index_simplified, index_direct),\
97
        "index_simplified=%s, index_direct=%s" %(index_simplified, index_direct)
98 99
    idxd = tvm.tir.indexdiv
    idxm = tvm.tir.indexmod
100
    # Test Case1
101
    index_simplified = A_stride.vload(
102
        (idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1))
103 104
    index_direct = A_stride.vload((0, k0))
    assert_simplified_equal(index_simplified, index_direct)
105

106
    # Test Case2
107 108 109
    index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
                                idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1)))
    index_direct = A.vload((0, idxm(k0, k1) + idxm(k0, idxd(k1, s))))
110 111
    assert_simplified_equal(index_simplified, index_direct)
    # Test Case3
112 113 114 115
    index_simplified = A.vload((idxd((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
                                idxd(idxm(k0, idxd(k1, s)), n),
                                idxm((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
                                idxm(idxm(k0, idxd(k1, s)), n)))
116 117 118
    index_direct = A.vload((0, k0))
    assert_simplified_equal(index_simplified, index_direct)
    # Test Case4 (not able to simplify)
119 120 121 122
    index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
                                idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1)))
    index_direct = A.vload((0, idxd(idxm(k0, idxd(k1, s)), n) * n +
                            (idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1))))
123 124
    assert_simplified_equal(index_simplified, index_direct)

125

126
def test_buffer_broadcast():
127 128 129
    m0, m1, m2 = te.size_var("m0"), te.size_var("m1"), te.size_var("m2")
    n0, n1, n2 = te.size_var("n0"), te.size_var("n1"), te.size_var("n2")
    o0, o1, o2 = te.size_var("o0"), te.size_var("o1"), te.size_var("o2")
130

131 132
    A = te.placeholder((m0, m1, m2), name='A')
    B = te.placeholder((n0, n1, n2), name='B')
133

134
    C = te.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
135

136 137 138
    Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
    Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
    s = te.create_schedule(C.op)
139 140

    def check():
141
        if not tvm.runtime.enabled("llvm"):
142 143 144 145 146 147 148 149 150 151 152 153
            return
        fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
        ctx = tvm.cpu(0)
        a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(size=(2, 1, 1)).astype(B.dtype), ctx)
        c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx)
        fadd(a, b, c)
        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())

    check()


154
def test_buffer_broadcast_expr():
155 156 157 158 159 160 161 162 163 164 165 166
    n0, m0, x = te.size_var('n0'), te.size_var('m0'), te.size_var('x')
    n1, m1 = te.size_var('n1'), te.size_var('m1')
    o0, o1 = te.size_var('o0'), te.size_var('o1')

    A = te.placeholder((m0, n0), name='A')
    B = te.placeholder((m1, n1), name='B')
    C = te.compute((o0, o1//x), lambda i, j: A[i, j] + B[i, j], name='C')

    Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
    Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
    Cc = tvm.tir.decl_buffer(C.shape, C.dtype, name="Cc", buffer_type="auto_broadcast")
    s = te.create_schedule(C.op)
167 168

    def check_stride():
169
        if not tvm.runtime.enabled("llvm"):
170 171 172 173 174 175 176 177 178 179 180
            return
        fadd = tvm.build(s, [A, B, C, o1, x], target='llvm', name='bcast_add',
                         binds={A:Ab, B:Bb, C:Cc})
        ctx = tvm.cpu(0)
        a = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(A.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(B.dtype), ctx)
        c = tvm.nd.array(np.zeros((2, 4), dtype=C.dtype), ctx)
        fadd(a, b, c, 4, 1)
        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())

    def check_no_stride():
181
        if not tvm.runtime.enabled("llvm"):
182 183 184 185 186 187 188 189 190 191
            return
        fadd = tvm.build(s, [A, B, C, o1, x], target='llvm', name='bcast_add',
                         binds={A: Ab, B: Bb, C: Cc})
        ctx = tvm.cpu(0)
        a = tvm.nd.array(np.random.uniform(size=(1, 4)).astype(A.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(B.dtype), ctx)
        c = tvm.nd.array(np.zeros((2, 4), dtype=C.dtype), ctx)
        fadd(a, b, c, 4, 1)
        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())

192
    def check_auto_bind():
193
        if not tvm.runtime.enabled("llvm"):
194 195 196 197 198 199 200 201 202 203
            return
        # Let build bind buffers
        fadd = tvm.build(s, [A, B, C, o1, x], target='llvm', name='bcast_add')
        ctx = tvm.cpu(0)
        a = tvm.nd.array(np.random.uniform(size=(1, 4)).astype(A.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(B.dtype), ctx)
        c = tvm.nd.array(np.zeros((2, 4), dtype=C.dtype), ctx)
        fadd(a, b, c, 4, 1)
        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())

204 205
    check_stride()
    check_no_stride()
206
    check_auto_bind()
207 208


209 210
if __name__ == "__main__":
    test_buffer()
211
    test_buffer_access_ptr()
212
    test_buffer_access_ptr_offset()
213
    test_buffer_access_ptr_extent()
214
    test_buffer_vload()
215
    test_buffer_index_merge_mult_mod()
216
    test_buffer_broadcast()
217
    test_buffer_broadcast_expr()