test_pass_vectorize.py 3.01 KB
Newer Older
1 2 3 4
import tvm

def test_vectorize_loop():
    dtype = 'int64'
5
    n = tvm.var('n')
6 7 8 9
    ib = tvm.ir_builder.create()
    A = ib.pointer("float32", name="A")
    with ib.for_range(0, n) as i:
        with ib.for_range(0, 4, for_type="vectorize") as j:
10
            A[j] = tvm.const(1, A.dtype)
11
    stmt = ib.get()
12

13 14 15 16
    assert isinstance(stmt.body, tvm.stmt.For)
    stmt = tvm.ir_pass.VectorizeLoop(stmt)
    assert isinstance(stmt, tvm.stmt.For)
    assert not isinstance(stmt.body, tvm.stmt.For)
17 18 19
    assert isinstance(stmt.body.index, tvm.expr.Ramp)
    assert isinstance(stmt.body.value, tvm.expr.Broadcast)

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
def test_vectorize_vector():
    dtype = 'int64'
    n = tvm.var('n')
    ib = tvm.ir_builder.create()
    A = ib.pointer("float32x4", name="A")
    with ib.for_range(0, n) as i:
        with ib.for_range(0, 4, for_type="vectorize") as j:
            A[j] = tvm.const(1, A.dtype)
    stmt = ib.get()
    assert isinstance(stmt.body, tvm.stmt.For)
    stmt = tvm.ir_pass.VectorizeLoop(stmt)
    assert isinstance(stmt, tvm.stmt.For)
    assert not isinstance(stmt.body, tvm.stmt.For)
    assert isinstance(stmt.body.index, tvm.expr.Ramp)
    assert isinstance(stmt.body.value, tvm.expr.Broadcast)


37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
def test_vectorize_with_if():
    n = tvm.var('n')
    x = tvm.var('x')
    ib = tvm.ir_builder.create()
    A = ib.pointer("float32", name="A")
    with ib.for_range(0, 4, for_type="vectorize") as i:
        with ib.if_scope(x < n):
            A[i] = A[i] + 1
        with ib.else_scope():
            with ib.if_scope(i < n):
                A[i] = 2.0
    stmt = ib.get()
    stmt = tvm.ir_pass.VectorizeLoop(stmt)
    assert isinstance(stmt, tvm.stmt.IfThenElse)
    assert isinstance(stmt.then_case.index, tvm.expr.Ramp)
    assert isinstance(stmt.then_case.value, tvm.expr.Add)
    assert stmt.then_case.value.dtype == "float32x4"
    assert isinstance(stmt.else_case, tvm.stmt.For)
55

56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
def test_vectorize_if_then_else():
    n = tvm.var('n')
    x = tvm.var('x')
    ib = tvm.ir_builder.create()
    A = ib.pointer("float32", name="A")
    with ib.for_range(0, 4, for_type="vectorize") as i:
        A[i] = tvm.call_intrin("float32", "tvm_if_then_else",
                               i > 0,
                               A[i] + 1, A[i])
    stmt = ib.get()
    stmt = tvm.ir_pass.VectorizeLoop(stmt)
    assert isinstance(stmt, tvm.stmt.For)


    ib = tvm.ir_builder.create()
    A = ib.pointer("float32", name="A")
    with ib.for_range(0, n) as k:
        with ib.for_range(0, 4, for_type="vectorize") as i:
            A[k * 4 + i] = tvm.call_intrin("float32", "tvm_if_then_else",
                                           k > 0,
                                           A[k * 4 + i], 0)
    stmt = ib.get()
    assert isinstance(stmt.body, tvm.stmt.For)
    stmt = tvm.ir_pass.VectorizeLoop(stmt)
    assert not isinstance(stmt.body, tvm.stmt.For)
    assert isinstance(stmt.body.value.args[2], tvm.expr.Broadcast)


84
if __name__ == "__main__":
85
    test_vectorize_vector()
86
    test_vectorize_with_if()
87
    test_vectorize_loop()
88
    test_vectorize_if_then_else()