test_reduce.py 11 KB
Newer Older
1 2 3
import tvm
import numpy as np

4

ziheng committed
5 6 7 8 9 10
def test_reduce_prims():
    def test_prim(reducer, np_reducer):
        # graph
        n = tvm.var('n')
        m = tvm.var('m')
        A = tvm.placeholder((n, m), name='A')
11
        R = tvm.compute((n, ), lambda i: tvm.select((i > 1), 1, 0), name='R')
ziheng committed
12
        k = tvm.reduce_axis((0, m))
13
        B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i]==1)), name='B')
ziheng committed
14 15 16 17 18 19 20
        # schedule
        s = tvm.create_schedule(B.op)
        # create iter var and assign them tags.
        num_thread = 1
        xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
        s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
        s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
21
        s[R].compute_inline()
22

ziheng committed
23 24
        # one line to build the function.
        def check_device(device, host="stackvm"):
25
            ctx = tvm.context(device, 0)
26
            if not tvm.module.enabled(host):
ziheng committed
27
                return
28
            if not ctx.exist:
29
                print("skip because %s is not enabled.." % device)
ziheng committed
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
                return
            freduce = tvm.build(s,
                             args=[A, B],
                             target=device, target_host=host,
                             name="myreduce")
            # launch the kernel.
            n = 1028
            m = 129
            x = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx)
            y = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
            freduce(x, y)
            npy = y.asnumpy()
            npy[:2] = 0
            res = np_reducer(x.asnumpy(), axis=1)
            res[:2] = 0
            np.testing.assert_allclose(npy, res, rtol=1e-4)
46

47
        check_device("metal")
48
        check_device("vulkan")
ziheng committed
49 50 51 52 53
        check_device("cuda")
        check_device("opencl")
    test_prim(tvm.sum, np.sum)
    test_prim(tvm.min, np.amin)
    test_prim(tvm.max, np.amax)
54

55

56 57 58 59
def test_rfactor():
    n = tvm.convert(1027)
    A = tvm.placeholder((n,), name='A')
    k = tvm.reduce_axis((0, n))
60
    B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B')
61
    # schedule
62
    s = tvm.create_schedule(B.op)
63
    kf, ki = s[B].split(k, nparts=4)
64 65 66 67
    BF = s.rfactor(B, kf)
    s[BF].parallel(BF.op.axis[0])
    # one line to build the function.
    def check_target(target="llvm"):
68
        if not tvm.module.enabled(target):
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
            return
        ctx = tvm.cpu(0)
        fapi = tvm.lower(s, args=[A, B])
        fsum = tvm.build(fapi,
                         target=target,
                         name="mysum")
        # launch the kernel.
        n = 1027
        a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
        b  = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
        fsum(a, b)
        res = np.sum(a.asnumpy(), axis=0)
        np.testing.assert_allclose(
            b.asnumpy(), res, rtol=1e-4)

    check_target()

86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
def test_rfactor_factor_axis():
    n = tvm.convert(1027)
    A = tvm.placeholder((n,), name='A')
    k = tvm.reduce_axis((0, n))
    B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B')
    # schedule
    s = tvm.create_schedule(B.op)
    kf, ki = s[B].split(k, nparts=4)
    BF = s.rfactor(B, kf, 1)
    s[BF].parallel(BF.op.axis[0])
    # one line to build the function.
    def check_target(target="llvm"):
        if not tvm.module.enabled(target):
            return
        ctx = tvm.cpu(0)
        fapi = tvm.lower(s, args=[A, B])
        fsum = tvm.build(fapi,
                         target=target,
                         name="mysum")
        # launch the kernel.
        n = 1027
        a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
        b  = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
        fsum(a, b)
        res = np.sum(a.asnumpy(), axis=0)
        np.testing.assert_allclose(
            b.asnumpy(), res, rtol=1e-4)

    check_target()

116 117 118 119 120 121 122 123 124 125 126

def test_rfactor_threads():
    nn = 1027
    mm = 10
    n = tvm.convert(nn)
    m = tvm.convert(mm)
    A = tvm.placeholder((m, n), name='A')
    k = tvm.reduce_axis((0, n))
    nthread = 16
    B = tvm.compute((m,), lambda i: tvm.sum(A[i, k], axis=k, where=(i>1)), name='B')
    # schedule
127
    s = tvm.create_schedule(B.op)
128 129
    ko, kf = s[B].split(k, factor=nthread)
    BF = s.rfactor(B, kf)
130
    bx, ty = s[B].split(s[B].op.axis[0], factor=nthread)
131
    s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
132 133
    s[B].bind(ty, tvm.thread_axis("threadIdx.y"))
    tx = s[B].op.reduce_axis[0]
134 135
    thread_x = tvm.thread_axis("threadIdx.x")
    s[B].bind(tx, thread_x)
136
    s[BF].compute_at(s[B], tx)
137
    s[B].set_store_predicate(thread_x.var.equal(0))
138 139 140

    # one line to build the function.
    def check_target(device, host="stackvm"):
141 142
        ctx = tvm.context(device, 0)
        if not ctx.exist:
143
            print("skip because %s is not enabled.." % device)
144
            return
145

146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
        fapi = tvm.lower(s, args=[A, B])
        fsum = tvm.build(fapi,
                         target=device,
                         name="mysum")
        # launch the kernel.
        n = nn
        m = mm
        a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx)
        b  = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
        fsum(a, b)
        res = np.sum(a.asnumpy(), axis=1)
        res[:2] = 0
        np.testing.assert_allclose(
            b.asnumpy(), res, rtol=1e-4)

161
    check_target("vulkan")
162
    check_target("cuda")
163
    check_target("metal")
164 165
    check_target("opencl")

166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195

def test_rfactor_elemwise_threads():
    n = 1025
    m = 10
    A = tvm.placeholder((m, n), name='A')
    k = tvm.reduce_axis((0, n))
    nthread = 16
    B = tvm.compute((m,), lambda i: tvm.sum(A[i, k], axis=k), name='B')
    BB = tvm.compute((m,), lambda i: B[i] + 1, name='BB')
    C = tvm.compute((m,), lambda i: BB[i] + 1, name='C')
    # schedule
    s = tvm.create_schedule(C.op)
    s[BB].compute_inline()
    bx, ty = s[C].split(s[C].op.axis[0], factor=nthread)
    ko, kf = s[B].split(k, factor=nthread)
    BF = s.rfactor(B, kf)
    s[B].compute_at(s[C], ty)
    s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[C].bind(ty, tvm.thread_axis("threadIdx.y"))
    tx = s[B].op.reduce_axis[0]
    thread_x = tvm.thread_axis("threadIdx.x")
    s[B].bind(tx, thread_x)
    s[BF].compute_at(s[B], tx)
    # Since thread_x is shared across reductions
    # only one of them need to do write back
    s[B].set_store_predicate(thread_x.var.equal(0))
    s[C].set_store_predicate(thread_x.var.equal(0))

    # one line to build the function.
    def check_target(device, host="stackvm"):
196 197
        ctx = tvm.context(device, 0)
        if not ctx.exist:
198 199 200 201 202 203 204 205 206 207 208 209 210 211
            print("skip because %s is not enabled.." % device)
            return
        fapi = tvm.lower(s, args=[A, C])
        fsum = tvm.build(fapi,
                         target=device,
                         name="mysum")
        # launch the kernel.
        a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx)
        b  = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
        fsum(a, b)
        res = np.sum(a.asnumpy(), axis=1) + 2
        np.testing.assert_allclose(
            b.asnumpy(), res, rtol=1e-4)

212
    check_target("vulkan")
213 214 215 216
    check_target("cuda")
    check_target("metal")
    check_target("opencl")

217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
def test_argmax():
    def fcombine(x, y):
        lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
        rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
        return lhs, rhs

    def fidentity(t0, t1):
        return tvm.const(-1, t0), tvm.min_value(t1)

    argmax = tvm.comm_reducer(fcombine,
                              fidentity,
                              name='argmax')
    m = tvm.var('m')
    n = tvm.var('n')
    idx = tvm.placeholder((m, n), name='idx', dtype='int32')
    val = tvm.placeholder((m, n), name='val', dtype='float32')
    k = tvm.reduce_axis((0, n), 'k')
    T0, T1 = tvm.compute((m,), lambda i: argmax((idx[i,k], val[i,k]), axis=k), name='T')
    s = tvm.create_schedule(T0.op)

    def check_target():
        device = 'cpu'
        if not tvm.module.enabled(device):
            print("skip because %s is not enabled.." % device)
            return
        ctx = tvm.context(device, 0)
        fapi = tvm.lower(s, args=[idx, val, T0, T1])
        fargmax = tvm.build(fapi,
                            target='llvm',
                            name="argmax")

        mm = 12
        nn = 16
        np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
        np_val = np.random.uniform(size=(mm, nn)).astype('float32')
        np_res = np.argmax(np_val, axis=1)

        nd_idx  = tvm.nd.array(np_idx, ctx)
        nd_val  = tvm.nd.array(np_val, ctx)
        nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
        nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
        fargmax(nd_idx, nd_val, nd_res0, nd_res1)
        np.testing.assert_allclose(np_res, nd_res0.asnumpy())

    check_target()


def test_rfactor_argmax():
    def fcombine(x, y):
        lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
        rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
        return lhs, rhs

    def fidentity(t0, t1):
        return tvm.const(-1, t0), tvm.min_value(t1)

    argmax = tvm.comm_reducer(fcombine,
                              fidentity,
                              name='argmax')

    nn = 1027
    mm = 10
    n = tvm.convert(nn)
    m = tvm.convert(mm)
    A0 = tvm.placeholder((m, n), name='A0', dtype='int32')
    A1 = tvm.placeholder((m, n), name='A1', dtype='float32')
    k = tvm.reduce_axis((0, n))
    B0, B1 = tvm.compute((m,), lambda i: argmax((A0[i, k], A1[i, k]), axis=k), name='B')

    # schedule
    s = tvm.create_schedule(B0.op)
    nthread = 16
    ko, kf = s[B0].split(k, factor=nthread)
    BF0, BF1 = s.rfactor(B0, kf)
    bx, ty = s[B0].split(s[B0].op.axis[0], factor=nthread)
    s[B0].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[B0].bind(ty, tvm.thread_axis("threadIdx.y"))
    tx = s[B0].op.reduce_axis[0]
    thread_x = tvm.thread_axis("threadIdx.x")
    s[B0].bind(tx, thread_x)
    s[BF0.op].compute_at(s[B0], tx)
    s[B0].set_store_predicate(thread_x.var.equal(0))

    def check_target(device):
301 302
        ctx = tvm.context(device, 0)
        if not ctx.exist:
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
            print("skip because %s is not enabled.." % device)
            return
        fapi = tvm.lower(s, args=[A0, A1, B0, B1])
        fargmax = tvm.build(fapi,
                            target=device,
                            name="argmax")

        np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
        np_val = np.random.uniform(size=(mm, nn)).astype('float32')
        np_res = np.argmax(np_val, axis=1)

        nd_idx  = tvm.nd.array(np_idx, ctx)
        nd_val  = tvm.nd.array(np_val, ctx)
        nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
        nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
        fargmax(nd_idx, nd_val, nd_res0, nd_res1)
        np.testing.assert_allclose(np_res, nd_res0.asnumpy())

    check_target("cuda")
322
    check_target("vulkan")
323

324
if __name__ == "__main__":
325
    test_rfactor_elemwise_threads()
326
    test_rfactor_threads()
327
    test_rfactor_factor_axis()
328
    test_rfactor()
ziheng committed
329
    test_reduce_prims()
330 331
    test_argmax()
    test_rfactor_argmax()