resnet.py 10.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
"""
ResNet Inference Example
========================
**Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_

This tutorial provides an end-to-end demo, on how to run ResNet-18 inference
onto the VTA accelerator design to perform ImageNet classification tasks.

"""

######################################################################
# Import Libraries
# ----------------
# We start by importing the tvm, vta, nnvm libraries to run this example.

from __future__ import absolute_import, print_function

import os
19 20 21
import time
from io import BytesIO

22 23
import numpy as np
import requests
24 25
from matplotlib import pyplot as plt
from PIL import Image
26

27 28
import tvm
from tvm import rpc, autotvm
29
from tvm.contrib import graph_runtime, util
30
from tvm.contrib.download import download
31 32 33
import nnvm.compiler
import vta
import vta.testing
34

35
# Load VTA parameters from the vta/config/vta_config.json file
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
env = vta.get_env()

# Helper to crop an image to a square (224, 224)
# Takes in an Image object, returns an Image object
def thumbnailify(image, pad=15):
    w, h = image.size
    crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad)
    image = image.crop(crop)
    image = image.resize((224, 224))
    return image

# Helper function to read in image
# Takes in Image object, returns an ND array
def process_image(image):
    # Convert to neural network input format
    image = np.array(image) - np.array([123., 117., 104.])
    image /= np.array([58.395, 57.12, 57.375])
    image = image.transpose((2, 0, 1))
    image = image[np.newaxis, :]

    return tvm.nd.array(image.astype("float32"))

# Classification helper function
# Takes in the graph runtime, and an image, and returns top result and time
def classify(m, image):
    m.set_input('data', image)
    timer = m.module.time_evaluator("run", ctx, number=1)
    tcost = timer()
64 65
    tvm_output = m.get_output(0)
    top = np.argmax(tvm_output.asnumpy()[0])
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
    tcost = "t={0:.2f}s".format(tcost.mean)
    return tcost + " {}".format(synset[top])

# Helper function to compile the NNVM graph
# Takes in a path to a graph file, params file, and device target
# Returns the NNVM graph object, a compiled library object, and the params dict
def generate_graph(graph_fn, params_fn, device="vta"):
    # Measure build start time
    build_start = time.time()

    # Derive the TVM target
    target = tvm.target.create("llvm -device={}".format(device))

    # Derive the LLVM compiler flags
    # When targetting the Pynq, cross-compile to ARMv7 ISA
    if env.TARGET == "sim":
        target_host = "llvm"
    elif env.TARGET == "pynq":
        target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"

    # Load the ResNet-18 graph and parameters
    sym = nnvm.graph.load_json(open(graph_fn).read())
    params = nnvm.compiler.load_param_dict(open(params_fn, 'rb').read())

    # Populate the shape and data type dictionary
    shape_dict = {"data": (1, 3, 224, 224)}
    dtype_dict = {"data": 'float32'}
    shape_dict.update({k: v.shape for k, v in params.items()})
    dtype_dict.update({k: str(v.dtype) for k, v in params.items()})

    # Apply NNVM graph optimization passes
    sym = vta.graph.clean_cast(sym)
    sym = vta.graph.clean_conv_fuse(sym)
    if target.device_name == "vta":
        assert env.BLOCK_IN == env.BLOCK_OUT
        sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT)

    # Compile NNVM graph
    with nnvm.compiler.build_config(opt_level=3):
        if target.device_name != "vta":
            graph, lib, params = nnvm.compiler.build(
107 108
                sym, target, shape_dict, dtype_dict,
                params=params, target_host=target_host)
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
        else:
            with vta.build_config():
                graph, lib, params = nnvm.compiler.build(
                    sym, target, shape_dict, dtype_dict,
                    params=params, target_host=target_host)

    # Save the compiled inference graph library
    assert tvm.module.enabled("rpc")
    temp = util.tempdir()
    lib.save(temp.relpath("graphlib.o"))

    # Send the inference library over to the remote RPC server
    remote.upload(temp.relpath("graphlib.o"))
    lib = remote.load_module("graphlib.o")

    # Measure build time
    build_time = time.time() - build_start
    print("ResNet-18 inference graph built in {0:.2f}s!".format(build_time))

    return graph, lib, params


######################################################################
# Download ResNet Model
# --------------------------------------------
# Download the necessary files to run ResNet-18.
#

# Obtain ResNet model and download them into _data dir
url = "https://github.com/uwsaml/web-data/raw/master/vta/models/"
categ_fn = 'synset.txt'
graph_fn = 'resnet18_qt8.json'
params_fn = 'resnet18_qt8.params'

# Create data dir
data_dir = "_data/"
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

# Download files
for file in [categ_fn, graph_fn, params_fn]:
    if not os.path.isfile(file):
        download(os.path.join(url, file), os.path.join(data_dir, file))

# Read in ImageNet Categories
synset = eval(open(os.path.join(data_dir, categ_fn)).read())

156
# Download pre-tuned op parameters of conv2d for ARM CPU used in VTA
157
autotvm.tophub.check_backend('vta')
158

159 160 161 162 163 164 165 166 167 168 169 170 171 172

######################################################################
# Setup the Pynq Board's RPC Server
# ---------------------------------
# Build the RPC server's VTA runtime and program the Pynq FPGA.

# Measure build start time
reconfig_start = time.time()

# We read the Pynq RPC host IP address and port number from the OS environment
host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99")
port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091"))

# We configure both the bitstream and the runtime system on the Pynq
173
# to match the VTA configuration specified by the vta_config.json file.
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
if env.TARGET == "pynq":
    # Make sure that TVM was compiled with RPC=1
    assert tvm.module.enabled("rpc")
    remote = rpc.connect(host, port)

    # Reconfigure the JIT runtime
    vta.reconfig_runtime(remote)

    # Program the FPGA with a pre-compiled VTA bitstream.
    # You can program the FPGA with your own custom bitstream
    # by passing the path to the bitstream file instead of None.
    vta.program_fpga(remote, bitstream=None)

    # Report on reconfiguration time
    reconfig_time = time.time() - reconfig_start
    print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time))

# In simulation mode, host the RPC server locally.
elif env.TARGET == "sim":
    remote = rpc.LocalSession()


######################################################################
# Build the ResNet Runtime
# ------------------------
# Build the ResNet graph runtime, and configure the parameters.

201 202
# Set ``device=vtacpu`` to run inference on the CPU
# or ``device=vta`` to run inference on the FPGA.
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
device = "vta"

# Device context
ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)

# Build the graph runtime
graph, lib, params = generate_graph(os.path.join(data_dir, graph_fn),
                                    os.path.join(data_dir, params_fn),
                                    device)
m = graph_runtime.create(graph, lib, ctx)

# Set the parameters
m.set_input(**params)

######################################################################
# Run ResNet-18 inference on a sample image
# -----------------------------------------
# Perform image classification on test image.
# You can change the test image URL to any image of your choosing.

# Read in test image
image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg'
# Read in test image
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).resize((224, 224))
# Show Image
plt.imshow(image)
plt.show()
# Set the input
image = process_image(image)
m.set_input('data', image)

# Perform inference
timer = m.module.time_evaluator("run", ctx, number=1)
tcost = timer()

# Get classification results
240 241
tvm_output = m.get_output(0)
top_categories = np.argsort(tvm_output.asnumpy()[0])
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314

# Report top-5 classification results
print("ResNet-18 Prediction #1:", synset[top_categories[-1]])
print("                     #2:", synset[top_categories[-2]])
print("                     #3:", synset[top_categories[-3]])
print("                     #4:", synset[top_categories[-4]])
print("                     #5:", synset[top_categories[-5]])
print("Performed inference in {0:.2f}s".format(tcost.mean))


######################################################################
# Run a Youtube Video Image Classifier
# ------------------------------------
# Perform image classification on test stream on 1 frame every 48 frames.
# Comment the `if False:` out to run the demo

# Early exit - remove for Demo
if False:

    import cv2
    import pafy
    from IPython.display import clear_output

    # Helper to crop an image to a square (224, 224)
    # Takes in an Image object, returns an Image object
    def thumbnailify(image, pad=15):
        w, h = image.size
        crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad)
        image = image.crop(crop)
        image = image.resize((224, 224))
        return image

    # 16:16 inches
    plt.rcParams['figure.figsize'] = [16, 16]

    # Stream the video in
    url = "https://www.youtube.com/watch?v=PJlmYh27MHg&t=2s"
    video = pafy.new(url)
    best = video.getbest(preftype="mp4")
    cap = cv2.VideoCapture(best.url)

    # Process one frame out of every 48 for variety
    count = 0
    guess = ""
    while(count<2400):

        # Capture frame-by-frame
        ret, frame = cap.read()

        # Process one every 48 frames
        if count % 48 == 1:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)
            # Crop and resize
            thumb = np.array(thumbnailify(frame))
            image = process_image(thumb)
            guess = classify(m, image)

            # Insert guess in frame
            frame = cv2.rectangle(thumb,(0,0),(200,0),(0,0,0),50)
            cv2.putText(frame, guess, (5,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (256,256,256), 1, cv2.LINE_AA)

            plt.imshow(thumb)
            plt.axis('off')
            plt.show()
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
            clear_output(wait=True)

        count += 1

    # When everything done, release the capture
    cap.release()
315
    cv2.destroyAllWindows()