test_pass_vectorize.py 1.96 KB
Newer Older
1 2 3 4
import tvm

def test_vectorize_loop():
    dtype = 'int64'
5
    n = tvm.var('n')
6 7 8 9
    ib = tvm.ir_builder.create()
    A = ib.pointer("float32", name="A")
    with ib.for_range(0, n) as i:
        with ib.for_range(0, 4, for_type="vectorize") as j:
10
            A[j] = tvm.const(1, A.dtype)
11
    stmt = ib.get()
12

13 14 15 16
    assert isinstance(stmt.body, tvm.stmt.For)
    stmt = tvm.ir_pass.VectorizeLoop(stmt)
    assert isinstance(stmt, tvm.stmt.For)
    assert not isinstance(stmt.body, tvm.stmt.For)
17 18 19
    assert isinstance(stmt.body.index, tvm.expr.Ramp)
    assert isinstance(stmt.body.value, tvm.expr.Broadcast)

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
def test_vectorize_vector():
    dtype = 'int64'
    n = tvm.var('n')
    ib = tvm.ir_builder.create()
    A = ib.pointer("float32x4", name="A")
    with ib.for_range(0, n) as i:
        with ib.for_range(0, 4, for_type="vectorize") as j:
            A[j] = tvm.const(1, A.dtype)
    stmt = ib.get()
    assert isinstance(stmt.body, tvm.stmt.For)
    stmt = tvm.ir_pass.VectorizeLoop(stmt)
    assert isinstance(stmt, tvm.stmt.For)
    assert not isinstance(stmt.body, tvm.stmt.For)
    assert isinstance(stmt.body.index, tvm.expr.Ramp)
    assert isinstance(stmt.body.value, tvm.expr.Broadcast)


37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
def test_vectorize_with_if():
    n = tvm.var('n')
    x = tvm.var('x')
    ib = tvm.ir_builder.create()
    A = ib.pointer("float32", name="A")
    with ib.for_range(0, 4, for_type="vectorize") as i:
        with ib.if_scope(x < n):
            A[i] = A[i] + 1
        with ib.else_scope():
            with ib.if_scope(i < n):
                A[i] = 2.0
    stmt = ib.get()
    stmt = tvm.ir_pass.VectorizeLoop(stmt)
    assert isinstance(stmt, tvm.stmt.IfThenElse)
    assert isinstance(stmt.then_case.index, tvm.expr.Ramp)
    assert isinstance(stmt.then_case.value, tvm.expr.Add)
    assert stmt.then_case.value.dtype == "float32x4"
    assert isinstance(stmt.else_case, tvm.stmt.For)
55 56

if __name__ == "__main__":
57
    test_vectorize_vector()
58
    test_vectorize_with_if()
59
    test_vectorize_loop()