test_cublas.py 5.99 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import tvm
18
from tvm import te
19 20
import numpy as np
from tvm.contrib import cublas
21
from tvm.contrib import cublaslt
22

23
def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5):
24 25
    n = 1024
    l = 128
26
    m = 236
27 28
    A = te.placeholder((n, l), name='A', dtype=in_dtype)
    B = te.placeholder((l, m), name='B', dtype=in_dtype)
29
    C = cublas.matmul(A, B, dtype=out_dtype)
30
    s = te.create_schedule(C.op)
31 32

    def verify(target="cuda"):
33
        if not tvm.runtime.enabled(target):
34 35 36
            print("skip because %s is not enabled..." % target)
            return
        if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
37
            print("skip because extern function is not available")
38 39 40
            return
        ctx = tvm.gpu(0)
        f = tvm.build(s, [A, B, C], target)
41 42
        a = tvm.nd.array(np.random.uniform(0, 128, size=(n, l)).astype(A.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(0, 128, size=(l, m)).astype(B.dtype), ctx)
43 44
        c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
        f(a, b, c)
45
        tvm.testing.assert_allclose(
46
            c.asnumpy(), np.dot(a.asnumpy().astype(C.dtype), b.asnumpy().astype(C.dtype)), rtol=rtol)
47 48
    verify()

49 50 51 52 53 54 55 56 57 58 59
def roundoff(v, d):
    return int(np.floor((v + d - 1) / d) * d)

def verify_matmul_add_igemm(in_dtype, out_dtype, rtol=1e-5):
    n = 1024
    l = 1024
    m = 1024
    L = roundoff(l, 32)
    N = roundoff(n, 8)
    N_out = roundoff(n, 32)

60 61
    A = te.placeholder((N, L), name='A', dtype=in_dtype)
    B = te.placeholder((m, L), name='B', dtype=in_dtype)
62 63
    # C has CUBLASLT_ORDER_COL32 layout, thus a different shape
    C = cublaslt.matmul(A, B, False, True, m, N_out, dtype=out_dtype)
64
    s = te.create_schedule(C.op)
65 66

    def verify(target="cuda"):
67
        if not tvm.runtime.enabled(target):
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
            print("skip because %s is not enabled..." % target)
            return
        if not tvm.get_global_func("tvm.contrib.cublaslt.matmul", True):
            print("skip because extern function is not available")
            return
        ctx = tvm.gpu(0)
        f = tvm.build(s, [A, B, C], target)
        a_old = np.random.uniform(0, 128, size=(n, l))
        b_old = np.random.uniform(0, 128, size=(l, m))

        # Transform a to become CUBLASLT_ORDER_COL4_4R2_8C layout
        a_new = np.hstack((a_old.astype(A.dtype), np.zeros([n, L-l])))
        a_new = np.vstack((a_new.astype(A.dtype), np.zeros([N-n, L])))
        a_even = np.vsplit(a_new[::2], N / 8)
        a_odd = np.vsplit(a_new[1::2], N / 8)
        a_new = [None]*(len(a_even) + len(a_odd))
        a_new[::2] = a_even
        a_new[1::2] = a_odd
        a_new = np.vstack(a_new)
        a_new = np.vstack(np.vstack(np.vstack(np.hsplit(i, 8)).reshape([4, 32]) for i in np.vsplit(j, N/4)) for j in np.hsplit(a_new, L/32))
        a_new = a_new.reshape([N, L])
        # Transform b to become CUBLASLT_ORDER_COL32 layout
        b_new = np.vstack(np.hsplit(np.hstack((b_old.T.astype(B.dtype), np.zeros([m, L - l]))), L / 32))
        b_new = b_new.reshape([m, L])

        a = tvm.nd.array(a_new.astype(A.dtype), ctx)
        b = tvm.nd.array(b_new.astype(B.dtype), ctx)
        c = tvm.nd.array(np.zeros((m, N_out), dtype=C.dtype), ctx)
        f(a, b, c)
        # Transform output c from layout CUBLASLT_ORDER_COL32 to row major layout
        c_out = c.asnumpy()
        c_out = c_out.reshape([int(m * N_out / 32), 32])
        c_out = np.hstack(np.vsplit(c_out, int(N_out / 32)))
        c_out = c_out[:, :n]
        c_out = c_out.T
        tvm.testing.assert_allclose(
            c_out, np.dot(a_old.astype(C.dtype), b_old.astype(C.dtype)), rtol=rtol)
    verify()

107
def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5):
108 109 110
    j = 16
    n = 1024
    l = 128
111
    m = 236
112 113
    A = te.placeholder((j, n, l), name='A', dtype=in_dtype)
    B = te.placeholder((j, l, m), name='B', dtype=in_dtype)
114
    C = cublas.batch_matmul(A, B, dtype=out_dtype)
115
    s = te.create_schedule(C.op)
116 117

    def verify(target="cuda"):
118
        if not tvm.runtime.enabled(target):
119 120 121 122 123 124 125 126 127 128 129 130
            print("skip because %s is not enabled..." % target)
            return
        if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
            print("skip because extern function is not available")
            return
        ctx = tvm.gpu(0)
        f = tvm.build(s, [A, B, C], target)
        a = tvm.nd.array(np.random.uniform(size=(j, n, l)).astype(A.dtype), ctx)
        b = tvm.nd.array(np.random.uniform(size=(j, l, m)).astype(B.dtype), ctx)
        c = tvm.nd.array(np.zeros((j, n, m), dtype=C.dtype), ctx)
        f(a, b, c)
        tvm.testing.assert_allclose(
131 132
            c.asnumpy(), np.matmul(a.asnumpy().astype(C.dtype),
                                   b.asnumpy().astype(C.dtype)).astype(C.dtype), rtol=rtol)
133 134
    verify()

135
def test_matmul_add():
136
    verify_matmul_add('float', 'float', rtol=1e-3)
137 138 139 140
    verify_matmul_add('float16', 'float')
    verify_matmul_add('float16', 'float16', rtol=1e-2)
    verify_matmul_add('int8', 'int32')

141 142 143
def test_matmul_add_igemm():
    verify_matmul_add_igemm('int8', 'int32')

144 145 146 147
def test_batch_matmul():
    verify_batch_matmul('float', 'float')
    verify_batch_matmul('float16', 'float')
    verify_batch_matmul('float16', 'float16', rtol=1e-2)
148 149 150

if __name__ == "__main__":
    test_matmul_add()
151
    test_batch_matmul()
152
    test_matmul_add_igemm()
153