test_buffer_linebuff.py 2.23 KB
Newer Older
1 2
import tvm
import numpy as np
3
from tvm.contrib import verilog
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

def test_buffer_linebuff():
    # Test the tvm_buffer.v module as a line buffer
    # Window is 8x8, kernel is 3x3
    window_width = 8
    kernel_width = 3

    # Find file will search root/verilog and root/tests/verilog
    sess = verilog.session([
        verilog.find_file("test_buffer_linebuff.v"),
        verilog.find_file("tvm_buffer.v")
    ])

    # Get the handles by their names
    rst = sess.main.rst
    write_advance = sess.main.write_advance
    write_valid = sess.main.write_valid
    write_ready = sess.main.write_ready
    write_data = sess.main.write_data
    read_data = sess.main.read_data
    read_data_valid = sess.main.read_data_valid

    # Simulation input data
    test_data = np.arange(window_width*window_width).astype('int8')

    # Initial state
    rst.put_int(1)
    write_advance.put_int(0)
    write_valid.put_int(0)
    write_data.put_int(0)

    # De-assert reset
36
    sess.yield_until_next_cycle()
37 38 39
    rst.put_int(0)

    # Leave the following signals set to true
40
    sess.yield_until_next_cycle()
41 42 43 44 45 46 47 48 49 50
    write_advance.put_int(1)
    write_valid.put_int(1)

    # Main simulation loop
    write_idx = 0
    read_idx = 0
    while read_idx < (window_width-kernel_width+1)*(window_width-kernel_width+1)*kernel_width*kernel_width:
        # write logic
        if (write_idx < len(test_data)):
            if (write_ready.get_int()):
51
                write_data.put_int(int(test_data[write_idx]))
52 53 54 55
                write_idx += 1
        else:
            write_advance.put_int(0)
            write_valid.put_int(0)
56

57 58 59
        # correctness checks
        if (read_data_valid.get_int()):
            # Derive convolution window indices
60 61 62
            baseIdx = read_idx // (kernel_width*kernel_width)
            offsetIdx = read_idx % (kernel_width*kernel_width)
            yOffset = offsetIdx // kernel_width
63 64 65 66 67 68 69
            xOffset = offsetIdx%kernel_width
            pixIndex = baseIdx + yOffset * window_width + xOffset
            assert(read_data.get_int()==test_data[pixIndex])
            # print "{} {}".format(read_data.get_int(), test_data[pixIndex])
            read_idx += 1

        # step
70
        sess.yield_until_next_cycle()
71 72 73 74


if __name__ == "__main__":
    test_buffer_linebuff()