import tvm def test_rewrite_select(): ib = tvm.ir_builder.create() A = ib.allocate("float32", 100, name="A", scope="global") i = tvm.var("i") y = tvm.select(i > 1, A[i-1], 1.0) yy = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(y)).value z = tvm.select(tvm.select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1) zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value a = tvm.select(i>10, y, z) aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value assert yy.name == "tvm_if_then_else" assert zz.name == "tvm_if_then_else" assert isinstance(aa, tvm.expr.Select) if __name__ == "__main__": test_rewrite_select()