import tvm def test_scan(): m = tvm.var("m") n = tvm.var("n") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: x[0, i], name="s_init") x_trans = tvm.compute((m, n), lambda i, j: x[i, j] + 1, name="x_trans") s_up1 = tvm.compute((m, n), lambda t, i: s_state[t - 1, i] + 1, name="up1") s_update = tvm.compute((m, n), lambda t, i: s_up1[t, i] + x_trans[t, i], name="update") s_scan = tvm.scan(s_init, s_update, s_state) def test_getbody(): body = tvm.schedule.ScanGetBody(s_scan.op) assert set(body) == set([s_scan.op, s_update.op, s_up1.op]) def test_attach_path(): s = tvm.create_schedule(s_scan.op) s[x_trans].compute_at(s[s_update], s_update.op.axis[0]) apath = tvm.schedule.CreateAttachPath(s) assert(tuple(apath[s_update.op]) == tuple([s_scan.op.scan_axis])) assert(tuple(apath[x_trans.op]) == tuple([s_update.op.axis[0], s_scan.op.scan_axis])) def test_fix_pt(): body = tvm.schedule.ScanGetBody(s_scan.op) fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) assert(fxpt[s_scan.spatial_axis_[0]].value != 0) def test_scan_fix_point(): m = tvm.var("m") n = tvm.var("n") l = tvm.var("l") x = tvm.compute((l, m, n), lambda *i: tvm.const(1, "float32"), name="x") s_state = tvm.placeholder((l, m, n)) s_init = tvm.compute((1, m, n), lambda _, i, j: x[0, i, j], name="s_init") def test_scan0(): s_update = tvm.compute((l, m, n), lambda t, i, j: x[t, j, i] + s_state[t-1, i, j], name="update") s_scan = tvm.scan(s_init, s_update, s_state) body = tvm.schedule.ScanGetBody(s_scan.op) fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1) assert(fxpt[s_scan.op.spatial_axis_[1]].value == 1) def test_scan1(): s_update = tvm.compute((l, m, n), lambda t, i, j: x[t, j, i] + s_state[t-1, j, i], name="update") s_scan = tvm.scan(s_init, s_update, s_state) body = tvm.schedule.ScanGetBody(s_scan.op) fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0) assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0) def test_scan3_not_exact_reach(): s_h1 = tvm.compute((l, n, m), lambda t, j, i: s_state[t-1, i, j], name="h1") s_h2 = tvm.compute((l, m, n), lambda t, i, j: s_state[t-1, i, 10] * 2, name="h1") s_update = tvm.compute((l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update") s_scan = tvm.scan(s_init, s_update, s_state) body = tvm.schedule.ScanGetBody(s_scan.op) fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op) assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1) assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0) def test_scan4_reach_other(): s_h1 = tvm.compute((l, n, m), lambda t, j, i: s_state[t-1, j, j], name="h1") s_h2 = tvm.compute((l, m, n), lambda t, i, j: s_state[t-1, i, j] * 2, name="h1") s_update = tvm.compute((l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update") s_scan = tvm.scan(s_init, s_update, s_state) fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op) assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0) assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0) def test_scan5_multi_output(): m = tvm.var("m") n = tvm.var("n") x1 = tvm.placeholder((m, n)) s1 = tvm.placeholder((m, n)) x2 = tvm.placeholder((m, n)) s2 = tvm.placeholder((m, n)) s1_init = tvm.compute((1, n), lambda _, i: x1[0, i]) s2_init = tvm.compute((1, n), lambda _, i: x2[0, i]) s1_update = tvm.compute((m, n), lambda t, i: s1[t-1, i] + x1[t, i]) s2_update = tvm.compute((m, n), lambda t, i: x2[t, i] + s2[t-1,i]) r0, r1 = tvm.scan([s1_init, s2_init], [s1_update, s2_update], [s1, s2]) body = tvm.schedule.ScanGetBody(r0.op) fxpt = tvm.schedule.ScanFixPointAnalysis(r0.op) assert(fxpt[r1.op.spatial_axis_[0]].value == 1) test_scan0() test_scan1() test_scan3_not_exact_reach() test_scan4_reach_other() test_scan5_multi_output() def test_create_read_graph(): m = tvm.var('m') l = tvm.var('l') A = tvm.placeholder((m, l), name='A') A1 = tvm.compute((m, l), lambda i, j: A[i, j]) A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3) g = tvm.schedule.CreateReadGraph([A2.op]) assert g[A2.op][0] == A1 assert g[A1.op][0] == A post_order = tvm.schedule.PostDFSOrder([A2.op], g) assert(post_order[0] == A.op) assert(post_order[1] == A1.op) if __name__ == "__main__": test_scan() test_create_read_graph() test_scan_fix_point()