test_reduce.py 11.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19
import tvm
import numpy as np

20

ziheng committed
21 22 23 24 25 26
def test_reduce_prims():
    def test_prim(reducer, np_reducer):
        # graph
        n = tvm.var('n')
        m = tvm.var('m')
        A = tvm.placeholder((n, m), name='A')
27
        R = tvm.compute((n, ), lambda i: tvm.expr.Select((i > 1), 1, 0), name='R')
ziheng committed
28
        k = tvm.reduce_axis((0, m))
29
        B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i]==1)), name='B')
ziheng committed
30 31 32 33 34 35 36
        # schedule
        s = tvm.create_schedule(B.op)
        # create iter var and assign them tags.
        num_thread = 1
        xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
        s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
        s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
37
        s[R].compute_inline()
38

ziheng committed
39 40
        # one line to build the function.
        def check_device(device, host="stackvm"):
41
            ctx = tvm.context(device, 0)
42
            if not tvm.module.enabled(host):
ziheng committed
43
                return
44
            if not ctx.exist:
45
                print("skip because %s is not enabled.." % device)
ziheng committed
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
                return
            freduce = tvm.build(s,
                             args=[A, B],
                             target=device, target_host=host,
                             name="myreduce")
            # launch the kernel.
            n = 1028
            m = 129
            x = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx)
            y = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
            freduce(x, y)
            npy = y.asnumpy()
            npy[:2] = 0
            res = np_reducer(x.asnumpy(), axis=1)
            res[:2] = 0
61
            tvm.testing.assert_allclose(npy, res, rtol=1e-4)
62

63
        check_device("metal")
64
        check_device("vulkan")
ziheng committed
65 66 67 68 69
        check_device("cuda")
        check_device("opencl")
    test_prim(tvm.sum, np.sum)
    test_prim(tvm.min, np.amin)
    test_prim(tvm.max, np.amax)
70

71

72 73 74 75
def test_rfactor():
    n = tvm.convert(1027)
    A = tvm.placeholder((n,), name='A')
    k = tvm.reduce_axis((0, n))
76
    B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B')
77
    # schedule
78
    s = tvm.create_schedule(B.op)
79
    kf, ki = s[B].split(k, nparts=4)
80 81 82 83
    BF = s.rfactor(B, kf)
    s[BF].parallel(BF.op.axis[0])
    # one line to build the function.
    def check_target(target="llvm"):
84
        if not tvm.module.enabled(target):
85 86 87 88 89 90 91 92 93 94 95 96
            return
        ctx = tvm.cpu(0)
        fapi = tvm.lower(s, args=[A, B])
        fsum = tvm.build(fapi,
                         target=target,
                         name="mysum")
        # launch the kernel.
        n = 1027
        a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
        b  = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
        fsum(a, b)
        res = np.sum(a.asnumpy(), axis=0)
97
        tvm.testing.assert_allclose(
98 99 100 101
            b.asnumpy(), res, rtol=1e-4)

    check_target()

102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
def test_rfactor_factor_axis():
    n = tvm.convert(1027)
    A = tvm.placeholder((n,), name='A')
    k = tvm.reduce_axis((0, n))
    B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B')
    # schedule
    s = tvm.create_schedule(B.op)
    kf, ki = s[B].split(k, nparts=4)
    BF = s.rfactor(B, kf, 1)
    s[BF].parallel(BF.op.axis[0])
    # one line to build the function.
    def check_target(target="llvm"):
        if not tvm.module.enabled(target):
            return
        ctx = tvm.cpu(0)
        fapi = tvm.lower(s, args=[A, B])
        fsum = tvm.build(fapi,
                         target=target,
                         name="mysum")
        # launch the kernel.
        n = 1027
        a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
        b  = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
        fsum(a, b)
        res = np.sum(a.asnumpy(), axis=0)
127
        tvm.testing.assert_allclose(
128 129 130 131
            b.asnumpy(), res, rtol=1e-4)

    check_target()

132 133 134 135 136 137 138 139 140 141 142

def test_rfactor_threads():
    nn = 1027
    mm = 10
    n = tvm.convert(nn)
    m = tvm.convert(mm)
    A = tvm.placeholder((m, n), name='A')
    k = tvm.reduce_axis((0, n))
    nthread = 16
    B = tvm.compute((m,), lambda i: tvm.sum(A[i, k], axis=k, where=(i>1)), name='B')
    # schedule
143
    s = tvm.create_schedule(B.op)
144 145
    ko, kf = s[B].split(k, factor=nthread)
    BF = s.rfactor(B, kf)
146
    bx, ty = s[B].split(s[B].op.axis[0], factor=nthread)
147
    s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
148 149
    s[B].bind(ty, tvm.thread_axis("threadIdx.y"))
    tx = s[B].op.reduce_axis[0]
150 151
    thread_x = tvm.thread_axis("threadIdx.x")
    s[B].bind(tx, thread_x)
152
    s[BF].compute_at(s[B], tx)
153
    s[B].set_store_predicate(thread_x.var.equal(0))
154 155 156

    # one line to build the function.
    def check_target(device, host="stackvm"):
157 158
        ctx = tvm.context(device, 0)
        if not ctx.exist:
159
            print("skip because %s is not enabled.." % device)
160
            return
161

162 163 164 165 166 167 168 169 170 171 172 173
        fapi = tvm.lower(s, args=[A, B])
        fsum = tvm.build(fapi,
                         target=device,
                         name="mysum")
        # launch the kernel.
        n = nn
        m = mm
        a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx)
        b  = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
        fsum(a, b)
        res = np.sum(a.asnumpy(), axis=1)
        res[:2] = 0
174
        tvm.testing.assert_allclose(
175 176
            b.asnumpy(), res, rtol=1e-4)

177
    check_target("vulkan")
178
    check_target("cuda")
179
    check_target("metal")
180 181
    check_target("opencl")

182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211

def test_rfactor_elemwise_threads():
    n = 1025
    m = 10
    A = tvm.placeholder((m, n), name='A')
    k = tvm.reduce_axis((0, n))
    nthread = 16
    B = tvm.compute((m,), lambda i: tvm.sum(A[i, k], axis=k), name='B')
    BB = tvm.compute((m,), lambda i: B[i] + 1, name='BB')
    C = tvm.compute((m,), lambda i: BB[i] + 1, name='C')
    # schedule
    s = tvm.create_schedule(C.op)
    s[BB].compute_inline()
    bx, ty = s[C].split(s[C].op.axis[0], factor=nthread)
    ko, kf = s[B].split(k, factor=nthread)
    BF = s.rfactor(B, kf)
    s[B].compute_at(s[C], ty)
    s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[C].bind(ty, tvm.thread_axis("threadIdx.y"))
    tx = s[B].op.reduce_axis[0]
    thread_x = tvm.thread_axis("threadIdx.x")
    s[B].bind(tx, thread_x)
    s[BF].compute_at(s[B], tx)
    # Since thread_x is shared across reductions
    # only one of them need to do write back
    s[B].set_store_predicate(thread_x.var.equal(0))
    s[C].set_store_predicate(thread_x.var.equal(0))

    # one line to build the function.
    def check_target(device, host="stackvm"):
212 213
        ctx = tvm.context(device, 0)
        if not ctx.exist:
214 215 216 217 218 219 220 221 222 223 224
            print("skip because %s is not enabled.." % device)
            return
        fapi = tvm.lower(s, args=[A, C])
        fsum = tvm.build(fapi,
                         target=device,
                         name="mysum")
        # launch the kernel.
        a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx)
        b  = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
        fsum(a, b)
        res = np.sum(a.asnumpy(), axis=1) + 2
225
        tvm.testing.assert_allclose(
226 227
            b.asnumpy(), res, rtol=1e-4)

228
    check_target("vulkan")
229 230 231 232
    check_target("cuda")
    check_target("metal")
    check_target("opencl")

233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
def test_argmax():
    def fcombine(x, y):
        lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
        rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
        return lhs, rhs

    def fidentity(t0, t1):
        return tvm.const(-1, t0), tvm.min_value(t1)

    argmax = tvm.comm_reducer(fcombine,
                              fidentity,
                              name='argmax')
    m = tvm.var('m')
    n = tvm.var('n')
    idx = tvm.placeholder((m, n), name='idx', dtype='int32')
    val = tvm.placeholder((m, n), name='val', dtype='float32')
    k = tvm.reduce_axis((0, n), 'k')
    T0, T1 = tvm.compute((m,), lambda i: argmax((idx[i,k], val[i,k]), axis=k), name='T')
    s = tvm.create_schedule(T0.op)

    def check_target():
        device = 'cpu'
        if not tvm.module.enabled(device):
            print("skip because %s is not enabled.." % device)
            return
        ctx = tvm.context(device, 0)
        fapi = tvm.lower(s, args=[idx, val, T0, T1])
        fargmax = tvm.build(fapi,
                            target='llvm',
                            name="argmax")

        mm = 12
        nn = 16
        np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
        np_val = np.random.uniform(size=(mm, nn)).astype('float32')
        np_res = np.argmax(np_val, axis=1)

        nd_idx  = tvm.nd.array(np_idx, ctx)
        nd_val  = tvm.nd.array(np_val, ctx)
        nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
        nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
        fargmax(nd_idx, nd_val, nd_res0, nd_res1)
275
        tvm.testing.assert_allclose(np_res, nd_res0.asnumpy())
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316

    check_target()


def test_rfactor_argmax():
    def fcombine(x, y):
        lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
        rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
        return lhs, rhs

    def fidentity(t0, t1):
        return tvm.const(-1, t0), tvm.min_value(t1)

    argmax = tvm.comm_reducer(fcombine,
                              fidentity,
                              name='argmax')

    nn = 1027
    mm = 10
    n = tvm.convert(nn)
    m = tvm.convert(mm)
    A0 = tvm.placeholder((m, n), name='A0', dtype='int32')
    A1 = tvm.placeholder((m, n), name='A1', dtype='float32')
    k = tvm.reduce_axis((0, n))
    B0, B1 = tvm.compute((m,), lambda i: argmax((A0[i, k], A1[i, k]), axis=k), name='B')

    # schedule
    s = tvm.create_schedule(B0.op)
    nthread = 16
    ko, kf = s[B0].split(k, factor=nthread)
    BF0, BF1 = s.rfactor(B0, kf)
    bx, ty = s[B0].split(s[B0].op.axis[0], factor=nthread)
    s[B0].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[B0].bind(ty, tvm.thread_axis("threadIdx.y"))
    tx = s[B0].op.reduce_axis[0]
    thread_x = tvm.thread_axis("threadIdx.x")
    s[B0].bind(tx, thread_x)
    s[BF0.op].compute_at(s[B0], tx)
    s[B0].set_store_predicate(thread_x.var.equal(0))

    def check_target(device):
317 318
        ctx = tvm.context(device, 0)
        if not ctx.exist:
319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
            print("skip because %s is not enabled.." % device)
            return
        fapi = tvm.lower(s, args=[A0, A1, B0, B1])
        fargmax = tvm.build(fapi,
                            target=device,
                            name="argmax")

        np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
        np_val = np.random.uniform(size=(mm, nn)).astype('float32')
        np_res = np.argmax(np_val, axis=1)

        nd_idx  = tvm.nd.array(np_idx, ctx)
        nd_val  = tvm.nd.array(np_val, ctx)
        nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
        nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
        fargmax(nd_idx, nd_val, nd_res0, nd_res1)
335
        tvm.testing.assert_allclose(np_res, nd_res0.asnumpy())
336 337

    check_target("cuda")
338
    check_target("vulkan")
339

340
if __name__ == "__main__":
341
    test_rfactor_elemwise_threads()
342
    test_rfactor_threads()
343
    test_rfactor_factor_axis()
344
    test_rfactor()
ziheng committed
345
    test_reduce_prims()
346 347
    test_argmax()
    test_rfactor_argmax()