test_lang_group.py 2.5 KB
Newer Older
1 2 3 4
"""Test group effect"""
import tvm

def test_scan_group():
5 6
    m = tvm.var("m")
    n = tvm.var("n")
7 8 9 10 11 12 13 14 15
    x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
    s_state = tvm.placeholder((m, n))
    s_init = tvm.compute((1, n), lambda _, i: x[0, i])

    s_update1 = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i])
    s_update2 = tvm.compute((m, n), lambda t, i: s_update1[t, i] + 1)
    s_update3 = tvm.compute((m, n), lambda t, i: s_update2[t, i] + 1)
    res = tvm.scan(s_init, s_update3, s_state, inputs=x)

16
    s = tvm.create_schedule(res.op)
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
    assert s[s_update1].group is not None
    assert s[s_update2].group == s[s_update1].group
    # Assign within group, is valid
    s[s_update1].compute_at(s[s_update2], s_update2.op.axis[1])
    # create a new group, for [s_update2 and s_update1]
    g2 = s.create_group(outputs=s_update2, inputs=[s_state, x])
    assert g2.group is not None
    assert g2.group == s[s_update3].group
    assert s[s_update2].group == g2
    assert s[s_update1].group == g2
    g2.compute_at(s[s_update3], s_update3.op.axis[1])
    assert g2.attach_stage == s[s_update3]
    try:
        # compute outside group error.
        s[s_update2].compute_at(s[s_init], s_init.op.axis[0])
        assert False
    except tvm.TVMError:
        pass

def test_compute_group():
37 38
    m = tvm.var("m")
    n = tvm.var("n")
39 40 41
    x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
    x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
    x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
42
    s = tvm.create_schedule(x2.op)
43 44 45 46 47 48 49 50
    g = s.create_group(outputs=x1, inputs=x, include_inputs=True)
    assert s[x1].group == g
    assert s[x].group == g
    g.compute_at(s[x2], x2.op.axis[1])
    assert g.attach_stage == s[x2]
    assert g.num_child_stages == 2

def test_nest_group():
51 52
    m = tvm.var("m")
    n = tvm.var("n")
53 54 55
    x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
    x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1")
    x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2")
56
    s = tvm.create_schedule(x2.op)
57 58 59 60 61 62 63 64 65 66 67 68 69
    g1 = s.create_group(outputs=x1, inputs=x)
    g2 = s.create_group(outputs=x1, inputs=x, include_inputs=True)
    assert set(s.groups) == set([g1, g2])
    assert s[x].group == g2
    assert s[x1].group == g1
    assert g1.group == g2
    assert g2.num_child_stages == 2
    assert g1.num_child_stages == 1

if __name__ == "__main__":
    test_nest_group()
    test_compute_group()
    test_scan_group()