test_te_schedule_bound_inference.py 16.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import tvm
18
from tvm import te
19

20
def test_bound1():
21 22 23 24 25
    m = te.var('m')
    l = te.var('l')
    A = te.placeholder((m, l), name='A')
    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
26

27
    s = te.create_schedule([A2.op])
28 29
    xo, xi = s[A2].split(s[A2].op.axis[0], 8)
    s[A1].compute_at(s[A2], xo)
30
    bounds = tvm.te.schedule.InferBound(s)
31
    assert isinstance(bounds, tvm.container.Map)
32
    assert(bounds[A1.op.axis[0]].extent.value == 8)
33 34

def test_bound2():
35 36 37 38 39 40
    m = te.var('m')
    l = te.var('l')
    A = te.placeholder((m, l), name='A')
    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
    s = te.create_schedule(A2.op)
41
    xo, yo, xi, yi = s[A2].tile(A2.op.axis[0], A2.op.axis[1], 8, 8)
42 43
    # test normalize not affecting schedule
    _ = s.normalize()
44
    s[A1].compute_at(s[A2], yo)
45
    bounds = tvm.te.schedule.InferBound(s)
46
    assert isinstance(bounds, tvm.container.Map)
47 48
    assert(bounds[A1.op.axis[0]].extent.value == 8)
    assert(bounds[A1.op.axis[1]].extent.value == 8)
49 50

def test_bound3():
51 52 53 54 55
    m = te.var('m')
    l = te.var('l')
    A = te.placeholder((m, l), name='A')
    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
56

57
    s = te.create_schedule(A2.op)
58 59
    s[A1].set_scope("shared")
    xo, xi = s[A2].split(A2.op.axis[0], 32)
60
    xi0, xi1 = s[A2].split(xi, nparts=16)
61
    s[A2].bind(xi0, te.thread_axis("threadIdx.x"))
62
    yo, yi = s[A2].split(A2.op.axis[1], 16)
63 64
    # test normalize not affecting schedule
    _ = s.normalize()
65 66
    s[A2].reorder(xo, xi0, yo, xi1, yi)
    s[A1].compute_at(s[A2], yo)
67

68
    bounds = tvm.te.schedule.InferBound(s)
69
    assert isinstance(bounds, tvm.container.Map)
70 71
    assert(bounds[A1.op.axis[0]].extent.value==32)
    assert(bounds[A1.op.axis[1]].extent.value==16)
72

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
def test_bound_split_ext_less_than_factor():
    m = 8
    I = te.placeholder((m,), name='I')
    EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
    E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
    s = te.create_schedule([E.op])
    xo, xi = s[E].split(s[E].op.axis[0], factor = 32)
    s[EF].compute_at(s[E], xo)

    bounds = tvm.te.schedule.InferBound(s)
    assert isinstance(bounds, tvm.container.Map)
    assert bounds[xi].extent.value == m

def test_bound_split_ext_less_than_naprts():
    m = 8
    I = te.placeholder((m,), name='I')
    EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
    E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
    s = te.create_schedule([E.op])
    xo, xi = s[E].split(s[E].op.axis[0], nparts = 32)
    s[EF].compute_at(s[E], xo)

    bounds = tvm.te.schedule.InferBound(s)
    assert isinstance(bounds, tvm.container.Map)
    assert bounds[xo].extent.value == m

99
def test_bound_split_divisible():
100 101 102 103 104
    m = te.var('m')
    l = te.var('l')
    A = te.placeholder((8 * m, l), name='A')
    B = te.compute((8 * m, l), lambda i, j: A[i, j], name='B')
    s = te.create_schedule(B.op)
105
    xo, xi = s[B].split(B.op.axis[0], 8)
106
    bounds = tvm.te.schedule.InferBound(s)
107 108 109 110 111
    assert isinstance(bounds, tvm.container.Map)
    assert bounds[xo].extent == m
    assert bounds[xi].extent.value == 8

def test_bound_tile_divisible():
112 113
    m = te.var('m')
    l = te.var('l')
114
    shape = (8 * m, 32 * l)
115 116 117
    A = te.placeholder(shape, name='A')
    B = te.compute(shape, lambda i, j: A[i, j], name='B')
    s = te.create_schedule(B.op)
118
    xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], 8, 32)
119
    bounds = tvm.te.schedule.InferBound(s)
120 121 122 123 124 125
    assert isinstance(bounds, tvm.container.Map)
    assert bounds[xo].extent == m
    assert bounds[xi].extent.value == 8
    assert bounds[yo].extent == l
    assert bounds[yi].extent.value == 32

126
def test_bound_fusesplit1():
127 128 129 130 131 132 133 134
    m = te.var('m')
    l = te.var('l')
    split1 = te.var('s')
    A = te.placeholder((m, l), name='A')
    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')

    s = te.create_schedule(A2.op)
135 136 137 138
    fused_axes = s[A2].fuse(A2.op.axis[0], A2.op.axis[1])
    xo, xi = s[A2].split(fused_axes, split1)
    s[A1].compute_at(s[A2], xo)

139
    bounds = tvm.te.schedule.InferBound(s)
140
    assert isinstance(bounds, tvm.container.Map)
141
    idxdiv = tvm.tir.indexdiv
142 143
    tvm.testing.assert_prim_expr_equal(
        bounds[A1.op.axis[0]].min, idxdiv(xo * split1, l))
144

145
    expected_extent = (idxdiv((xo + 1) * split1 - 1, l) - idxdiv(xo * split1, l) + 1)
146 147 148
    for i in range(1, 6):
        for j in range(1, 6):
            for k in range(1, 6):
149
                vars = tvm.runtime.convert({split1: tvm.tir.const(i, "int32"), l: tvm.tir.const(j, "int32"), xo.var: tvm.tir.const(k, "int32")})
150
                tvm.testing.assert_prim_expr_equal(
151 152
                    tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].extent, vars),
                    tvm.tir.stmt_functor.substitute(expected_extent, vars)
153
                )
154

155
    tvm.testing.assert_prim_expr_equal(bounds[A1.op.axis[1]].extent, l)
156 157

def test_bound_fusesplit2():
158 159 160 161 162 163 164 165
    m = te.var("m")
    l = tvm.runtime.convert(6)
    split = tvm.runtime.convert(3)
    A = te.placeholder((m, l), name='A')
    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')

    s = te.create_schedule(A2.op)
166 167 168 169
    fused_axes = s[A2].fuse(A2.op.axis[0], A2.op.axis[1])
    xo, xi = s[A2].split(fused_axes, split)
    s[A1].compute_at(s[A2], xo)

170
    bounds = tvm.te.schedule.InferBound(s)
171
    assert isinstance(bounds, tvm.container.Map)
172
    vars = tvm.runtime.convert({xo.var: tvm.tir.const(5, "int32")})
173 174 175 176
    tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].min, vars), 2)
    tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[1]].min, vars), 3)
    tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].extent, vars), 1)
    tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[1]].extent, vars), 3)
177

178 179

def test_bound_warp():
180 181 182 183 184
    m = te.var('m')
    l = te.var('l')
    A = te.placeholder((m, l), name='A')
    A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
    A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
185

186
    s = te.create_schedule(A2.op)
187 188 189
    s[A1].set_scope("warp")
    xo, xi = s[A2].split(A2.op.axis[0], 32)
    xi0, xi1 = s[A2].split(xi, factor=16)
190
    tx = te.thread_axis("threadIdx.x")
191
    s[A2].bind(xi1, tx)
192
    s[A2].bind(xi0, te.thread_axis("threadIdx.y"))
193 194 195 196
    y = s[A2].op.axis[1]
    s[A1].compute_at(s[A2], y)
    xo, xi = s[A1].split(s[A1].op.axis[0], factor=16)
    s[A1].bind(xi, tx)
197
    bounds = tvm.te.schedule.InferBound(s)
198 199 200
    assert isinstance(bounds, tvm.container.Map)
    assert(bounds[A1.op.axis[0]].extent.value==16)

201
def test_bound_scan():
202 203 204 205 206 207 208
    m = te.var("m")
    n = te.var("n")
    X = te.compute((m, n), lambda i, j: tvm.tir.const(1, "float32"), name="x")
    s_state = te.placeholder((m, n))
    s_init = te.compute((1, n), lambda _, i: X[0, i])
    s_update = te.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
    s_scan = tvm.te.scan(s_init, s_update, s_state)
209

210
    assert tuple(s_scan.shape) == (m, n)
211
    s = te.create_schedule(s_scan.op)
212 213 214
    XX = s.cache_read(X, "local", s_update)
    xo, xi = s[s_update].split(s_update.op.axis[1], factor=4)
    s[XX].compute_at(s[s_update], xo)
215
    s = s.normalize()
216 217
    bounds = tvm.te.schedule.InferBound(s)
    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
218
    assert bounds[XX.op.axis[1]].extent.value == 4
219

220
def test_bound_conv1d():
221 222
    n = te.var('n')
    A = te.compute((n+2), lambda i: 1,  name='A')
223 224 225
    def computeB(ii):
        i = ii + 1
        return A[i-1] + A[i] + A[i+1]
226 227
    B = te.compute(n, computeB, name='B')
    s = te.create_schedule(B.op)
228
    s[A].compute_at(s[B], B.op.axis[0])
229
    s = s.normalize()
230
    bounds = tvm.te.schedule.InferBound(s)
231 232 233
    assert(bounds[A.op.axis[0]].extent.value == 3)

def test_bound_blur():
234 235
    n = tvm.runtime.convert(12)
    A = te.compute((n, n), lambda i, j: 1, name='A')
236 237 238 239 240
    def computeB(ii, jj):
        # set the correct center
        i = ii + 1
        j = jj + 1
        return A[i][j] + A[i-1][j] + A[i+1][j] + A[i][j+1] + A[i][j-1]
241 242
    B = te.compute((n-2, n-2), computeB, name='B')
    s = te.create_schedule(B.op)
243
    s[A].compute_at(s[B], B.op.axis[1])
244
    s = s.normalize()
245
    bounds = tvm.te.schedule.InferBound(s)
246 247
    assert(bounds[A.op.axis[0]].extent.value == 3)
    assert(bounds[A.op.axis[1]].extent.value == 3)
248

249
def test_bound_rfactor():
250 251 252 253
    n = te.var('n')
    A = te.placeholder((n,), name='A')
    k = te.reduce_axis((0, n))
    B = te.compute((1,), lambda i: te.sum(A[k], axis=k, where=(i>1)), name='B')
254
    # schedule
255
    s = te.create_schedule(B.op)
256
    kf, ki = s[B].split(k, nparts=4)
257
    BF = s.rfactor(B, kf)
258
    s = s.normalize()
259
    bounds = tvm.te.schedule.InferBound(s)
260

261 262 263
    assert(bounds[BF.op.axis[0]].extent.value == 4)
    assert(bounds[BF.op.axis[1]].extent.value == 1)

264
def test_bound_group_schedule():
265 266 267 268 269 270
    m = te.var("m")
    n = te.var("n")
    x = te.compute((m, n), lambda i, j: tvm.tir.const(1, "float32"), name="x")
    x1 = te.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
    x2 = te.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
    s = te.create_schedule(x2.op)
271 272 273 274
    g = s.create_group(outputs=x1, inputs=x, include_inputs=True)
    g.compute_at(s[x2], x2.op.axis[0])
    assert s[x1].group == g
    assert s[x].group == g
275
    s = s.normalize()
276
    bounds = tvm.te.schedule.InferBound(s)
277 278 279 280
    assert bounds[x.op.axis[0]].extent.value == 1
    assert bounds[x.op.axis[1]].extent == n

def test_bound_nest_group():
281 282 283 284 285 286
    m = te.var("m")
    n = te.var("n")
    x = te.compute((m, n), lambda i, j: tvm.tir.const(1, "float32"), name="x")
    x1 = te.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
    x2 = te.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
    s = te.create_schedule(x2.op)
287 288 289 290 291 292
    g1 = s.create_group(outputs=x, inputs=x, include_inputs=True)
    g2 = s.create_group(outputs=x1, inputs=x, include_inputs=True)
    assert s[x].group == g1
    assert s[x1].group == g2
    g2.compute_at(s[x2], x2.op.axis[0])
    g1.compute_at(s[x1], s[x1].op.axis[1])
293
    s = s.normalize()
294
    bounds = tvm.te.schedule.InferBound(s)
295 296 297 298
    assert bounds[x.op.axis[0]].extent.value == 1
    assert bounds[x.op.axis[1]].extent.value == 1
    assert bounds[x1.op.axis[0]].extent.value == 1
    assert bounds[x1.op.axis[1]].extent == n
299

300 301

def test_bound_nest_thread():
302 303 304 305 306
    m = te.var('m')
    A = te.placeholder((m), name='A')
    A1 = te.compute((m,), lambda i: A[i], name='A1')
    A2 = te.compute((m,), lambda i: A1[i] + 2, name='A2')
    A3 = te.compute((m,), lambda i: A2[i] + 3, name='A3')
307

308
    s = te.create_schedule(A3.op)
309 310 311
    s[A2].set_scope("shared")
    s[A1].set_scope("local")

312 313
    block_x = te.thread_axis("blockIdx.x")
    thread_x = te.thread_axis("threadIdx.x")
314 315 316 317 318 319 320
    bx, tx = s[A3].split(A3.op.axis[0], factor=32)
    s[A3].bind(bx, block_x)
    s[A3].bind(tx, thread_x)
    s[A2].compute_at(s[A3], tx)
    _, xi = s[A2].split(A2.op.axis[0], nparts=1)
    s[A2].bind(xi, thread_x)
    s[A1].compute_at(s[A3], tx)
321
    s = s.normalize()
322
    bounds = tvm.te.schedule.InferBound(s)
323 324 325 326
    assert(bounds[A1.op.axis[0]].extent.value==1)
    assert(bounds[A2.op.axis[0]].extent.value==32)
    assert(bounds[A3.op.axis[0]].extent == m)

327 328
def test_gemm_bound():
    nn = 1024
329 330 331 332 333
    n = tvm.runtime.convert(nn)
    A = te.placeholder((n, n), name='A')
    B = te.placeholder((n, n), name='B')
    k = te.reduce_axis((0, n), name='k')
    C = te.compute(
334
        (n, n),
335
        lambda ii, jj: te.sum(A[ii, k] * B[jj, k], axis=k),
336 337
        name='CC')
    # schedule
338
    s = te.create_schedule(C.op)
339 340 341 342
    xtile, ytile = 32, 32
    scale = 8
    num_thread = 8
    block_factor = scale * num_thread
343 344 345 346
    block_x = te.thread_axis("blockIdx.x")
    thread_x = te.thread_axis("threadIdx.x")
    block_y = te.thread_axis("blockIdx.y")
    thread_y = te.thread_axis("threadIdx.y")
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376

    CC = s.cache_write(C, "local")
    AA = s.cache_read(A, "shared", [CC])
    BB = s.cache_read(B, "shared", [CC])
    by, yi = s[C].split(C.op.axis[0], factor=block_factor)
    bx, xi = s[C].split(C.op.axis[1], factor=block_factor)
    s[C].reorder(by, bx, yi, xi)
    s[C].bind(by, block_y)
    s[C].bind(bx, block_x)
    ty, yi = s[C].split(yi, nparts=num_thread)
    tx, xi = s[C].split(xi, nparts=num_thread)
    s[C].reorder(ty, tx, yi, xi)
    s[C].bind(ty, thread_y)
    s[C].bind(tx, thread_x)
    yo, xo = CC.op.axis
    s[CC].reorder(k, yo, xo)

    s[CC].compute_at(s[C], tx)
    s[AA].compute_at(s[CC], k)
    s[BB].compute_at(s[CC], k)

    ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread)
    tx, xi = s[AA].split(xi, nparts=num_thread)
    s[AA].bind(ty, thread_y)
    s[AA].bind(tx, thread_x)

    ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread)
    tx, xi = s[BB].split(xi, nparts=num_thread)
    s[BB].bind(ty, thread_y)
    s[BB].bind(tx, thread_x)
377
    s = s.normalize()
378
    bounds = tvm.te.schedule.InferBound(s)
379 380 381 382 383
    assert(bounds[BB.op.axis[0]].extent.value==64)
    assert(bounds[AA.op.axis[0]].extent.value==64)
    assert(bounds[CC.op.axis[0]].extent.value == 8)
    assert(bounds[CC.op.axis[1]].extent.value == 8)

384

385 386
def test_bound_tensor_compute_op():
    def intrin_test():
387 388 389 390
      m1 = te.var("m1")
      n1 = te.var("n1")
      a = te.placeholder((m1, n1), name='a')
      c = te.compute((1, n1), lambda i, j : a[0, j] + a[1, j] + a[2, j], name='c')
391

392 393
      Ab = tvm.tir.decl_buffer(a.shape, name="Abuf", offset_factor=1)
      Cb = tvm.tir.decl_buffer(c.shape, name="Cbuf", offset_factor=1)
394 395 396 397 398

      def intrin_func(ins, outs):
        aa = ins[0]
        cc = outs[0]
        def _body():
399 400
          ib = tvm.tir.ir_builder.create()
          ib.emit(tvm.tir.call_extern("int32", "test", cc.access_ptr("w"), aa.access_ptr("r")))
401 402
          return ib.get()
        return _body()
403 404
      with tvm.target.build_config(offset_factor=1):
        return te.decl_tensor_intrin(c.op, intrin_func, binds={a : Ab, c : Cb})
405 406

    test_func = intrin_test()
407 408 409 410 411
    A = te.placeholder((20,20), name='A')
    B = te.compute(A.shape, lambda i,j : A[i,j], name='B')
    C = te.compute((10, 20), lambda i : test_func(B[i:10, 0:20]), name='C')
    s = te.create_schedule(C.op)
    bounds = tvm.te.schedule.InferBound(s)
412 413 414
    assert isinstance(bounds, tvm.container.Map)
    assert(bounds[B.op.axis[0]].extent.value == 10)

415 416
def test_bound_simplification_failure():
    # Check that the bounds are not expanded
417
    A = te.compute((2,), lambda j: j, "A")
418 419

    def _check(B, A=A):
420
        s = te.create_schedule(B.op)
421
        s = s.normalize()
422
        bounds = tvm.te.schedule.InferBound(s)
423 424 425 426
        stmt = tvm.lower(s, [B, A], simple_mode=True)
        if not bounds[A.op.axis[0]].extent.value <= 2:
            print(stmt)
            assert bounds[A.op.axis[0]].extent.value <= 2
427
    tdiv = tvm.tir.truncdiv
428
    # These are hard to simplify, moreover we don't simplify them
429 430 431 432
    _check(te.compute((10,), lambda i: A[tvm.te.min(3*i, 4*i) + tvm.te.min(-3*i, -2*i)]))
    _check(te.compute((10,), lambda i: A[tvm.te.min(3*i, 4*i) + tvm.te.max(-3*i, -4*i)]))
    _check(te.compute((10,), lambda i: A[-2*tdiv(i,2) - tvm.te.min(i, 0-i)]))
    _check(te.compute((10,), lambda i: A[i + (0 - i)]))
433
    # This would cause out of bounds, but we nevertheless include it
434
    _check(te.compute((10,), lambda i: A[i]))
435

436
if __name__ == "__main__":
437 438
    test_bound_nest_thread()
    test_bound1()
439 440 441 442
    test_bound_nest_group()
    test_bound_group_schedule()
    test_bound_scan()
    test_bound3()
443
    test_bound_rfactor()
444 445
    test_bound_blur()
    test_bound_conv1d()
446
    test_bound2()
447
    test_gemm_bound()
448
    test_bound_warp()
449
    test_bound_tensor_compute_op()
450
    test_bound_simplification_failure()
451 452
    test_bound_fusesplit1()
    test_bound_fusesplit2()
453 454
    test_bound_split_divisible()
    test_bound_tile_divisible()