test_pass_dead_code_elimination.py 2.68 KB
Newer Older
1 2 3
import tvm
from tvm import relay
from tvm.relay.ir_pass import dead_code_elimination, alpha_equal
4
from tvm.relay.op import log, add, equal, subtract
5

6

7 8 9 10 11 12 13 14 15 16 17 18 19 20
class env:
    def __init__(self):
        self.a = relay.Var("a")
        self.b = relay.Var("b")
        self.c = relay.Var("c")
        self.d = relay.Var("d")
        self.e = relay.Var("e")
        self.x = relay.Var("x")
        self.y = relay.Var("y")
        self.z = relay.Var("z")
        self.shape = tvm.convert([1, 2, 3])
        self.tt = relay.TensorType(self.shape, "float32")
        self.int32 = relay.TensorType([], "int32")
        self.float32 = relay.TensorType([], "float32")
21 22 23
        self.one = relay.const(1.0)
        self.two = relay.const(2.0)
        self.three = relay.const(3.0)
24

25

26 27
e = env()

28

29
def test_let():
30
    orig = relay.Let(e.x, e.y, e.z)
31 32
    assert alpha_equal(dead_code_elimination(orig), e.z)

33

34
def test_used_let():
35 36
    orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
    assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c))
37

38

39
def test_chain_unused_let():
40
    orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
41 42
    assert alpha_equal(dead_code_elimination(orig), e.e)

43

44 45 46 47 48 49 50 51 52 53 54 55 56 57
# make sure we dont infinite loop
def test_recursion():
    """
    Program:
       let f(n: i32, data: f32) -> f32 = {
          if (n == 0) {
              return data;
          } else {
              return f(n - 1, log(data));
          }
       }
       f(2, 10000);
    """
    f = relay.Var("f")
58 59
    n = relay.Var("n", e.int32)
    data = relay.Var("data", e.float32)
60 61 62 63
    funcbody = relay.If(equal(n, relay.const(0)),
                        data,
                        relay.Call(f, [subtract(n, relay.const(1.0)),
                                       log(data)]))
64
    value = relay.Function([n, data], funcbody, e.float32, [])
65
    orig = relay.Let(f, value, relay.Call(f, [relay.const(2.0), relay.const(10000.0)]))
66
    assert alpha_equal(dead_code_elimination(orig), orig)
67
    assert alpha_equal(dead_code_elimination(relay.Let(f, value, e.three)), e.three)
68

69

70
def test_op_let():
71
    assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two))
72

73

74
def test_if():
75 76 77 78
    cond = relay.const(True)
    orig = relay.If(cond, e.a, e.b)
    y = dead_code_elimination(orig)
    assert alpha_equal(y, e.a)
79 80


81 82 83 84
def test_tuple_get_item():
    t = relay.Var('t')
    g = relay.TupleGetItem(t, 0)
    assert alpha_equal(dead_code_elimination(g), g)
85
    assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t), 0)), g)
86 87


88
if __name__ == "__main__":
89
    test_if()
90 91 92 93 94
    test_let()
    test_used_let()
    test_chain_unused_let()
    test_recursion()
    test_op_let()
95
    test_tuple_get_item()