# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm import relay
from tvm.relay import transform


def run_combine_parallel(expr, min_num_branches=3):
    mod = relay.Module.from_expr(expr)
    mod = transform.CombineParallelDense(min_num_branches)(mod)
    return mod["main"]

def run_opt_pass(expr, opt_pass):
    assert isinstance(opt_pass, transform.Pass)
    mod = relay.Module.from_expr(expr)
    mod = opt_pass(mod)
    return mod["main"]


def test_combine_parallel_dense():
    """Simple testcase. One dense cannot be combined due to shape mismatch"""
    def before(x, w1, w2, w3, w4):
        args = [x, w1, w2, w3, w4]
        y1 = relay.nn.dense(x, w1)
        y2 = relay.nn.dense(x, w2)

        # y3 cannot be combined
        y3 = relay.nn.dense(x, w3)

        y4 = relay.nn.dense(x, w4)
        y = relay.Tuple((y1, y2, y3, y4))
        return relay.Function(args, y)

    def expected(x, w1, w2, w3, w4):
        # use a fixed order of args so alpha equal check can pass
        args = [x, w1, w2, w3, w4]
        x_stacked = relay.stack((x, x, x), axis=0)
        w = relay.stack((w1, w2, w4), axis=0)
        y = relay.nn.batch_matmul(x_stacked, w)
        (y1, y2, y4) = relay.split(y, 3)
        y1 = relay.squeeze(y1, [0])
        y2 = relay.squeeze(y2, [0])
        y4 = relay.squeeze(y4, [0])

        # y3 cannot be combined
        y3 = relay.nn.dense(x, w3)

        y = relay.Tuple((y1, y2, y3, y4))
        return relay.Function(args, y)

    def check(i, j, k):
        x =  relay.var("x", shape=(i, k))
        w1 = relay.var("w1", shape=(j, k))
        w2 = relay.var("w2", shape=(j, k))
        w3 = relay.var("w3", shape=(j + 1, k))
        w4 = relay.var("w4", shape=(j, k))

        y_before = before(x, w1, w2, w3, w4)
        y = run_opt_pass(y_before,
                         transform.CombineParallelDense(min_num_branches=2))
        y_expected = expected(x, w1, w2, w3, w4)
        y_expected = run_opt_pass(y_expected, transform.InferType())
        assert relay.analysis.alpha_equal(y, y_expected)

    check(3, 5, 4)
    check(100, 200, 300)


def test_combine_parallel_dense_biasadd():
    """Testcase of combining dense + 1d biasadd"""
    def before(x, w1, w2, b1, b2):
        args = [x, w1, w2, b1, b2]
        y1 = relay.nn.dense(x, w1)
        y2 = relay.nn.dense(x, w2)
        y1 = relay.add(y1, b1)
        y2 = relay.add(y2, b2)
        y = relay.Tuple((y1, y2))
        return relay.Function(args, y)

    def expected(x, w1, w2, b1, b2, is_2d_bias):
        args = [x, w1, w2, b1, b2]
        x_stacked = relay.stack((x, x), axis=0)
        w = relay.stack((w1, w2), axis=0)
        y = relay.nn.batch_matmul(x_stacked, w)

        if not is_2d_bias:
            b1 = relay.expand_dims(b1, 0)
            b2 = relay.expand_dims(b2, 0)

        b = relay.stack((b1, b2), axis=0)
        y = relay.add(y, b)
        (y1, y2) = relay.split(y, 2)
        y1 = relay.squeeze(y1, [0])
        y2 = relay.squeeze(y2, [0])
        y = relay.Tuple((y1, y2))
        return relay.Function(args, y)

    def check(i, j, k, is_2d_bias):
        x =  relay.var("x", shape=(i, k))
        w1 = relay.var("w1", shape=(j, k))
        w2 = relay.var("w2", shape=(j, k))

        if is_2d_bias:
            b1 = relay.var("b1", shape=(i, j))
            b2 = relay.var("b2", shape=(i, j))
        else:
            b1 = relay.var("b1", shape=(j,))
            b2 = relay.var("b2", shape=(j,))

        y_before = before(x, w1, w2, b1, b2)
        y = run_opt_pass(y_before,
                         transform.CombineParallelDense(min_num_branches=2))
        y_expected = expected(x, w1, w2, b1, b2, is_2d_bias)
        y_expected = run_opt_pass(y_expected, transform.InferType())
        assert relay.analysis.alpha_equal(y, y_expected)

    check(3, 5, 4, False)
    check(100, 200, 300, False)
    check(3, 5, 4, True)
    check(100, 200, 300, True)

def test_combine_parallel_dense_biasadd_scale_reshape():
    """Testcase of combining dense + 1d biasadd + multiply with non-fused reshape"""
    def before(x, w1, w2, b1, b2, scale1, scale2, newshape):
        args = [x, w1, w2, b1, b2, scale1, scale2]
        y1 = relay.nn.dense(x, w1)
        y2 = relay.nn.dense(x, w2)
        y1 = relay.add(y1, b1)
        y2 = relay.add(y2, b2)
        y1 = relay.multiply(y1, scale1)
        y2 = relay.multiply(y2, scale2)
        y1 = relay.reshape(y1, newshape=newshape)
        y2 = relay.reshape(y2, newshape=newshape)
        y = relay.Tuple((y1, y2))
        return relay.Function(args, y)

    def expected(x, w1, w2, b1, b2, scale1, scale2, newshape):
        args = [x, w1, w2, b1, b2, scale1, scale2]
        x_stacked = relay.stack((x, x), axis=0)
        w = relay.stack((w1, w2), axis=0)
        y = relay.nn.batch_matmul(x_stacked, w)
        b1 = relay.expand_dims(b1, 0)
        b2 = relay.expand_dims(b2, 0)
        b = relay.stack((b1, b2), axis=0)
        y = relay.add(y, b)
        scale1 = relay.expand_dims(scale1, 0)
        scale2 = relay.expand_dims(scale2, 0)
        scale = relay.stack((scale1, scale2), axis=0)
        y = relay.multiply(y, scale)
        (y1, y2) = relay.split(y, 2)
        y1 = relay.squeeze(y1, [0])
        y2 = relay.squeeze(y2, [0])
        y1 = relay.reshape(y1, newshape=newshape)
        y2 = relay.reshape(y2, newshape=newshape)
        y = relay.Tuple((y1, y2))
        return relay.Function(args, y)

    def check(i, j, k, scale1, scale2, newshape):
        x =  relay.var("x", shape=(i, k))
        w1 = relay.var("w1", shape=(j, k))
        w2 = relay.var("w2", shape=(j, k))
        b1 = relay.var("b1", shape=(j,))
        b2 = relay.var("b2", shape=(j,))
        scale1 = relay.var("scale1", shape=(1,))
        scale2 = relay.var("scale2", shape=(1,))

        y_before = before(x, w1, w2, b1, b2, scale1, scale2, newshape)
        y = run_opt_pass(y_before,
                         transform.CombineParallelDense(min_num_branches=2))
        y_expected = expected(x, w1, w2, b1, b2, scale1, scale2, newshape)
        y_expected = run_opt_pass(y_expected, transform.InferType())
        assert relay.analysis.alpha_equal(y, y_expected)

    check(3, 5, 4, 0.5, 0.25, (1, 1, 15))
    check(100, 200, 300, 0.5, 0.25, (1, 1, 200))


if __name__ == "__main__":
    test_combine_parallel_dense()
    test_combine_parallel_dense_biasadd()
    test_combine_parallel_dense_biasadd_scale_reshape()