# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import mxnet as mx
import nnvm
from nnvm.compiler import graph_util, graph_attr
import model_zoo

def compare_graph(sym1, sym2, ishape=(2, 3, 224, 224)):
    g1 = nnvm.graph.create(sym1)
    g2 = nnvm.graph.create(sym2)
    graph_attr.set_shape_inputs(g1, {'data':ishape})
    graph_attr.set_shape_inputs(g2, {'data':ishape})
    g1 = g1.apply("InferShape").apply("SimplifyInference")
    g2 = g2.apply("InferShape").apply("SimplifyInference")
    graph_util.check_graph_equal(g1, g2)

def test_mlp():
    mx_sym = model_zoo.mx_mlp
    from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
    nnvm_sym = model_zoo.nnvm_mlp
    compare_graph(from_mx_sym, nnvm_sym)

def test_vgg():
    for n in [11, 13, 16, 19]:
        mx_sym = model_zoo.mx_vgg[n]
        from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
        nnvm_sym = model_zoo.nnvm_vgg[n]
        compare_graph(from_mx_sym, nnvm_sym)

def test_resnet():
    for n in [18, 34, 50, 101]:
        mx_sym = model_zoo.mx_resnet[n]
        from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
        nnvm_sym = model_zoo.nnvm_resnet[n]
        compare_graph(from_mx_sym, nnvm_sym)

def test_squeezenet():
    for version in ['1.0', '1.1']:
        mx_sym = model_zoo.mx_squeezenet[version]
        from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
        nnvm_sym = model_zoo.nnvm_squeezenet[version]
        compare_graph(from_mx_sym, nnvm_sym)

def test_inception_v3():
    mx_sym = model_zoo.mx_inception_v3
    from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
    nnvm_sym = model_zoo.nnvm_inception_v3
    compare_graph(from_mx_sym, nnvm_sym, ishape=(2, 3, 299, 299))

def test_dqn():
    mx_sym = model_zoo.mx_dqn
    from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
    nnvm_sym = model_zoo.nnvm_dqn
    compare_graph(from_mx_sym, nnvm_sym, ishape=(2, 4, 84, 84))

def test_dcgan():
    mx_sym = model_zoo.mx_dcgan
    from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
    nnvm_sym = model_zoo.nnvm_dcgan
    compare_graph(from_mx_sym, nnvm_sym, ishape=(2, 100))

def test_multi_outputs():
    def compose(F, **kwargs):
        x = F.sym.Variable('x')
        y = F.sym.Variable('y')
        z = F.sym.split(x, **kwargs)
        return F.sym.broadcast_sub(F.sym.broadcast_add(z[0], z[2]), y)
    mx_sym = compose(mx, num_outputs=3, axis=1)
    from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
    nnvm_sym = compose(nnvm, indices_or_sections=3, axis=1)
    compare_graph(from_mx_sym, nnvm_sym)

if __name__ == '__main__':
    test_mlp()
    test_vgg()
    test_resnet()
    test_multi_outputs()
    test_dqn()
    test_dcgan()
    test_squeezenet()
    test_inception_v3()