test_random.py 2.28 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
import tvm
import numpy as np
from tvm.contrib import random

def test_randint():
    m = 1024
    n = 1024
    A = random.randint(-127, 128, size=(m, n), dtype='int32')
    s = tvm.create_schedule(A.op)

    def verify(target="llvm"):
        if not tvm.module.enabled(target):
            print("skip because %s is not enabled..." % target)
            return
        if not tvm.get_global_func("tvm.contrib.random.randint", True):
16
            print("skip because extern function is not available")
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
            return
        ctx = tvm.cpu(0)
        f = tvm.build(s, [A], target)
        a = tvm.nd.array(np.zeros((m, n), dtype=A.dtype), ctx)
        f(a)
        na = a.asnumpy()
        assert abs(np.mean(na)) < 0.2
        assert np.min(na) == -127
        assert np.max(na) == 127
    verify()


def test_uniform():
    m = 1024
    n = 1024
    A = random.uniform(0, 1, size=(m, n))
    s = tvm.create_schedule(A.op)

    def verify(target="llvm"):
        if not tvm.module.enabled(target):
            print("skip because %s is not enabled..." % target)
            return
        if not tvm.get_global_func("tvm.contrib.random.uniform", True):
40
            print("skip because extern function is not available")
41 42 43 44 45 46 47 48 49 50 51 52
            return
        ctx = tvm.cpu(0)
        f = tvm.build(s, [A], target)
        a = tvm.nd.array(np.zeros((m, n), dtype=A.dtype), ctx)
        f(a)
        na = a.asnumpy()
        assert abs(np.mean(na) - 0.5) < 1e-2
        assert abs(np.min(na) - 0.0) < 1e-3
        assert abs(np.max(na) - 1.0) < 1e-3
    verify()


53 54 55 56 57 58 59 60 61 62 63
def test_normal():
    m = 1024
    n = 1024
    A = random.normal(3, 4, size=(m, n))
    s = tvm.create_schedule(A.op)

    def verify(target="llvm"):
        if not tvm.module.enabled(target):
            print("skip because %s is not enabled..." % target)
            return
        if not tvm.get_global_func("tvm.contrib.random.normal", True):
64
            print("skip because extern function is not available")
65 66 67 68 69 70 71 72 73 74 75
            return
        ctx = tvm.cpu(0)
        f = tvm.build(s, [A], target)
        a = tvm.nd.array(np.zeros((m, n), dtype=A.dtype), ctx)
        f(a)
        na = a.asnumpy()
        assert abs(np.mean(na) - 3) < 1e-2
        assert abs(np.std(na) - 4) < 1e-2
    verify()


76 77 78
if __name__ == "__main__":
    test_randint()
    test_uniform()
79
    test_normal()