test_pass_rewrite_unsafe_select.py 712 Bytes
Newer Older
1 2 3
import tvm


4
def test_rewrite_Select():
5 6 7
    ib = tvm.ir_builder.create()
    A = ib.allocate("float32", 100, name="A", scope="global")
    i = tvm.var("i")
8
    y = tvm.expr.Select(i > 1, A[i-1], 1.0)
9 10
    yy = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(y)).value

11 12
    z = tvm.expr.Select(
        tvm.expr.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
13 14
    zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value

15
    a = tvm.expr.Select(i>10, y, z)
16 17 18 19 20 21 22
    aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value
    assert yy.name == "tvm_if_then_else"
    assert zz.name == "tvm_if_then_else"
    assert isinstance(aa, tvm.expr.Select)


if __name__ == "__main__":
23
    test_rewrite_Select()