Commit e722dbcb by Wenhao Hu Committed by Tianqi Chen

Onnx opset support (#416)

parent f4789db6
pip2 install onnx>=0.2.0 pip2 install onnx>=1.1.0
pip3 install onnx>=0.2.0 pip3 install onnx>=1.1.0
pip2 install http://download.pytorch.org/whl/cu75/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl pip2 install http://download.pytorch.org/whl/cu75/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl
pip2 install torchvision pip2 install torchvision
......
...@@ -14,8 +14,8 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape): ...@@ -14,8 +14,8 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
c2_out = prepared_backend.run(W)[0] c2_out = prepared_backend.run(W)[0]
return c2_out return c2_out
def get_tvm_output(graph, x, target, ctx, dtype='float32'): def get_tvm_output(model, x, target, ctx, dtype='float32'):
new_sym, params = nnvm.frontend.from_onnx(graph) new_sym, params = nnvm.frontend.from_onnx(model)
shape_dict = {'input_0': x.shape} shape_dict = {'input_0': x.shape}
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params) graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
......
...@@ -5,8 +5,8 @@ from nnvm.compiler import graph_util, graph_attr ...@@ -5,8 +5,8 @@ from nnvm.compiler import graph_util, graph_attr
from model_zoo import super_resolution, super_resolution_sym from model_zoo import super_resolution, super_resolution_sym
def compare_graph(onnx_file, nnvm_sym, ishape): def compare_graph(onnx_file, nnvm_sym, ishape):
onnx_graph = onnx.load(onnx_file) onnx_model = onnx.load(onnx_file)
onnx_sym, params = nnvm.frontend.from_onnx(onnx_graph) onnx_sym, params = nnvm.frontend.from_onnx(onnx_model)
g1 = nnvm.graph.create(onnx_sym) g1 = nnvm.graph.create(onnx_sym)
g2 = nnvm.graph.create(nnvm_sym) g2 = nnvm.graph.create(nnvm_sym)
ishapes = {'input_0': ishape} ishapes = {'input_0': ishape}
......
...@@ -44,9 +44,9 @@ model_url = ''.join(['https://gist.github.com/zhreshold/', ...@@ -44,9 +44,9 @@ model_url = ''.join(['https://gist.github.com/zhreshold/',
'super_resolution_0.2.onnx']) 'super_resolution_0.2.onnx'])
download(model_url, 'super_resolution.onnx', True) download(model_url, 'super_resolution.onnx', True)
# now you have super_resolution.onnx on disk # now you have super_resolution.onnx on disk
onnx_graph = onnx.load('super_resolution.onnx') onnx_model = onnx.load('super_resolution.onnx')
# we can load the graph as NNVM compatible model # we can load the graph as NNVM compatible model
sym, params = nnvm.frontend.from_onnx(onnx_graph) sym, params = nnvm.frontend.from_onnx(onnx_model)
###################################################################### ######################################################################
# Load a test image # Load a test image
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment