import mxnet as mx
from tvm import relay
import model_zoo

def compare_graph(f1, f2):
    f1 = relay.ir_pass.infer_type(f1)
    f2 = relay.ir_pass.infer_type(f2)
    assert relay.ir_pass.alpha_equal(f1, f2)

def test_mlp():
    shape = {"data": (1, 1, 28, 28)}
    mx_fun = model_zoo.mx_mlp()
    from_mx_fun, _ = relay.frontend.from_mxnet(mx_fun, shape=shape)
    relay_fun = model_zoo.relay_mlp()
    compare_graph(from_mx_fun, relay_fun)


def test_vgg():
    shape = {"data": (1, 3, 224, 224)}
    for n in [11, 13, 16, 19]:
        mx_sym = model_zoo.mx_vgg(n)
        from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape=shape)
        relay_sym = model_zoo.relay_vgg(n)
        compare_graph(from_mx_sym, relay_sym)


def test_resnet():
    shape = {"data": (1, 3, 224, 224)}
    for n in [18, 34, 50, 101]:
        mx_sym = model_zoo.mx_resnet(n)
        from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape=shape)
        relay_sym = model_zoo.relay_resnet(n)
        compare_graph(from_mx_sym, relay_sym)


def test_squeezenet():
    shape = {"data": (1, 3, 224, 224)}
    for version in ['1.0', '1.1']:
        mx_sym = model_zoo.mx_squeezenet(version)
        from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape)
        relay_sym = model_zoo.relay_squeezenet(version)
        compare_graph(from_mx_sym, relay_sym)


def test_inception_v3():
    shape = {"data": (1, 3, 299, 299)}
    mx_sym = model_zoo.mx_inception_v3()
    from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape)
    relay_sym = model_zoo.relay_inception_v3()
    compare_graph(from_mx_sym, relay_sym)


def test_dqn():
    shape = {"data": (1, 4, 84, 84)}
    mx_sym = model_zoo.mx_dqn()
    from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape)
    relay_sym = model_zoo.relay_dqn()
    compare_graph(from_mx_sym, relay_sym)


def test_dcgan():
    shape = {"data": (2, 100)}
    mx_sym = model_zoo.mx_dcgan()
    from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape)
    relay_sym = model_zoo.relay_dcgan(batch_size=2)
    compare_graph(from_mx_sym, relay_sym)


def test_multi_outputs():
    xshape = (10, 27)
    yshape = (10, 9)

    def mx_compose(F, **kwargs):
        x = F.sym.Variable("x")
        y = F.sym.Variable("y")
        z = F.sym.split(x, **kwargs)
        return F.sym.broadcast_sub(F.sym.broadcast_add(z[0], z[2]), y)

    def relay_compose(F, **kwargs):
        x = F.var("x", shape=xshape)
        y = F.var("y", shape=yshape)
        z = F.split(x, **kwargs)
        z = F.subtract(F.add(z[0], z[2]), y)
        return relay.Function(relay.ir_pass.free_vars(z), z)

    mx_sym = mx_compose(mx, num_outputs=3, axis=1)
    from_mx_sym, _ = relay.frontend.from_mxnet(
        mx_sym, shape={"x":xshape, "y":yshape})
    relay_sym = relay_compose(relay, indices_or_sections=3, axis=1)
    compare_graph(from_mx_sym, relay_sym)


if __name__ == "__main__":
    test_mlp()
    test_resnet()
    test_vgg()
    test_multi_outputs()
    test_dqn()
    test_dcgan()
    test_squeezenet()
    test_inception_v3()