"""Test graph equality of caffe2 models.""" import nnvm from nnvm.compiler import graph_util, graph_attr from model_zoo import c2_squeezenet, squeezenet def compare_graph(init, predict, nnvm_sym, ishape): caffe2_sym, params = nnvm.frontend.from_caffe2(init, predict) g1 = nnvm.graph.create(caffe2_sym) g2 = nnvm.graph.create(nnvm_sym) input_name = predict.external_input[0] ishapes = {input_name: ishape} graph_attr.set_shape_inputs(g1, ishapes) graph_attr.set_shape_inputs(g2, ishapes) g1 = g1.apply("InferShape").apply("SimplifyInference") g2 = g2.apply("InferShape").apply("SimplifyInference") graph_util.check_graph_equal(g1, g2) def test_squeeze_net(): symbol, params = squeezenet.get_workload(version='1.1') compare_graph(c2_squeezenet.init_net, c2_squeezenet.predict_net, symbol, ishape=(1, 3, 224, 224)) if __name__ == '__main__': test_squeeze_net()