# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# NOTE: We name this test file to start with test_graph_tuner
# to make it execute after zero_rank tensor test cases. This
# helps avoid topi arithmetic operator overloading issue:
# https://github.com/dmlc/tvm/issues/3240
# TODO: restore the file name after this issue is resolved.
import tvm

from tvm import autotvm, relay
from tvm.relay.testing import resnet
from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \
    get_out_nodes, expr2graph, bind_inputs
from tvm.relay.expr import Call, TupleGetItem, Tuple
from topi.nn.conv2d import conv2d


def create_workload(dshape, kshape, strides,
                    padding, dilation, layout,
                    out_layout, dtype, out_dtype):
    data = tvm.placeholder(dshape, dtype=dtype)
    kernel = tvm.placeholder(kshape, dtype=dtype)
    return autotvm.task.args_to_workload([data, kernel, strides, padding, dilation, layout,
                                          out_dtype], conv2d)


def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result):
    out = has_multiple_inputs(node_list, node_idx, input_names)
    assert out == expected_result, "Output mismatch: expecting checking %s to be %s but got %s." \
                                   % (node_list[node_idx]["op"], str(expected_result), str(out))


def test_has_multiple_inputs():
    data = relay.var("data")
    out1 = data * relay.expr.const(3.0)
    w0 = relay.var("w0")
    out2 = relay.nn.conv2d(data, w0)
    out = relay.add(out1, out2)
    net = relay.Function(relay.ir_pass.free_vars(out), out)
    net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1)})
    target_ops = ["conv2d"]
    node_list = []
    node_dict = {}
    expr2graph(net, target_ops, node_dict, node_list)
    input_names = ["data"]
    verify_has_multiple_inputs(node_list, 2, input_names, False)
    verify_has_multiple_inputs(node_list, 4, input_names, False)
    verify_has_multiple_inputs(node_list, 5, input_names, True)


def test_expr2graph():
    net, _ = resnet.get_workload(num_layers=50, batch_size=1)
    node_dict = {}
    node_list = []
    target_ops = ["conv2d"]
    op_name_list = []
    def _count_node(node):
        if not isinstance(node, relay.op.op.Op,):
            return
        if isinstance(node, Call):
            op_name_list.append(node.op.name.split(".")[-1])
        elif isinstance(node, TupleGetItem):
            op_name_list.append("TupleGetItem")
        elif isinstance(node, Tuple):
            op_name_list.append("Tuple")
        else:
            op_name_list.append("null")
    relay.ir_pass.post_order_visit(net, _count_node)

    expr2graph(net, target_ops, node_dict, node_list)
    for i, item in enumerate(zip(op_name_list, node_list)):
        op_name, node = item
        assert op_name == node["op"], "%dth Node operator mismatch: expecting %s but got %s" \
                                      % (i, str(op_name), str(node["op"]))


def test_get_direct_ancestor():
    data = relay.var("data")
    w0 = relay.var("w0")
    out1 = relay.nn.conv2d(data, w0)
    out2 = relay.add(out1, data * relay.expr.const(5.0))
    out3 = out2 + relay.expr.const(2.5)
    w1 = relay.var("w1")
    out = relay.nn.conv2d(out3, w1)
    net = relay.Function(relay.ir_pass.free_vars(out), out)
    net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1)})
    target_ops = ["conv2d"]
    node_list = []
    node_dict = {}
    expr2graph(net, target_ops, node_dict, node_list)
    visited_dict = {}
    input_names = ["data"]
    out = get_direct_ancestor(node_list, visited_dict, target_ops, 5, input_names)
    assert out == [2, 0], "Output mismatch: expecting [2, 0] but got %s." % str(out)


def test_get_in_nodes():
    data = relay.var("data")
    w0 = relay.var("w0")
    out1 = relay.nn.conv2d(data, w0)
    out2 = relay.add(out1, data)
    out3 = out2 + relay.expr.const(2.5)
    w1 = relay.var("w1")
    out = relay.nn.conv2d(out3, w1)
    net = relay.Function(relay.ir_pass.free_vars(out), out)
    net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1)})
    target_ops = ["conv2d"]
    input_names = ["data"]
    node_list = []
    node_dict = {}
    expr2graph(net, target_ops, node_dict, node_list)
    out = get_in_nodes(node_list, target_ops, input_names)
    expected_out = {7: [3], 3: [2, 0], 2: [0]}
    diff_set = set(out) ^ set(expected_out)
    if len(diff_set) != 0:
        raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))


def test_get_out_nodes():
    in_nodes_dict = {8: [4], 4: [3, 0], 3: [0]}
    expected_out = {0: [3, 4], 3: [4], 4: [8], 8: []}
    out = get_out_nodes(in_nodes_dict)
    diff_set = set(out) ^ set(expected_out)
    if len(diff_set) != 0:
        raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))



if __name__ == "__main__":
    test_has_multiple_inputs()
    test_expr2graph()
    test_get_direct_ancestor()
    test_get_in_nodes()
    test_get_out_nodes()