Commit 09215019 by Joshua Z. Zhang Committed by Tianqi Chen

[TUTORIAL] use resnet v2 (#51)

* use resnet v2

* fix
parent 120753d4
...@@ -28,7 +28,7 @@ from mxnet.gluon.model_zoo.vision import get_model ...@@ -28,7 +28,7 @@ from mxnet.gluon.model_zoo.vision import get_model
from mxnet.gluon.utils import download from mxnet.gluon.utils import download
import Image import Image
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
block = get_model('resnet18_v1', pretrained=True) block = get_model('resnet18_v2', pretrained=True)
img_name = 'cat.jpg' img_name = 'cat.jpg'
synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
'4d0b62f3d01426887599d4f7ede23ee5/raw/', '4d0b62f3d01426887599d4f7ede23ee5/raw/',
...@@ -85,7 +85,7 @@ m.set_input(**params) ...@@ -85,7 +85,7 @@ m.set_input(**params)
m.run() m.run()
# get outputs # get outputs
tvm_output = m.get_output(0, tvm.nd.empty((1000,), dtype)) tvm_output = m.get_output(0, tvm.nd.empty((1000,), dtype))
top1 = np.argmax(tvm_output) top1 = np.argmax(tvm_output.asnumpy())
print('TVM prediction top-1:', top1, synset[top1]) print('TVM prediction top-1:', top1, synset[top1])
###################################################################### ######################################################################
...@@ -103,12 +103,12 @@ def block2symbol(block): ...@@ -103,12 +103,12 @@ def block2symbol(block):
return sym, args, auxs return sym, args, auxs
mx_sym, args, auxs = block2symbol(block) mx_sym, args, auxs = block2symbol(block)
# usually we would save/load it as checkpoint # usually we would save/load it as checkpoint
mx.model.save_checkpoint('resnet18_v1', 0, mx_sym, args, auxs) mx.model.save_checkpoint('resnet18_v2', 0, mx_sym, args, auxs)
# there are 'resnet18_v1-0000.params' and 'resnet18_v1-symbol.json' on disk # there are 'resnet18_v2-0000.params' and 'resnet18_v2-symbol.json' on disk
###################################################################### ######################################################################
# for a normal mxnet model, we start from here # for a normal mxnet model, we start from here
mx_sym, args, auxs = mx.model.load_checkpoint('resnet18_v1', 0) mx_sym, args, auxs = mx.model.load_checkpoint('resnet18_v2', 0)
# now we use the same API to get NNVM compatible symbol # now we use the same API to get NNVM compatible symbol
nnvm_sym, nnvm_params = nnvm.frontend.from_mxnet(mx_sym, args, auxs) nnvm_sym, nnvm_params = nnvm.frontend.from_mxnet(mx_sym, args, auxs)
# repeat the same steps to run this model using TVM # repeat the same steps to run this model using TVM
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment