test_reduce.py 8.1 KB
Newer Older
1 2 3
import tvm
import numpy as np

ziheng committed
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
def test_reduce_prims():
    def test_prim(reducer, np_reducer):
        # graph
        n = tvm.var('n')
        m = tvm.var('m')
        A = tvm.placeholder((n, m), name='A')
        k = tvm.reduce_axis((0, m))
        B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(i>1)), name='B')
        # schedule
        s = tvm.create_schedule(B.op)
        # create iter var and assign them tags.
        num_thread = 1
        xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
        s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
        s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
19

ziheng committed
20 21
        # one line to build the function.
        def check_device(device, host="stackvm"):
22
            if not tvm.module.enabled(host):
ziheng committed
23
                return
24 25
            if not tvm.module.enabled(device):
                print("skip because %s is not enabled.." % device)
ziheng committed
26
                return
27
            ctx = tvm.context(device, 0)
ziheng committed
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
            freduce = tvm.build(s,
                             args=[A, B],
                             target=device, target_host=host,
                             name="myreduce")
            # launch the kernel.
            n = 1028
            m = 129
            x = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx)
            y = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
            freduce(x, y)
            npy = y.asnumpy()
            npy[:2] = 0
            res = np_reducer(x.asnumpy(), axis=1)
            res[:2] = 0
            np.testing.assert_allclose(npy, res, rtol=1e-4)
43

44
        check_device("metal")
ziheng committed
45 46 47 48 49
        check_device("cuda")
        check_device("opencl")
    test_prim(tvm.sum, np.sum)
    test_prim(tvm.min, np.amin)
    test_prim(tvm.max, np.amax)
50

51

52 53 54 55
def test_rfactor():
    n = tvm.convert(1027)
    A = tvm.placeholder((n,), name='A')
    k = tvm.reduce_axis((0, n))
56
    B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B')
57
    # schedule
58
    s = tvm.create_schedule(B.op)
59
    kf, ki = s[B].split(k, nparts=4)
60 61 62 63
    BF = s.rfactor(B, kf)
    s[BF].parallel(BF.op.axis[0])
    # one line to build the function.
    def check_target(target="llvm"):
64
        if not tvm.module.enabled(target):
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
            return
        ctx = tvm.cpu(0)
        fapi = tvm.lower(s, args=[A, B])
        fsum = tvm.build(fapi,
                         target=target,
                         name="mysum")
        # launch the kernel.
        n = 1027
        a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
        b  = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
        fsum(a, b)
        res = np.sum(a.asnumpy(), axis=0)
        np.testing.assert_allclose(
            b.asnumpy(), res, rtol=1e-4)

    check_target()

82 83 84 85 86 87 88 89 90 91 92

def test_rfactor_threads():
    nn = 1027
    mm = 10
    n = tvm.convert(nn)
    m = tvm.convert(mm)
    A = tvm.placeholder((m, n), name='A')
    k = tvm.reduce_axis((0, n))
    nthread = 16
    B = tvm.compute((m,), lambda i: tvm.sum(A[i, k], axis=k, where=(i>1)), name='B')
    # schedule
93
    s = tvm.create_schedule(B.op)
94 95
    ko, kf = s[B].split(k, factor=nthread)
    BF = s.rfactor(B, kf)
96
    bx, ty = s[B].split(s[B].op.axis[0], factor=nthread)
97
    s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
98 99
    s[B].bind(ty, tvm.thread_axis("threadIdx.y"))
    tx = s[B].op.reduce_axis[0]
100 101
    thread_x = tvm.thread_axis("threadIdx.x")
    s[B].bind(tx, thread_x)
102
    s[BF].compute_at(s[B], tx)
103
    s[B].set_store_predicate(thread_x.var.equal(0))
104 105 106

    # one line to build the function.
    def check_target(device, host="stackvm"):
107 108
        if not tvm.module.enabled(device):
            print("skip because %s is not enabled.." % device)
109
            return
110
        ctx = tvm.context(device, 0)
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
        fapi = tvm.lower(s, args=[A, B])
        fsum = tvm.build(fapi,
                         target=device,
                         name="mysum")
        # launch the kernel.
        n = nn
        m = mm
        a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx)
        b  = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
        fsum(a, b)
        res = np.sum(a.asnumpy(), axis=1)
        res[:2] = 0
        np.testing.assert_allclose(
            b.asnumpy(), res, rtol=1e-4)

    check_target("cuda")
127
    check_target("metal")
128 129
    check_target("opencl")

130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
def test_argmax():
    def fcombine(x, y):
        lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
        rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
        return lhs, rhs

    def fidentity(t0, t1):
        return tvm.const(-1, t0), tvm.min_value(t1)

    argmax = tvm.comm_reducer(fcombine,
                              fidentity,
                              name='argmax')
    m = tvm.var('m')
    n = tvm.var('n')
    idx = tvm.placeholder((m, n), name='idx', dtype='int32')
    val = tvm.placeholder((m, n), name='val', dtype='float32')
    k = tvm.reduce_axis((0, n), 'k')
    T0, T1 = tvm.compute((m,), lambda i: argmax((idx[i,k], val[i,k]), axis=k), name='T')
    s = tvm.create_schedule(T0.op)

    def check_target():
        device = 'cpu'
        if not tvm.module.enabled(device):
            print("skip because %s is not enabled.." % device)
            return
        ctx = tvm.context(device, 0)
        fapi = tvm.lower(s, args=[idx, val, T0, T1])
        fargmax = tvm.build(fapi,
                            target='llvm',
                            name="argmax")

        mm = 12
        nn = 16
        np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
        np_val = np.random.uniform(size=(mm, nn)).astype('float32')
        np_res = np.argmax(np_val, axis=1)

        nd_idx  = tvm.nd.array(np_idx, ctx)
        nd_val  = tvm.nd.array(np_val, ctx)
        nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
        nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
        fargmax(nd_idx, nd_val, nd_res0, nd_res1)
        np.testing.assert_allclose(np_res, nd_res0.asnumpy())

    check_target()


def test_rfactor_argmax():
    def fcombine(x, y):
        lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
        rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
        return lhs, rhs

    def fidentity(t0, t1):
        return tvm.const(-1, t0), tvm.min_value(t1)

    argmax = tvm.comm_reducer(fcombine,
                              fidentity,
                              name='argmax')

    nn = 1027
    mm = 10
    n = tvm.convert(nn)
    m = tvm.convert(mm)
    A0 = tvm.placeholder((m, n), name='A0', dtype='int32')
    A1 = tvm.placeholder((m, n), name='A1', dtype='float32')
    k = tvm.reduce_axis((0, n))
    B0, B1 = tvm.compute((m,), lambda i: argmax((A0[i, k], A1[i, k]), axis=k), name='B')

    # schedule
    s = tvm.create_schedule(B0.op)
    nthread = 16
    ko, kf = s[B0].split(k, factor=nthread)
    BF0, BF1 = s.rfactor(B0, kf)
    bx, ty = s[B0].split(s[B0].op.axis[0], factor=nthread)
    s[B0].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[B0].bind(ty, tvm.thread_axis("threadIdx.y"))
    tx = s[B0].op.reduce_axis[0]
    thread_x = tvm.thread_axis("threadIdx.x")
    s[B0].bind(tx, thread_x)
    s[BF0.op].compute_at(s[B0], tx)
    s[B0].set_store_predicate(thread_x.var.equal(0))

    def check_target(device):
        if not tvm.module.enabled(device):
            print("skip because %s is not enabled.." % device)
            return
        ctx = tvm.context(device, 0)
        fapi = tvm.lower(s, args=[A0, A1, B0, B1])
        fargmax = tvm.build(fapi,
                            target=device,
                            name="argmax")

        np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
        np_val = np.random.uniform(size=(mm, nn)).astype('float32')
        np_res = np.argmax(np_val, axis=1)

        nd_idx  = tvm.nd.array(np_idx, ctx)
        nd_val  = tvm.nd.array(np_val, ctx)
        nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
        nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
        fargmax(nd_idx, nd_val, nd_res0, nd_res1)
        np.testing.assert_allclose(np_res, nd_res0.asnumpy())

    check_target("cuda")

236
if __name__ == "__main__":
237
    test_rfactor_threads()
238
    test_rfactor()
ziheng committed
239
    test_reduce_prims()
240 241
    test_argmax()
    test_rfactor_argmax()