<html>

<head>
  <meta charset="UTF-8">
  <title>NNVM WebGL Test Page</title>
</head>

<body>
  <h1>NNVM WebGL Test Page</h1>

  <!-- We will draw the input image here. -->
  <div>Input Image:</div>
  <img id="image", src="data.png">

  <!-- We need a canvas to get the image pixel data. Hide this element. -->
  <canvas hidden id="image_canvas" width="224" height="224"></canvas>

  <!-- We will write te prediction result here. -->
  <div id="prediction"></div>

  <!-- We will write all log messages here. -->
  <div id="log">Log:</div>

  <!-- The OpenGL canvas. -->
  <canvas id="canvas"></canvas>

  <script>
    var Module = {};

    // resnet.js would recognize Module["canvas"]
    Module["canvas"] = document.getElementById("canvas");
  </script>

  <script src="resnet.js"></script>
  <script src="tvm_runtime.js"></script>

  <script>

    /**
     * Load a text file synchronously.
     * @param {string} url The file path.
     * @return {string} The file content.
     */
    function load_file(url) {
      assert(typeof url == "string", "URL must be string");

      var req = new XMLHttpRequest();
      var result;
      req.addEventListener("load", function() {
        result = this.responseText;
      });
      req.open("get", url, false);
      req.send();
      return result;
    }

    /**
     * The index of the maximum element in an array.
     * @param {Array} The array.
     * @return {number} The index.
     */
    function argmax(arr) {
      assert(typeof arr.length == "number", "Input must be array-like");

      var res = 0;
      for (var i = 0; i < arr.length; i++) {
        if (arr[i] > arr[res]) {
          res = i;
        }
      }
      return res;
    }

    /**
     * Preprocess an image to fit resnet input format.
     * @param {ImageData} The input image data. Should be 224x224xRGBA.
     * @return {Float32Array} The preprocessed input array.
     */
    function preprocess_image(image_data) {
      assert(image_data instanceof ImageData, "Input must be ImageData.");
      assert(image_data.width == 224, "Width must be 224.");
      assert(image_data.height == 224, "Height must be 224.");

      var width = image_data.width;
      var height = image_data.height;
      var npixels = width * height;

      var rgba_uint8 = image_data.data;
      assert(rgba_uint8.length == npixels * 4, "Image should be RGBA.");

      // Drop alpha channel. Resnet does not need it.
      var rgb_uint8 = new Uint8Array(npixels * 3);
      for (var i = 0; i < npixels; i++) {
        rgb_uint8[i * 3] = rgba_uint8[i * 4];
        rgb_uint8[i * 3 + 1] = rgba_uint8[i * 4 + 1];
        rgb_uint8[i * 3 + 2] = rgba_uint8[i * 4 + 2];
      }

      // Cast to float and normalize.
      var rgb_float = new Float32Array(npixels * 3);
      for (var i = 0; i < npixels; i++) {
        rgb_float[i * 3] = (rgb_uint8[i * 3] - 123.0) / 58.395;
        rgb_float[i * 3 + 1] = (rgb_uint8[i * 3 + 1] - 117.0) / 57.12;
        rgb_float[i * 3 + 2] = (rgb_uint8[i * 3 + 2] - 104.0) / 57.375;
      }

      // Transpose. Resnet expects 3 greyscale images.
      var data = new Float32Array(npixels * 3);
      for (var i = 0; i < npixels; i++) {
        data[i] = rgb_float[i * 3];
        data[npixels + i] = rgb_float[i * 3 + 1];
        data[npixels * 2 + i] = rgb_float[i * 3 + 2];
      }

      return data;
    }

    // Set these variables at the global scope so that we can debug more easily.
    var tvm;
    var syslib;
    var graph_json_str;
    var loaded_module;
    var data_array;
    var data;
    var input;
    var base64_params;
    var output;
    Module["onRuntimeInitialized"] = function () {
      tvm = tvm_runtime.create(Module);

      tvm.logger = function (message) {
        console.log(message);
        var d = document.createElement("div");
        d.innerHTML = message;
        document.getElementById("log").appendChild(d);
      };

      tvm.logger("Loading SystemLib...");
      syslib = tvm.systemLib();
      tvm.logger("- SystemLib loaded!");

      tvm.logger("Loading resnet model...");
      graph_json_str = load_file("resnet.json");
      ctx = tvm.context("opengl", 0);
      loaded_module = tvm.createGraphRuntime(graph_json_str, syslib, ctx);
      tvm.logger("- Model loaded!");

      tvm.logger("Loading model parameters...");
      base64_params = load_file("resnet.params");
      loaded_module.load_base64_params(base64_params);
      tvm.logger("- Model parameters loaded!");

      tvm.logger("Loading input image...");
      var image = document.getElementById("image");
      var image_canvas = document.getElementById("image_canvas");
      var image_canvas_context = image_canvas.getContext("2d");
      image_canvas_context.drawImage(image, 0, 0);
      var image_data = image_canvas_context.getImageData(0, 0, 224, 224);
      data_array = preprocess_image(image_data);
      tvm.logger("- Input image loaded!");

      tvm.logger("Setting input data...");
      data_shape = JSON.parse(load_file("data_shape.json"));
      data = tvm.empty(data_shape, "float32", ctx);
      data.copyFrom(data_array);
      loaded_module.set_input("data", data);
      tvm.logger("- Input data set!");

      tvm.logger("Running model...");
      loaded_module.run();
      tvm.logger("- Model execution completed!");

      out_shape = JSON.parse(load_file("out_shape.json"));
      output = tvm.empty(out_shape, "float32", ctx);
      loaded_module.get_output(0, output);

      prediction = argmax(output.asArray());
      
      synset = JSON.parse(load_file("synset.json"));
      result_string = "Prediction: " + synset[prediction] + "\n";
      document.getElementById("prediction").innerHTML = result_string;
    };

  </script>
</body>

</html>