test_pass_remove_no_op.py 839 Bytes
Newer Older
Tianqi Chen committed
1 2 3
import tvm

def test_remove_no_op():
4 5 6 7 8
    i = tvm.var('i')
    j = tvm.var('j')
    k = tvm.var('k')
    m = tvm.var('m')
    n = tvm.var('n')
Tianqi Chen committed
9
    dtype = 'int64'
10
    Ab = tvm.decl_buffer((n, ), dtype)
Tianqi Chen committed
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
    stmt = tvm.make.For(
        i, 0, 4, 0, 0,
        tvm.make.For(
            j, 0, n, 0, 0,
            tvm.make.For(
                k, 0, m, 0, 0,
                tvm.make.IfThenElse(
                    (i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)))))
    ret = tvm.ir_pass.RemoveNoOp(stmt)
    assert(isinstance(ret, tvm.stmt.Evaluate))
    store = tvm.make.Store(Ab.data,
                           tvm.make.Load(dtype, Ab.data, i) + 1,
                           i + 1)
    stmt2 = tvm.make.Block(stmt, store)
    assert(tvm.ir_pass.RemoveNoOp(stmt2) == store)


if __name__ == "__main__":
    test_remove_no_op()