test_lang_buffer.py 824 Bytes
Newer Older
1
import tvm
2
from tvm.schedule import Buffer
3 4

def test_buffer():
5 6 7 8 9
    m = tvm.var('m')
    n = tvm.var('n')
    l = tvm.var('l')
    Ab = tvm.decl_buffer((m, n), tvm.float32)
    Bb = tvm.decl_buffer((n, l), tvm.float32)
10

11
    assert isinstance(Ab, tvm.schedule.Buffer)
12 13 14
    assert Ab.dtype == tvm.float32
    assert tuple(Ab.shape) == (m, n)

15 16 17 18 19 20 21 22 23 24
def test_buffer_access_ptr():
    m = tvm.var('m')
    n = tvm.var('n')
    Ab = tvm.decl_buffer((m, n), tvm.float32, strides=[n + 1 , 1])
    aptr = Ab.access_ptr("rw")
    assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m)
    assert aptr.args[0].dtype == Ab.dtype
    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
    aptr = Ab.access_ptr("w")
    assert aptr.args[4].value == Buffer.WRITE
25 26 27

if __name__ == "__main__":
    test_buffer()
28
    test_buffer_access_ptr()