test_topi_bitserial_conv2d_rasp.py 1.69 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
import os
import re
import numpy as np
import tvm
import topi
import topi.testing

def generate_quantized_np(shape, bits, out_dtype):
    np.random.seed(0)
    min_val = 0
    max_val = 1 << bits
    return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)

# Verify that certain special instructions from the tensorize pass exist
def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding, 
16
                                 activation_bits, weight_bits, dorefa):
17
    in_height = in_width = in_size
18 19
    input_type = 'uint32'
    out_dtype = 'int32'
20

21
    with tvm.target.arm_cpu('rasp3b'):
22 23 24
        A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
        W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
        B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype, 
25
                                     layout="NHWC", dorefa=dorefa)
26 27
        s = topi.generic.schedule_bitserial_conv2d_nhwc([B])

28
    func = tvm.build(s, [A, W, B], tvm.target.arm_cpu('rasp3b'))
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
   
    assembly = func.get_source('asm')
    matches = re.findall("vpadal", assembly)
    assert (len(matches) > 0)
    matches = re.findall("vcnt", assembly)
    assert (len(matches) > 0)
    matches = re.findall("vpadd", assembly)
    assert (len(matches) > 0)

def test_bitserial_conv2d():
    in_size = 56
    ic, oc = 64, 64
    k = 3
    stride = 1
    pad = 1

    verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, False)
    verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False)

if __name__ == "__main__":
    test_bitserial_conv2d()