test_lang_buffer.py 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import tvm
18
from tvm.schedule import Buffer
19
import numpy as np
20 21

def test_buffer():
22 23 24 25 26
    m = tvm.var('m')
    n = tvm.var('n')
    l = tvm.var('l')
    Ab = tvm.decl_buffer((m, n), tvm.float32)
    Bb = tvm.decl_buffer((n, l), tvm.float32)
27

28
    assert isinstance(Ab, tvm.schedule.Buffer)
29 30 31
    assert Ab.dtype == tvm.float32
    assert tuple(Ab.shape) == (m, n)

32 33 34 35 36 37 38 39 40 41
def test_buffer_access_ptr():
    m = tvm.var('m')
    n = tvm.var('n')
    Ab = tvm.decl_buffer((m, n), tvm.float32, strides=[n + 1 , 1])
    aptr = Ab.access_ptr("rw")
    assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m)
    assert aptr.args[0].dtype == Ab.dtype
    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
    aptr = Ab.access_ptr("w")
    assert aptr.args[4].value == Buffer.WRITE
42

43 44 45 46 47 48 49 50
def test_buffer_access_ptr_offset():
    m = tvm.var('m')
    n = tvm.var('n')
    Ab = tvm.decl_buffer((m, n), tvm.float32)
    aptr = Ab.access_ptr("rw", offset=100)
    offset = tvm.ir_pass.Simplify(aptr.args[2])
    assert tvm.ir_pass.Equal(offset, 100)
    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
51 52 53 54 55 56 57 58 59
    v = tvm.var('int32')
    aptr = Ab.access_ptr("rw", offset=100 + 100 + v)
    offset = tvm.ir_pass.Simplify(aptr.args[2])
    assert tvm.ir_pass.Equal(offset, 200 + v)
    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
    aptr = Ab.access_ptr("rw", offset=tvm.call_extern('int32', "test_call", 100 + 100 + v))
    offset = tvm.ir_pass.Simplify(aptr.args[2])
    assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v))
    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
60

61 62 63 64 65 66 67 68 69 70 71 72
def test_buffer_access_ptr_extent():
    m = tvm.var('m')
    n = tvm.var('n')
    Ab = tvm.decl_buffer((m, n), tvm.float32)
    aptr = Ab.access_ptr("rw")
    assert tvm.ir_pass.Equal(aptr.args[3], m * n)
    aptr = Ab.access_ptr("rw", offset=100)
    assert tvm.ir_pass.Equal(aptr.args[3], m * n - 100)
    Ab = tvm.decl_buffer((m, n), tvm.float32, strides=[n + 1 , 1])
    aptr = Ab.access_ptr("rw", offset=100)
    assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m - 100)

73 74 75 76 77 78 79 80
def test_buffer_vload():
    m = tvm.var('m')
    n = tvm.var('n')
    Ab = tvm.decl_buffer((m, n), tvm.float32, elem_offset=100)
    load = Ab.vload([2, 3])
    offset = tvm.ir_pass.Simplify(load.index)
    assert tvm.ir_pass.Equal(offset, n * 2 + 103)

81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
def test_buffer_index_merge_mult_mod():
    m = tvm.var('m')
    n = tvm.var('n')
    s = tvm.var('s')
    k0 = tvm.var('k0')
    k1 = tvm.var('k1')
    A = tvm.decl_buffer((m, n), tvm.float32)
    A_stride = tvm.decl_buffer((m, n), tvm.float32, strides=(s, 1))
    def assert_simplified_equal(index_simplified, index_direct):
        assert tvm.ir_pass.Equal(index_simplified, index_direct),\
        "index_simplified=%s, index_direct=%s" %(index_simplified, index_direct)
    # Test Case1
    index_simplified = A_stride.vload(((k0 % k1) / s, (k0 % k1) % s + (k0 / k1) * k1))
    index_direct = A_stride.vload((0, k0))
    assert_simplified_equal(index_simplified, index_direct)
    # Test Case2
    index_simplified = A.vload(((k0 % (k1 / s)) / n,
                                (k0 % (k1 / s)) % n + (k0 % k1)))
    index_direct = A.vload((0, k0 % k1 + k0 % (k1 / s)))
    assert_simplified_equal(index_simplified, index_direct)
    # Test Case3
    index_simplified = A.vload((((k0 / (k1 / s)) * (k1 / s)) / n + (k0 % (k1 / s)) / n,
                                ((k0 / (k1 / s)) * (k1 / s)) % n + (k0 % (k1 / s)) % n))
    index_direct = A.vload((0, k0))
    assert_simplified_equal(index_simplified, index_direct)
    # Test Case4 (not able to simplify)
    index_simplified = A.vload(((k0 % (k1 / s)) / n,
                                (k0 % (k1 / n)) % n + (k0 % k1)))
    index_direct = A.vload((0, ((k0 % (k1 / s)) / n) * n + ((k0 % (k1 / n)) % n + (k0 % k1))))
    assert_simplified_equal(index_simplified, index_direct)

112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
def test_buffer_broadcast():
    m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2")
    n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2")
    o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2")

    A = tvm.placeholder((m0, m1, m2), name='A')
    B = tvm.placeholder((n0, n1, n2), name='B')

    C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')

    Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
    Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
    s = tvm.create_schedule(C.op)

    def check():
        if not tvm.module.enabled("llvm"):
            return
        fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
        ctx = tvm.cpu(0)
        a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(size=(2, 1, 1)).astype(B.dtype), ctx)
        c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx)
        fadd(a, b, c)
        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())

    check()


140 141
if __name__ == "__main__":
    test_buffer()
142
    test_buffer_access_ptr()
143
    test_buffer_access_ptr_offset()
144
    test_buffer_access_ptr_extent()
145
    test_buffer_vload()
146
    test_buffer_index_merge_mult_mod()
147
    test_buffer_broadcast()