# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test legalize pass"""
import numpy as np
import tvm

from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay import transform, analysis
from tvm.relay.testing.temp_op_attr import TempOpAttr


def run_opt_pass(expr, passes):
    passes = passes if isinstance(passes, list) else [passes]
    mod = relay.Module.from_expr(expr)
    seq = transform.Sequential(passes)
    with transform.PassContext(opt_level=3):
        mod = seq(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body

def test_legalize():
    """Test directly replacing an operator with a new one"""
    def before():
        x = relay.var("x", shape=(1, 64, 56, 56))
        weight = relay.var('weight', shape=(64, 64, 3, 3))
        y = relay.nn.conv2d(x, weight,
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y = relay.nn.relu(y)
        y = relay.Function([x, weight], y)
        return y

    def legalize_conv2d(attrs, inputs, types):
        data, weight = inputs
        weight = relay.multiply(weight, relay.const(2.0, "float32"))
        return relay.nn.conv2d(data, weight, **attrs)

    def expected():
        x = relay.var("x", shape=(1, 64, 56, 56))
        weight = relay.var('weight', shape=(64, 64, 3, 3))
        y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y = relay.nn.relu(y)
        y = relay.Function([x, weight], y)
        return y

    with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d):
        a = before()
        a = run_opt_pass(a, transform.Legalize())
        b = run_opt_pass(expected(), transform.InferType())

    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)

def test_legalize_none():
    """Test doing nothing by returning 'None' """
    def before():
        x = relay.var("x", shape=(1, 64, 56, 56))
        y = relay.nn.global_max_pool2d(x)
        y = relay.Function([x], y)
        return y

    called = [False]

    def legalize_conv2d(attrs, inputs, types):
        called[0] = True
        return None

    with TempOpAttr("nn.global_max_pool2d", "FTVMLegalize", legalize_conv2d):
        a = before()
        a = run_opt_pass(a, transform.Legalize())
        b = run_opt_pass(before(), transform.InferType())

    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
    assert(called[0])

def test_legalize_multiple_ops():
    """Test directly replacing an operator with a new one"""
    def before():
        x = relay.var("x", shape=(1, 64, 56, 56))
        weight = relay.var('weight', shape=(64, 64, 3, 3))
        y = relay.nn.conv2d(x, weight,
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y = relay.nn.relu(y)
        y = relay.Function([x, weight], y)
        return y

    def legalize_conv2d(attrs, inputs, types):
        data, weight = inputs
        weight = relay.multiply(weight, relay.const(2.0, "float32"))
        return relay.nn.conv2d(data, weight, **attrs)

    def legalize_relu(attrs, inputs, types):
        data = inputs[0]
        add = relay.add(tvm.relay.const(0, "float32"), data)
        return relay.nn.relu(add)


    def expected():
        x = relay.var("x", shape=(1, 64, 56, 56))
        weight = relay.var('weight', shape=(64, 64, 3, 3))
        y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
                            channels=64,
                            kernel_size=(3, 3),
                            padding=(1, 1))
        y = relay.add(tvm.relay.const(0, "float32"), y)
        y = relay.nn.relu(y)
        y = relay.Function([x, weight], y)
        return y

    with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d):
        with TempOpAttr("nn.relu", "FTVMLegalize", legalize_relu):
            a = before()
            a = run_opt_pass(a, transform.Legalize())
            b = run_opt_pass(expected(), transform.InferType())

    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


def test_legalize_multi_input():
    """Test directly replacing an operator with a new one"""
    def before():
        x = relay.var("x", shape=(1, 64, 56, 56))
        y = relay.var("y", shape=(1, 64, 56, 20))
        z = relay.var("z", shape=(1, 64, 56, 10))
        func = relay.concatenate([x, y, z], axis=3)
        func = relay.Function([x, y, z], func)
        return func

    def legalize_concatenate(attrs, inputs, types):
        # Check that the correct multi-input case is handled.
        assert len(inputs) == 1
        assert isinstance(inputs[0], tvm.relay.expr.Tuple)
        assert len(types) == 2
        assert isinstance(types[0], tvm.relay.ty.TupleType)
        assert isinstance(types[1], tvm.relay.ty.TensorType)
        return None

    def expected():
        x = relay.var("x", shape=(1, 64, 56, 56))
        y = relay.var("y", shape=(1, 64, 56, 20))
        z = relay.var("z", shape=(1, 64, 56, 10))
        func = relay.concatenate([x, y, z], axis=3)
        func = relay.Function([x, y, z], func)
        return func


    with TempOpAttr("concatenate", "FTVMLegalize", legalize_concatenate):
        a = before()
        a = run_opt_pass(a, transform.Legalize())
        b = run_opt_pass(expected(), transform.InferType())

    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


if __name__ == "__main__":
    test_legalize()
    test_legalize_none()
    test_legalize_multiple_ops()
    test_legalize_multi_input()