test_schedule_bound_inference.py 9.26 KB
Newer Older
1 2
import tvm

3
def test_bound1():
4 5
    m = tvm.var('m')
    l = tvm.var('l')
6
    A = tvm.placeholder((m, l), name='A')
7 8
    A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
    A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
9

10
    s = tvm.create_schedule([A2.op])
11 12 13
    xo, xi = s[A2].split(s[A2].op.axis[0], 8)
    s[A1].compute_at(s[A2], xo)
    bounds = tvm.schedule.InferBound(s)
14
    assert isinstance(bounds, tvm.container.Map)
15
    assert(bounds[A1.op.axis[0]].extent.value == 8)
16 17

def test_bound2():
18 19
    m = tvm.var('m')
    l = tvm.var('l')
20 21 22
    A = tvm.placeholder((m, l), name='A')
    A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
    A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
23
    s = tvm.create_schedule(A2.op)
24
    xo, yo, xi, yi = s[A2].tile(A2.op.axis[0], A2.op.axis[1], 8, 8)
25 26
    # test normalize not affecting schedule
    _ = s.normalize()
27 28
    s[A1].compute_at(s[A2], yo)
    bounds = tvm.schedule.InferBound(s)
29
    assert isinstance(bounds, tvm.container.Map)
30 31
    assert(bounds[A1.op.axis[0]].extent.value == 8)
    assert(bounds[A1.op.axis[1]].extent.value == 8)
32 33

def test_bound3():
34 35
    m = tvm.var('m')
    l = tvm.var('l')
36 37 38
    A = tvm.placeholder((m, l), name='A')
    A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
    A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
39

40
    s = tvm.create_schedule(A2.op)
41 42
    s[A1].set_scope("shared")
    xo, xi = s[A2].split(A2.op.axis[0], 32)
43 44
    xi0, xi1 = s[A2].split(xi, nparts=16)
    s[A2].bind(xi0, tvm.thread_axis("threadIdx.x"))
45
    yo, yi = s[A2].split(A2.op.axis[1], 16)
46 47
    # test normalize not affecting schedule
    _ = s.normalize()
48 49
    s[A2].reorder(xo, xi0, yo, xi1, yi)
    s[A1].compute_at(s[A2], yo)
50

51
    bounds = tvm.schedule.InferBound(s)
52
    assert isinstance(bounds, tvm.container.Map)
53 54
    assert(bounds[A1.op.axis[0]].extent.value==32)
    assert(bounds[A1.op.axis[1]].extent.value==16)
55

56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78

def test_bound_warp():
    m = tvm.var('m')
    l = tvm.var('l')
    A = tvm.placeholder((m, l), name='A')
    A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
    A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')

    s = tvm.create_schedule(A2.op)
    s[A1].set_scope("warp")
    xo, xi = s[A2].split(A2.op.axis[0], 32)
    xi0, xi1 = s[A2].split(xi, factor=16)
    tx = tvm.thread_axis("threadIdx.x")
    s[A2].bind(xi1, tx)
    s[A2].bind(xi0, tvm.thread_axis("threadIdx.y"))
    y = s[A2].op.axis[1]
    s[A1].compute_at(s[A2], y)
    xo, xi = s[A1].split(s[A1].op.axis[0], factor=16)
    s[A1].bind(xi, tx)
    bounds = tvm.schedule.InferBound(s)
    assert isinstance(bounds, tvm.container.Map)
    assert(bounds[A1.op.axis[0]].extent.value==16)

79
def test_bound_scan():
80 81
    m = tvm.var("m")
    n = tvm.var("n")
82 83 84 85 86
    X = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
    s_state = tvm.placeholder((m, n))
    s_init = tvm.compute((1, n), lambda _, i: X[0, i])
    s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
    s_scan = tvm.scan(s_init, s_update, s_state)
87

88
    assert tuple(s_scan.shape) == (m, n)
89
    s = tvm.create_schedule(s_scan.op)
90 91 92
    XX = s.cache_read(X, "local", s_update)
    xo, xi = s[s_update].split(s_update.op.axis[1], factor=4)
    s[XX].compute_at(s[s_update], xo)
93
    s = s.normalize()
94 95 96
    bounds = tvm.schedule.InferBound(s)
    stmt = tvm.schedule.ScheduleOps(s, bounds)
    assert bounds[XX.op.axis[1]].extent.value == 4
97

98
def test_bound_conv1d():
99
    n = tvm.var('n')
100 101 102 103 104
    A = tvm.compute((n+2), lambda i: 1,  name='A')
    def computeB(ii):
        i = ii + 1
        return A[i-1] + A[i] + A[i+1]
    B = tvm.compute(n, computeB, name='B')
105
    s = tvm.create_schedule(B.op)
106
    s[A].compute_at(s[B], B.op.axis[0])
107
    s = s.normalize()
108 109 110 111 112 113 114 115 116 117 118 119
    bounds = tvm.schedule.InferBound(s)
    assert(bounds[A.op.axis[0]].extent.value == 3)

def test_bound_blur():
    n = tvm.convert(12)
    A = tvm.compute((n, n), lambda i, j: 1, name='A')
    def computeB(ii, jj):
        # set the correct center
        i = ii + 1
        j = jj + 1
        return A[i][j] + A[i-1][j] + A[i+1][j] + A[i][j+1] + A[i][j-1]
    B = tvm.compute((n-2, n-2), computeB, name='B')
120
    s = tvm.create_schedule(B.op)
121
    s[A].compute_at(s[B], B.op.axis[1])
122
    s = s.normalize()
123 124 125
    bounds = tvm.schedule.InferBound(s)
    assert(bounds[A.op.axis[0]].extent.value == 3)
    assert(bounds[A.op.axis[1]].extent.value == 3)
126

127
def test_bound_rfactor():
128
    n = tvm.var('n')
129 130 131 132
    A = tvm.placeholder((n,), name='A')
    k = tvm.reduce_axis((0, n))
    B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k, where=(i>1)), name='B')
    # schedule
133
    s = tvm.create_schedule(B.op)
134
    kf, ki = s[B].split(k, nparts=4)
135
    BF = s.rfactor(B, kf)
136
    s = s.normalize()
137
    bounds = tvm.schedule.InferBound(s)
138

139 140 141
    assert(bounds[BF.op.axis[0]].extent.value == 4)
    assert(bounds[BF.op.axis[1]].extent.value == 1)

142
def test_bound_group_schedule():
143 144
    m = tvm.var("m")
    n = tvm.var("n")
145 146 147
    x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
    x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
    x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
148
    s = tvm.create_schedule(x2.op)
149 150 151 152
    g = s.create_group(outputs=x1, inputs=x, include_inputs=True)
    g.compute_at(s[x2], x2.op.axis[0])
    assert s[x1].group == g
    assert s[x].group == g
153
    s = s.normalize()
154 155 156 157 158
    bounds = tvm.schedule.InferBound(s)
    assert bounds[x.op.axis[0]].extent.value == 1
    assert bounds[x.op.axis[1]].extent == n

def test_bound_nest_group():
159 160
    m = tvm.var("m")
    n = tvm.var("n")
161 162 163
    x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
    x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
    x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
164
    s = tvm.create_schedule(x2.op)
165 166 167 168 169 170
    g1 = s.create_group(outputs=x, inputs=x, include_inputs=True)
    g2 = s.create_group(outputs=x1, inputs=x, include_inputs=True)
    assert s[x].group == g1
    assert s[x1].group == g2
    g2.compute_at(s[x2], x2.op.axis[0])
    g1.compute_at(s[x1], s[x1].op.axis[1])
171
    s = s.normalize()
172 173 174 175 176
    bounds = tvm.schedule.InferBound(s)
    assert bounds[x.op.axis[0]].extent.value == 1
    assert bounds[x.op.axis[1]].extent.value == 1
    assert bounds[x1.op.axis[0]].extent.value == 1
    assert bounds[x1.op.axis[1]].extent == n
177

178 179

def test_bound_nest_thread():
180
    m = tvm.var('m')
181 182 183 184 185
    A = tvm.placeholder((m), name='A')
    A1 = tvm.compute((m,), lambda i: A[i], name='A1')
    A2 = tvm.compute((m,), lambda i: A1[i] + 2, name='A2')
    A3 = tvm.compute((m,), lambda i: A2[i] + 3, name='A3')

186
    s = tvm.create_schedule(A3.op)
187 188 189 190 191 192 193 194 195 196 197 198
    s[A2].set_scope("shared")
    s[A1].set_scope("local")

    block_x = tvm.thread_axis("blockIdx.x")
    thread_x = tvm.thread_axis("threadIdx.x")
    bx, tx = s[A3].split(A3.op.axis[0], factor=32)
    s[A3].bind(bx, block_x)
    s[A3].bind(tx, thread_x)
    s[A2].compute_at(s[A3], tx)
    _, xi = s[A2].split(A2.op.axis[0], nparts=1)
    s[A2].bind(xi, thread_x)
    s[A1].compute_at(s[A3], tx)
199
    s = s.normalize()
200 201 202 203 204
    bounds = tvm.schedule.InferBound(s)
    assert(bounds[A1.op.axis[0]].extent.value==1)
    assert(bounds[A2.op.axis[0]].extent.value==32)
    assert(bounds[A3.op.axis[0]].extent == m)

205 206 207 208 209 210 211 212 213 214 215
def test_gemm_bound():
    nn = 1024
    n = tvm.convert(nn)
    A = tvm.placeholder((n, n), name='A')
    B = tvm.placeholder((n, n), name='B')
    k = tvm.reduce_axis((0, n), name='k')
    C = tvm.compute(
        (n, n),
        lambda ii, jj: tvm.sum(A[ii, k] * B[jj, k], axis=k),
        name='CC')
    # schedule
216
    s = tvm.create_schedule(C.op)
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
    xtile, ytile = 32, 32
    scale = 8
    num_thread = 8
    block_factor = scale * num_thread
    block_x = tvm.thread_axis("blockIdx.x")
    thread_x = tvm.thread_axis("threadIdx.x")
    block_y = tvm.thread_axis("blockIdx.y")
    thread_y = tvm.thread_axis("threadIdx.y")

    CC = s.cache_write(C, "local")
    AA = s.cache_read(A, "shared", [CC])
    BB = s.cache_read(B, "shared", [CC])
    by, yi = s[C].split(C.op.axis[0], factor=block_factor)
    bx, xi = s[C].split(C.op.axis[1], factor=block_factor)
    s[C].reorder(by, bx, yi, xi)
    s[C].bind(by, block_y)
    s[C].bind(bx, block_x)
    ty, yi = s[C].split(yi, nparts=num_thread)
    tx, xi = s[C].split(xi, nparts=num_thread)
    s[C].reorder(ty, tx, yi, xi)
    s[C].bind(ty, thread_y)
    s[C].bind(tx, thread_x)
    yo, xo = CC.op.axis
    s[CC].reorder(k, yo, xo)

    s[CC].compute_at(s[C], tx)
    s[AA].compute_at(s[CC], k)
    s[BB].compute_at(s[CC], k)

    ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread)
    tx, xi = s[AA].split(xi, nparts=num_thread)
    s[AA].bind(ty, thread_y)
    s[AA].bind(tx, thread_x)

    ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread)
    tx, xi = s[BB].split(xi, nparts=num_thread)
    s[BB].bind(ty, thread_y)
    s[BB].bind(tx, thread_x)
255
    s = s.normalize()
256 257 258 259 260 261
    bounds = tvm.schedule.InferBound(s)
    assert(bounds[BB.op.axis[0]].extent.value==64)
    assert(bounds[AA.op.axis[0]].extent.value==64)
    assert(bounds[CC.op.axis[0]].extent.value == 8)
    assert(bounds[CC.op.axis[1]].extent.value == 8)

262

263
if __name__ == "__main__":
264 265
    test_bound_nest_thread()
    test_bound1()
266 267 268 269
    test_bound_nest_group()
    test_bound_group_schedule()
    test_bound_scan()
    test_bound3()
270
    test_bound_rfactor()
271 272
    test_bound_blur()
    test_bound_conv1d()
273
    test_bound2()
274
    test_gemm_bound()
275
    test_bound_warp()