Unverified Commit ce0d5144 by Tianqi Chen Committed by GitHub

Fix VTA Tutorial for more strict graphrt check (#1737)

parent c4421f57
......@@ -61,8 +61,8 @@ def classify(m, image):
m.set_input('data', image)
timer = m.module.time_evaluator("run", ctx, number=1)
tcost = timer()
tvm_output = m.get_output(0, tvm.nd.empty((1000,), "float32", remote.cpu(0)))
top = np.argmax(tvm_output.asnumpy())
tvm_output = m.get_output(0)
top = np.argmax(tvm_output.asnumpy()[0])
tcost = "t={0:.2f}s".format(tcost.mean)
return tcost + " {}".format(synset[top])
......@@ -237,8 +237,8 @@ timer = m.module.time_evaluator("run", ctx, number=1)
tcost = timer()
# Get classification results
tvm_output = m.get_output(0, tvm.nd.empty((1000,), "float32", remote.cpu(0)))
top_categories = np.argsort(tvm_output.asnumpy())
tvm_output = m.get_output(0)
top_categories = np.argsort(tvm_output.asnumpy()[0])
# Report top-5 classification results
print("ResNet-18 Prediction #1:", synset[top_categories[-1]])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment