"""Test group effect""" import tvm def test_scan_group(): m = tvm.var("m") n = tvm.var("n") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: x[0, i]) s_update1 = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i]) s_update2 = tvm.compute((m, n), lambda t, i: s_update1[t, i] + 1) s_update3 = tvm.compute((m, n), lambda t, i: s_update2[t, i] + 1) res = tvm.scan(s_init, s_update3, s_state, inputs=x) s = tvm.create_schedule(res.op) assert s[s_update1].group is not None assert s[s_update2].group == s[s_update1].group # Assign within group, is valid s[s_update1].compute_at(s[s_update2], s_update2.op.axis[1]) # create a new group, for [s_update2 and s_update1] g2 = s.create_group(outputs=s_update2, inputs=[s_state, x]) assert g2.group is not None assert g2.group == s[s_update3].group assert s[s_update2].group == g2 assert s[s_update1].group == g2 g2.compute_at(s[s_update3], s_update3.op.axis[1]) assert g2.attach_stage == s[s_update3] try: # compute outside group error. s[s_update2].compute_at(s[s_init], s_init.op.axis[0]) assert False except tvm.TVMError: pass def test_compute_group(): m = tvm.var("m") n = tvm.var("n") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1") x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2") s = tvm.create_schedule(x2.op) g = s.create_group(outputs=x1, inputs=x, include_inputs=True) assert s[x1].group == g assert s[x].group == g g.compute_at(s[x2], x2.op.axis[1]) assert g.attach_stage == s[x2] assert g.num_child_stages == 2 def test_nest_group(): m = tvm.var("m") n = tvm.var("n") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1") x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2") s = tvm.create_schedule(x2.op) g1 = s.create_group(outputs=x1, inputs=x) g2 = s.create_group(outputs=x1, inputs=x, include_inputs=True) assert set(s.groups) == set([g1, g2]) assert s[x].group == g2 assert s[x1].group == g1 assert g1.group == g2 assert g2.num_child_stages == 2 assert g1.num_child_stages == 1 if __name__ == "__main__": test_nest_group() test_compute_group() test_scan_group()