# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from nose.tools import nottest import tvm from tvm import relay from tvm.relay.ir_pass import dead_code_elimination, alpha_equal from tvm.relay.op import log, add, equal, subtract class env: def __init__(self): self.a = relay.Var("a") self.b = relay.Var("b") self.c = relay.Var("c") self.d = relay.Var("d") self.e = relay.Var("e") self.x = relay.Var("x") self.y = relay.Var("y") self.z = relay.Var("z") self.shape = tvm.convert([1, 2, 3]) self.tt = relay.TensorType(self.shape, "float32") self.int32 = relay.TensorType([], "int32") self.float32 = relay.TensorType([], "float32") self.one = relay.const(1.0) self.two = relay.const(2.0) self.three = relay.const(3.0) e = env() def test_let(): orig = relay.Let(e.x, e.y, e.z) assert alpha_equal(dead_code_elimination(orig), e.z) def test_used_let(): orig = relay.Let(e.c, e.one, e.c + e.c) assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.one, e.c + e.c)) @nottest def test_inline(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c)) assert alpha_equal(dead_code_elimination(orig), e.d) def test_chain_unused_let(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e)) assert alpha_equal(dead_code_elimination(orig), e.e) # make sure we dont infinite loop def test_recursion(): """ Program: let f(n: i32, data: f32) -> f32 = { if (n == 0) { return data; } else { return f(n - 1, log(data)); } } f(2, 10000); """ f = relay.Var("f") n = relay.Var("n", e.int32) data = relay.Var("data", e.float32) funcbody = relay.If(equal(n, relay.const(0)), data, relay.Call(f, [subtract(n, relay.const(1.0)), log(data)])) value = relay.Function([n, data], funcbody, e.float32, []) orig = relay.Let(f, value, relay.Call(f, [relay.const(2.0), relay.const(10000.0)])) assert alpha_equal(dead_code_elimination(orig), orig) assert alpha_equal(dead_code_elimination(relay.Let(f, value, e.three)), e.three) def test_op_let(): assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two)) def test_tuple_get_item(): t = relay.Var('t') g = relay.TupleGetItem(t, 0) assert alpha_equal(dead_code_elimination(g), g) assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t), 0)), g) if __name__ == "__main__": test_let() test_used_let() test_inline() test_chain_unused_let() test_recursion() test_op_let() test_tuple_get_item()