Commit 4794f3d6 by Zhixun Tan Committed by Tianqi Chen

WebGL end-to-end example (#369)

parent 9ba95ec4
"""
Quick Start - End-to-End Tutorial for NNVM/TVM Pipeline for OpenGL and WebGL
============================================================================
**Author**: `Zhixun Tan <https://github.com/phisiart>`_
This example shows how to build a neural network with NNVM python frontend and
generate runtime library for WebGL running in a browser with TVM. (Thanks to
Tianqi's `tutorial for cuda <http://nnvm.tvmlang.org/tutorials/get_started.html>`_ and
Ziheng's `tutorial for Raspberry Pi <http://nnvm.tvmlang.org/tutorials/deploy_model_on_rasp.html>`_)
To run this notebook, you need to install tvm and nnvm following
`these instructions <https://github.com/dmlc/nnvm/blob/master/docs/how_to/install.md>`_.
Notice that you need to build tvm with OpenGL.
"""
######################################################################
# Overview
# --------
# In this tutorial, we will download a pre-trained resnet18 model from Gluon
# Model Zoo, and run image classification in 3 different ways:
#
# - Run locally:
# We will compile the model into a TVM library with OpenGL device code and
# directly run it locally.
#
# - Run in a browser through RPC:
# We will compile the model into a JavaScript TVM library with WebGL device
# code, and upload it to an RPC server that is hosting JavaScript TVM runtime
# to run it.
#
# - Export a JavaScript library and run in a browser:
# We will compile the model into a JavaScript TVM library with WebGL device
# code, combine it with JavaScript TVM runtime, and pack everything together.
# Then we will run it directly in a browser.
#
from __future__ import print_function
import numpy as np
import tvm
import nnvm.compiler
import nnvm.testing
# This tutorial must be run with OpenGL backend enabled in TVM.
# The NNVM CI does not enable OpenGL yet. But the user can run this script.
if not tvm.module.enabled("opengl"):
print("OpenGL backend not enabled. This tutorial cannot be run.")
exit(0)
# To run the local demo, set this flag to True.
run_deploy_local = False
# To run the RPC demo, set this flag to True.
run_deploy_rpc = False
# To run the WebGL deploy demo, set this flag to True.
run_deploy_web = True
######################################################################
# Download a Pre-trained Resnet18 Model
# -------------------------------------
# Here we define 2 functions:
#
# - A function that downloads a pre-trained resnet18 model from Gluon Model Zoo.
# The model that we download is in MXNet format, we then transform it into an
# NNVM computation graph.
#
# - A function that downloads a file that contains the name of all the image
# classes in this model.
#
def load_mxnet_resnet():
"""Load a pretrained resnet model from MXNet and transform that into NNVM
format.
Returns
-------
net : nnvm.Symbol
The loaded resnet computation graph.
params : dict[str -> NDArray]
The pretrained model parameters.
data_shape: tuple
The shape of the input tensor (an image).
out_shape: tuple
The shape of the output tensor (probability of all classes).
"""
print("Loading pretrained resnet model from MXNet...")
# Download a pre-trained mxnet resnet18_v1 model.
from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True)
# Transform the mxnet model into NNVM.
# We want a probability so add a softmax operator.
sym, params = nnvm.frontend.from_mxnet(block)
sym = nnvm.sym.softmax(sym)
print("- Model loaded!")
return sym, params, (1, 3, 224, 224), (1, 1000)
def download_synset():
"""Download a dictionary from class index to name.
This lets us know what our prediction actually is.
Returns
-------
synset : dict[int -> str]
The loaded synset.
"""
print("Downloading synset...")
from mxnet import gluon
url = "https://gist.githubusercontent.com/zhreshold/" + \
"4d0b62f3d01426887599d4f7ede23ee5/raw/" + \
"596b27d23537e5a1b5751d2b0481ef172f58b539/" + \
"imagenet1000_clsid_to_human.txt"
file_name = "synset.txt"
gluon.utils.download(url, file_name)
with open(file_name) as f:
synset = eval(f.read())
print("- Synset downloaded!")
return synset
######################################################################
# Download Input Image
# --------------------
# Here we define 2 functions that prepare an image that we want to perform
# classification on.
#
# - A function that downloads a cat image.
#
# - A function that performs preprocessing to an image so that it fits the
# format required by the resnet18 model.
#
def download_image():
"""Download a cat image and resize it to 224x224 which fits resnet.
Returns
-------
image : PIL.Image.Image
The loaded and resized image.
"""
print("Downloading cat image...")
from matplotlib import pyplot as plt
from mxnet import gluon
from PIL import Image
url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true"
img_name = "cat.jpg"
gluon.utils.download(url, img_name)
image = Image.open(img_name).resize((224, 224))
print("- Cat image downloaded!")
plt.imshow(image)
plt.show()
return image
def transform_image(image):
"""Perform necessary preprocessing to input image.
Parameters
----------
image : numpy.ndarray
The raw image.
Returns
-------
image : numpy.ndarray
The preprocessed image.
"""
image = np.array(image) - np.array([123., 117., 104.])
image /= np.array([58.395, 57.12, 57.375])
image = image.transpose((2, 0, 1))
image = image[np.newaxis, :]
return image
######################################################################
# Compile the Model
# -----------------
# Here we define a function that invokes the NNVM compiler.
#
def compile_net(net, target_host, target, data_shape, params):
"""Compiles an NNVM computation graph.
Parameters
----------
net : nnvm.Graph
The NNVM computation graph.
target_host : str
The target to compile the host portion of the library.
target : str
The target to compile the device portion of the library.
data_shape : tuple
The shape of the input data (image).
params : dict[str -> NDArray]
Model parameters.
Returns
-------
graph : Graph
The final execution graph.
libmod : tvm.Module
The module that comes with the execution graph
params : dict[str -> NDArray]
The updated parameters of graph if params is passed.
This can be different from the params passed in.
"""
print("Compiling the neural network...")
with nnvm.compiler.build_config(opt_level=0):
deploy_graph, lib, deploy_params = nnvm.compiler.build(
net,
target_host=target_host,
target=target,
shape={"data": data_shape},
params=params)
print("- Complilation completed!")
return deploy_graph, lib, deploy_params
######################################################################
# Demo 1: Deploy Locally
# ----------------------
# In this demo, we will compile the model targetting the local machine.
#
# Then we will demonstrate how to save the compiled model as a shared library
# and load it back.
#
# Finally, we will run the model.
#
def deploy_local():
"""Runs the demo that deploys a model locally.
"""
# Load resnet model.
net, params, data_shape, out_shape = load_mxnet_resnet()
# Compile the model.
# Note that we specify the the host target as "llvm".
deploy_graph, lib, deploy_params = compile_net(
net,
target_host="llvm",
target="opengl",
data_shape=data_shape,
params=params)
# Save the compiled module.
# Note we need to save all three files returned from the NNVM compiler.
print("Saving the compiled module...")
from tvm.contrib import util
temp = util.tempdir()
path_lib = temp.relpath("deploy_lib.so")
path_graph_json = temp.relpath("deploy_graph.json")
path_params = temp.relpath("deploy_param.params")
lib.export_library(path_lib)
with open(path_graph_json, "w") as fo:
fo.write(deploy_graph.json())
with open(path_params, "wb") as fo:
fo.write(nnvm.compiler.save_param_dict(deploy_params))
print("- Saved files:", temp.listdir())
# Load the module back.
print("Loading the module back...")
loaded_lib = tvm.module.load(path_lib)
with open(path_graph_json) as fi:
loaded_graph_json = fi.read()
with open(path_params, "rb") as fi:
loaded_params = bytearray(fi.read())
print("- Module loaded!")
# Run the model! We will perform prediction on an image.
print("Running the graph...")
from tvm.contrib import graph_runtime
module = graph_runtime.create(loaded_graph_json, loaded_lib, tvm.opengl(0))
module.load_params(loaded_params)
image = transform_image(download_image())
input_data = tvm.nd.array(image.astype("float32"), ctx=tvm.opengl(0))
module.set_input("data", input_data)
module.run()
# Retrieve the output.
out = module.get_output(0, tvm.nd.empty(out_shape, ctx=tvm.opengl(0)))
top1 = np.argmax(out.asnumpy())
synset = download_synset()
print('TVM prediction top-1:', top1, synset[top1])
if run_deploy_local:
deploy_local()
######################################################################
# Demo 2: Deploy the Model to WebGL Remotely with RPC
# -------------------------------------------------------
# Following the steps above, we can also compile the model for WebGL.
# TVM provides rpc module to help with remote deploying.
#
# When we deploy a model locally to OpenGL, the model consists of two parts:
# the host LLVM part and the device GLSL part. Now that we want to deploy to
# WebGL, we need to leverage Emscripten to transform LLVM into JavaScript. In
# order to do that, we will need to specify the host target as
# 'llvm -target=asmjs-unknown-emscripten -system-lib`. Then call Emscripten to
# compile the LLVM binary output into a JavaScript file.
#
# First, we need to manually start an RPC server. Please follow the instructions
# in `tvm/web/README.md`. After following the steps, you should have a web page
# opened in a browser, and a Python script running a proxy.
#
def deploy_rpc():
"""Runs the demo that deploys a model remotely through RPC.
"""
from tvm.contrib import rpc, util, emscripten
# As usual, load the resnet18 model.
net, params, data_shape, out_shape = load_mxnet_resnet()
# Compile the model.
# Note that this time we are changing the target.
# This is because we want to translate the host library into JavaScript
# through Emscripten.
graph, lib, params = compile_net(
net,
target_host="llvm -target=asmjs-unknown-emscripten -system-lib",
target="opengl",
data_shape=data_shape,
params=params)
# Now we want to deploy our model through RPC.
# First we ned to prepare the module files locally.
print("Saving the compiled module...")
temp = util.tempdir()
path_obj = temp.relpath("deploy.bc") # host LLVM part
path_dso = temp.relpath("deploy.js") # host JavaScript part
path_gl = temp.relpath("deploy.gl") # device GLSL part
path_json = temp.relpath("deploy.tvm_meta.json")
lib.save(path_obj)
emscripten.create_js(path_dso, path_obj, side_module=True)
lib.imported_modules[0].save(path_gl)
print("- Saved files:", temp.listdir())
# Connect to the RPC server.
print("Connecting to RPC server...")
proxy_host = 'localhost'
proxy_port = 9090
remote = rpc.connect(proxy_host, proxy_port, key="js")
print("- Connected to RPC server!")
# Upload module to RPC server.
print("Uploading module to RPC server...")
remote.upload(path_dso, "deploy.dso")
remote.upload(path_gl)
remote.upload(path_json)
print("- Upload completed!")
# Load remote library.
print("Loading remote library...")
fdev = remote.load_module("deploy.gl")
fhost = remote.load_module("deploy.dso")
fhost.import_module(fdev)
rlib = fhost
print("- Remote library loaded!")
ctx = remote.opengl(0)
# Upload the parameters.
print("Uploading parameters...")
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
print("- Parameters uploaded!")
# Create the remote runtime module.
print("Running remote module...")
from tvm.contrib import graph_runtime
module = graph_runtime.create(graph, rlib, ctx)
# Set parameter.
module.set_input(**rparams)
# Set input data.
input_data = np.random.uniform(size=data_shape)
module.set_input('data', tvm.nd.array(input_data.astype('float32')))
# Run.
module.run()
print("- Remote module execution completed!")
out = module.get_output(0, out=tvm.nd.empty(out_shape, ctx=ctx))
# Print first 10 elements of output.
print(out.asnumpy()[0][0:10])
if run_deploy_rpc:
deploy_rpc()
######################################################################
# Demo 3: Deploy the Model to WebGL SystemLib
# -----------------------------------------------
# This time we are not using RPC. Instead, we will compile the model and link it
# with the entire tvm runtime into a single giant JavaScript file. Then we will
# run the model using JavaScript.
#
def deploy_web():
"""Runs the demo that deploys to web.
"""
import base64
import json
import os
import shutil
import SimpleHTTPServer, SocketServer
from tvm.contrib import emscripten
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
working_dir = os.getcwd()
output_dir = os.path.join(working_dir, "resnet")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# As usual, load the resnet18 model.
net, params, data_shape, out_shape = load_mxnet_resnet()
# As usual, compile the model.
graph, lib, params = compile_net(
net,
target_host="llvm -target=asmjs-unknown-emscripten -system-lib",
target="opengl",
data_shape=data_shape,
params=params)
# Now we save the model and link it with the TVM web runtime.
path_lib = os.path.join(output_dir, "resnet.js")
path_graph = os.path.join(output_dir, "resnet.json")
path_params = os.path.join(output_dir, "resnet.params")
path_data_shape = os.path.join(output_dir, "data_shape.json")
path_out_shape = os.path.join(output_dir, "out_shape.json")
lib.export_library(path_lib, emscripten.create_js, options=[
"-s", "USE_GLFW=3",
"-s", "USE_WEBGL2=1",
"-lglfw",
"-s", "TOTAL_MEMORY=1073741824",
])
with open(path_graph, "w") as fo:
fo.write(graph.json())
with open(path_params, "w") as fo:
fo.write(base64.b64encode(nnvm.compiler.save_param_dict(params)))
shutil.copyfile(os.path.join(curr_path, "../tvm/web/tvm_runtime.js"),
os.path.join(output_dir, "tvm_runtime.js"))
shutil.copyfile(os.path.join(curr_path, "web/resnet.html"),
os.path.join(output_dir, "resnet.html"))
# Now we want to save some extra files so that we can execute the model from
# JavaScript.
# - data shape
with open(path_data_shape, "w") as fo:
json.dump(list(data_shape), fo)
# - out shape
with open(path_out_shape, "w") as fo:
json.dump(list(out_shape), fo)
# - input image
image = download_image()
image.save(os.path.join(output_dir, "data.png"))
# - synset
synset = download_synset()
with open(os.path.join(output_dir, "synset.json"), "w") as fo:
json.dump(synset, fo)
print("Output files are in", output_dir)
# Finally, we fire up a simple web server to serve all the exported files.
print("Now running a simple server to serve the files...")
os.chdir(output_dir)
port = 8080
handler = SimpleHTTPServer.SimpleHTTPRequestHandler
httpd = SocketServer.TCPServer(("", port), handler)
print("Please open http://localhost:" + str(port) + "/resnet.html")
httpd.serve_forever()
if run_deploy_web:
deploy_web()
\ No newline at end of file
<html>
<head>
<meta charset="UTF-8">
<title>NNVM WebGL Test Page</title>
</head>
<body>
<h1>NNVM WebGL Test Page</h1>
<!-- We will draw the input image here. -->
<div>Input Image:</div>
<img id="image", src="data.png">
<!-- We need a canvas to get the image pixel data. Hide this element. -->
<canvas hidden id="image_canvas" width="224" height="224"></canvas>
<!-- We will write te prediction result here. -->
<div id="prediction"></div>
<!-- We will write all log messages here. -->
<div id="log">Log:</div>
<!-- The OpenGL canvas. -->
<canvas id="canvas"></canvas>
<script>
var Module = {};
// resnet.js would recognize Module["canvas"]
Module["canvas"] = document.getElementById("canvas");
</script>
<script src="resnet.js"></script>
<script src="tvm_runtime.js"></script>
<script>
/**
* Load a text file synchronously.
* @param {string} url The file path.
* @return {string} The file content.
*/
function load_file(url) {
assert(typeof url == "string", "URL must be string");
var req = new XMLHttpRequest();
var result;
req.addEventListener("load", function() {
result = this.responseText;
});
req.open("get", url, false);
req.send();
return result;
}
/**
* The index of the maximum element in an array.
* @param {Array} The array.
* @return {number} The index.
*/
function argmax(arr) {
assert(typeof arr.length == "number", "Input must be array-like");
var res = 0;
for (var i = 0; i < arr.length; i++) {
if (arr[i] > arr[res]) {
res = i;
}
}
return res;
}
/**
* Preprocess an image to fit resnet input format.
* @param {ImageData} The input image data. Should be 224x224xRGBA.
* @return {Float32Array} The preprocessed input array.
*/
function preprocess_image(image_data) {
assert(image_data instanceof ImageData, "Input must be ImageData.");
assert(image_data.width == 224, "Width must be 224.");
assert(image_data.height == 224, "Height must be 224.");
var width = image_data.width;
var height = image_data.height;
var npixels = width * height;
var rgba_uint8 = image_data.data;
assert(rgba_uint8.length == npixels * 4, "Image should be RGBA.");
// Drop alpha channel. Resnet does not need it.
var rgb_uint8 = new Uint8Array(npixels * 3);
for (var i = 0; i < npixels; i++) {
rgb_uint8[i * 3] = rgba_uint8[i * 4];
rgb_uint8[i * 3 + 1] = rgba_uint8[i * 4 + 1];
rgb_uint8[i * 3 + 2] = rgba_uint8[i * 4 + 2];
}
// Cast to float and normalize.
var rgb_float = new Float32Array(npixels * 3);
for (var i = 0; i < npixels; i++) {
rgb_float[i * 3] = (rgb_uint8[i * 3] - 123.0) / 58.395;
rgb_float[i * 3 + 1] = (rgb_uint8[i * 3 + 1] - 117.0) / 57.12;
rgb_float[i * 3 + 2] = (rgb_uint8[i * 3 + 2] - 104.0) / 57.375;
}
// Transpose. Resnet expects 3 greyscale images.
var data = new Float32Array(npixels * 3);
for (var i = 0; i < npixels; i++) {
data[i] = rgb_float[i * 3];
data[npixels + i] = rgb_float[i * 3 + 1];
data[npixels * 2 + i] = rgb_float[i * 3 + 2];
}
return data;
}
// Set these variables at the global scope so that we can debug more easily.
var tvm;
var syslib;
var graph_json_str;
var loaded_module;
var data_array;
var data;
var input;
var base64_params;
var output;
Module["onRuntimeInitialized"] = function () {
tvm = tvm_runtime.create(Module);
tvm.logger = function (message) {
console.log(message);
var d = document.createElement("div");
d.innerHTML = message;
document.getElementById("log").appendChild(d);
};
tvm.logger("Loading SystemLib...");
syslib = tvm.systemLib();
tvm.logger("- SystemLib loaded!");
tvm.logger("Loading resnet model...");
graph_json_str = load_file("resnet.json");
ctx = tvm.context("opengl", 0);
loaded_module = tvm.createGraphRuntime(graph_json_str, syslib, ctx);
tvm.logger("- Model loaded!");
tvm.logger("Loading model parameters...");
base64_params = load_file("resnet.params");
loaded_module.load_base64_params(base64_params);
tvm.logger("- Model parameters loaded!");
tvm.logger("Loading input image...");
var image = document.getElementById("image");
var image_canvas = document.getElementById("image_canvas");
var image_canvas_context = image_canvas.getContext("2d");
image_canvas_context.drawImage(image, 0, 0);
var image_data = image_canvas_context.getImageData(0, 0, 224, 224);
data_array = preprocess_image(image_data);
tvm.logger("- Input image loaded!");
tvm.logger("Setting input data...");
data_shape = JSON.parse(load_file("data_shape.json"));
data = tvm.empty(data_shape, "float32", ctx);
data.copyFrom(data_array);
loaded_module.set_input("data", data);
tvm.logger("- Input data set!");
tvm.logger("Running model...");
loaded_module.run();
tvm.logger("- Model execution completed!");
out_shape = JSON.parse(load_file("out_shape.json"));
output = tvm.empty(out_shape, "float32", ctx);
loaded_module.get_output(0, output);
prediction = argmax(output.asArray());
synset = JSON.parse(load_file("synset.json"));
result_string = "Prediction: " + synset[prediction] + "\n";
document.getElementById("prediction").innerHTML = result_string;
};
</script>
</body>
</html>
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment