Commit 4794f3d6 by Zhixun Tan Committed by Tianqi Chen

WebGL end-to-end example (#369)

parent 9ba95ec4
<html>
<head>
<meta charset="UTF-8">
<title>NNVM WebGL Test Page</title>
</head>
<body>
<h1>NNVM WebGL Test Page</h1>
<!-- We will draw the input image here. -->
<div>Input Image:</div>
<img id="image", src="data.png">
<!-- We need a canvas to get the image pixel data. Hide this element. -->
<canvas hidden id="image_canvas" width="224" height="224"></canvas>
<!-- We will write te prediction result here. -->
<div id="prediction"></div>
<!-- We will write all log messages here. -->
<div id="log">Log:</div>
<!-- The OpenGL canvas. -->
<canvas id="canvas"></canvas>
<script>
var Module = {};
// resnet.js would recognize Module["canvas"]
Module["canvas"] = document.getElementById("canvas");
</script>
<script src="resnet.js"></script>
<script src="tvm_runtime.js"></script>
<script>
/**
* Load a text file synchronously.
* @param {string} url The file path.
* @return {string} The file content.
*/
function load_file(url) {
assert(typeof url == "string", "URL must be string");
var req = new XMLHttpRequest();
var result;
req.addEventListener("load", function() {
result = this.responseText;
});
req.open("get", url, false);
req.send();
return result;
}
/**
* The index of the maximum element in an array.
* @param {Array} The array.
* @return {number} The index.
*/
function argmax(arr) {
assert(typeof arr.length == "number", "Input must be array-like");
var res = 0;
for (var i = 0; i < arr.length; i++) {
if (arr[i] > arr[res]) {
res = i;
}
}
return res;
}
/**
* Preprocess an image to fit resnet input format.
* @param {ImageData} The input image data. Should be 224x224xRGBA.
* @return {Float32Array} The preprocessed input array.
*/
function preprocess_image(image_data) {
assert(image_data instanceof ImageData, "Input must be ImageData.");
assert(image_data.width == 224, "Width must be 224.");
assert(image_data.height == 224, "Height must be 224.");
var width = image_data.width;
var height = image_data.height;
var npixels = width * height;
var rgba_uint8 = image_data.data;
assert(rgba_uint8.length == npixels * 4, "Image should be RGBA.");
// Drop alpha channel. Resnet does not need it.
var rgb_uint8 = new Uint8Array(npixels * 3);
for (var i = 0; i < npixels; i++) {
rgb_uint8[i * 3] = rgba_uint8[i * 4];
rgb_uint8[i * 3 + 1] = rgba_uint8[i * 4 + 1];
rgb_uint8[i * 3 + 2] = rgba_uint8[i * 4 + 2];
}
// Cast to float and normalize.
var rgb_float = new Float32Array(npixels * 3);
for (var i = 0; i < npixels; i++) {
rgb_float[i * 3] = (rgb_uint8[i * 3] - 123.0) / 58.395;
rgb_float[i * 3 + 1] = (rgb_uint8[i * 3 + 1] - 117.0) / 57.12;
rgb_float[i * 3 + 2] = (rgb_uint8[i * 3 + 2] - 104.0) / 57.375;
}
// Transpose. Resnet expects 3 greyscale images.
var data = new Float32Array(npixels * 3);
for (var i = 0; i < npixels; i++) {
data[i] = rgb_float[i * 3];
data[npixels + i] = rgb_float[i * 3 + 1];
data[npixels * 2 + i] = rgb_float[i * 3 + 2];
}
return data;
}
// Set these variables at the global scope so that we can debug more easily.
var tvm;
var syslib;
var graph_json_str;
var loaded_module;
var data_array;
var data;
var input;
var base64_params;
var output;
Module["onRuntimeInitialized"] = function () {
tvm = tvm_runtime.create(Module);
tvm.logger = function (message) {
console.log(message);
var d = document.createElement("div");
d.innerHTML = message;
document.getElementById("log").appendChild(d);
};
tvm.logger("Loading SystemLib...");
syslib = tvm.systemLib();
tvm.logger("- SystemLib loaded!");
tvm.logger("Loading resnet model...");
graph_json_str = load_file("resnet.json");
ctx = tvm.context("opengl", 0);
loaded_module = tvm.createGraphRuntime(graph_json_str, syslib, ctx);
tvm.logger("- Model loaded!");
tvm.logger("Loading model parameters...");
base64_params = load_file("resnet.params");
loaded_module.load_base64_params(base64_params);
tvm.logger("- Model parameters loaded!");
tvm.logger("Loading input image...");
var image = document.getElementById("image");
var image_canvas = document.getElementById("image_canvas");
var image_canvas_context = image_canvas.getContext("2d");
image_canvas_context.drawImage(image, 0, 0);
var image_data = image_canvas_context.getImageData(0, 0, 224, 224);
data_array = preprocess_image(image_data);
tvm.logger("- Input image loaded!");
tvm.logger("Setting input data...");
data_shape = JSON.parse(load_file("data_shape.json"));
data = tvm.empty(data_shape, "float32", ctx);
data.copyFrom(data_array);
loaded_module.set_input("data", data);
tvm.logger("- Input data set!");
tvm.logger("Running model...");
loaded_module.run();
tvm.logger("- Model execution completed!");
out_shape = JSON.parse(load_file("out_shape.json"));
output = tvm.empty(out_shape, "float32", ctx);
loaded_module.get_output(0, output);
prediction = argmax(output.asArray());
synset = JSON.parse(load_file("synset.json"));
result_string = "Prediction: " + synset[prediction] + "\n";
document.getElementById("prediction").innerHTML = result_string;
};
</script>
</body>
</html>
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment