import tvm def test_remove_no_op(): i = tvm.var('i') j = tvm.var('j') k = tvm.var('k') m = tvm.var('m') n = tvm.var('n') dtype = 'int64' Ab = tvm.decl_buffer((n, ), dtype) stmt = tvm.make.For( i, 0, 4, 0, 0, tvm.make.For( j, 0, n, 0, 0, tvm.make.For( k, 0, m, 0, 0, tvm.make.IfThenElse( (i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n))))) ret = tvm.ir_pass.RemoveNoOp(stmt) assert(isinstance(ret, tvm.stmt.Evaluate)) store = tvm.make.Store(Ab.data, tvm.make.Load(dtype, Ab.data, i) + 1, i + 1) stmt2 = tvm.make.Block(stmt, store) assert(tvm.ir_pass.RemoveNoOp(stmt2) == store) if __name__ == "__main__": test_remove_no_op()