Commit 09215019 by Joshua Z. Zhang Committed by Tianqi Chen

[TUTORIAL] use resnet v2 (#51)

* use resnet v2

* fix
parent 120753d4
......@@ -28,7 +28,7 @@ from mxnet.gluon.model_zoo.vision import get_model
from mxnet.gluon.utils import download
import Image
from matplotlib import pyplot as plt
block = get_model('resnet18_v1', pretrained=True)
block = get_model('resnet18_v2', pretrained=True)
img_name = 'cat.jpg'
synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
'4d0b62f3d01426887599d4f7ede23ee5/raw/',
......@@ -85,7 +85,7 @@ m.set_input(**params)
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty((1000,), dtype))
top1 = np.argmax(tvm_output)
top1 = np.argmax(tvm_output.asnumpy())
print('TVM prediction top-1:', top1, synset[top1])
######################################################################
......@@ -103,12 +103,12 @@ def block2symbol(block):
return sym, args, auxs
mx_sym, args, auxs = block2symbol(block)
# usually we would save/load it as checkpoint
mx.model.save_checkpoint('resnet18_v1', 0, mx_sym, args, auxs)
# there are 'resnet18_v1-0000.params' and 'resnet18_v1-symbol.json' on disk
mx.model.save_checkpoint('resnet18_v2', 0, mx_sym, args, auxs)
# there are 'resnet18_v2-0000.params' and 'resnet18_v2-symbol.json' on disk
######################################################################
# for a normal mxnet model, we start from here
mx_sym, args, auxs = mx.model.load_checkpoint('resnet18_v1', 0)
mx_sym, args, auxs = mx.model.load_checkpoint('resnet18_v2', 0)
# now we use the same API to get NNVM compatible symbol
nnvm_sym, nnvm_params = nnvm.frontend.from_mxnet(mx_sym, args, auxs)
# repeat the same steps to run this model using TVM
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment