test_reduce.py 11.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import tvm
18
from tvm import te
19 20
import numpy as np

21

ziheng committed
22 23 24
def test_reduce_prims():
    def test_prim(reducer, np_reducer):
        # graph
25 26 27 28 29 30
        n = tvm.te.size_var('n')
        m = tvm.te.size_var('m')
        A = te.placeholder((n, m), name='A')
        R = te.compute((n, ), lambda i: tvm.tir.Select((i > 1), 1, 0), name='R')
        k = te.reduce_axis((0, m))
        B = te.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i]==1)), name='B')
ziheng committed
31
        # schedule
32
        s = te.create_schedule(B.op)
ziheng committed
33 34 35
        # create iter var and assign them tags.
        num_thread = 1
        xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
36 37
        s[B].bind(xo, te.thread_axis("blockIdx.x"))
        s[B].bind(xi, te.thread_axis("threadIdx.x"))
38
        s[R].compute_inline()
39

ziheng committed
40
        # one line to build the function.
41
        def check_device(device, host="llvm"):
42
            ctx = tvm.context(device, 0)
43
            if not tvm.runtime.enabled(host):
ziheng committed
44
                return
45
            if not ctx.exist:
46
                print("skip because %s is not enabled.." % device)
ziheng committed
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
                return
            freduce = tvm.build(s,
                             args=[A, B],
                             target=device, target_host=host,
                             name="myreduce")
            # launch the kernel.
            n = 1028
            m = 129
            x = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx)
            y = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
            freduce(x, y)
            npy = y.asnumpy()
            npy[:2] = 0
            res = np_reducer(x.asnumpy(), axis=1)
            res[:2] = 0
62
            tvm.testing.assert_allclose(npy, res, rtol=1e-4)
63

64
        check_device("metal")
65
        check_device("vulkan")
ziheng committed
66 67
        check_device("cuda")
        check_device("opencl")
68 69 70
    test_prim(te.sum, np.sum)
    test_prim(tvm.te.min, np.amin)
    test_prim(tvm.te.max, np.amax)
71

72

73
def test_rfactor():
74 75 76 77
    n = tvm.runtime.convert(1027)
    A = te.placeholder((n,), name='A')
    k = te.reduce_axis((0, n))
    B = te.compute((1,), lambda i: te.sum(A[k], axis=k), name='B')
78
    # schedule
79
    s = te.create_schedule(B.op)
80
    kf, ki = s[B].split(k, nparts=4)
81 82 83 84
    BF = s.rfactor(B, kf)
    s[BF].parallel(BF.op.axis[0])
    # one line to build the function.
    def check_target(target="llvm"):
85
        if not tvm.runtime.enabled(target):
86 87 88 89 90 91 92 93 94 95 96 97
            return
        ctx = tvm.cpu(0)
        fapi = tvm.lower(s, args=[A, B])
        fsum = tvm.build(fapi,
                         target=target,
                         name="mysum")
        # launch the kernel.
        n = 1027
        a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
        b  = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
        fsum(a, b)
        res = np.sum(a.asnumpy(), axis=0)
98
        tvm.testing.assert_allclose(
99 100 101 102
            b.asnumpy(), res, rtol=1e-4)

    check_target()

103
def test_rfactor_factor_axis():
104 105 106 107
    n = tvm.runtime.convert(1027)
    A = te.placeholder((n,), name='A')
    k = te.reduce_axis((0, n))
    B = te.compute((1,), lambda i: te.sum(A[k], axis=k), name='B')
108
    # schedule
109
    s = te.create_schedule(B.op)
110 111 112 113 114
    kf, ki = s[B].split(k, nparts=4)
    BF = s.rfactor(B, kf, 1)
    s[BF].parallel(BF.op.axis[0])
    # one line to build the function.
    def check_target(target="llvm"):
115
        if not tvm.runtime.enabled(target):
116 117 118 119 120 121 122 123 124 125 126 127
            return
        ctx = tvm.cpu(0)
        fapi = tvm.lower(s, args=[A, B])
        fsum = tvm.build(fapi,
                         target=target,
                         name="mysum")
        # launch the kernel.
        n = 1027
        a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
        b  = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
        fsum(a, b)
        res = np.sum(a.asnumpy(), axis=0)
128
        tvm.testing.assert_allclose(
129 130 131 132
            b.asnumpy(), res, rtol=1e-4)

    check_target()

133 134 135 136

def test_rfactor_threads():
    nn = 1027
    mm = 10
137 138 139 140
    n = tvm.runtime.convert(nn)
    m = tvm.runtime.convert(mm)
    A = te.placeholder((m, n), name='A')
    k = te.reduce_axis((0, n))
141
    nthread = 16
142
    B = te.compute((m,), lambda i: te.sum(A[i, k], axis=k, where=(i>1)), name='B')
143
    # schedule
144
    s = te.create_schedule(B.op)
145 146
    ko, kf = s[B].split(k, factor=nthread)
    BF = s.rfactor(B, kf)
147
    bx, ty = s[B].split(s[B].op.axis[0], factor=nthread)
148 149
    s[B].bind(bx, te.thread_axis("blockIdx.x"))
    s[B].bind(ty, te.thread_axis("threadIdx.y"))
150
    tx = s[B].op.reduce_axis[0]
151
    thread_x = te.thread_axis("threadIdx.x")
152
    s[B].bind(tx, thread_x)
153
    s[BF].compute_at(s[B], tx)
154
    s[B].set_store_predicate(thread_x.var.equal(0))
155 156 157

    # one line to build the function.
    def check_target(device, host="stackvm"):
158 159
        ctx = tvm.context(device, 0)
        if not ctx.exist:
160
            print("skip because %s is not enabled.." % device)
161
            return
162

163 164 165 166 167 168 169 170 171 172 173 174
        fapi = tvm.lower(s, args=[A, B])
        fsum = tvm.build(fapi,
                         target=device,
                         name="mysum")
        # launch the kernel.
        n = nn
        m = mm
        a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx)
        b  = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
        fsum(a, b)
        res = np.sum(a.asnumpy(), axis=1)
        res[:2] = 0
175
        tvm.testing.assert_allclose(
176 177
            b.asnumpy(), res, rtol=1e-4)

178
    check_target("vulkan")
179
    check_target("cuda")
180
    check_target("metal")
181 182
    check_target("opencl")

183 184 185 186

def test_rfactor_elemwise_threads():
    n = 1025
    m = 10
187 188
    A = te.placeholder((m, n), name='A')
    k = te.reduce_axis((0, n))
189
    nthread = 16
190 191 192
    B = te.compute((m,), lambda i: te.sum(A[i, k], axis=k), name='B')
    BB = te.compute((m,), lambda i: B[i] + 1, name='BB')
    C = te.compute((m,), lambda i: BB[i] + 1, name='C')
193
    # schedule
194
    s = te.create_schedule(C.op)
195 196 197 198 199
    s[BB].compute_inline()
    bx, ty = s[C].split(s[C].op.axis[0], factor=nthread)
    ko, kf = s[B].split(k, factor=nthread)
    BF = s.rfactor(B, kf)
    s[B].compute_at(s[C], ty)
200 201
    s[C].bind(bx, te.thread_axis("blockIdx.x"))
    s[C].bind(ty, te.thread_axis("threadIdx.y"))
202
    tx = s[B].op.reduce_axis[0]
203
    thread_x = te.thread_axis("threadIdx.x")
204 205 206 207 208 209 210 211 212
    s[B].bind(tx, thread_x)
    s[BF].compute_at(s[B], tx)
    # Since thread_x is shared across reductions
    # only one of them need to do write back
    s[B].set_store_predicate(thread_x.var.equal(0))
    s[C].set_store_predicate(thread_x.var.equal(0))

    # one line to build the function.
    def check_target(device, host="stackvm"):
213 214
        ctx = tvm.context(device, 0)
        if not ctx.exist:
215 216 217 218 219 220 221 222 223 224 225
            print("skip because %s is not enabled.." % device)
            return
        fapi = tvm.lower(s, args=[A, C])
        fsum = tvm.build(fapi,
                         target=device,
                         name="mysum")
        # launch the kernel.
        a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx)
        b  = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
        fsum(a, b)
        res = np.sum(a.asnumpy(), axis=1) + 2
226
        tvm.testing.assert_allclose(
227 228
            b.asnumpy(), res, rtol=1e-4)

229
    check_target("vulkan")
230 231 232 233
    check_target("cuda")
    check_target("metal")
    check_target("opencl")

234 235
def test_argmax():
    def fcombine(x, y):
236 237
        lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
        rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
238 239 240
        return lhs, rhs

    def fidentity(t0, t1):
241
        return tvm.tir.const(-1, t0), tvm.te.min_value(t1)
242

243
    argmax = te.comm_reducer(fcombine,
244 245
                              fidentity,
                              name='argmax')
246 247 248 249 250 251 252
    m = te.size_var('m')
    n = te.size_var('n')
    idx = te.placeholder((m, n), name='idx', dtype='int32')
    val = te.placeholder((m, n), name='val', dtype='float32')
    k = te.reduce_axis((0, n), 'k')
    T0, T1 = te.compute((m,), lambda i: argmax((idx[i,k], val[i,k]), axis=k), name='T')
    s = te.create_schedule(T0.op)
253 254 255

    def check_target():
        device = 'cpu'
256
        if not tvm.runtime.enabled(device):
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
            print("skip because %s is not enabled.." % device)
            return
        ctx = tvm.context(device, 0)
        fapi = tvm.lower(s, args=[idx, val, T0, T1])
        fargmax = tvm.build(fapi,
                            target='llvm',
                            name="argmax")

        mm = 12
        nn = 16
        np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
        np_val = np.random.uniform(size=(mm, nn)).astype('float32')
        np_res = np.argmax(np_val, axis=1)

        nd_idx  = tvm.nd.array(np_idx, ctx)
        nd_val  = tvm.nd.array(np_val, ctx)
        nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
        nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
        fargmax(nd_idx, nd_val, nd_res0, nd_res1)
276
        tvm.testing.assert_allclose(np_res, nd_res0.asnumpy())
277 278 279 280 281 282

    check_target()


def test_rfactor_argmax():
    def fcombine(x, y):
283 284
        lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
        rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
285 286 287
        return lhs, rhs

    def fidentity(t0, t1):
288
        return tvm.tir.const(-1, t0), tvm.te.min_value(t1)
289

290
    argmax = te.comm_reducer(fcombine,
291 292 293 294 295
                              fidentity,
                              name='argmax')

    nn = 1027
    mm = 10
296 297 298 299 300 301
    n = tvm.runtime.convert(nn)
    m = tvm.runtime.convert(mm)
    A0 = te.placeholder((m, n), name='A0', dtype='int32')
    A1 = te.placeholder((m, n), name='A1', dtype='float32')
    k = te.reduce_axis((0, n))
    B0, B1 = te.compute((m,), lambda i: argmax((A0[i, k], A1[i, k]), axis=k), name='B')
302 303

    # schedule
304
    s = te.create_schedule(B0.op)
305 306 307 308
    nthread = 16
    ko, kf = s[B0].split(k, factor=nthread)
    BF0, BF1 = s.rfactor(B0, kf)
    bx, ty = s[B0].split(s[B0].op.axis[0], factor=nthread)
309 310
    s[B0].bind(bx, te.thread_axis("blockIdx.x"))
    s[B0].bind(ty, te.thread_axis("threadIdx.y"))
311
    tx = s[B0].op.reduce_axis[0]
312
    thread_x = te.thread_axis("threadIdx.x")
313 314 315 316 317
    s[B0].bind(tx, thread_x)
    s[BF0.op].compute_at(s[B0], tx)
    s[B0].set_store_predicate(thread_x.var.equal(0))

    def check_target(device):
318 319
        ctx = tvm.context(device, 0)
        if not ctx.exist:
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
            print("skip because %s is not enabled.." % device)
            return
        fapi = tvm.lower(s, args=[A0, A1, B0, B1])
        fargmax = tvm.build(fapi,
                            target=device,
                            name="argmax")

        np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
        np_val = np.random.uniform(size=(mm, nn)).astype('float32')
        np_res = np.argmax(np_val, axis=1)

        nd_idx  = tvm.nd.array(np_idx, ctx)
        nd_val  = tvm.nd.array(np_val, ctx)
        nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
        nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
        fargmax(nd_idx, nd_val, nd_res0, nd_res1)
336
        tvm.testing.assert_allclose(np_res, nd_res0.asnumpy())
337 338

    check_target("cuda")
339
    check_target("vulkan")
340

341
if __name__ == "__main__":
342
    test_rfactor_elemwise_threads()
343
    test_rfactor_threads()
344
    test_rfactor_factor_axis()
345
    test_rfactor()
ziheng committed
346
    test_reduce_prims()
347 348
    test_argmax()
    test_rfactor_argmax()