import tvm
from tvm.schedule import Buffer

def test_buffer():
    m = tvm.var('m')
    n = tvm.var('n')
    l = tvm.var('l')
    Ab = tvm.decl_buffer((m, n), tvm.float32)
    Bb = tvm.decl_buffer((n, l), tvm.float32)

    assert isinstance(Ab, tvm.schedule.Buffer)
    assert Ab.dtype == tvm.float32
    assert tuple(Ab.shape) == (m, n)

def test_buffer_access_ptr():
    m = tvm.var('m')
    n = tvm.var('n')
    Ab = tvm.decl_buffer((m, n), tvm.float32, strides=[n + 1 , 1])
    aptr = Ab.access_ptr("rw")
    assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m)
    assert aptr.args[0].dtype == Ab.dtype
    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
    aptr = Ab.access_ptr("w")
    assert aptr.args[4].value == Buffer.WRITE

def test_buffer_access_ptr_offset():
    m = tvm.var('m')
    n = tvm.var('n')
    Ab = tvm.decl_buffer((m, n), tvm.float32)
    aptr = Ab.access_ptr("rw", offset=100)
    offset = tvm.ir_pass.Simplify(aptr.args[2])
    assert tvm.ir_pass.Equal(offset, 100)
    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
    v = tvm.var('int32')
    aptr = Ab.access_ptr("rw", offset=100 + 100 + v)
    offset = tvm.ir_pass.Simplify(aptr.args[2])
    assert tvm.ir_pass.Equal(offset, 200 + v)
    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
    aptr = Ab.access_ptr("rw", offset=tvm.call_extern('int32', "test_call", 100 + 100 + v))
    offset = tvm.ir_pass.Simplify(aptr.args[2])
    assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v))
    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE

def test_buffer_access_ptr_extent():
    m = tvm.var('m')
    n = tvm.var('n')
    Ab = tvm.decl_buffer((m, n), tvm.float32)
    aptr = Ab.access_ptr("rw")
    assert tvm.ir_pass.Equal(aptr.args[3], m * n)
    aptr = Ab.access_ptr("rw", offset=100)
    assert tvm.ir_pass.Equal(aptr.args[3], m * n - 100)
    Ab = tvm.decl_buffer((m, n), tvm.float32, strides=[n + 1 , 1])
    aptr = Ab.access_ptr("rw", offset=100)
    assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m - 100)

def test_buffer_vload():
    m = tvm.var('m')
    n = tvm.var('n')
    Ab = tvm.decl_buffer((m, n), tvm.float32, elem_offset=100)
    load = Ab.vload([2, 3])
    offset = tvm.ir_pass.Simplify(load.index)
    assert tvm.ir_pass.Equal(offset, n * 2 + 103)

def test_buffer_index_merge_mult_mod():
    m = tvm.var('m')
    n = tvm.var('n')
    s = tvm.var('s')
    k0 = tvm.var('k0')
    k1 = tvm.var('k1')
    A = tvm.decl_buffer((m, n), tvm.float32)
    A_stride = tvm.decl_buffer((m, n), tvm.float32, strides=(s, 1))
    def assert_simplified_equal(index_simplified, index_direct):
        assert tvm.ir_pass.Equal(index_simplified, index_direct),\
        "index_simplified=%s, index_direct=%s" %(index_simplified, index_direct)
    # Test Case1
    index_simplified = A_stride.vload(((k0 % k1) / s, (k0 % k1) % s + (k0 / k1) * k1))
    index_direct = A_stride.vload((0, k0))
    assert_simplified_equal(index_simplified, index_direct)
    # Test Case2
    index_simplified = A.vload(((k0 % (k1 / s)) / n,
                                (k0 % (k1 / s)) % n + (k0 % k1)))
    index_direct = A.vload((0, k0 % k1 + k0 % (k1 / s)))
    assert_simplified_equal(index_simplified, index_direct)
    # Test Case3
    index_simplified = A.vload((((k0 / (k1 / s)) * (k1 / s)) / n + (k0 % (k1 / s)) / n,
                                ((k0 / (k1 / s)) * (k1 / s)) % n + (k0 % (k1 / s)) % n))
    index_direct = A.vload((0, k0))
    assert_simplified_equal(index_simplified, index_direct)
    # Test Case4 (not able to simplify)
    index_simplified = A.vload(((k0 % (k1 / s)) / n,
                                (k0 % (k1 / n)) % n + (k0 % k1)))
    index_direct = A.vload((0, ((k0 % (k1 / s)) / n) * n + ((k0 % (k1 / n)) % n + (k0 % k1))))
    assert_simplified_equal(index_simplified, index_direct)

if __name__ == "__main__":
    test_buffer()
    test_buffer_access_ptr()
    test_buffer_access_ptr_offset()
    test_buffer_access_ptr_extent()
    test_buffer_vload()
    test_buffer_index_merge_mult_mod()