# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
import pytest
import tvm
from tvm import relay


def test_callgraph_construct():
    mod = tvm.IRModule({})
    x = relay.var("x", shape=(2, 3))
    y = relay.var("y", shape=(2, 3))
    mod["g1"] = relay.Function([x, y], x + y)
    call_graph = relay.analysis.CallGraph(mod)
    assert "g1" in str(call_graph)
    assert tvm.ir.structural_equal(mod, call_graph.module)


def test_print_element():
    mod = tvm.IRModule({})
    x0 = relay.var("x0", shape=(2, 3))
    y0 = relay.var("y0", shape=(2, 3))
    mod["g0"] = relay.Function([x0, y0], x0 + y0)
    x1 = relay.var("x1", shape=(2, 3))
    y1 = relay.var("y1", shape=(2, 3))
    mod["g1"] = relay.Function([x1, y1], x1 - y1)
    call_graph = relay.analysis.CallGraph(mod)

    assert "#refs = 0" in str(call_graph.print_var("g0"))
    assert "#refs = 0" in str(call_graph.print_var("g1"))


def test_global_call_count():
    mod = tvm.IRModule({})
    x0 = relay.var("x0", shape=(2, 3))
    y0 = relay.var("y0", shape=(2, 3))
    g0 = relay.GlobalVar("g0")
    mod[g0] = relay.Function([x0, y0], x0 + y0)
    x1 = relay.var("x1", shape=(2, 3))
    y1 = relay.var("y1", shape=(2, 3))
    g1 = relay.GlobalVar("g1")
    mod[g1] = relay.Function([x1, y1], g0(x1, y1))
    call_graph = relay.analysis.CallGraph(mod)

    p0 = relay.var("p0", shape=(2, 3))
    p1 = relay.var("p1", shape=(2, 3))
    func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
    mod["main"] = func
    call_graph = relay.analysis.CallGraph(mod)

    assert call_graph.global_call_count(g0) == 0
    assert call_graph.global_call_count(g1) == 1
    assert call_graph.global_call_count("main") == 2


def test_ref_count():
    mod = tvm.IRModule({})
    x0 = relay.var("x0", shape=(2, 3))
    y0 = relay.var("y0", shape=(2, 3))
    g0 = relay.GlobalVar("g0")
    mod[g0] = relay.Function([x0, y0], x0 + y0)
    x1 = relay.var("x1", shape=(2, 3))
    y1 = relay.var("y1", shape=(2, 3))
    g1 = relay.GlobalVar("g1")
    mod[g1] = relay.Function([x1, y1], x1 - y1)
    call_graph = relay.analysis.CallGraph(mod)

    p0 = relay.var("p0", shape=(2, 3))
    p1 = relay.var("p1", shape=(2, 3))
    func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
    mod["main"] = func
    call_graph = relay.analysis.CallGraph(mod)

    assert call_graph.ref_count(g0) == 1
    assert call_graph.ref_count(g1) == 1
    assert call_graph.ref_count("main") == 0


def test_nested_ref():
    mod = tvm.IRModule({})
    x0 = relay.var("x0", shape=(2, 3))
    y0 = relay.var("y0", shape=(2, 3))
    g0 = relay.GlobalVar("g0")
    mod[g0] = relay.Function([x0, y0], x0 + y0)
    x1 = relay.var("x1", shape=(2, 3))
    y1 = relay.var("y1", shape=(2, 3))
    g1 = relay.GlobalVar("g1")
    mod[g1] = relay.Function([x1, y1], g0(x1, y1))
    call_graph = relay.analysis.CallGraph(mod)

    p0 = relay.var("p0", shape=(2, 3))
    p1 = relay.var("p1", shape=(2, 3))
    func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
    mod["main"] = func
    call_graph = relay.analysis.CallGraph(mod)

    assert call_graph.ref_count(g0) == 2
    assert call_graph.ref_count(g1) == 1
    assert call_graph.ref_count("main") == 0


def test_recursive_func():
    mod = tvm.IRModule({})

    x = relay.var('x', shape=[], dtype='int32')
    fn0 = relay.Function([x], x)
    gx = relay.GlobalVar("gx")
    mod[gx] = fn0

    sum_up = relay.GlobalVar('sum_up')
    i = relay.var('i', shape=[], dtype='int32')
    sb = relay.ScopeBuilder()
    with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
        sb.ret(i)
    with sb.else_scope():
        one_less = relay.subtract(i, relay.const(1, dtype='int32'))
        global_call = gx(i)
        rec_call = relay.Call(sum_up, [one_less]) + global_call
        sb.ret(relay.add(rec_call, i))
    func = relay.Function([i],
                          sb.get(),
                          ret_type=relay.TensorType([], 'int32'))
    func = func.with_attr("Compiler", tvm.tir.StringImm("a"))
    mod[sum_up] = func
    iarg = relay.var('i', shape=[], dtype='int32')
    mod["main"] = relay.Function([iarg], sum_up(iarg))
    call_graph = relay.analysis.CallGraph(mod)

    assert call_graph.is_recursive(sum_up)
    assert call_graph.ref_count(sum_up) == 2
    assert call_graph.ref_count(gx) == 1
    assert call_graph.ref_count("main") == 0


if __name__ == "__main__":
    pytest.main()