# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import mxnet as mx

import tvm
from tvm import relay
from tvm.relay import transform
import model_zoo

def compare_graph(lhs_mod, rhs_mod):
    lhs_mod = transform.InferType()(lhs_mod)
    rhs_mod = transform.InferType()(rhs_mod)
    assert relay.analysis.alpha_equal(lhs_mod["main"], rhs_mod["main"])

def test_mlp():
    shape = {"data": (1, 1, 28, 28)}
    mx_fun = model_zoo.mx_mlp()
    mod, _ = relay.frontend.from_mxnet(mx_fun, shape=shape)
    relay_fun = model_zoo.relay_mlp()
    compare_graph(mod, relay_fun)


def test_vgg():
    shape = {"data": (1, 3, 224, 224)}
    for n in [11, 13, 16, 19]:
        mx_sym = model_zoo.mx_vgg(n)
        mod, _ = relay.frontend.from_mxnet(mx_sym, shape=shape)
        relay_mod = model_zoo.relay_vgg(n)
        compare_graph(mod, relay_mod)


def test_resnet():
    shape = {"data": (1, 3, 224, 224)}
    for n in [18, 34, 50, 101]:
        mx_sym = model_zoo.mx_resnet(n)
        mod, _ = relay.frontend.from_mxnet(mx_sym, shape=shape)
        relay_mod = model_zoo.relay_resnet(n)
        compare_graph(mod, relay_mod)


def test_squeezenet():
    shape = {"data": (1, 3, 224, 224)}
    for version in ['1.0', '1.1']:
        mx_sym = model_zoo.mx_squeezenet(version)
        mod, _ = relay.frontend.from_mxnet(mx_sym, shape)
        relay_mod = model_zoo.relay_squeezenet(version)
        compare_graph(mod, relay_mod)


def test_inception_v3():
    shape = {"data": (1, 3, 299, 299)}
    mx_sym = model_zoo.mx_inception_v3()
    mod, _ = relay.frontend.from_mxnet(mx_sym, shape)
    relay_mod = model_zoo.relay_inception_v3()
    compare_graph(mod, relay_mod)


def test_dqn():
    shape = {"data": (1, 4, 84, 84)}
    mx_sym = model_zoo.mx_dqn()
    mod, _ = relay.frontend.from_mxnet(mx_sym, shape)
    relay_mod = model_zoo.relay_dqn()
    compare_graph(mod, relay_mod)


def test_dcgan():
    shape = {"data": (2, 100)}
    mx_sym = model_zoo.mx_dcgan()
    mod, _ = relay.frontend.from_mxnet(mx_sym, shape)
    relay_mod = model_zoo.relay_dcgan(batch_size=2)
    compare_graph(mod, relay_mod)


def test_multi_outputs():
    xshape = (10, 27)
    yshape = (10, 9)

    def mx_compose(F, **kwargs):
        x = F.sym.Variable("x")
        y = F.sym.Variable("y")
        z = F.sym.split(x, **kwargs)
        return F.sym.broadcast_sub(F.sym.broadcast_add(z[0], z[2]), y)

    def relay_compose(F, **kwargs):
        x = F.var("x", shape=xshape)
        y = F.var("y", shape=yshape)
        z = F.split(x, **kwargs)
        z = F.subtract(F.add(z[0], z[2]), y)
        func = relay.Function(relay.analysis.free_vars(z), z)
        return tvm.IRModule.from_expr(func)

    mx_sym = mx_compose(mx, num_outputs=3, axis=1)
    mod, _ = relay.frontend.from_mxnet(
        mx_sym, shape={"x":xshape, "y":yshape})
    relay_mod = relay_compose(relay, indices_or_sections=3, axis=1)
    compare_graph(mod, relay_mod)


if __name__ == "__main__":
    test_mlp()
    test_resnet()
    test_vgg()
    test_multi_outputs()
    test_dqn()
    test_dcgan()
    test_squeezenet()
    test_inception_v3()