Commit e7bbb32c by Takeo Imai Committed by Tianqi Chen

Change the way to import ONNX model. (#186, #187) (#189)

* Change the way to import ONNX model.

* Decide ONNX version and change the processing

* Check whether onnx has the attribute '__version__'
parent ae16a366
......@@ -5,8 +5,13 @@ from nnvm.compiler import graph_util, graph_attr
from model_zoo import super_resolution
def compare_graph(onnx_file, nnvm_sym, ishape):
onnx_graph = onnx.load(onnx_file)
onnx_sym, params = nnvm.frontend.from_onnx(onnx_graph)
onnx_vars = [int(n) for n in onnx.__version__.split('.')] if hasattr(onnx, "__version__") else []
if len(onnx_vars) >= 2 and (onnx_vars[0] > 0 or onnx_vars[1] >= 2): # version >= 0.2
onnx_model = onnx.load(onnx_file)
onnx_sym, params = nnvm.frontend.from_onnx(onnx_model.graph)
else:
onnx_graph = onnx.load(onnx_file)
onnx_sym, params = nnvm.frontend.from_onnx(onnx_graph)
g1 = nnvm.graph.create(onnx_sym)
g2 = nnvm.graph.create(nnvm_sym)
ishapes = {'input_0': ishape}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment