test_pass_loop_partition.py 6.44 KB
Newer Older
1 2
import tvm

3 4 5 6 7
def collect_visit(stmt, f):
    ret = []
    tvm.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x)))
    return ret

8 9 10 11 12 13 14 15 16 17 18 19 20 21
def lower(sch, args):
    binds = {}
    arg_list = []
    for x in args:
        if isinstance(x, tvm.tensor.Tensor):
            buf = tvm.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
            assert x not in binds
            binds[x] = buf
            arg_list.append(buf)
        else:
            raise ValueError("args must be Tensor, Buffer or Var")
    sch = sch.normalize()
    bounds = tvm.schedule.InferBound(sch)
    stmt = tvm.schedule.ScheduleOps(sch, bounds)
22
    stmt = tvm.ir_pass.LoopPartition(stmt, False)
23
    stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 64)
24 25 26 27 28
    stmt = tvm.ir_pass.CanonicalSimplify(stmt)
    stmt = tvm.ir_pass.VectorizeLoop(stmt)
    stmt = tvm.ir_pass.Simplify(stmt)
    return stmt

29
def test_basic():
30
    n = tvm.var('n')
31 32 33 34
    A = tvm.placeholder((n, ), name='A')
    B = tvm.placeholder((n, ), name='B')

    T = tvm.compute((n, ), lambda i: A[i]+B[i])
35
    s = tvm.create_schedule(T.op)
36 37 38 39
    xo, xi = s[T].split(T.op.axis[0], factor=4)

    bounds = tvm.schedule.InferBound(s)
    stmt = tvm.schedule.ScheduleOps(s, bounds)
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
    stmt = tvm.ir_pass.LoopPartition(stmt, False)
    stmt = tvm.ir_pass.Simplify(stmt)
    assert('if' not in str(stmt.body.body.body.first))

def test_const_loop():
    n = 21
    A = tvm.placeholder((n, ), name='A')
    B = tvm.placeholder((n, ), name='B')

    T = tvm.compute((n, ), lambda i: A[i]+B[i])
    s = tvm.create_schedule(T.op)
    xo, xi = s[T].split(T.op.axis[0], factor=4)

    bounds = tvm.schedule.InferBound(s)
    stmt = tvm.schedule.ScheduleOps(s, bounds)
    stmt = tvm.ir_pass.LoopPartition(stmt, True)
56
    stmt = tvm.ir_pass.Simplify(stmt)
57 58 59
    assert('if' not in str(stmt.body.body.body.first))

def test_multi_loop():
60
    ib = tvm.ir_builder.create()
61 62
    m = tvm.var('m')
    n = tvm.var('n')
63 64 65
    with ib.for_range(0, 4, "i") as i:
        with ib.for_range(0, n, "j") as j:
            with ib.for_range(0, m, "k") as k:
66
                with ib.if_scope(ib.likely(i*m+j+k < n)):
67 68 69 70
                    ib.emit(tvm.make.Evaluate(m))
                with ib.else_scope():
                    ib.emit(tvm.make.Evaluate(n))
    stmt = ib.get()
71
    stmt = tvm.ir_pass.LoopPartition(stmt, False)
72 73
    stmt = tvm.ir_pass.Simplify(stmt)
    assert(not any(collect_visit(stmt.body.first, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
74 75

def test_multi_if():
76
    ib = tvm.ir_builder.create()
77 78
    m = tvm.var('m')
    n = tvm.var('n')
79 80 81 82 83 84 85 86 87 88 89 90
    with ib.for_range(0, 4, 'i') as i:
        with ib.for_range(0, n, 'j') as j:
            with ib.for_range(0, m, 'k') as k:
                with ib.if_scope(ib.likely(i*m+j+k < n)):
                    ib.emit(tvm.make.Evaluate(m))
                with ib.else_scope():
                    ib.emit(tvm.make.Evaluate(n))
                with ib.if_scope(ib.likely(i*m+j-k < n)):
                    ib.emit(tvm.make.Evaluate(m))
                with ib.else_scope():
                    ib.emit(tvm.make.Evaluate(n))
    stmt = ib.get()
91
    stmt = tvm.ir_pass.LoopPartition(stmt, False)
92
    stmt = tvm.ir_pass.Simplify(stmt)
93 94
    assert('if' not in str(stmt.body.first))

95
def test_thread_axis():
96 97
    m = tvm.var('m')
    l = tvm.var('l')
98 99
    A = tvm.placeholder((m, l), name='A')
    B = tvm.compute((m, l), lambda i, j: A[i, j] + 3, name='B')
100
    s = tvm.create_schedule(B.op)
101 102 103 104 105 106 107 108 109

    s[B].set_scope("shared")
    num_thread = 16
    xo, xi = s[B].split(B.op.axis[0], 32)
    xi0, xi1 = s[B].split(xi, nparts=num_thread)
    s[B].bind(xi0, tvm.thread_axis("threadIdx.x"))

    bounds = tvm.schedule.InferBound(s)
    stmt = tvm.schedule.ScheduleOps(s, bounds)
110
    stmt = tvm.ir_pass.LoopPartition(stmt, False)
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
    stmt = tvm.ir_pass.Simplify(stmt)
    assert('if' not in str(stmt.body.body.body.first))

def test_vectorize():
    n = tvm.var('n')
    A = tvm.placeholder((n,), name='A')
    B = tvm.placeholder((n,), name='B')
    bias = tvm.var("bias", dtype="float32")
    scale = tvm.var("scale", dtype="float32")
    C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name='C')
    # schedule
    s = tvm.create_schedule(C.op)
    # create iter var and assign them tags.
    num_thread = 32
    bx, x = s[C].split(C.op.axis[0], factor=num_thread*4)
    tx, x = s[C].split(x, nparts=num_thread)
    _, x = s[C].split(x, factor=4)
    s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
    s[C].vectorize(x)
131
    stmt = lower(s, [A, B])
132 133 134 135 136 137 138 139 140 141 142 143 144
    body = stmt.body.body.body.body.body
    assert(x.var.name not in str(body.condition))
    assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.expr.Ramp))))

def test_select():
    ib = tvm.ir_builder.create()
    m = tvm.var('m')
    n = tvm.var('n')
    with ib.for_range(0, ((n+3)/4), 'i') as i:
      with ib.for_range(0, 4, 'j') as j:
        ib.emit(tvm.make.Evaluate(
          tvm.make.Select(ib.likely(i*4+j<n), m, n)))
    stmt = ib.get()
145
    stmt = tvm.ir_pass.LoopPartition(stmt, False)
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
    stmt = tvm.ir_pass.Simplify(stmt)
    assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))

def test_thread_axis2():
    n = tvm.convert(4096)
    m = tvm.var('m')
    A = tvm.placeholder((n,), name='A')
    B = tvm.placeholder((n,), name='B')
    C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
    s = tvm.create_schedule(C.op)
    num_thread = 32
    bx, x = s[C].split(C.op.axis[0], factor=32)
    tx, x = s[C].split(x, nparts=num_thread)
    _,  x = s[C].split(x, factor=m)
    s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
162
    stmt = lower(s, [A, B])
163 164
    for_body = stmt.body.body.body.body.body.first
    assert('threadIdx' not in str(for_body.extent))
165

166 167 168 169 170 171 172 173 174 175
def test_everything_during_deduction():
    m = tvm.var('m')
    n = tvm.var('n')
    ib = tvm.ir_builder.create()
    with ib.for_range(0, n, 'i') as i:
        with ib.for_range(0, 32, 'j') as j:
            with ib.if_scope(ib.likely(i/j < m)):
                # this guard will produce everything during deduction
                ib.emit(tvm.make.Evaluate(m))
    stmt = ib.get()
176
    stmt = tvm.ir_pass.LoopPartition(stmt, False)
177 178 179
    stmt = tvm.ir_pass.Simplify(stmt)
    assert(isinstance(stmt.body.body, tvm.stmt.IfThenElse))

180
if __name__ == "__main__":
181
    test_basic()
182
    test_const_loop()
183
    test_multi_loop()
184
    test_multi_if()
185
    test_thread_axis()
186 187 188
    test_vectorize()
    test_select()
    test_thread_axis2()
189
    test_everything_during_deduction()