test_cublas.py 1.04 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
import tvm
import numpy as np
from tvm.contrib import cublas

def test_matmul_add():
    n = 1024
    l = 128
    m = 235
    A = tvm.placeholder((n, l), name='A')
    B = tvm.placeholder((l, m), name='B')
    C = cublas.matmul(A, B)
    s = tvm.create_schedule(C.op)

    def verify(target="cuda"):
        if not tvm.module.enabled(target):
            print("skip because %s is not enabled..." % target)
            return
        if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
19
            print("skip because extern function is not available")
20 21 22 23 24 25 26 27 28 29 30 31 32 33
            return
        ctx = tvm.gpu(0)
        f = tvm.build(s, [A, B, C], target)
        a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
        c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
        f(a, b, c)
        np.testing.assert_allclose(
            c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
    verify()


if __name__ == "__main__":
    test_matmul_add()