test_mps.py 2.52 KB
Newer Older
1 2 3 4
import tvm
import numpy as np
from tvm.contrib import mps

Leyuan Wang committed
5 6 7 8
def test_matmul():
    if not tvm.module.enabled("metal"):
        print("skip because %s is not enabled..." % "metal")
        return
9 10
    n = 1024
    l = 128
Leyuan Wang committed
11
    m = 256
12 13
    A = tvm.placeholder((n, l), name='A')
    B = tvm.placeholder((l, m), name='B')
Leyuan Wang committed
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
    C = mps.matmul(A, B)
    D = tvm.compute(
        C.shape,
        lambda *i: C(*i) + 1.
    )
    s = tvm.create_schedule(D.op)
    yo, xo = D.op.axis
    block_y = tvm.thread_axis("blockIdx.y")
    block_x = tvm.thread_axis("blockIdx.x")
    thread_y = tvm.thread_axis("threadIdx.y")
    thread_x = tvm.thread_axis("threadIdx.x")
    by, ty = s[D].split(yo, factor=16)
    bx, tx = s[D].split(xo, factor=16)
    s[D].bind(by, block_y)
    s[D].bind(bx, block_x)
    s[D].bind(ty, thread_y)
    s[D].bind(tx, thread_x)



    def verify(A, B, D, s, target="metal"):
35 36 37
        if not tvm.get_global_func("tvm.contrib.mps.matmul", True):
            print("skip because extern function is not avalable")
            return
Leyuan Wang committed
38 39
        ctx = tvm.metal(0)
        f = tvm.build(s, [A, B, D], "metal")
40 41
        a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
Leyuan Wang committed
42 43
        c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
        f(a, b, c)
44
        np.testing.assert_allclose(
Leyuan Wang committed
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
            c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + 1, rtol=1e-5)
    verify(A, B, D, s)

def test_conv2d():
    if not tvm.module.enabled("metal"):
        print("skip because %s is not enabled..." % "metal")
        return
    n = 1
    h = 14
    w = 14
    ci = 2
    co = 4
    kh = 3
    kw = 3
    stride = 2
    A = tvm.placeholder((n, h, w, ci), name="x")
    B = tvm.placeholder((co, kh, kw, ci), name="w")
    C = mps.conv2d(A, B, 'SAME', 2)
    s1 = tvm.create_schedule(C.op)

    def verify(A, B, C, target="llvm"):
        if not tvm.get_global_func("tvm.contrib.mps.conv2d", True):
            print("skip because extern function is not avalable")
            return
        ctx = tvm.metal(0)
        f = tvm.build(s1, [A, B, C], "metal")
        a = tvm.nd.array(np.random.uniform(size=(n, h, w, ci)).astype(A.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(size=(co, kh, kw, ci)).astype(B.dtype), ctx)
        c = tvm.nd.array(np.zeros((n, h // stride, w // stride, co), dtype=C.dtype), ctx)
        f(a, b, c)
        # print(c.asnumpy())
        # print(c.shape)
        
    verify(A, B, C, s1)
79 80 81


if __name__ == "__main__":
Leyuan Wang committed
82 83 84
    #test_matmul()
    test_conv2d()