Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
ffe1badd
Commit
ffe1badd
authored
Jul 03, 2018
by
Thierry Moreau
Committed by
Tianqi Chen
Jul 11, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TUTORIAL] Resnet-18 end to end tutorial example (#55)
parent
8539ac58
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
327 additions
and
0 deletions
+327
-0
vta/tutorials/resnet.py
+327
-0
No files found.
vta/tutorials/resnet.py
0 → 100644
View file @
ffe1badd
"""
ResNet Inference Example
========================
**Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_
This tutorial provides an end-to-end demo, on how to run ResNet-18 inference
onto the VTA accelerator design to perform ImageNet classification tasks.
"""
######################################################################
# Import Libraries
# ----------------
# We start by importing the tvm, vta, nnvm libraries to run this example.
from
__future__
import
absolute_import
,
print_function
import
os
import
sys
import
nnvm
import
nnvm.compiler
import
tvm
import
vta
import
vta.testing
import
numpy
as
np
import
json
import
requests
import
time
from
nnvm.compiler
import
graph_attr
from
tvm.contrib
import
graph_runtime
,
rpc
,
util
from
tvm.contrib.download
import
download
from
vta.testing
import
simulator
from
io
import
BytesIO
from
matplotlib
import
pyplot
as
plt
from
PIL
import
Image
# Load VTA parameters from the config.json file
env
=
vta
.
get_env
()
# Helper to crop an image to a square (224, 224)
# Takes in an Image object, returns an Image object
def
thumbnailify
(
image
,
pad
=
15
):
w
,
h
=
image
.
size
crop
=
((
w
-
h
)
//
2
+
pad
,
pad
,
h
+
(
w
-
h
)
//
2
-
pad
,
h
-
pad
)
image
=
image
.
crop
(
crop
)
image
=
image
.
resize
((
224
,
224
))
return
image
# Helper function to read in image
# Takes in Image object, returns an ND array
def
process_image
(
image
):
# Convert to neural network input format
image
=
np
.
array
(
image
)
-
np
.
array
([
123.
,
117.
,
104.
])
image
/=
np
.
array
([
58.395
,
57.12
,
57.375
])
image
=
image
.
transpose
((
2
,
0
,
1
))
image
=
image
[
np
.
newaxis
,
:]
return
tvm
.
nd
.
array
(
image
.
astype
(
"float32"
))
# Classification helper function
# Takes in the graph runtime, and an image, and returns top result and time
def
classify
(
m
,
image
):
m
.
set_input
(
'data'
,
image
)
timer
=
m
.
module
.
time_evaluator
(
"run"
,
ctx
,
number
=
1
)
tcost
=
timer
()
tvm_output
=
m
.
get_output
(
0
,
tvm
.
nd
.
empty
((
1000
,),
"float32"
,
remote
.
cpu
(
0
)))
top
=
np
.
argmax
(
tvm_output
.
asnumpy
())
tcost
=
"t={0:.2f}s"
.
format
(
tcost
.
mean
)
return
tcost
+
" {}"
.
format
(
synset
[
top
])
# Helper function to compile the NNVM graph
# Takes in a path to a graph file, params file, and device target
# Returns the NNVM graph object, a compiled library object, and the params dict
def
generate_graph
(
graph_fn
,
params_fn
,
device
=
"vta"
):
# Measure build start time
build_start
=
time
.
time
()
# Derive the TVM target
target
=
tvm
.
target
.
create
(
"llvm -device={}"
.
format
(
device
))
# Derive the LLVM compiler flags
# When targetting the Pynq, cross-compile to ARMv7 ISA
if
env
.
TARGET
==
"sim"
:
target_host
=
"llvm"
elif
env
.
TARGET
==
"pynq"
:
target_host
=
"llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
# Load the ResNet-18 graph and parameters
sym
=
nnvm
.
graph
.
load_json
(
open
(
graph_fn
)
.
read
())
params
=
nnvm
.
compiler
.
load_param_dict
(
open
(
params_fn
,
'rb'
)
.
read
())
# Populate the shape and data type dictionary
shape_dict
=
{
"data"
:
(
1
,
3
,
224
,
224
)}
dtype_dict
=
{
"data"
:
'float32'
}
shape_dict
.
update
({
k
:
v
.
shape
for
k
,
v
in
params
.
items
()})
dtype_dict
.
update
({
k
:
str
(
v
.
dtype
)
for
k
,
v
in
params
.
items
()})
# Create NNVM graph
graph
=
nnvm
.
graph
.
create
(
sym
)
graph_attr
.
set_shape_inputs
(
sym
,
shape_dict
)
graph_attr
.
set_dtype_inputs
(
sym
,
dtype_dict
)
graph
=
graph
.
apply
(
"InferShape"
)
.
apply
(
"InferType"
)
# Apply NNVM graph optimization passes
sym
=
vta
.
graph
.
clean_cast
(
sym
)
sym
=
vta
.
graph
.
clean_conv_fuse
(
sym
)
if
target
.
device_name
==
"vta"
:
assert
env
.
BLOCK_IN
==
env
.
BLOCK_OUT
sym
=
vta
.
graph
.
pack
(
sym
,
shape_dict
,
env
.
BATCH
,
env
.
BLOCK_OUT
)
# Compile NNVM graph
with
nnvm
.
compiler
.
build_config
(
opt_level
=
3
):
if
target
.
device_name
!=
"vta"
:
graph
,
lib
,
params
=
nnvm
.
compiler
.
build
(
sym
,
target_host
,
shape_dict
,
dtype_dict
,
params
=
params
)
else
:
with
vta
.
build_config
():
graph
,
lib
,
params
=
nnvm
.
compiler
.
build
(
sym
,
target
,
shape_dict
,
dtype_dict
,
params
=
params
,
target_host
=
target_host
)
# Save the compiled inference graph library
assert
tvm
.
module
.
enabled
(
"rpc"
)
temp
=
util
.
tempdir
()
lib
.
save
(
temp
.
relpath
(
"graphlib.o"
))
# Send the inference library over to the remote RPC server
remote
.
upload
(
temp
.
relpath
(
"graphlib.o"
))
lib
=
remote
.
load_module
(
"graphlib.o"
)
# Measure build time
build_time
=
time
.
time
()
-
build_start
print
(
"ResNet-18 inference graph built in {0:.2f}s!"
.
format
(
build_time
))
return
graph
,
lib
,
params
######################################################################
# Download ResNet Model
# --------------------------------------------
# Download the necessary files to run ResNet-18.
#
# Obtain ResNet model and download them into _data dir
url
=
"https://github.com/uwsaml/web-data/raw/master/vta/models/"
categ_fn
=
'synset.txt'
graph_fn
=
'resnet18_qt8.json'
params_fn
=
'resnet18_qt8.params'
# Create data dir
data_dir
=
"_data/"
if
not
os
.
path
.
exists
(
data_dir
):
os
.
makedirs
(
data_dir
)
# Download files
for
file
in
[
categ_fn
,
graph_fn
,
params_fn
]:
if
not
os
.
path
.
isfile
(
file
):
download
(
os
.
path
.
join
(
url
,
file
),
os
.
path
.
join
(
data_dir
,
file
))
# Read in ImageNet Categories
synset
=
eval
(
open
(
os
.
path
.
join
(
data_dir
,
categ_fn
))
.
read
())
######################################################################
# Setup the Pynq Board's RPC Server
# ---------------------------------
# Build the RPC server's VTA runtime and program the Pynq FPGA.
# Measure build start time
reconfig_start
=
time
.
time
()
# We read the Pynq RPC host IP address and port number from the OS environment
host
=
os
.
environ
.
get
(
"VTA_PYNQ_RPC_HOST"
,
"192.168.2.99"
)
port
=
int
(
os
.
environ
.
get
(
"VTA_PYNQ_RPC_PORT"
,
"9091"
))
# We configure both the bitstream and the runtime system on the Pynq
# to match the VTA configuration specified by the config.json file.
if
env
.
TARGET
==
"pynq"
:
# Make sure that TVM was compiled with RPC=1
assert
tvm
.
module
.
enabled
(
"rpc"
)
remote
=
rpc
.
connect
(
host
,
port
)
# Reconfigure the JIT runtime
vta
.
reconfig_runtime
(
remote
)
# Program the FPGA with a pre-compiled VTA bitstream.
# You can program the FPGA with your own custom bitstream
# by passing the path to the bitstream file instead of None.
vta
.
program_fpga
(
remote
,
bitstream
=
None
)
# Report on reconfiguration time
reconfig_time
=
time
.
time
()
-
reconfig_start
print
(
"Reconfigured FPGA and RPC runtime in {0:.2f}s!"
.
format
(
reconfig_time
))
# In simulation mode, host the RPC server locally.
elif
env
.
TARGET
==
"sim"
:
remote
=
rpc
.
LocalSession
()
######################################################################
# Build the ResNet Runtime
# ------------------------
# Build the ResNet graph runtime, and configure the parameters.
# Set ``device=cpu`` to run inference on the CPU,
# or ``device=vtacpu`` to run inference on the FPGA.
device
=
"vta"
# Device context
ctx
=
remote
.
ext_dev
(
0
)
if
device
==
"vta"
else
remote
.
cpu
(
0
)
# Build the graph runtime
graph
,
lib
,
params
=
generate_graph
(
os
.
path
.
join
(
data_dir
,
graph_fn
),
os
.
path
.
join
(
data_dir
,
params_fn
),
device
)
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
# Set the parameters
m
.
set_input
(
**
params
)
######################################################################
# Run ResNet-18 inference on a sample image
# -----------------------------------------
# Perform image classification on test image.
# You can change the test image URL to any image of your choosing.
# Read in test image
image_url
=
'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg'
# Read in test image
response
=
requests
.
get
(
image_url
)
image
=
Image
.
open
(
BytesIO
(
response
.
content
))
.
resize
((
224
,
224
))
# Show Image
plt
.
imshow
(
image
)
plt
.
show
()
# Set the input
image
=
process_image
(
image
)
m
.
set_input
(
'data'
,
image
)
# Perform inference
timer
=
m
.
module
.
time_evaluator
(
"run"
,
ctx
,
number
=
1
)
tcost
=
timer
()
# Get classification results
tvm_output
=
m
.
get_output
(
0
,
tvm
.
nd
.
empty
((
1000
,),
"float32"
,
remote
.
cpu
(
0
)))
top_categories
=
np
.
argsort
(
tvm_output
.
asnumpy
())
# Report top-5 classification results
print
(
"ResNet-18 Prediction #1:"
,
synset
[
top_categories
[
-
1
]])
print
(
" #2:"
,
synset
[
top_categories
[
-
2
]])
print
(
" #3:"
,
synset
[
top_categories
[
-
3
]])
print
(
" #4:"
,
synset
[
top_categories
[
-
4
]])
print
(
" #5:"
,
synset
[
top_categories
[
-
5
]])
print
(
"Performed inference in {0:.2f}s"
.
format
(
tcost
.
mean
))
######################################################################
# Run a Youtube Video Image Classifier
# ------------------------------------
# Perform image classification on test stream on 1 frame every 48 frames.
# Comment the `if False:` out to run the demo
# Early exit - remove for Demo
if
False
:
import
cv2
import
pafy
from
IPython.display
import
clear_output
# Helper to crop an image to a square (224, 224)
# Takes in an Image object, returns an Image object
def
thumbnailify
(
image
,
pad
=
15
):
w
,
h
=
image
.
size
crop
=
((
w
-
h
)
//
2
+
pad
,
pad
,
h
+
(
w
-
h
)
//
2
-
pad
,
h
-
pad
)
image
=
image
.
crop
(
crop
)
image
=
image
.
resize
((
224
,
224
))
return
image
# 16:16 inches
plt
.
rcParams
[
'figure.figsize'
]
=
[
16
,
16
]
# Stream the video in
url
=
"https://www.youtube.com/watch?v=PJlmYh27MHg&t=2s"
video
=
pafy
.
new
(
url
)
best
=
video
.
getbest
(
preftype
=
"mp4"
)
cap
=
cv2
.
VideoCapture
(
best
.
url
)
# Process one frame out of every 48 for variety
count
=
0
guess
=
""
while
(
count
<
2400
):
# Capture frame-by-frame
ret
,
frame
=
cap
.
read
()
# Process one every 48 frames
if
count
%
48
==
1
:
frame
=
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_BGR2RGB
)
frame
=
Image
.
fromarray
(
frame
)
# Crop and resize
thumb
=
np
.
array
(
thumbnailify
(
frame
))
image
=
process_image
(
thumb
)
guess
=
classify
(
m
,
image
)
# Insert guess in frame
frame
=
cv2
.
rectangle
(
thumb
,(
0
,
0
),(
200
,
0
),(
0
,
0
,
0
),
50
)
cv2
.
putText
(
frame
,
guess
,
(
5
,
15
),
cv2
.
FONT_HERSHEY_SIMPLEX
,
0.5
,
(
256
,
256
,
256
),
1
,
cv2
.
LINE_AA
)
plt
.
imshow
(
thumb
)
plt
.
axis
(
'off'
)
plt
.
show
()
if
cv2
.
waitKey
(
1
)
&
0xFF
==
ord
(
'q'
):
break
clear_output
(
wait
=
True
)
count
+=
1
# When everything done, release the capture
cap
.
release
()
cv2
.
destroyAllWindows
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment