""" Benchmark script for performance on GPUs.

For example, run the file with:
`python gpu_imagenet_bench.py --model=mobilenet --target=cuda`.
For more details about how to set up the inference environment on GPUs,
please refer to NNVM Tutorial: ImageNet Inference on the GPU
"""
import time
import argparse
import numpy as np
import tvm
import nnvm.compiler
import nnvm.testing
from tvm.contrib import util, nvcc
from tvm.contrib import graph_runtime as runtime

@tvm.register_func
def tvm_callback_cuda_compile(code):
    ptx = nvcc.compile_cuda(code, target="ptx")
    return ptx

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, required=True,
                        choices=['resnet', 'mobilenet'],
                        help="The model type.")
    parser.add_argument('--target', type=str, required=True,
                        choices=['cuda', 'rocm', 'opencl', 'metal'],
                        help="Compilation target.")
    parser.add_argument('--opt-level', type=int, default=1, help="Level of optimization.")
    parser.add_argument('--num-iter', type=int, default=1000, help="Number of iteration during benchmark.")
    parser.add_argument('--repeat', type=int, default=1, help="Number of repeative times.")
    args = parser.parse_args()
    opt_level = args.opt_level
    num_iter = args.num_iter
    ctx = tvm.context(args.target, 0)
    batch_size = 1
    num_classes = 1000
    image_shape = (3, 224, 224)

    data_shape = (batch_size,) + image_shape
    out_shape = (batch_size, num_classes)
    if args.model == 'resnet':
        net, params = nnvm.testing.resnet.get_workload(
            batch_size=1, image_shape=image_shape)
    elif args.model == 'mobilenet':
        net, params = nnvm.testing.mobilenet.get_workload(
            batch_size=1, image_shape=image_shape)
    else:
        raise ValueError('no benchmark prepared for {}.'.format(args.model))

    if args.target == "cuda":
        unroll = 1400
    else:
        unroll = 128
    with nnvm.compiler.build_config(opt_level=opt_level):
        with tvm.build_config(auto_unroll_max_step=unroll,
                              unroll_explicit=(args.target != "cuda")):
            graph, lib, params = nnvm.compiler.build(
                net, args.target, shape={"data": data_shape}, params=params)

    data = np.random.uniform(-1, 1, size=data_shape).astype("float32")
    module = runtime.create(graph, lib, ctx)
    module.set_input(**params)
    module.set_input("data", data)
    module.run()
    out = module.get_output(0, tvm.nd.empty(out_shape))
    out.asnumpy()

    print('benchmark args: {}'.format(args))
    ftimer = module.module.time_evaluator("run", ctx, num_iter)
    for i in range(args.repeat):
        prof_res = ftimer()
        print(prof_res)
        # sleep for avoiding device overheat
        if i + 1 != args.repeat:
            time.sleep(45)

if __name__ == '__main__':
    main()