# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
ResNet Inference Example
========================
**Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_

This tutorial provides an end-to-end demo, on how to run ResNet-18 inference
onto the VTA accelerator design to perform ImageNet classification tasks.

"""

######################################################################
# Import Libraries
# ----------------
# We start by importing the tvm, vta, nnvm libraries to run this example.

from __future__ import absolute_import, print_function

import os
import time
from io import BytesIO

import numpy as np
import requests
from matplotlib import pyplot as plt
from PIL import Image

import tvm
from tvm import rpc, autotvm
from tvm.contrib import graph_runtime, util
from tvm.contrib.download import download
import nnvm.compiler
import vta
import vta.testing

# Load VTA parameters from the vta/config/vta_config.json file
env = vta.get_env()

# Helper to crop an image to a square (224, 224)
# Takes in an Image object, returns an Image object
def thumbnailify(image, pad=15):
    w, h = image.size
    crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad)
    image = image.crop(crop)
    image = image.resize((224, 224))
    return image

# Helper function to read in image
# Takes in Image object, returns an ND array
def process_image(image):
    # Convert to neural network input format
    image = np.array(image) - np.array([123., 117., 104.])
    image /= np.array([58.395, 57.12, 57.375])
    image = image.transpose((2, 0, 1))
    image = image[np.newaxis, :]

    return tvm.nd.array(image.astype("float32"))

# Classification helper function
# Takes in the graph runtime, and an image, and returns top result and time
def classify(m, image):
    m.set_input('data', image)
    timer = m.module.time_evaluator("run", ctx, number=1)
    tcost = timer()
    tvm_output = m.get_output(0)
    top = np.argmax(tvm_output.asnumpy()[0])
    tcost = "t={0:.2f}s".format(tcost.mean)
    return tcost + " {}".format(synset[top])

# Helper function to compile the NNVM graph
# Takes in a path to a graph file, params file, and device target
# Returns the NNVM graph object, a compiled library object, and the params dict
def generate_graph(graph_fn, params_fn, device="vta"):
    # Measure build start time
    build_start = time.time()

    # Derive the TVM target
    target = tvm.target.create("llvm -device={}".format(device))

    # Derive the LLVM compiler flags
    # When targetting the Pynq, cross-compile to ARMv7 ISA
    if env.TARGET == "sim":
        target_host = "llvm"
    elif env.TARGET == "pynq":
        target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"

    # Load the ResNet-18 graph and parameters
    sym = nnvm.graph.load_json(open(graph_fn).read())
    params = nnvm.compiler.load_param_dict(open(params_fn, 'rb').read())

    # Populate the shape and data type dictionary
    shape_dict = {"data": (1, 3, 224, 224)}
    dtype_dict = {"data": 'float32'}
    shape_dict.update({k: v.shape for k, v in params.items()})
    dtype_dict.update({k: str(v.dtype) for k, v in params.items()})

    # Apply NNVM graph optimization passes
    sym = vta.graph.clean_cast(sym)
    sym = vta.graph.clean_conv_fuse(sym)
    if target.device_name == "vta":
        assert env.BLOCK_IN == env.BLOCK_OUT
        sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT)

    # Compile NNVM graph
    with nnvm.compiler.build_config(opt_level=3):
        if target.device_name != "vta":
            graph, lib, params = nnvm.compiler.build(
                sym, target, shape_dict, dtype_dict,
                params=params, target_host=target_host)
        else:
            with vta.build_config():
                graph, lib, params = nnvm.compiler.build(
                    sym, target, shape_dict, dtype_dict,
                    params=params, target_host=target_host)

    # Save the compiled inference graph library
    assert tvm.module.enabled("rpc")
    temp = util.tempdir()
    lib.save(temp.relpath("graphlib.o"))

    # Send the inference library over to the remote RPC server
    remote.upload(temp.relpath("graphlib.o"))
    lib = remote.load_module("graphlib.o")

    # Measure build time
    build_time = time.time() - build_start
    print("ResNet-18 inference graph built in {0:.2f}s!".format(build_time))

    return graph, lib, params


######################################################################
# Download ResNet Model
# --------------------------------------------
# Download the necessary files to run ResNet-18.
#

# Obtain ResNet model and download them into _data dir
url = "https://github.com/uwsaml/web-data/raw/master/vta/models/"
categ_fn = 'synset.txt'
graph_fn = 'resnet18_qt8.json'
params_fn = 'resnet18_qt8.params'

# Create data dir
data_dir = "_data/"
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

# Download files
for file in [categ_fn, graph_fn, params_fn]:
    download(os.path.join(url, file), os.path.join(data_dir, file))

# Read in ImageNet Categories
synset = eval(open(os.path.join(data_dir, categ_fn)).read())

# Download pre-tuned op parameters of conv2d for ARM CPU used in VTA
autotvm.tophub.check_backend('vta')


######################################################################
# Setup the Pynq Board's RPC Server
# ---------------------------------
# Build the RPC server's VTA runtime and program the Pynq FPGA.

# Measure build start time
reconfig_start = time.time()

# We read the Pynq RPC host IP address and port number from the OS environment
host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99")
port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091"))

# We configure both the bitstream and the runtime system on the Pynq
# to match the VTA configuration specified by the vta_config.json file.
if env.TARGET == "pynq":
    # Make sure that TVM was compiled with RPC=1
    assert tvm.module.enabled("rpc")
    remote = rpc.connect(host, port)

    # Reconfigure the JIT runtime
    vta.reconfig_runtime(remote)

    # Program the FPGA with a pre-compiled VTA bitstream.
    # You can program the FPGA with your own custom bitstream
    # by passing the path to the bitstream file instead of None.
    vta.program_fpga(remote, bitstream=None)

    # Report on reconfiguration time
    reconfig_time = time.time() - reconfig_start
    print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time))

# In simulation mode, host the RPC server locally.
elif env.TARGET == "sim":
    remote = rpc.LocalSession()


######################################################################
# Build the ResNet Runtime
# ------------------------
# Build the ResNet graph runtime, and configure the parameters.

# Set ``device=vtacpu`` to run inference on the CPU
# or ``device=vta`` to run inference on the FPGA.
device = "vta"

# Device context
ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)

# Build the graph runtime
graph, lib, params = generate_graph(os.path.join(data_dir, graph_fn),
                                    os.path.join(data_dir, params_fn),
                                    device)
m = graph_runtime.create(graph, lib, ctx)

# Set the parameters
m.set_input(**params)

######################################################################
# Run ResNet-18 inference on a sample image
# -----------------------------------------
# Perform image classification on test image.
# You can change the test image URL to any image of your choosing.

# Read in test image
image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg'
# Read in test image
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).resize((224, 224))
# Show Image
plt.imshow(image)
plt.show()
# Set the input
image = process_image(image)
m.set_input('data', image)

# Perform inference
timer = m.module.time_evaluator("run", ctx, number=1)
tcost = timer()

# Get classification results
tvm_output = m.get_output(0)
top_categories = np.argsort(tvm_output.asnumpy()[0])

# Report top-5 classification results
print("ResNet-18 Prediction #1:", synset[top_categories[-1]])
print("                     #2:", synset[top_categories[-2]])
print("                     #3:", synset[top_categories[-3]])
print("                     #4:", synset[top_categories[-4]])
print("                     #5:", synset[top_categories[-5]])
print("Performed inference in {0:.2f}s".format(tcost.mean))


######################################################################
# Run a Youtube Video Image Classifier
# ------------------------------------
# Perform image classification on test stream on 1 frame every 48 frames.
# Comment the `if False:` out to run the demo

# Early exit - remove for Demo
if False:

    import cv2
    import pafy
    from IPython.display import clear_output

    # Helper to crop an image to a square (224, 224)
    # Takes in an Image object, returns an Image object
    def thumbnailify(image, pad=15):
        w, h = image.size
        crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad)
        image = image.crop(crop)
        image = image.resize((224, 224))
        return image

    # 16:16 inches
    plt.rcParams['figure.figsize'] = [16, 16]

    # Stream the video in
    url = "https://www.youtube.com/watch?v=PJlmYh27MHg&t=2s"
    video = pafy.new(url)
    best = video.getbest(preftype="mp4")
    cap = cv2.VideoCapture(best.url)

    # Process one frame out of every 48 for variety
    count = 0
    guess = ""
    while(count<2400):

        # Capture frame-by-frame
        ret, frame = cap.read()

        # Process one every 48 frames
        if count % 48 == 1:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)
            # Crop and resize
            thumb = np.array(thumbnailify(frame))
            image = process_image(thumb)
            guess = classify(m, image)

            # Insert guess in frame
            frame = cv2.rectangle(thumb,(0,0),(200,0),(0,0,0),50)
            cv2.putText(frame, guess, (5,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (256,256,256), 1, cv2.LINE_AA)

            plt.imshow(thumb)
            plt.axis('off')
            plt.show()
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
            clear_output(wait=True)

        count += 1

    # When everything done, release the capture
    cap.release()
    cv2.destroyAllWindows()