Commit 11050fda by Alexander Pivovarov Committed by Tianqi Chen

Use new onnx API to load model from file (#1874)

parent c4ebe6bd
...@@ -66,7 +66,7 @@ def get_caffe2_output(model, x, dtype='float32'): ...@@ -66,7 +66,7 @@ def get_caffe2_output(model, x, dtype='float32'):
def verify_onnx_forward_impl(graph_file, data_shape, out_shape): def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
dtype = 'float32' dtype = 'float32'
x = np.random.uniform(size=data_shape) x = np.random.uniform(size=data_shape)
model = onnx.load(graph_file) model = onnx.load_model(graph_file)
c2_out = get_caffe2_output(model, x, dtype) c2_out = get_caffe2_output(model, x, dtype)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype) tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype)
......
...@@ -6,7 +6,7 @@ from model_zoo import super_resolution, super_resolution_sym ...@@ -6,7 +6,7 @@ from model_zoo import super_resolution, super_resolution_sym
from model_zoo import squeezenet as squeezenet from model_zoo import squeezenet as squeezenet
def compare_graph(onnx_file, nnvm_sym, ishape): def compare_graph(onnx_file, nnvm_sym, ishape):
onnx_model = onnx.load(onnx_file) onnx_model = onnx.load_model(onnx_file)
onnx_sym, params = nnvm.frontend.from_onnx(onnx_model) onnx_sym, params = nnvm.frontend.from_onnx(onnx_model)
g1 = nnvm.graph.create(onnx_sym) g1 = nnvm.graph.create(onnx_sym)
g2 = nnvm.graph.create(nnvm_sym) g2 = nnvm.graph.create(nnvm_sym)
......
...@@ -46,7 +46,7 @@ model_url = ''.join(['https://gist.github.com/zhreshold/', ...@@ -46,7 +46,7 @@ model_url = ''.join(['https://gist.github.com/zhreshold/',
'super_resolution_0.2.onnx']) 'super_resolution_0.2.onnx'])
download(model_url, 'super_resolution.onnx', True) download(model_url, 'super_resolution.onnx', True)
# now you have super_resolution.onnx on disk # now you have super_resolution.onnx on disk
onnx_model = onnx.load('super_resolution.onnx') onnx_model = onnx.load_model('super_resolution.onnx')
# we can load the graph as NNVM compatible model # we can load the graph as NNVM compatible model
sym, params = nnvm.frontend.from_onnx(onnx_model) sym, params = nnvm.frontend.from_onnx(onnx_model)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment